s_zip/
writer.rs

1//! Streaming ZIP writer that compresses data on-the-fly without temp files
2//!
3//! This eliminates:
4//! - Temp file disk I/O
5//! - File read buffers
6//! - Intermediate storage
7//!
8//! Expected RAM savings: 5-8 MB per file
9//!
10//! Now supports arbitrary writers (File, Vec<u8>, network streams, etc.)
11
12use crate::error::{Result, SZipError};
13use crc32fast::Hasher as Crc32;
14use flate2::write::DeflateEncoder;
15use flate2::Compression;
16use std::fs::File;
17use std::io::{Seek, Write};
18use std::path::Path;
19
20/// Compression method to use for ZIP entries
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum CompressionMethod {
23    /// No compression (stored)
24    Stored,
25    /// DEFLATE compression (most common)
26    Deflate,
27    /// Zstd compression (requires zstd-support feature)
28    #[cfg(feature = "zstd-support")]
29    Zstd,
30}
31
32impl CompressionMethod {
33    pub(crate) fn to_zip_method(self) -> u16 {
34        match self {
35            CompressionMethod::Stored => 0,
36            CompressionMethod::Deflate => 8,
37            #[cfg(feature = "zstd-support")]
38            CompressionMethod::Zstd => 93,
39        }
40    }
41}
42
43/// Entry being written to ZIP
44struct ZipEntry {
45    name: String,
46    local_header_offset: u64,
47    crc32: u32,
48    compressed_size: u64,
49    uncompressed_size: u64,
50    compression_method: u16,
51}
52
53/// Streaming ZIP writer that compresses data on-the-fly
54pub struct StreamingZipWriter<W: Write + Seek> {
55    output: W,
56    entries: Vec<ZipEntry>,
57    current_entry: Option<CurrentEntry>,
58    compression_level: u32,
59    compression_method: CompressionMethod,
60}
61
62struct CurrentEntry {
63    name: String,
64    local_header_offset: u64,
65    encoder: Box<dyn CompressorWrite>,
66    counter: CrcCounter,
67    compression_method: u16,
68}
69
70trait CompressorWrite: Write {
71    fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer>;
72    fn get_buffer_mut(&mut self) -> &mut CompressedBuffer;
73}
74
75struct DeflateCompressor {
76    encoder: DeflateEncoder<CompressedBuffer>,
77}
78
79impl Write for DeflateCompressor {
80    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
81        self.encoder.write(buf)
82    }
83
84    fn flush(&mut self) -> std::io::Result<()> {
85        self.encoder.flush()
86    }
87}
88
89impl CompressorWrite for DeflateCompressor {
90    fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer> {
91        Ok(self.encoder.finish()?)
92    }
93
94    fn get_buffer_mut(&mut self) -> &mut CompressedBuffer {
95        self.encoder.get_mut()
96    }
97}
98
99#[cfg(feature = "zstd-support")]
100struct ZstdCompressor {
101    encoder: zstd::Encoder<'static, CompressedBuffer>,
102}
103
104#[cfg(feature = "zstd-support")]
105impl Write for ZstdCompressor {
106    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
107        self.encoder.write(buf)
108    }
109
110    fn flush(&mut self) -> std::io::Result<()> {
111        self.encoder.flush()
112    }
113}
114
115#[cfg(feature = "zstd-support")]
116impl CompressorWrite for ZstdCompressor {
117    fn finish_compression(self: Box<Self>) -> Result<CompressedBuffer> {
118        Ok(self.encoder.finish()?)
119    }
120
121    fn get_buffer_mut(&mut self) -> &mut CompressedBuffer {
122        self.encoder.get_mut()
123    }
124}
125
126/// Metadata tracker for CRC and byte counts
127struct CrcCounter {
128    crc: Crc32,
129    uncompressed_count: u64,
130    compressed_count: u64,
131}
132
133impl CrcCounter {
134    fn new() -> Self {
135        Self {
136            crc: Crc32::new(),
137            uncompressed_count: 0,
138            compressed_count: 0,
139        }
140    }
141
142    fn update_uncompressed(&mut self, data: &[u8]) {
143        self.crc.update(data);
144        self.uncompressed_count += data.len() as u64;
145    }
146
147    fn add_compressed(&mut self, count: u64) {
148        self.compressed_count += count;
149    }
150
151    fn finalize(&self) -> u32 {
152        self.crc.clone().finalize()
153    }
154}
155
156/// Buffered writer for compressed data with size limit
157struct CompressedBuffer {
158    buffer: Vec<u8>,
159    flush_threshold: usize,
160}
161
162impl CompressedBuffer {
163    fn new() -> Self {
164        Self {
165            buffer: Vec::with_capacity(64 * 1024), // 64KB initial capacity
166            flush_threshold: 1024 * 1024,          // 1MB threshold
167        }
168    }
169
170    fn take(&mut self) -> Vec<u8> {
171        std::mem::take(&mut self.buffer)
172    }
173
174    fn should_flush(&self) -> bool {
175        self.buffer.len() >= self.flush_threshold
176    }
177}
178
179impl Write for CompressedBuffer {
180    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
181        self.buffer.extend_from_slice(buf);
182        Ok(buf.len())
183    }
184
185    fn flush(&mut self) -> std::io::Result<()> {
186        Ok(())
187    }
188}
189
190impl StreamingZipWriter<File> {
191    /// Create a new ZIP writer with default compression level (6) using DEFLATE
192    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
193        Self::with_compression(path, 6)
194    }
195
196    /// Create a new ZIP writer with custom compression level (0-9) using DEFLATE
197    pub fn with_compression<P: AsRef<Path>>(path: P, compression_level: u32) -> Result<Self> {
198        Self::with_method(path, CompressionMethod::Deflate, compression_level)
199    }
200
201    /// Create a new ZIP writer with specified compression method and level
202    ///
203    /// # Arguments
204    /// * `path` - Path to the output ZIP file
205    /// * `method` - Compression method to use (Deflate, Zstd, or Stored)
206    /// * `compression_level` - Compression level (0-9 for DEFLATE, 1-21 for Zstd)
207    pub fn with_method<P: AsRef<Path>>(
208        path: P,
209        method: CompressionMethod,
210        compression_level: u32,
211    ) -> Result<Self> {
212        let output = File::create(path)?;
213        Ok(Self {
214            output,
215            entries: Vec::new(),
216            current_entry: None,
217            compression_level,
218            compression_method: method,
219        })
220    }
221
222    /// Create a new ZIP writer with Zstd compression (requires zstd-support feature)
223    #[cfg(feature = "zstd-support")]
224    pub fn with_zstd<P: AsRef<Path>>(path: P, compression_level: i32) -> Result<Self> {
225        let output = File::create(path)?;
226        Ok(Self {
227            output,
228            entries: Vec::new(),
229            current_entry: None,
230            compression_level: compression_level as u32,
231            compression_method: CompressionMethod::Zstd,
232        })
233    }
234}
235
236impl<W: Write + Seek> StreamingZipWriter<W> {
237    /// Create a new ZIP writer from an arbitrary writer with default compression level (6) using DEFLATE
238    pub fn from_writer(writer: W) -> Result<Self> {
239        Self::from_writer_with_compression(writer, 6)
240    }
241
242    /// Create a new ZIP writer from an arbitrary writer with custom compression level
243    pub fn from_writer_with_compression(writer: W, compression_level: u32) -> Result<Self> {
244        Self::from_writer_with_method(writer, CompressionMethod::Deflate, compression_level)
245    }
246
247    /// Create a new ZIP writer from an arbitrary writer with specified compression method and level
248    ///
249    /// # Arguments
250    /// * `writer` - Any writer implementing Write + Seek
251    /// * `method` - Compression method to use (Deflate, Zstd, or Stored)
252    /// * `compression_level` - Compression level (0-9 for DEFLATE, 1-21 for Zstd)
253    pub fn from_writer_with_method(
254        writer: W,
255        method: CompressionMethod,
256        compression_level: u32,
257    ) -> Result<Self> {
258        Ok(Self {
259            output: writer,
260            entries: Vec::new(),
261            current_entry: None,
262            compression_level,
263            compression_method: method,
264        })
265    }
266
267    /// Start a new entry (file) in the ZIP
268    pub fn start_entry(&mut self, name: &str) -> Result<()> {
269        // Finish previous entry if any
270        self.finish_current_entry()?;
271
272        let local_header_offset = self.output.stream_position()?;
273        let compression_method = self.compression_method.to_zip_method();
274
275        // Write local file header with data descriptor flag (bit 3)
276        self.output.write_all(&[0x50, 0x4b, 0x03, 0x04])?; // signature
277        self.output.write_all(&[20, 0])?; // version needed
278        self.output.write_all(&[8, 0])?; // general purpose bit flag (bit 3 set)
279        self.output.write_all(&compression_method.to_le_bytes())?; // compression method
280        self.output.write_all(&[0, 0, 0, 0])?; // mod time/date
281        self.output.write_all(&0u32.to_le_bytes())?; // crc32 placeholder
282        self.output.write_all(&0u32.to_le_bytes())?; // compressed size placeholder
283        self.output.write_all(&0u32.to_le_bytes())?; // uncompressed size placeholder
284        self.output.write_all(&(name.len() as u16).to_le_bytes())?;
285        self.output.write_all(&0u16.to_le_bytes())?; // extra len
286        self.output.write_all(name.as_bytes())?;
287
288        // Create encoder for this entry based on compression method
289        let encoder: Box<dyn CompressorWrite> = match self.compression_method {
290            CompressionMethod::Deflate => Box::new(DeflateCompressor {
291                encoder: DeflateEncoder::new(
292                    CompressedBuffer::new(),
293                    Compression::new(self.compression_level),
294                ),
295            }),
296            #[cfg(feature = "zstd-support")]
297            CompressionMethod::Zstd => {
298                let mut encoder =
299                    zstd::Encoder::new(CompressedBuffer::new(), self.compression_level as i32)?;
300                encoder.include_checksum(false)?; // ZIP uses CRC32, not zstd checksum
301                Box::new(ZstdCompressor { encoder })
302            }
303            CompressionMethod::Stored => {
304                // For stored, we don't compress
305                return Err(SZipError::InvalidFormat(
306                    "Stored method not yet implemented".to_string(),
307                ));
308            }
309        };
310
311        self.current_entry = Some(CurrentEntry {
312            name: name.to_string(),
313            local_header_offset,
314            encoder,
315            counter: CrcCounter::new(),
316            compression_method,
317        });
318
319        Ok(())
320    }
321
322    /// Write uncompressed data to current entry (will be compressed on-the-fly)
323    pub fn write_data(&mut self, data: &[u8]) -> Result<()> {
324        let entry = self
325            .current_entry
326            .as_mut()
327            .ok_or_else(|| SZipError::InvalidFormat("No entry started".to_string()))?;
328
329        // Update CRC and size with uncompressed data
330        entry.counter.update_uncompressed(data);
331
332        // Write to encoder (compresses data into buffer)
333        entry.encoder.write_all(data)?;
334
335        // Flush encoder to ensure all data is in buffer
336        entry.encoder.flush()?;
337
338        // Check if buffer should be flushed to output
339        let buffer = entry.encoder.get_buffer_mut();
340        if buffer.should_flush() {
341            // Flush buffer to output to keep memory usage low
342            let compressed_data = buffer.take();
343            self.output.write_all(&compressed_data)?;
344            entry.counter.add_compressed(compressed_data.len() as u64);
345        }
346
347        Ok(())
348    }
349
350    /// Finish current entry and write data descriptor
351    fn finish_current_entry(&mut self) -> Result<()> {
352        if let Some(mut entry) = self.current_entry.take() {
353            // Finish compression and get remaining buffered data
354            let mut buffer = entry.encoder.finish_compression()?;
355
356            // Flush any remaining data from buffer to output
357            let remaining_data = buffer.take();
358            if !remaining_data.is_empty() {
359                self.output.write_all(&remaining_data)?;
360                entry.counter.add_compressed(remaining_data.len() as u64);
361            }
362
363            let crc = entry.counter.finalize();
364            let compressed_size = entry.counter.compressed_count;
365            let uncompressed_size = entry.counter.uncompressed_count;
366
367            // Write data descriptor
368            // signature
369            self.output.write_all(&[0x50, 0x4b, 0x07, 0x08])?;
370            self.output.write_all(&crc.to_le_bytes())?;
371            // If sizes exceed 32-bit, write 64-bit sizes (ZIP64 data descriptor)
372            if compressed_size > u32::MAX as u64 || uncompressed_size > u32::MAX as u64 {
373                self.output.write_all(&compressed_size.to_le_bytes())?;
374                self.output.write_all(&uncompressed_size.to_le_bytes())?;
375            } else {
376                self.output
377                    .write_all(&(compressed_size as u32).to_le_bytes())?;
378                self.output
379                    .write_all(&(uncompressed_size as u32).to_le_bytes())?;
380            }
381
382            // Save entry info for central directory
383            self.entries.push(ZipEntry {
384                name: entry.name,
385                local_header_offset: entry.local_header_offset,
386                crc32: crc,
387                compressed_size,
388                uncompressed_size,
389                compression_method: entry.compression_method,
390            });
391        }
392        Ok(())
393    }
394
395    /// Finish ZIP file (write central directory and return the writer)
396    pub fn finish(mut self) -> Result<W> {
397        // Finish last entry
398        self.finish_current_entry()?;
399
400        let central_dir_offset = self.output.stream_position()?;
401
402        // Write central directory
403        for entry in &self.entries {
404            self.output.write_all(&[0x50, 0x4b, 0x01, 0x02])?; // central dir sig
405            self.output.write_all(&[20, 0])?; // version made by
406            self.output.write_all(&[20, 0])?; // version needed
407            self.output.write_all(&[8, 0])?; // general purpose bit flag (bit 3 set)
408            self.output
409                .write_all(&entry.compression_method.to_le_bytes())?; // compression method
410            self.output.write_all(&[0, 0, 0, 0])?; // mod time/date
411            self.output.write_all(&entry.crc32.to_le_bytes())?;
412
413            // Write sizes (32-bit placeholders or actual values)
414            if entry.compressed_size > u32::MAX as u64 {
415                self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
416            } else {
417                self.output
418                    .write_all(&(entry.compressed_size as u32).to_le_bytes())?;
419            }
420
421            if entry.uncompressed_size > u32::MAX as u64 {
422                self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
423            } else {
424                self.output
425                    .write_all(&(entry.uncompressed_size as u32).to_le_bytes())?;
426            }
427
428            self.output
429                .write_all(&(entry.name.len() as u16).to_le_bytes())?;
430
431            // Prepare ZIP64 extra field if needed
432            let mut extra_field: Vec<u8> = Vec::new();
433            if entry.uncompressed_size > u32::MAX as u64
434                || entry.compressed_size > u32::MAX as u64
435                || entry.local_header_offset > u32::MAX as u64
436            {
437                // ZIP64 extra header ID 0x0001
438                extra_field.extend_from_slice(&0x0001u16.to_le_bytes());
439                // data size: we'll include uncompressed (8) if needed, compressed (8) if needed, and offset (8) if needed
440                let mut data: Vec<u8> = Vec::new();
441                if entry.uncompressed_size > u32::MAX as u64 {
442                    data.extend_from_slice(&entry.uncompressed_size.to_le_bytes());
443                }
444                if entry.compressed_size > u32::MAX as u64 {
445                    data.extend_from_slice(&entry.compressed_size.to_le_bytes());
446                }
447                if entry.local_header_offset > u32::MAX as u64 {
448                    data.extend_from_slice(&entry.local_header_offset.to_le_bytes());
449                }
450                extra_field.extend_from_slice(&(data.len() as u16).to_le_bytes());
451                extra_field.extend_from_slice(&data);
452            }
453
454            self.output
455                .write_all(&(extra_field.len() as u16).to_le_bytes())?; // extra len
456            self.output.write_all(&0u16.to_le_bytes())?; // file comment len
457            self.output.write_all(&0u16.to_le_bytes())?; // disk number start
458            self.output.write_all(&0u16.to_le_bytes())?; // internal attrs
459            self.output.write_all(&0u32.to_le_bytes())?; // external attrs
460
461            // local header offset (32-bit or 0xFFFFFFFF)
462            if entry.local_header_offset > u32::MAX as u64 {
463                self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
464            } else {
465                self.output
466                    .write_all(&(entry.local_header_offset as u32).to_le_bytes())?;
467            }
468
469            self.output.write_all(entry.name.as_bytes())?;
470            if !extra_field.is_empty() {
471                self.output.write_all(&extra_field)?;
472            }
473        }
474
475        let central_dir_size = self.output.stream_position()? - central_dir_offset;
476
477        // Determine if we need ZIP64 EOCD
478        let need_zip64 = self.entries.len() > u16::MAX as usize
479            || central_dir_size > u32::MAX as u64
480            || central_dir_offset > u32::MAX as u64;
481
482        if need_zip64 {
483            // Write ZIP64 End of Central Directory Record
484            // signature
485            self.output.write_all(&[0x50, 0x4b, 0x06, 0x06])?; // 0x06064b50
486                                                               // size of zip64 eocd record (size of remaining fields)
487                                                               // We'll write fixed-size fields: version made by(2)+version needed(2)+disk numbers(4+4)+entries on disk(8)+total entries(8)+cd size(8)+cd offset(8)
488            let zip64_eocd_size: u64 = 44;
489            self.output.write_all(&zip64_eocd_size.to_le_bytes())?;
490            // version made by, version needed
491            self.output.write_all(&[20, 0])?;
492            self.output.write_all(&[20, 0])?;
493            // disk number, disk where central dir starts
494            self.output.write_all(&0u32.to_le_bytes())?;
495            self.output.write_all(&0u32.to_le_bytes())?;
496            // entries on this disk (8)
497            self.output
498                .write_all(&(self.entries.len() as u64).to_le_bytes())?;
499            // total entries (8)
500            self.output
501                .write_all(&(self.entries.len() as u64).to_le_bytes())?;
502            // central directory size (8)
503            self.output.write_all(&central_dir_size.to_le_bytes())?;
504            // central directory offset (8)
505            self.output.write_all(&central_dir_offset.to_le_bytes())?;
506
507            // Write ZIP64 EOCD locator
508            // signature
509            self.output.write_all(&[0x50, 0x4b, 0x06, 0x07])?; // 0x07064b50
510                                                               // disk with ZIP64 EOCD (4)
511            self.output.write_all(&0u32.to_le_bytes())?;
512            // relative offset of ZIP64 EOCD (8)
513            let zip64_eocd_pos = central_dir_offset + central_dir_size; // directly after central dir
514            self.output.write_all(&zip64_eocd_pos.to_le_bytes())?;
515            // total number of disks
516            self.output.write_all(&0u32.to_le_bytes())?;
517        }
518
519        // Write end of central directory (classic)
520        self.output.write_all(&[0x50, 0x4b, 0x05, 0x06])?;
521        self.output.write_all(&0u16.to_le_bytes())?; // disk number
522        self.output.write_all(&0u16.to_le_bytes())?; // disk with central dir
523
524        // number of entries (16-bit or 0xFFFF if ZIP64 used)
525        if self.entries.len() > u16::MAX as usize {
526            self.output.write_all(&0xFFFFu16.to_le_bytes())?;
527            self.output.write_all(&0xFFFFu16.to_le_bytes())?;
528        } else {
529            self.output
530                .write_all(&(self.entries.len() as u16).to_le_bytes())?;
531            self.output
532                .write_all(&(self.entries.len() as u16).to_le_bytes())?;
533        }
534
535        // central dir size and offset (32-bit or 0xFFFFFFFF)
536        if central_dir_size > u32::MAX as u64 {
537            self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
538        } else {
539            self.output
540                .write_all(&(central_dir_size as u32).to_le_bytes())?;
541        }
542
543        if central_dir_offset > u32::MAX as u64 {
544            self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
545        } else {
546            self.output
547                .write_all(&(central_dir_offset as u32).to_le_bytes())?;
548        }
549
550        self.output.write_all(&0u16.to_le_bytes())?; // comment len
551
552        self.output.flush()?;
553        Ok(self.output)
554    }
555}