1use crate::error::{Result, SZipError};
13use crc32fast::Hasher as Crc32;
14use flate2::write::DeflateEncoder;
15use flate2::Compression;
16use std::fs::File;
17use std::io::{Seek, Write};
18use std::path::Path;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum CompressionMethod {
23 Stored,
25 Deflate,
27 #[cfg(feature = "zstd-support")]
29 Zstd,
30}
31
32impl CompressionMethod {
33 pub(crate) fn to_zip_method(self) -> u16 {
34 match self {
35 CompressionMethod::Stored => 0,
36 CompressionMethod::Deflate => 8,
37 #[cfg(feature = "zstd-support")]
38 CompressionMethod::Zstd => 93,
39 }
40 }
41}
42
43struct ZipEntry {
45 name: String,
46 local_header_offset: u64,
47 crc32: u32,
48 compressed_size: u64,
49 uncompressed_size: u64,
50 compression_method: u16,
51}
52
53pub struct StreamingZipWriter<W: Write + Seek> {
55 output: W,
56 entries: Vec<ZipEntry>,
57 current_entry: Option<CurrentEntry>,
58 compression_level: u32,
59 compression_method: CompressionMethod,
60}
61
62struct CurrentEntry {
63 name: String,
64 local_header_offset: u64,
65 encoder: Box<dyn CompressorWrite>,
66 counter: CrcCounter,
67 compression_method: u16,
68}
69
70trait CompressorWrite: Write {
71 fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer>;
72 fn get_buffer_mut(&mut self) -> &mut CompressedBuffer;
73}
74
75struct DeflateCompressor {
76 encoder: DeflateEncoder<CompressedBuffer>,
77}
78
79impl Write for DeflateCompressor {
80 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
81 self.encoder.write(buf)
82 }
83
84 fn flush(&mut self) -> std::io::Result<()> {
85 self.encoder.flush()
86 }
87}
88
89impl CompressorWrite for DeflateCompressor {
90 fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer> {
91 Ok(self.encoder.finish()?)
92 }
93
94 fn get_buffer_mut(&mut self) -> &mut CompressedBuffer {
95 self.encoder.get_mut()
96 }
97}
98
99#[cfg(feature = "zstd-support")]
100struct ZstdCompressor {
101 encoder: zstd::Encoder<'static, CompressedBuffer>,
102}
103
104#[cfg(feature = "zstd-support")]
105impl Write for ZstdCompressor {
106 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
107 self.encoder.write(buf)
108 }
109
110 fn flush(&mut self) -> std::io::Result<()> {
111 self.encoder.flush()
112 }
113}
114
115#[cfg(feature = "zstd-support")]
116impl CompressorWrite for ZstdCompressor {
117 fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer> {
118 Ok(self.encoder.finish()?)
119 }
120
121 fn get_buffer_mut(&mut self) -> &mut CompressedBuffer {
122 self.encoder.get_mut()
123 }
124}
125
126struct CrcCounter {
128 crc: Crc32,
129 uncompressed_count: u64,
130 compressed_count: u64,
131}
132
133impl CrcCounter {
134 fn new() -> Self {
135 Self {
136 crc: Crc32::new(),
137 uncompressed_count: 0,
138 compressed_count: 0,
139 }
140 }
141
142 fn update_uncompressed(&mut self, data: &[u8]) {
143 self.crc.update(data);
144 self.uncompressed_count += data.len() as u64;
145 }
146
147 fn add_compressed(&mut self, count: u64) {
148 self.compressed_count += count;
149 }
150
151 fn finalize(&self) -> u32 {
152 self.crc.clone().finalize()
153 }
154}
155
156struct CompressedBuffer {
158 buffer: Vec<u8>,
159 flush_threshold: usize,
160}
161
162impl CompressedBuffer {
163 fn new() -> Self {
164 Self {
165 buffer: Vec::with_capacity(64 * 1024), flush_threshold: 1024 * 1024, }
168 }
169
170 fn take(&mut self) -> Vec<u8> {
171 std::mem::take(&mut self.buffer)
172 }
173
174 fn should_flush(&self) -> bool {
175 self.buffer.len() >= self.flush_threshold
176 }
177}
178
179impl Write for CompressedBuffer {
180 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
181 self.buffer.extend_from_slice(buf);
182 Ok(buf.len())
183 }
184
185 fn flush(&mut self) -> std::io::Result<()> {
186 Ok(())
187 }
188}
189
190impl StreamingZipWriter<File> {
191 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
193 Self::with_compression(path, 6)
194 }
195
196 pub fn with_compression<P: AsRef<Path>>(path: P, compression_level: u32) -> Result<Self> {
198 Self::with_method(path, CompressionMethod::Deflate, compression_level)
199 }
200
201 pub fn with_method<P: AsRef<Path>>(
208 path: P,
209 method: CompressionMethod,
210 compression_level: u32,
211 ) -> Result<Self> {
212 let output = File::create(path)?;
213 Ok(Self {
214 output,
215 entries: Vec::new(),
216 current_entry: None,
217 compression_level,
218 compression_method: method,
219 })
220 }
221
222 #[cfg(feature = "zstd-support")]
224 pub fn with_zstd<P: AsRef<Path>>(path: P, compression_level: i32) -> Result<Self> {
225 let output = File::create(path)?;
226 Ok(Self {
227 output,
228 entries: Vec::new(),
229 current_entry: None,
230 compression_level: compression_level as u32,
231 compression_method: CompressionMethod::Zstd,
232 })
233 }
234}
235
236impl<W: Write + Seek> StreamingZipWriter<W> {
237 pub fn from_writer(writer: W) -> Result<Self> {
239 Self::from_writer_with_compression(writer, 6)
240 }
241
242 pub fn from_writer_with_compression(writer: W, compression_level: u32) -> Result<Self> {
244 Self::from_writer_with_method(writer, CompressionMethod::Deflate, compression_level)
245 }
246
247 pub fn from_writer_with_method(
254 writer: W,
255 method: CompressionMethod,
256 compression_level: u32,
257 ) -> Result<Self> {
258 Ok(Self {
259 output: writer,
260 entries: Vec::new(),
261 current_entry: None,
262 compression_level,
263 compression_method: method,
264 })
265 }
266
267 pub fn start_entry(&mut self, name: &str) -> Result<()> {
269 self.finish_current_entry()?;
271
272 let local_header_offset = self.output.stream_position()?;
273 let compression_method = self.compression_method.to_zip_method();
274
275 self.output.write_all(&[0x50, 0x4b, 0x03, 0x04])?; self.output.write_all(&[20, 0])?; self.output.write_all(&[8, 0])?; self.output.write_all(&compression_method.to_le_bytes())?; self.output.write_all(&[0, 0, 0, 0])?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&(name.len() as u16).to_le_bytes())?;
285 self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(name.as_bytes())?;
287
288 let encoder: Box<dyn CompressorWrite> = match self.compression_method {
290 CompressionMethod::Deflate => Box::new(DeflateCompressor {
291 encoder: DeflateEncoder::new(
292 CompressedBuffer::new(),
293 Compression::new(self.compression_level),
294 ),
295 }),
296 #[cfg(feature = "zstd-support")]
297 CompressionMethod::Zstd => {
298 let mut encoder =
299 zstd::Encoder::new(CompressedBuffer::new(), self.compression_level as i32)?;
300 encoder.include_checksum(false)?; Box::new(ZstdCompressor { encoder })
302 }
303 CompressionMethod::Stored => {
304 return Err(SZipError::InvalidFormat(
306 "Stored method not yet implemented".to_string(),
307 ));
308 }
309 };
310
311 self.current_entry = Some(CurrentEntry {
312 name: name.to_string(),
313 local_header_offset,
314 encoder,
315 counter: CrcCounter::new(),
316 compression_method,
317 });
318
319 Ok(())
320 }
321
322 pub fn write_data(&mut self, data: &[u8]) -> Result<()> {
324 let entry = self
325 .current_entry
326 .as_mut()
327 .ok_or_else(|| SZipError::InvalidFormat("No entry started".to_string()))?;
328
329 entry.counter.update_uncompressed(data);
331
332 entry.encoder.write_all(data)?;
334
335 entry.encoder.flush()?;
337
338 let buffer = entry.encoder.get_buffer_mut();
340 if buffer.should_flush() {
341 let compressed_data = buffer.take();
343 self.output.write_all(&compressed_data)?;
344 entry.counter.add_compressed(compressed_data.len() as u64);
345 }
346
347 Ok(())
348 }
349
350 fn finish_current_entry(&mut self) -> Result<()> {
352 if let Some(mut entry) = self.current_entry.take() {
353 let mut buffer = entry.encoder.finish_compression()?;
355
356 let remaining_data = buffer.take();
358 if !remaining_data.is_empty() {
359 self.output.write_all(&remaining_data)?;
360 entry.counter.add_compressed(remaining_data.len() as u64);
361 }
362
363 let crc = entry.counter.finalize();
364 let compressed_size = entry.counter.compressed_count;
365 let uncompressed_size = entry.counter.uncompressed_count;
366
367 self.output.write_all(&[0x50, 0x4b, 0x07, 0x08])?;
370 self.output.write_all(&crc.to_le_bytes())?;
371 if compressed_size > u32::MAX as u64 || uncompressed_size > u32::MAX as u64 {
373 self.output.write_all(&compressed_size.to_le_bytes())?;
374 self.output.write_all(&uncompressed_size.to_le_bytes())?;
375 } else {
376 self.output
377 .write_all(&(compressed_size as u32).to_le_bytes())?;
378 self.output
379 .write_all(&(uncompressed_size as u32).to_le_bytes())?;
380 }
381
382 self.entries.push(ZipEntry {
384 name: entry.name,
385 local_header_offset: entry.local_header_offset,
386 crc32: crc,
387 compressed_size,
388 uncompressed_size,
389 compression_method: entry.compression_method,
390 });
391 }
392 Ok(())
393 }
394
395 pub fn finish(mut self) -> Result<W> {
397 self.finish_current_entry()?;
399
400 let central_dir_offset = self.output.stream_position()?;
401
402 for entry in &self.entries {
404 self.output.write_all(&[0x50, 0x4b, 0x01, 0x02])?; self.output.write_all(&[20, 0])?; self.output.write_all(&[20, 0])?; self.output.write_all(&[8, 0])?; self.output
409 .write_all(&entry.compression_method.to_le_bytes())?; self.output.write_all(&[0, 0, 0, 0])?; self.output.write_all(&entry.crc32.to_le_bytes())?;
412
413 if entry.compressed_size > u32::MAX as u64 {
415 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
416 } else {
417 self.output
418 .write_all(&(entry.compressed_size as u32).to_le_bytes())?;
419 }
420
421 if entry.uncompressed_size > u32::MAX as u64 {
422 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
423 } else {
424 self.output
425 .write_all(&(entry.uncompressed_size as u32).to_le_bytes())?;
426 }
427
428 self.output
429 .write_all(&(entry.name.len() as u16).to_le_bytes())?;
430
431 let mut extra_field: Vec<u8> = Vec::new();
433 if entry.uncompressed_size > u32::MAX as u64
434 || entry.compressed_size > u32::MAX as u64
435 || entry.local_header_offset > u32::MAX as u64
436 {
437 extra_field.extend_from_slice(&0x0001u16.to_le_bytes());
439 let mut data: Vec<u8> = Vec::new();
441 if entry.uncompressed_size > u32::MAX as u64 {
442 data.extend_from_slice(&entry.uncompressed_size.to_le_bytes());
443 }
444 if entry.compressed_size > u32::MAX as u64 {
445 data.extend_from_slice(&entry.compressed_size.to_le_bytes());
446 }
447 if entry.local_header_offset > u32::MAX as u64 {
448 data.extend_from_slice(&entry.local_header_offset.to_le_bytes());
449 }
450 extra_field.extend_from_slice(&(data.len() as u16).to_le_bytes());
451 extra_field.extend_from_slice(&data);
452 }
453
454 self.output
455 .write_all(&(extra_field.len() as u16).to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; if entry.local_header_offset > u32::MAX as u64 {
463 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
464 } else {
465 self.output
466 .write_all(&(entry.local_header_offset as u32).to_le_bytes())?;
467 }
468
469 self.output.write_all(entry.name.as_bytes())?;
470 if !extra_field.is_empty() {
471 self.output.write_all(&extra_field)?;
472 }
473 }
474
475 let central_dir_size = self.output.stream_position()? - central_dir_offset;
476
477 let need_zip64 = self.entries.len() > u16::MAX as usize
479 || central_dir_size > u32::MAX as u64
480 || central_dir_offset > u32::MAX as u64;
481
482 if need_zip64 {
483 self.output.write_all(&[0x50, 0x4b, 0x06, 0x06])?; let zip64_eocd_size: u64 = 44;
489 self.output.write_all(&zip64_eocd_size.to_le_bytes())?;
490 self.output.write_all(&[20, 0])?;
492 self.output.write_all(&[20, 0])?;
493 self.output.write_all(&0u32.to_le_bytes())?;
495 self.output.write_all(&0u32.to_le_bytes())?;
496 self.output
498 .write_all(&(self.entries.len() as u64).to_le_bytes())?;
499 self.output
501 .write_all(&(self.entries.len() as u64).to_le_bytes())?;
502 self.output.write_all(¢ral_dir_size.to_le_bytes())?;
504 self.output.write_all(¢ral_dir_offset.to_le_bytes())?;
506
507 self.output.write_all(&[0x50, 0x4b, 0x06, 0x07])?; self.output.write_all(&0u32.to_le_bytes())?;
512 let zip64_eocd_pos = central_dir_offset + central_dir_size; self.output.write_all(&zip64_eocd_pos.to_le_bytes())?;
515 self.output.write_all(&0u32.to_le_bytes())?;
517 }
518
519 self.output.write_all(&[0x50, 0x4b, 0x05, 0x06])?;
521 self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; if self.entries.len() > u16::MAX as usize {
526 self.output.write_all(&0xFFFFu16.to_le_bytes())?;
527 self.output.write_all(&0xFFFFu16.to_le_bytes())?;
528 } else {
529 self.output
530 .write_all(&(self.entries.len() as u16).to_le_bytes())?;
531 self.output
532 .write_all(&(self.entries.len() as u16).to_le_bytes())?;
533 }
534
535 if central_dir_size > u32::MAX as u64 {
537 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
538 } else {
539 self.output
540 .write_all(&(central_dir_size as u32).to_le_bytes())?;
541 }
542
543 if central_dir_offset > u32::MAX as u64 {
544 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
545 } else {
546 self.output
547 .write_all(&(central_dir_offset as u32).to_le_bytes())?;
548 }
549
550 self.output.write_all(&0u16.to_le_bytes())?; self.output.flush()?;
553 Ok(self.output)
554 }
555}