1use crate::error::{Result, SZipError};
13use crc32fast::Hasher as Crc32;
14use flate2::write::DeflateEncoder;
15use flate2::Compression;
16use std::fs::File;
17use std::io::{Seek, Write};
18use std::path::Path;
19
20#[cfg(feature = "encryption")]
21use crate::encryption::{AesEncryptor, AesStrength};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum CompressionMethod {
26 Stored,
28 Deflate,
30 #[cfg(feature = "zstd-support")]
32 Zstd,
33}
34
35impl CompressionMethod {
36 pub(crate) fn to_zip_method(self) -> u16 {
37 match self {
38 CompressionMethod::Stored => 0,
39 CompressionMethod::Deflate => 8,
40 #[cfg(feature = "zstd-support")]
41 CompressionMethod::Zstd => 93,
42 }
43 }
44}
45
46struct ZipEntry {
48 name: String,
49 local_header_offset: u64,
50 crc32: u32,
51 compressed_size: u64,
52 uncompressed_size: u64,
53 compression_method: u16,
54 #[cfg(feature = "encryption")]
55 #[allow(dead_code)] encryption_strength: Option<u16>,
57}
58
59pub struct StreamingZipWriter<W: Write + Seek> {
61 output: W,
62 entries: Vec<ZipEntry>,
63 current_entry: Option<CurrentEntry>,
64 compression_level: u32,
65 compression_method: CompressionMethod,
66 #[cfg(feature = "encryption")]
67 password: Option<String>,
68 #[cfg(feature = "encryption")]
69 encryption_strength: AesStrength,
70}
71
72struct CurrentEntry {
73 name: String,
74 local_header_offset: u64,
75 encoder: Box<dyn CompressorWrite>,
76 counter: CrcCounter,
77 compression_method: u16,
78 #[cfg(feature = "encryption")]
79 encryptor: Option<AesEncryptor>,
80}
81
82trait CompressorWrite: Write {
83 fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer>;
84 fn get_buffer_mut(&mut self) -> &mut CompressedBuffer;
85}
86
87struct DeflateCompressor {
88 encoder: DeflateEncoder<CompressedBuffer>,
89}
90
91impl Write for DeflateCompressor {
92 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
93 self.encoder.write(buf)
94 }
95
96 fn flush(&mut self) -> std::io::Result<()> {
97 self.encoder.flush()
98 }
99}
100
101impl CompressorWrite for DeflateCompressor {
102 fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer> {
103 Ok(self.encoder.finish()?)
104 }
105
106 fn get_buffer_mut(&mut self) -> &mut CompressedBuffer {
107 self.encoder.get_mut()
108 }
109}
110
111#[cfg(feature = "zstd-support")]
112struct ZstdCompressor {
113 encoder: zstd::Encoder<'static, CompressedBuffer>,
114}
115
116#[cfg(feature = "zstd-support")]
117impl Write for ZstdCompressor {
118 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
119 self.encoder.write(buf)
120 }
121
122 fn flush(&mut self) -> std::io::Result<()> {
123 self.encoder.flush()
124 }
125}
126
127#[cfg(feature = "zstd-support")]
128impl CompressorWrite for ZstdCompressor {
129 fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer> {
130 Ok(self.encoder.finish()?)
131 }
132
133 fn get_buffer_mut(&mut self) -> &mut CompressedBuffer {
134 self.encoder.get_mut()
135 }
136}
137
138struct CrcCounter {
140 crc: Crc32,
141 uncompressed_count: u64,
142 compressed_count: u64,
143}
144
145impl CrcCounter {
146 fn new() -> Self {
147 Self {
148 crc: Crc32::new(),
149 uncompressed_count: 0,
150 compressed_count: 0,
151 }
152 }
153
154 fn update_uncompressed(&mut self, data: &[u8]) {
155 self.crc.update(data);
156 self.uncompressed_count += data.len() as u64;
157 }
158
159 fn add_compressed(&mut self, count: u64) {
160 self.compressed_count += count;
161 }
162
163 fn finalize(&self) -> u32 {
164 self.crc.clone().finalize()
165 }
166}
167
168struct CompressedBuffer {
173 buffer: Vec<u8>,
174 flush_threshold: usize,
175}
176
177impl CompressedBuffer {
178 #[allow(dead_code)]
180 fn new() -> Self {
181 Self::with_size_hint(None)
182 }
183
184 fn with_size_hint(size_hint: Option<u64>) -> Self {
192 let (initial_capacity, flush_threshold) = match size_hint {
193 Some(size) if size < 10_000 => (8 * 1024, 256 * 1024), Some(size) if size < 100_000 => (32 * 1024, 512 * 1024), Some(size) if size < 1_000_000 => (128 * 1024, 2 * 1024 * 1024), Some(size) if size < 10_000_000 => (256 * 1024, 4 * 1024 * 1024), _ => (512 * 1024, 8 * 1024 * 1024), };
199
200 Self {
201 buffer: Vec::with_capacity(initial_capacity),
202 flush_threshold,
203 }
204 }
205
206 fn take(&mut self) -> Vec<u8> {
207 std::mem::take(&mut self.buffer)
208 }
209
210 fn should_flush(&self) -> bool {
211 self.buffer.len() >= self.flush_threshold
212 }
213}
214
215impl Write for CompressedBuffer {
216 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
217 self.buffer.extend_from_slice(buf);
218 Ok(buf.len())
219 }
220
221 fn flush(&mut self) -> std::io::Result<()> {
222 Ok(())
223 }
224}
225
226impl StreamingZipWriter<File> {
227 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
229 Self::with_compression(path, 6)
230 }
231
232 pub fn with_compression<P: AsRef<Path>>(path: P, compression_level: u32) -> Result<Self> {
234 Self::with_method(path, CompressionMethod::Deflate, compression_level)
235 }
236
237 pub fn with_method<P: AsRef<Path>>(
244 path: P,
245 method: CompressionMethod,
246 compression_level: u32,
247 ) -> Result<Self> {
248 let output = File::create(path)?;
249 Ok(Self {
250 output,
251 entries: Vec::new(),
252 current_entry: None,
253 compression_level,
254 compression_method: method,
255 #[cfg(feature = "encryption")]
256 password: None,
257 #[cfg(feature = "encryption")]
258 encryption_strength: AesStrength::Aes256,
259 })
260 }
261
262 #[cfg(feature = "zstd-support")]
264 pub fn with_zstd<P: AsRef<Path>>(path: P, compression_level: i32) -> Result<Self> {
265 let output = File::create(path)?;
266 Ok(Self {
267 output,
268 entries: Vec::new(),
269 current_entry: None,
270 compression_level: compression_level as u32,
271 compression_method: CompressionMethod::Zstd,
272 #[cfg(feature = "encryption")]
273 password: None,
274 #[cfg(feature = "encryption")]
275 encryption_strength: AesStrength::Aes256,
276 })
277 }
278}
279
280impl<W: Write + Seek> StreamingZipWriter<W> {
281 pub fn from_writer(writer: W) -> Result<Self> {
283 Self::from_writer_with_compression(writer, 6)
284 }
285
286 pub fn from_writer_with_compression(writer: W, compression_level: u32) -> Result<Self> {
288 Self::from_writer_with_method(writer, CompressionMethod::Deflate, compression_level)
289 }
290
291 pub fn from_writer_with_method(
298 writer: W,
299 method: CompressionMethod,
300 compression_level: u32,
301 ) -> Result<Self> {
302 Ok(Self {
303 output: writer,
304 entries: Vec::new(),
305 current_entry: None,
306 compression_level,
307 compression_method: method,
308 #[cfg(feature = "encryption")]
309 password: None,
310 #[cfg(feature = "encryption")]
311 encryption_strength: AesStrength::Aes256,
312 })
313 }
314
315 #[cfg(feature = "encryption")]
336 pub fn set_password(&mut self, password: impl Into<String>) -> &mut Self {
337 self.password = Some(password.into());
338 self
339 }
340
341 #[cfg(feature = "encryption")]
346 pub fn set_encryption_strength(&mut self, strength: AesStrength) -> &mut Self {
347 self.encryption_strength = strength;
348 self
349 }
350
351 #[cfg(feature = "encryption")]
353 pub fn clear_password(&mut self) -> &mut Self {
354 self.password = None;
355 self
356 }
357
358 pub fn start_entry(&mut self, name: &str) -> Result<()> {
360 self.start_entry_with_hint(name, None)
361 }
362
363 pub fn start_entry_with_hint(&mut self, name: &str, size_hint: Option<u64>) -> Result<()> {
384 self.finish_current_entry()?;
386
387 let local_header_offset = self.output.stream_position()?;
388 let compression_method = self.compression_method.to_zip_method();
389
390 #[cfg(feature = "encryption")]
392 let (encryptor, encryption_flag) = if let Some(ref password) = self.password {
393 let enc = AesEncryptor::new(password, self.encryption_strength)?;
394 (Some(enc), 0x01) } else {
396 (None, 0x00)
397 };
398
399 #[cfg(not(feature = "encryption"))]
400 let encryption_flag = 0x00;
401
402 self.output.write_all(&[0x50, 0x4b, 0x03, 0x04])?; self.output.write_all(&[51, 0])?; self.output.write_all(&[8 | encryption_flag, 0])?; self.output.write_all(&compression_method.to_le_bytes())?; self.output.write_all(&[0, 0, 0, 0])?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&(name.len() as u16).to_le_bytes())?;
412
413 #[cfg(feature = "encryption")]
415 let extra_len = if encryptor.is_some() { 11 } else { 0 };
416 #[cfg(not(feature = "encryption"))]
417 let extra_len = 0;
418
419 self.output.write_all(&(extra_len as u16).to_le_bytes())?; self.output.write_all(name.as_bytes())?;
421
422 #[cfg(feature = "encryption")]
424 if let Some(ref enc) = encryptor {
425 self.output.write_all(&[0x01, 0x99])?; self.output.write_all(&[7, 0])?; self.output.write_all(&[2, 0])?; self.output.write_all(&[0x41, 0x45])?; self.output
431 .write_all(&enc.strength().to_winzip_code().to_le_bytes())?; self.output.write_all(&compression_method.to_le_bytes())?; self.output.write_all(enc.salt())?;
436 self.output.write_all(enc.password_verify())?;
437 }
438
439 let encoder: Box<dyn CompressorWrite> = match self.compression_method {
442 CompressionMethod::Deflate => Box::new(DeflateCompressor {
443 encoder: DeflateEncoder::new(
444 CompressedBuffer::with_size_hint(size_hint),
445 Compression::new(self.compression_level),
446 ),
447 }),
448 #[cfg(feature = "zstd-support")]
449 CompressionMethod::Zstd => {
450 let mut encoder = zstd::Encoder::new(
451 CompressedBuffer::with_size_hint(size_hint),
452 self.compression_level as i32,
453 )?;
454 encoder.include_checksum(false)?; Box::new(ZstdCompressor { encoder })
456 }
457 CompressionMethod::Stored => {
458 return Err(SZipError::InvalidFormat(
460 "Stored method not yet implemented".to_string(),
461 ));
462 }
463 };
464
465 self.current_entry = Some(CurrentEntry {
466 name: name.to_string(),
467 local_header_offset,
468 encoder,
469 counter: CrcCounter::new(),
470 compression_method,
471 #[cfg(feature = "encryption")]
472 encryptor,
473 });
474
475 Ok(())
476 }
477
478 pub fn write_data(&mut self, data: &[u8]) -> Result<()> {
480 let entry = self
481 .current_entry
482 .as_mut()
483 .ok_or_else(|| SZipError::InvalidFormat("No entry started".to_string()))?;
484
485 entry.counter.update_uncompressed(data);
487
488 #[cfg(feature = "encryption")]
491 let data_to_compress = if let Some(ref mut encryptor) = entry.encryptor {
492 let mut encrypted = data.to_vec();
493 encryptor.encrypt(&mut encrypted)?;
494 encrypted
495 } else {
496 data.to_vec()
497 };
498
499 #[cfg(not(feature = "encryption"))]
500 let data_to_compress = data.to_vec();
501
502 entry.encoder.write_all(&data_to_compress)?;
504
505 entry.encoder.flush()?;
507
508 let buffer = entry.encoder.get_buffer_mut();
510 if buffer.should_flush() {
511 let compressed_data = buffer.take();
513 self.output.write_all(&compressed_data)?;
514 entry.counter.add_compressed(compressed_data.len() as u64);
515 }
516
517 Ok(())
518 }
519
520 fn finish_current_entry(&mut self) -> Result<()> {
522 if let Some(mut entry) = self.current_entry.take() {
523 let mut buffer = entry.encoder.finish_compression()?;
525
526 let remaining_data = buffer.take();
528 if !remaining_data.is_empty() {
529 self.output.write_all(&remaining_data)?;
530 entry.counter.add_compressed(remaining_data.len() as u64);
531 }
532
533 #[cfg(feature = "encryption")]
535 let (encryption_strength_code, auth_code_size) =
536 if let Some(encryptor) = entry.encryptor {
537 let strength_code = encryptor.strength().to_winzip_code();
538 let auth_code = encryptor.finalize();
539 self.output.write_all(&auth_code)?;
540 (Some(strength_code), auth_code.len() as u64)
541 } else {
542 (None, 0)
543 };
544
545 #[cfg(not(feature = "encryption"))]
546 let auth_code_size = 0u64;
547
548 let crc = entry.counter.finalize();
549 let compressed_size = entry.counter.compressed_count + auth_code_size;
550 let uncompressed_size = entry.counter.uncompressed_count;
551
552 self.output.write_all(&[0x50, 0x4b, 0x07, 0x08])?;
555 self.output.write_all(&crc.to_le_bytes())?;
556 if compressed_size > u32::MAX as u64 || uncompressed_size > u32::MAX as u64 {
558 self.output.write_all(&compressed_size.to_le_bytes())?;
559 self.output.write_all(&uncompressed_size.to_le_bytes())?;
560 } else {
561 self.output
562 .write_all(&(compressed_size as u32).to_le_bytes())?;
563 self.output
564 .write_all(&(uncompressed_size as u32).to_le_bytes())?;
565 }
566
567 self.entries.push(ZipEntry {
569 name: entry.name,
570 local_header_offset: entry.local_header_offset,
571 crc32: crc,
572 compressed_size,
573 uncompressed_size,
574 compression_method: entry.compression_method,
575 #[cfg(feature = "encryption")]
576 encryption_strength: encryption_strength_code,
577 });
578 }
579 Ok(())
580 }
581
582 pub fn finish(mut self) -> Result<W> {
584 self.finish_current_entry()?;
586
587 let central_dir_offset = self.output.stream_position()?;
588
589 for entry in &self.entries {
591 self.output.write_all(&[0x50, 0x4b, 0x01, 0x02])?; self.output.write_all(&[20, 0])?; self.output.write_all(&[20, 0])?; self.output.write_all(&[8, 0])?; self.output
596 .write_all(&entry.compression_method.to_le_bytes())?; self.output.write_all(&[0, 0, 0, 0])?; self.output.write_all(&entry.crc32.to_le_bytes())?;
599
600 if entry.compressed_size > u32::MAX as u64 {
602 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
603 } else {
604 self.output
605 .write_all(&(entry.compressed_size as u32).to_le_bytes())?;
606 }
607
608 if entry.uncompressed_size > u32::MAX as u64 {
609 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
610 } else {
611 self.output
612 .write_all(&(entry.uncompressed_size as u32).to_le_bytes())?;
613 }
614
615 self.output
616 .write_all(&(entry.name.len() as u16).to_le_bytes())?;
617
618 let mut extra_field: Vec<u8> = Vec::new();
620 if entry.uncompressed_size > u32::MAX as u64
621 || entry.compressed_size > u32::MAX as u64
622 || entry.local_header_offset > u32::MAX as u64
623 {
624 extra_field.extend_from_slice(&0x0001u16.to_le_bytes());
626 let mut data: Vec<u8> = Vec::new();
628 if entry.uncompressed_size > u32::MAX as u64 {
629 data.extend_from_slice(&entry.uncompressed_size.to_le_bytes());
630 }
631 if entry.compressed_size > u32::MAX as u64 {
632 data.extend_from_slice(&entry.compressed_size.to_le_bytes());
633 }
634 if entry.local_header_offset > u32::MAX as u64 {
635 data.extend_from_slice(&entry.local_header_offset.to_le_bytes());
636 }
637 extra_field.extend_from_slice(&(data.len() as u16).to_le_bytes());
638 extra_field.extend_from_slice(&data);
639 }
640
641 self.output
642 .write_all(&(extra_field.len() as u16).to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; if entry.local_header_offset > u32::MAX as u64 {
650 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
651 } else {
652 self.output
653 .write_all(&(entry.local_header_offset as u32).to_le_bytes())?;
654 }
655
656 self.output.write_all(entry.name.as_bytes())?;
657 if !extra_field.is_empty() {
658 self.output.write_all(&extra_field)?;
659 }
660 }
661
662 let central_dir_size = self.output.stream_position()? - central_dir_offset;
663
664 let need_zip64 = self.entries.len() > u16::MAX as usize
666 || central_dir_size > u32::MAX as u64
667 || central_dir_offset > u32::MAX as u64;
668
669 if need_zip64 {
670 self.output.write_all(&[0x50, 0x4b, 0x06, 0x06])?; let zip64_eocd_size: u64 = 44;
676 self.output.write_all(&zip64_eocd_size.to_le_bytes())?;
677 self.output.write_all(&[20, 0])?;
679 self.output.write_all(&[20, 0])?;
680 self.output.write_all(&0u32.to_le_bytes())?;
682 self.output.write_all(&0u32.to_le_bytes())?;
683 self.output
685 .write_all(&(self.entries.len() as u64).to_le_bytes())?;
686 self.output
688 .write_all(&(self.entries.len() as u64).to_le_bytes())?;
689 self.output.write_all(¢ral_dir_size.to_le_bytes())?;
691 self.output.write_all(¢ral_dir_offset.to_le_bytes())?;
693
694 self.output.write_all(&[0x50, 0x4b, 0x06, 0x07])?; self.output.write_all(&0u32.to_le_bytes())?;
699 let zip64_eocd_pos = central_dir_offset + central_dir_size; self.output.write_all(&zip64_eocd_pos.to_le_bytes())?;
702 self.output.write_all(&0u32.to_le_bytes())?;
704 }
705
706 self.output.write_all(&[0x50, 0x4b, 0x05, 0x06])?;
708 self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; if self.entries.len() > u16::MAX as usize {
713 self.output.write_all(&0xFFFFu16.to_le_bytes())?;
714 self.output.write_all(&0xFFFFu16.to_le_bytes())?;
715 } else {
716 self.output
717 .write_all(&(self.entries.len() as u16).to_le_bytes())?;
718 self.output
719 .write_all(&(self.entries.len() as u16).to_le_bytes())?;
720 }
721
722 if central_dir_size > u32::MAX as u64 {
724 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
725 } else {
726 self.output
727 .write_all(&(central_dir_size as u32).to_le_bytes())?;
728 }
729
730 if central_dir_offset > u32::MAX as u64 {
731 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
732 } else {
733 self.output
734 .write_all(&(central_dir_offset as u32).to_le_bytes())?;
735 }
736
737 self.output.write_all(&0u16.to_le_bytes())?; self.output.flush()?;
740 Ok(self.output)
741 }
742}