1use crate::error::{Result, SZipError};
13use crc32fast::Hasher as Crc32;
14use flate2::write::DeflateEncoder;
15use flate2::Compression;
16use std::fs::File;
17use std::io::{Seek, Write};
18use std::path::Path;
19
20#[cfg(feature = "encryption")]
21use crate::encryption::{AesEncryptor, AesStrength};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum CompressionMethod {
26 Stored,
28 Deflate,
30 #[cfg(feature = "zstd-support")]
32 Zstd,
33}
34
35impl CompressionMethod {
36 pub(crate) fn to_zip_method(self) -> u16 {
37 match self {
38 CompressionMethod::Stored => 0,
39 CompressionMethod::Deflate => 8,
40 #[cfg(feature = "zstd-support")]
41 CompressionMethod::Zstd => 93,
42 }
43 }
44}
45
46struct ZipEntry {
48 name: String,
49 local_header_offset: u64,
50 crc32: u32,
51 compressed_size: u64,
52 uncompressed_size: u64,
53 compression_method: u16,
54 #[cfg(feature = "encryption")]
55 #[allow(dead_code)] encryption_strength: Option<u16>,
57}
58
59pub struct StreamingZipWriter<W: Write + Seek> {
61 output: W,
62 entries: Vec<ZipEntry>,
63 current_entry: Option<CurrentEntry>,
64 compression_level: u32,
65 compression_method: CompressionMethod,
66 #[cfg(feature = "encryption")]
67 password: Option<String>,
68 #[cfg(feature = "encryption")]
69 encryption_strength: AesStrength,
70}
71
72struct CurrentEntry {
73 name: String,
74 local_header_offset: u64,
75 encoder: Box<dyn CompressorWrite>,
76 counter: CrcCounter,
77 compression_method: u16,
78 #[cfg(feature = "encryption")]
79 encryptor: Option<AesEncryptor>,
80}
81
82trait CompressorWrite: Write {
83 fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer>;
84 fn get_buffer_mut(&mut self) -> &mut CompressedBuffer;
85}
86
87struct DeflateCompressor {
88 encoder: DeflateEncoder<CompressedBuffer>,
89}
90
91impl Write for DeflateCompressor {
92 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
93 self.encoder.write(buf)
94 }
95
96 fn flush(&mut self) -> std::io::Result<()> {
97 self.encoder.flush()
98 }
99}
100
101impl CompressorWrite for DeflateCompressor {
102 fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer> {
103 Ok(self.encoder.finish()?)
104 }
105
106 fn get_buffer_mut(&mut self) -> &mut CompressedBuffer {
107 self.encoder.get_mut()
108 }
109}
110
111struct StoredCompressor {
113 buffer: CompressedBuffer,
114}
115
116impl Write for StoredCompressor {
117 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
118 self.buffer.write(buf)
119 }
120
121 fn flush(&mut self) -> std::io::Result<()> {
122 self.buffer.flush()
123 }
124}
125
126impl CompressorWrite for StoredCompressor {
127 fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer> {
128 Ok(self.buffer)
129 }
130
131 fn get_buffer_mut(&mut self) -> &mut CompressedBuffer {
132 &mut self.buffer
133 }
134}
135
136#[cfg(feature = "zstd-support")]
137struct ZstdCompressor {
138 encoder: zstd::Encoder<'static, CompressedBuffer>,
139}
140
141#[cfg(feature = "zstd-support")]
142impl Write for ZstdCompressor {
143 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
144 self.encoder.write(buf)
145 }
146
147 fn flush(&mut self) -> std::io::Result<()> {
148 self.encoder.flush()
149 }
150}
151
152#[cfg(feature = "zstd-support")]
153impl CompressorWrite for ZstdCompressor {
154 fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer> {
155 Ok(self.encoder.finish()?)
156 }
157
158 fn get_buffer_mut(&mut self) -> &mut CompressedBuffer {
159 self.encoder.get_mut()
160 }
161}
162
163struct CrcCounter {
165 crc: Crc32,
166 uncompressed_count: u64,
167 compressed_count: u64,
168}
169
170impl CrcCounter {
171 fn new() -> Self {
172 Self {
173 crc: Crc32::new(),
174 uncompressed_count: 0,
175 compressed_count: 0,
176 }
177 }
178
179 fn update_uncompressed(&mut self, data: &[u8]) {
180 self.crc.update(data);
181 self.uncompressed_count += data.len() as u64;
182 }
183
184 fn add_compressed(&mut self, count: u64) {
185 self.compressed_count += count;
186 }
187
188 fn finalize(&self) -> u32 {
189 self.crc.clone().finalize()
190 }
191}
192
193struct CompressedBuffer {
198 buffer: Vec<u8>,
199 flush_threshold: usize,
200}
201
202impl CompressedBuffer {
203 #[allow(dead_code)]
205 fn new() -> Self {
206 Self::with_size_hint(None)
207 }
208
209 fn with_size_hint(size_hint: Option<u64>) -> Self {
217 let (initial_capacity, flush_threshold) = match size_hint {
218 Some(size) if size < 10_000 => (8 * 1024, 256 * 1024), Some(size) if size < 100_000 => (32 * 1024, 512 * 1024), Some(size) if size < 1_000_000 => (128 * 1024, 2 * 1024 * 1024), Some(size) if size < 10_000_000 => (256 * 1024, 4 * 1024 * 1024), _ => (512 * 1024, 8 * 1024 * 1024), };
224
225 Self {
226 buffer: Vec::with_capacity(initial_capacity),
227 flush_threshold,
228 }
229 }
230
231 fn take(&mut self) -> Vec<u8> {
232 std::mem::take(&mut self.buffer)
233 }
234
235 fn should_flush(&self) -> bool {
236 self.buffer.len() >= self.flush_threshold
237 }
238}
239
240impl Write for CompressedBuffer {
241 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
242 self.buffer.extend_from_slice(buf);
243 Ok(buf.len())
244 }
245
246 fn flush(&mut self) -> std::io::Result<()> {
247 Ok(())
248 }
249}
250
251impl StreamingZipWriter<File> {
252 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
254 Self::with_compression(path, 6)
255 }
256
257 pub fn with_compression<P: AsRef<Path>>(path: P, compression_level: u32) -> Result<Self> {
259 Self::with_method(path, CompressionMethod::Deflate, compression_level)
260 }
261
262 pub fn with_method<P: AsRef<Path>>(
269 path: P,
270 method: CompressionMethod,
271 compression_level: u32,
272 ) -> Result<Self> {
273 let output = File::create(path)?;
274 Ok(Self {
275 output,
276 entries: Vec::new(),
277 current_entry: None,
278 compression_level,
279 compression_method: method,
280 #[cfg(feature = "encryption")]
281 password: None,
282 #[cfg(feature = "encryption")]
283 encryption_strength: AesStrength::Aes256,
284 })
285 }
286
287 #[cfg(feature = "zstd-support")]
289 pub fn with_zstd<P: AsRef<Path>>(path: P, compression_level: i32) -> Result<Self> {
290 let output = File::create(path)?;
291 Ok(Self {
292 output,
293 entries: Vec::new(),
294 current_entry: None,
295 compression_level: compression_level as u32,
296 compression_method: CompressionMethod::Zstd,
297 #[cfg(feature = "encryption")]
298 password: None,
299 #[cfg(feature = "encryption")]
300 encryption_strength: AesStrength::Aes256,
301 })
302 }
303}
304
305impl<W: Write + Seek> StreamingZipWriter<W> {
306 pub fn from_writer(writer: W) -> Result<Self> {
308 Self::from_writer_with_compression(writer, 6)
309 }
310
311 pub fn from_writer_with_compression(writer: W, compression_level: u32) -> Result<Self> {
313 Self::from_writer_with_method(writer, CompressionMethod::Deflate, compression_level)
314 }
315
316 pub fn from_writer_with_method(
323 writer: W,
324 method: CompressionMethod,
325 compression_level: u32,
326 ) -> Result<Self> {
327 Ok(Self {
328 output: writer,
329 entries: Vec::new(),
330 current_entry: None,
331 compression_level,
332 compression_method: method,
333 #[cfg(feature = "encryption")]
334 password: None,
335 #[cfg(feature = "encryption")]
336 encryption_strength: AesStrength::Aes256,
337 })
338 }
339
340 #[cfg(feature = "encryption")]
361 pub fn set_password(&mut self, password: impl Into<String>) -> &mut Self {
362 self.password = Some(password.into());
363 self
364 }
365
366 #[cfg(feature = "encryption")]
371 pub fn set_encryption_strength(&mut self, strength: AesStrength) -> &mut Self {
372 self.encryption_strength = strength;
373 self
374 }
375
376 #[cfg(feature = "encryption")]
378 pub fn clear_password(&mut self) -> &mut Self {
379 self.password = None;
380 self
381 }
382
383 pub fn start_entry(&mut self, name: &str) -> Result<()> {
385 self.start_entry_with_hint(name, None)
386 }
387
388 pub fn start_entry_with_hint(&mut self, name: &str, size_hint: Option<u64>) -> Result<()> {
409 self.finish_current_entry()?;
411
412 let local_header_offset = self.output.stream_position()?;
413 let compression_method = self.compression_method.to_zip_method();
414
415 #[cfg(feature = "encryption")]
417 let (encryptor, encryption_flag) = if let Some(ref password) = self.password {
418 let enc = AesEncryptor::new(password, self.encryption_strength)?;
419 (Some(enc), 0x01) } else {
421 (None, 0x00)
422 };
423
424 #[cfg(not(feature = "encryption"))]
425 let encryption_flag = 0x00;
426
427 self.output.write_all(&[0x50, 0x4b, 0x03, 0x04])?; self.output.write_all(&[51, 0])?; self.output.write_all(&[8 | encryption_flag, 0])?; self.output.write_all(&compression_method.to_le_bytes())?; self.output.write_all(&[0, 0, 0, 0])?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&(name.len() as u16).to_le_bytes())?;
437
438 #[cfg(feature = "encryption")]
440 let extra_len = if encryptor.is_some() { 11 } else { 0 };
441 #[cfg(not(feature = "encryption"))]
442 let extra_len = 0;
443
444 self.output.write_all(&(extra_len as u16).to_le_bytes())?; self.output.write_all(name.as_bytes())?;
446
447 #[cfg(feature = "encryption")]
449 if let Some(ref enc) = encryptor {
450 self.output.write_all(&[0x01, 0x99])?; self.output.write_all(&[7, 0])?; self.output.write_all(&[2, 0])?; self.output.write_all(&[0x41, 0x45])?; self.output
456 .write_all(&enc.strength().to_winzip_code().to_le_bytes())?; self.output.write_all(&compression_method.to_le_bytes())?; self.output.write_all(enc.salt())?;
461 self.output.write_all(enc.password_verify())?;
462 }
463
464 let encoder: Box<dyn CompressorWrite> = match self.compression_method {
467 CompressionMethod::Deflate => Box::new(DeflateCompressor {
468 encoder: DeflateEncoder::new(
469 CompressedBuffer::with_size_hint(size_hint),
470 Compression::new(self.compression_level),
471 ),
472 }),
473 #[cfg(feature = "zstd-support")]
474 CompressionMethod::Zstd => {
475 let mut encoder = zstd::Encoder::new(
476 CompressedBuffer::with_size_hint(size_hint),
477 self.compression_level as i32,
478 )?;
479 encoder.include_checksum(false)?; Box::new(ZstdCompressor { encoder })
481 }
482 CompressionMethod::Stored => {
483 Box::new(StoredCompressor {
485 buffer: CompressedBuffer::new(),
486 })
487 }
488 };
489
490 #[cfg_attr(not(feature = "encryption"), allow(unused_mut))]
491 let mut counter = CrcCounter::new();
492
493 #[cfg(feature = "encryption")]
495 if let Some(ref enc) = encryptor {
496 let encryption_overhead = (enc.salt().len() + 2) as u64; counter.add_compressed(encryption_overhead);
498 }
499
500 self.current_entry = Some(CurrentEntry {
501 name: name.to_string(),
502 local_header_offset,
503 encoder,
504 counter,
505 compression_method,
506 #[cfg(feature = "encryption")]
507 encryptor,
508 });
509
510 Ok(())
511 }
512
513 pub fn write_data(&mut self, data: &[u8]) -> Result<()> {
515 let entry = self
516 .current_entry
517 .as_mut()
518 .ok_or_else(|| SZipError::InvalidFormat("No entry started".to_string()))?;
519
520 entry.counter.update_uncompressed(data);
522
523 #[cfg(feature = "encryption")]
525 if let Some(ref mut encryptor) = entry.encryptor {
526 encryptor.update_hmac(data);
527 }
528
529 entry.encoder.write_all(data)?;
531
532 entry.encoder.flush()?;
534
535 let buffer = entry.encoder.get_buffer_mut();
537 if buffer.should_flush() {
538 let compressed_data = buffer.take();
540
541 #[cfg(feature = "encryption")]
543 let data_to_write = if let Some(ref mut encryptor) = entry.encryptor {
544 let mut data_to_encrypt = compressed_data;
545 encryptor.encrypt(&mut data_to_encrypt)?;
546 data_to_encrypt
547 } else {
548 compressed_data
549 };
550
551 #[cfg(not(feature = "encryption"))]
552 let data_to_write = compressed_data;
553
554 self.output.write_all(&data_to_write)?;
555 entry.counter.add_compressed(data_to_write.len() as u64);
556 }
557
558 Ok(())
559 }
560
561 fn finish_current_entry(&mut self) -> Result<()> {
563 if let Some(mut entry) = self.current_entry.take() {
564 let mut buffer = entry.encoder.finish_compression()?;
566
567 let remaining_data = buffer.take();
569 if !remaining_data.is_empty() {
570 #[cfg(feature = "encryption")]
572 let data_to_write = if let Some(ref mut encryptor) = entry.encryptor {
573 let mut data_to_encrypt = remaining_data;
574 encryptor.encrypt(&mut data_to_encrypt)?;
575 data_to_encrypt
576 } else {
577 remaining_data
578 };
579
580 #[cfg(not(feature = "encryption"))]
581 let data_to_write = remaining_data;
582
583 self.output.write_all(&data_to_write)?;
584 entry.counter.add_compressed(data_to_write.len() as u64);
585 }
586
587 #[cfg(feature = "encryption")]
589 let (encryption_strength_code, auth_code_size) =
590 if let Some(encryptor) = entry.encryptor {
591 let strength_code = encryptor.strength().to_winzip_code();
592 let auth_code = encryptor.finalize();
593 self.output.write_all(&auth_code)?;
594 (Some(strength_code), auth_code.len() as u64)
595 } else {
596 (None, 0)
597 };
598
599 #[cfg(not(feature = "encryption"))]
600 let auth_code_size = 0u64;
601
602 let crc = entry.counter.finalize();
603 let compressed_size = entry.counter.compressed_count + auth_code_size;
604 let uncompressed_size = entry.counter.uncompressed_count;
605
606 self.output.write_all(&[0x50, 0x4b, 0x07, 0x08])?;
609 self.output.write_all(&crc.to_le_bytes())?;
610 if compressed_size > u32::MAX as u64 || uncompressed_size > u32::MAX as u64 {
612 self.output.write_all(&compressed_size.to_le_bytes())?;
613 self.output.write_all(&uncompressed_size.to_le_bytes())?;
614 } else {
615 self.output
616 .write_all(&(compressed_size as u32).to_le_bytes())?;
617 self.output
618 .write_all(&(uncompressed_size as u32).to_le_bytes())?;
619 }
620
621 self.entries.push(ZipEntry {
623 name: entry.name,
624 local_header_offset: entry.local_header_offset,
625 crc32: crc,
626 compressed_size,
627 uncompressed_size,
628 compression_method: entry.compression_method,
629 #[cfg(feature = "encryption")]
630 encryption_strength: encryption_strength_code,
631 });
632 }
633 Ok(())
634 }
635
636 pub fn finish(mut self) -> Result<W> {
638 self.finish_current_entry()?;
640
641 let central_dir_offset = self.output.stream_position()?;
642
643 for entry in &self.entries {
645 self.output.write_all(&[0x50, 0x4b, 0x01, 0x02])?; self.output.write_all(&[20, 0])?; self.output.write_all(&[20, 0])?; #[cfg(feature = "encryption")]
651 let flags = if entry.encryption_strength.is_some() {
652 0x08 | 0x01 } else {
654 0x08 };
656 #[cfg(not(feature = "encryption"))]
657 let flags = 0x08;
658
659 self.output.write_all(&[flags, 0])?; self.output
661 .write_all(&entry.compression_method.to_le_bytes())?; self.output.write_all(&[0, 0, 0, 0])?; self.output.write_all(&entry.crc32.to_le_bytes())?;
664
665 if entry.compressed_size > u32::MAX as u64 {
667 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
668 } else {
669 self.output
670 .write_all(&(entry.compressed_size as u32).to_le_bytes())?;
671 }
672
673 if entry.uncompressed_size > u32::MAX as u64 {
674 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
675 } else {
676 self.output
677 .write_all(&(entry.uncompressed_size as u32).to_le_bytes())?;
678 }
679
680 self.output
681 .write_all(&(entry.name.len() as u16).to_le_bytes())?;
682
683 let mut extra_field: Vec<u8> = Vec::new();
685
686 #[cfg(feature = "encryption")]
688 if let Some(strength_code) = entry.encryption_strength {
689 extra_field.extend_from_slice(&[0x01, 0x99]); extra_field.extend_from_slice(&[7, 0]); extra_field.extend_from_slice(&[2, 0]); extra_field.extend_from_slice(&[0x41, 0x45]); extra_field.extend_from_slice(&strength_code.to_le_bytes()); extra_field.extend_from_slice(&entry.compression_method.to_le_bytes());
696 }
698
699 if entry.uncompressed_size > u32::MAX as u64
701 || entry.compressed_size > u32::MAX as u64
702 || entry.local_header_offset > u32::MAX as u64
703 {
704 extra_field.extend_from_slice(&0x0001u16.to_le_bytes());
706 let mut data: Vec<u8> = Vec::new();
708 if entry.uncompressed_size > u32::MAX as u64 {
709 data.extend_from_slice(&entry.uncompressed_size.to_le_bytes());
710 }
711 if entry.compressed_size > u32::MAX as u64 {
712 data.extend_from_slice(&entry.compressed_size.to_le_bytes());
713 }
714 if entry.local_header_offset > u32::MAX as u64 {
715 data.extend_from_slice(&entry.local_header_offset.to_le_bytes());
716 }
717 extra_field.extend_from_slice(&(data.len() as u16).to_le_bytes());
718 extra_field.extend_from_slice(&data);
719 }
720
721 self.output
722 .write_all(&(extra_field.len() as u16).to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; if entry.local_header_offset > u32::MAX as u64 {
730 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
731 } else {
732 self.output
733 .write_all(&(entry.local_header_offset as u32).to_le_bytes())?;
734 }
735
736 self.output.write_all(entry.name.as_bytes())?;
737 if !extra_field.is_empty() {
738 self.output.write_all(&extra_field)?;
739 }
740 }
741
742 let central_dir_size = self.output.stream_position()? - central_dir_offset;
743
744 let need_zip64 = self.entries.len() > u16::MAX as usize
746 || central_dir_size > u32::MAX as u64
747 || central_dir_offset > u32::MAX as u64;
748
749 if need_zip64 {
750 self.output.write_all(&[0x50, 0x4b, 0x06, 0x06])?; let zip64_eocd_size: u64 = 44;
756 self.output.write_all(&zip64_eocd_size.to_le_bytes())?;
757 self.output.write_all(&[20, 0])?;
759 self.output.write_all(&[20, 0])?;
760 self.output.write_all(&0u32.to_le_bytes())?;
762 self.output.write_all(&0u32.to_le_bytes())?;
763 self.output
765 .write_all(&(self.entries.len() as u64).to_le_bytes())?;
766 self.output
768 .write_all(&(self.entries.len() as u64).to_le_bytes())?;
769 self.output.write_all(¢ral_dir_size.to_le_bytes())?;
771 self.output.write_all(¢ral_dir_offset.to_le_bytes())?;
773
774 self.output.write_all(&[0x50, 0x4b, 0x06, 0x07])?; self.output.write_all(&0u32.to_le_bytes())?;
779 let zip64_eocd_pos = central_dir_offset + central_dir_size; self.output.write_all(&zip64_eocd_pos.to_le_bytes())?;
782 self.output.write_all(&0u32.to_le_bytes())?;
784 }
785
786 self.output.write_all(&[0x50, 0x4b, 0x05, 0x06])?;
788 self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; if self.entries.len() > u16::MAX as usize {
793 self.output.write_all(&0xFFFFu16.to_le_bytes())?;
794 self.output.write_all(&0xFFFFu16.to_le_bytes())?;
795 } else {
796 self.output
797 .write_all(&(self.entries.len() as u16).to_le_bytes())?;
798 self.output
799 .write_all(&(self.entries.len() as u16).to_le_bytes())?;
800 }
801
802 if central_dir_size > u32::MAX as u64 {
804 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
805 } else {
806 self.output
807 .write_all(&(central_dir_size as u32).to_le_bytes())?;
808 }
809
810 if central_dir_offset > u32::MAX as u64 {
811 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
812 } else {
813 self.output
814 .write_all(&(central_dir_offset as u32).to_le_bytes())?;
815 }
816
817 self.output.write_all(&0u16.to_le_bytes())?; self.output.flush()?;
820 Ok(self.output)
821 }
822}