1use crate::error::{Result, SZipError};
13use crc32fast::Hasher as Crc32;
14use flate2::write::DeflateEncoder;
15use flate2::Compression;
16use std::fs::File;
17use std::io::{Seek, Write};
18use std::path::Path;
19
20#[cfg(feature = "encryption")]
21use crate::encryption::{AesEncryptor, AesStrength};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum CompressionMethod {
26 Stored,
28 Deflate,
30 #[cfg(feature = "zstd-support")]
32 Zstd,
33}
34
35impl CompressionMethod {
36 pub(crate) fn to_zip_method(self) -> u16 {
37 match self {
38 CompressionMethod::Stored => 0,
39 CompressionMethod::Deflate => 8,
40 #[cfg(feature = "zstd-support")]
41 CompressionMethod::Zstd => 93,
42 }
43 }
44}
45
46struct ZipEntry {
48 name: String,
49 local_header_offset: u64,
50 crc32: u32,
51 compressed_size: u64,
52 uncompressed_size: u64,
53 compression_method: u16,
54 #[cfg(feature = "encryption")]
55 #[allow(dead_code)] encryption_strength: Option<u16>,
57}
58
59pub struct StreamingZipWriter<W: Write + Seek> {
61 output: W,
62 entries: Vec<ZipEntry>,
63 current_entry: Option<CurrentEntry>,
64 compression_level: u32,
65 compression_method: CompressionMethod,
66 #[cfg(feature = "encryption")]
67 password: Option<String>,
68 #[cfg(feature = "encryption")]
69 encryption_strength: AesStrength,
70}
71
72struct CurrentEntry {
73 name: String,
74 local_header_offset: u64,
75 encoder: Box<dyn CompressorWrite>,
76 counter: CrcCounter,
77 compression_method: u16,
78 #[cfg(feature = "encryption")]
79 encryptor: Option<AesEncryptor>,
80}
81
82trait CompressorWrite: Write {
83 fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer>;
84 fn get_buffer_mut(&mut self) -> &mut CompressedBuffer;
85}
86
87struct DeflateCompressor {
88 encoder: DeflateEncoder<CompressedBuffer>,
89}
90
91impl Write for DeflateCompressor {
92 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
93 self.encoder.write(buf)
94 }
95
96 fn flush(&mut self) -> std::io::Result<()> {
97 self.encoder.flush()
98 }
99}
100
101impl CompressorWrite for DeflateCompressor {
102 fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer> {
103 Ok(self.encoder.finish()?)
104 }
105
106 fn get_buffer_mut(&mut self) -> &mut CompressedBuffer {
107 self.encoder.get_mut()
108 }
109}
110
111#[cfg(feature = "zstd-support")]
112struct ZstdCompressor {
113 encoder: zstd::Encoder<'static, CompressedBuffer>,
114}
115
116#[cfg(feature = "zstd-support")]
117impl Write for ZstdCompressor {
118 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
119 self.encoder.write(buf)
120 }
121
122 fn flush(&mut self) -> std::io::Result<()> {
123 self.encoder.flush()
124 }
125}
126
127#[cfg(feature = "zstd-support")]
128impl CompressorWrite for ZstdCompressor {
129 fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer> {
130 Ok(self.encoder.finish()?)
131 }
132
133 fn get_buffer_mut(&mut self) -> &mut CompressedBuffer {
134 self.encoder.get_mut()
135 }
136}
137
138struct CrcCounter {
140 crc: Crc32,
141 uncompressed_count: u64,
142 compressed_count: u64,
143}
144
145impl CrcCounter {
146 fn new() -> Self {
147 Self {
148 crc: Crc32::new(),
149 uncompressed_count: 0,
150 compressed_count: 0,
151 }
152 }
153
154 fn update_uncompressed(&mut self, data: &[u8]) {
155 self.crc.update(data);
156 self.uncompressed_count += data.len() as u64;
157 }
158
159 fn add_compressed(&mut self, count: u64) {
160 self.compressed_count += count;
161 }
162
163 fn finalize(&self) -> u32 {
164 self.crc.clone().finalize()
165 }
166}
167
168struct CompressedBuffer {
170 buffer: Vec<u8>,
171 flush_threshold: usize,
172}
173
174impl CompressedBuffer {
175 fn new() -> Self {
176 Self {
177 buffer: Vec::with_capacity(64 * 1024), flush_threshold: 1024 * 1024, }
180 }
181
182 fn take(&mut self) -> Vec<u8> {
183 std::mem::take(&mut self.buffer)
184 }
185
186 fn should_flush(&self) -> bool {
187 self.buffer.len() >= self.flush_threshold
188 }
189}
190
191impl Write for CompressedBuffer {
192 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
193 self.buffer.extend_from_slice(buf);
194 Ok(buf.len())
195 }
196
197 fn flush(&mut self) -> std::io::Result<()> {
198 Ok(())
199 }
200}
201
202impl StreamingZipWriter<File> {
203 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
205 Self::with_compression(path, 6)
206 }
207
208 pub fn with_compression<P: AsRef<Path>>(path: P, compression_level: u32) -> Result<Self> {
210 Self::with_method(path, CompressionMethod::Deflate, compression_level)
211 }
212
213 pub fn with_method<P: AsRef<Path>>(
220 path: P,
221 method: CompressionMethod,
222 compression_level: u32,
223 ) -> Result<Self> {
224 let output = File::create(path)?;
225 Ok(Self {
226 output,
227 entries: Vec::new(),
228 current_entry: None,
229 compression_level,
230 compression_method: method,
231 #[cfg(feature = "encryption")]
232 password: None,
233 #[cfg(feature = "encryption")]
234 encryption_strength: AesStrength::Aes256,
235 })
236 }
237
238 #[cfg(feature = "zstd-support")]
240 pub fn with_zstd<P: AsRef<Path>>(path: P, compression_level: i32) -> Result<Self> {
241 let output = File::create(path)?;
242 Ok(Self {
243 output,
244 entries: Vec::new(),
245 current_entry: None,
246 compression_level: compression_level as u32,
247 compression_method: CompressionMethod::Zstd,
248 #[cfg(feature = "encryption")]
249 password: None,
250 #[cfg(feature = "encryption")]
251 encryption_strength: AesStrength::Aes256,
252 })
253 }
254}
255
256impl<W: Write + Seek> StreamingZipWriter<W> {
257 pub fn from_writer(writer: W) -> Result<Self> {
259 Self::from_writer_with_compression(writer, 6)
260 }
261
262 pub fn from_writer_with_compression(writer: W, compression_level: u32) -> Result<Self> {
264 Self::from_writer_with_method(writer, CompressionMethod::Deflate, compression_level)
265 }
266
267 pub fn from_writer_with_method(
274 writer: W,
275 method: CompressionMethod,
276 compression_level: u32,
277 ) -> Result<Self> {
278 Ok(Self {
279 output: writer,
280 entries: Vec::new(),
281 current_entry: None,
282 compression_level,
283 compression_method: method,
284 #[cfg(feature = "encryption")]
285 password: None,
286 #[cfg(feature = "encryption")]
287 encryption_strength: AesStrength::Aes256,
288 })
289 }
290
291 #[cfg(feature = "encryption")]
312 pub fn set_password(&mut self, password: impl Into<String>) -> &mut Self {
313 self.password = Some(password.into());
314 self
315 }
316
317 #[cfg(feature = "encryption")]
322 pub fn set_encryption_strength(&mut self, strength: AesStrength) -> &mut Self {
323 self.encryption_strength = strength;
324 self
325 }
326
327 #[cfg(feature = "encryption")]
329 pub fn clear_password(&mut self) -> &mut Self {
330 self.password = None;
331 self
332 }
333
334 pub fn start_entry(&mut self, name: &str) -> Result<()> {
336 self.finish_current_entry()?;
338
339 let local_header_offset = self.output.stream_position()?;
340 let compression_method = self.compression_method.to_zip_method();
341
342 #[cfg(feature = "encryption")]
344 let (encryptor, encryption_flag) = if let Some(ref password) = self.password {
345 let enc = AesEncryptor::new(password, self.encryption_strength)?;
346 (Some(enc), 0x01) } else {
348 (None, 0x00)
349 };
350
351 #[cfg(not(feature = "encryption"))]
352 let encryption_flag = 0x00;
353
354 self.output.write_all(&[0x50, 0x4b, 0x03, 0x04])?; self.output.write_all(&[51, 0])?; self.output.write_all(&[8 | encryption_flag, 0])?; self.output.write_all(&compression_method.to_le_bytes())?; self.output.write_all(&[0, 0, 0, 0])?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&(name.len() as u16).to_le_bytes())?;
364
365 #[cfg(feature = "encryption")]
367 let extra_len = if encryptor.is_some() { 11 } else { 0 };
368 #[cfg(not(feature = "encryption"))]
369 let extra_len = 0;
370
371 self.output.write_all(&(extra_len as u16).to_le_bytes())?; self.output.write_all(name.as_bytes())?;
373
374 #[cfg(feature = "encryption")]
376 if let Some(ref enc) = encryptor {
377 self.output.write_all(&[0x01, 0x99])?; self.output.write_all(&[7, 0])?; self.output.write_all(&[2, 0])?; self.output.write_all(&[0x41, 0x45])?; self.output
383 .write_all(&enc.strength().to_winzip_code().to_le_bytes())?; self.output.write_all(&compression_method.to_le_bytes())?; self.output.write_all(enc.salt())?;
388 self.output.write_all(enc.password_verify())?;
389 }
390
391 let encoder: Box<dyn CompressorWrite> = match self.compression_method {
393 CompressionMethod::Deflate => Box::new(DeflateCompressor {
394 encoder: DeflateEncoder::new(
395 CompressedBuffer::new(),
396 Compression::new(self.compression_level),
397 ),
398 }),
399 #[cfg(feature = "zstd-support")]
400 CompressionMethod::Zstd => {
401 let mut encoder =
402 zstd::Encoder::new(CompressedBuffer::new(), self.compression_level as i32)?;
403 encoder.include_checksum(false)?; Box::new(ZstdCompressor { encoder })
405 }
406 CompressionMethod::Stored => {
407 return Err(SZipError::InvalidFormat(
409 "Stored method not yet implemented".to_string(),
410 ));
411 }
412 };
413
414 self.current_entry = Some(CurrentEntry {
415 name: name.to_string(),
416 local_header_offset,
417 encoder,
418 counter: CrcCounter::new(),
419 compression_method,
420 #[cfg(feature = "encryption")]
421 encryptor,
422 });
423
424 Ok(())
425 }
426
427 pub fn write_data(&mut self, data: &[u8]) -> Result<()> {
429 let entry = self
430 .current_entry
431 .as_mut()
432 .ok_or_else(|| SZipError::InvalidFormat("No entry started".to_string()))?;
433
434 entry.counter.update_uncompressed(data);
436
437 #[cfg(feature = "encryption")]
440 let data_to_compress = if let Some(ref mut encryptor) = entry.encryptor {
441 let mut encrypted = data.to_vec();
442 encryptor.encrypt(&mut encrypted)?;
443 encrypted
444 } else {
445 data.to_vec()
446 };
447
448 #[cfg(not(feature = "encryption"))]
449 let data_to_compress = data.to_vec();
450
451 entry.encoder.write_all(&data_to_compress)?;
453
454 entry.encoder.flush()?;
456
457 let buffer = entry.encoder.get_buffer_mut();
459 if buffer.should_flush() {
460 let compressed_data = buffer.take();
462 self.output.write_all(&compressed_data)?;
463 entry.counter.add_compressed(compressed_data.len() as u64);
464 }
465
466 Ok(())
467 }
468
469 fn finish_current_entry(&mut self) -> Result<()> {
471 if let Some(mut entry) = self.current_entry.take() {
472 let mut buffer = entry.encoder.finish_compression()?;
474
475 let remaining_data = buffer.take();
477 if !remaining_data.is_empty() {
478 self.output.write_all(&remaining_data)?;
479 entry.counter.add_compressed(remaining_data.len() as u64);
480 }
481
482 #[cfg(feature = "encryption")]
484 let (encryption_strength_code, auth_code_size) =
485 if let Some(encryptor) = entry.encryptor {
486 let strength_code = encryptor.strength().to_winzip_code();
487 let auth_code = encryptor.finalize();
488 self.output.write_all(&auth_code)?;
489 (Some(strength_code), auth_code.len() as u64)
490 } else {
491 (None, 0)
492 };
493
494 #[cfg(not(feature = "encryption"))]
495 let auth_code_size = 0u64;
496
497 let crc = entry.counter.finalize();
498 let compressed_size = entry.counter.compressed_count + auth_code_size;
499 let uncompressed_size = entry.counter.uncompressed_count;
500
501 self.output.write_all(&[0x50, 0x4b, 0x07, 0x08])?;
504 self.output.write_all(&crc.to_le_bytes())?;
505 if compressed_size > u32::MAX as u64 || uncompressed_size > u32::MAX as u64 {
507 self.output.write_all(&compressed_size.to_le_bytes())?;
508 self.output.write_all(&uncompressed_size.to_le_bytes())?;
509 } else {
510 self.output
511 .write_all(&(compressed_size as u32).to_le_bytes())?;
512 self.output
513 .write_all(&(uncompressed_size as u32).to_le_bytes())?;
514 }
515
516 self.entries.push(ZipEntry {
518 name: entry.name,
519 local_header_offset: entry.local_header_offset,
520 crc32: crc,
521 compressed_size,
522 uncompressed_size,
523 compression_method: entry.compression_method,
524 #[cfg(feature = "encryption")]
525 encryption_strength: encryption_strength_code,
526 });
527 }
528 Ok(())
529 }
530
531 pub fn finish(mut self) -> Result<W> {
533 self.finish_current_entry()?;
535
536 let central_dir_offset = self.output.stream_position()?;
537
538 for entry in &self.entries {
540 self.output.write_all(&[0x50, 0x4b, 0x01, 0x02])?; self.output.write_all(&[20, 0])?; self.output.write_all(&[20, 0])?; self.output.write_all(&[8, 0])?; self.output
545 .write_all(&entry.compression_method.to_le_bytes())?; self.output.write_all(&[0, 0, 0, 0])?; self.output.write_all(&entry.crc32.to_le_bytes())?;
548
549 if entry.compressed_size > u32::MAX as u64 {
551 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
552 } else {
553 self.output
554 .write_all(&(entry.compressed_size as u32).to_le_bytes())?;
555 }
556
557 if entry.uncompressed_size > u32::MAX as u64 {
558 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
559 } else {
560 self.output
561 .write_all(&(entry.uncompressed_size as u32).to_le_bytes())?;
562 }
563
564 self.output
565 .write_all(&(entry.name.len() as u16).to_le_bytes())?;
566
567 let mut extra_field: Vec<u8> = Vec::new();
569 if entry.uncompressed_size > u32::MAX as u64
570 || entry.compressed_size > u32::MAX as u64
571 || entry.local_header_offset > u32::MAX as u64
572 {
573 extra_field.extend_from_slice(&0x0001u16.to_le_bytes());
575 let mut data: Vec<u8> = Vec::new();
577 if entry.uncompressed_size > u32::MAX as u64 {
578 data.extend_from_slice(&entry.uncompressed_size.to_le_bytes());
579 }
580 if entry.compressed_size > u32::MAX as u64 {
581 data.extend_from_slice(&entry.compressed_size.to_le_bytes());
582 }
583 if entry.local_header_offset > u32::MAX as u64 {
584 data.extend_from_slice(&entry.local_header_offset.to_le_bytes());
585 }
586 extra_field.extend_from_slice(&(data.len() as u16).to_le_bytes());
587 extra_field.extend_from_slice(&data);
588 }
589
590 self.output
591 .write_all(&(extra_field.len() as u16).to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; if entry.local_header_offset > u32::MAX as u64 {
599 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
600 } else {
601 self.output
602 .write_all(&(entry.local_header_offset as u32).to_le_bytes())?;
603 }
604
605 self.output.write_all(entry.name.as_bytes())?;
606 if !extra_field.is_empty() {
607 self.output.write_all(&extra_field)?;
608 }
609 }
610
611 let central_dir_size = self.output.stream_position()? - central_dir_offset;
612
613 let need_zip64 = self.entries.len() > u16::MAX as usize
615 || central_dir_size > u32::MAX as u64
616 || central_dir_offset > u32::MAX as u64;
617
618 if need_zip64 {
619 self.output.write_all(&[0x50, 0x4b, 0x06, 0x06])?; let zip64_eocd_size: u64 = 44;
625 self.output.write_all(&zip64_eocd_size.to_le_bytes())?;
626 self.output.write_all(&[20, 0])?;
628 self.output.write_all(&[20, 0])?;
629 self.output.write_all(&0u32.to_le_bytes())?;
631 self.output.write_all(&0u32.to_le_bytes())?;
632 self.output
634 .write_all(&(self.entries.len() as u64).to_le_bytes())?;
635 self.output
637 .write_all(&(self.entries.len() as u64).to_le_bytes())?;
638 self.output.write_all(¢ral_dir_size.to_le_bytes())?;
640 self.output.write_all(¢ral_dir_offset.to_le_bytes())?;
642
643 self.output.write_all(&[0x50, 0x4b, 0x06, 0x07])?; self.output.write_all(&0u32.to_le_bytes())?;
648 let zip64_eocd_pos = central_dir_offset + central_dir_size; self.output.write_all(&zip64_eocd_pos.to_le_bytes())?;
651 self.output.write_all(&0u32.to_le_bytes())?;
653 }
654
655 self.output.write_all(&[0x50, 0x4b, 0x05, 0x06])?;
657 self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; if self.entries.len() > u16::MAX as usize {
662 self.output.write_all(&0xFFFFu16.to_le_bytes())?;
663 self.output.write_all(&0xFFFFu16.to_le_bytes())?;
664 } else {
665 self.output
666 .write_all(&(self.entries.len() as u16).to_le_bytes())?;
667 self.output
668 .write_all(&(self.entries.len() as u16).to_le_bytes())?;
669 }
670
671 if central_dir_size > u32::MAX as u64 {
673 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
674 } else {
675 self.output
676 .write_all(&(central_dir_size as u32).to_le_bytes())?;
677 }
678
679 if central_dir_offset > u32::MAX as u64 {
680 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
681 } else {
682 self.output
683 .write_all(&(central_dir_offset as u32).to_le_bytes())?;
684 }
685
686 self.output.write_all(&0u16.to_le_bytes())?; self.output.flush()?;
689 Ok(self.output)
690 }
691}