From 51f2737ca7af016e3dd3ad03673f48a754a1642e Mon Sep 17 00:00:00 2001 From: igorechek06 Date: Sat, 10 Aug 2024 15:11:42 +0900 Subject: Replace Vec with Iterator --- src/zip/driver.rs | 41 ++++++++++++++++++++++++----------------- src/zip/structs.rs | 2 +- tests/zip.rs | 14 +++++++------- 3 files changed, 32 insertions(+), 25 deletions(-) diff --git a/src/zip/driver.rs b/src/zip/driver.rs index 8dc902f..b4217c8 100644 --- a/src/zip/driver.rs +++ b/src/zip/driver.rs @@ -15,21 +15,30 @@ use std::collections::HashMap as Map; use std::fs::File; use std::io::{BufReader, Read, Seek, SeekFrom, Write}; -#[inline] -fn split_fields(bytes: &[u8]) -> Option> { - let mut fields = Vec::new(); - - let mut p = 0; - while p < bytes.len() { - let header: ExtraHeader = deserialize(bytes.get(p..p + 4)?).unwrap(); - p += 4; - let data = bytes.get(p..p + header.size as usize)?; - p += header.size as usize; +struct Fields<'b> { + pointer: usize, + bytes: &'b [u8], +} - fields.push((header.id, data)); +impl<'b> Fields<'b> { + pub fn new(bytes: &'b [u8]) -> Self { + Self { pointer: 0, bytes } } +} - Some(fields) +impl<'b> Iterator for Fields<'b> { + type Item = (u16, &'b [u8]); + + fn next(&mut self) -> Option { + let header: ExtraHeader = + deserialize(self.bytes.get(self.pointer..self.pointer + 4)?).unwrap(); + self.pointer += 4; + let data = self + .bytes + .get(self.pointer..self.pointer + header.size as usize)?; + self.pointer += header.size as usize; + Some((header.id, data)) + } } #[inline] @@ -159,7 +168,7 @@ impl ArchiveRead for Zip { let mut ctime = None; // Parse extensible data fields - for (id, mut data) in split_fields(&extra_fields).ok_or(ZipError::InvalidExtraFields)? { + for (id, mut data) in Fields::new(&extra_fields) { match id { // Zip64 0x0001 => { @@ -175,9 +184,7 @@ impl ArchiveRead for Zip { } // NTFS 0x000a => { - for (id, mut data) in - split_fields(&data).ok_or(ZipError::InvalidExtraFields)? - { + for (id, mut data) in Fields::new(&data[4..]) { match id { 0x0001 => { mtime = ntfs_to_local(u64::from_le_bytes(data.read_arr()?)) @@ -198,7 +205,7 @@ impl ArchiveRead for Zip { // AES 0x9901 => { let aes: AesField = deserialize(&data.read_arr::<7>()?).unwrap(); - if aes.id != 0x4541 { + if aes.id != [0x41, 0x45] { return Err(ZipError::InvalidExtraFields); } encryption_method = match aes.strength { diff --git a/src/zip/structs.rs b/src/zip/structs.rs index 8b25400..4b4524f 100644 --- a/src/zip/structs.rs +++ b/src/zip/structs.rs @@ -67,7 +67,7 @@ pub struct ExtraHeader { #[derive(Serialize, Deserialize)] pub struct AesField { pub version: u16, - pub id: u16, + pub id: [u8; 2], pub strength: u8, pub compression_method: u16, } diff --git a/tests/zip.rs b/tests/zip.rs index 9283df3..e44098c 100644 --- a/tests/zip.rs +++ b/tests/zip.rs @@ -92,6 +92,13 @@ fn test_zip_weak() { #[test] fn test_zip() { + assert_eq!( + Archive::::read_from_file("tests/files/empty.zip") + .unwrap() + .len(), + 0 + ); + let mut archive = Archive::::read_from_file("tests/files/archive.zip").unwrap(); assert_eq!(archive.comment(), "archive comment"); @@ -158,11 +165,4 @@ fn test_zip() { fn test_bad_zip() { assert!(Archive::::read_from_file("tests/files/blank") .is_err_and(|e| e == ZipError::EocdrNotFound)); - - assert_eq!( - Archive::::read_from_file("tests/files/empty.zip") - .unwrap() - .len(), - 0 - ); } -- cgit v1.2.3