s_zip/
writer.rs

1//! Streaming ZIP writer that compresses data on-the-fly without temp files
2//!
3//! This eliminates:
4//! - Temp file disk I/O
5//! - File read buffers
6//! - Intermediate storage
7//!
8//! Expected RAM savings: 5-8 MB per file
9
10use crate::error::{Result, SZipError};
11use crc32fast::Hasher as Crc32;
12use flate2::write::DeflateEncoder;
13use flate2::Compression;
14use std::fs::File;
15use std::io::{Seek, Write};
16use std::path::Path;
17
18/// Compression method to use for ZIP entries
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum CompressionMethod {
21    /// No compression (stored)
22    Stored,
23    /// DEFLATE compression (most common)
24    Deflate,
25    /// Zstd compression (requires zstd-support feature)
26    #[cfg(feature = "zstd-support")]
27    Zstd,
28}
29
30impl CompressionMethod {
31    fn to_zip_method(self) -> u16 {
32        match self {
33            CompressionMethod::Stored => 0,
34            CompressionMethod::Deflate => 8,
35            #[cfg(feature = "zstd-support")]
36            CompressionMethod::Zstd => 93,
37        }
38    }
39}
40
41/// Entry being written to ZIP
42struct ZipEntry {
43    name: String,
44    local_header_offset: u64,
45    crc32: u32,
46    compressed_size: u64,
47    uncompressed_size: u64,
48    compression_method: u16,
49}
50
51/// Streaming ZIP writer that compresses data on-the-fly
52pub struct StreamingZipWriter {
53    output: File,
54    entries: Vec<ZipEntry>,
55    current_entry: Option<CurrentEntry>,
56    compression_level: u32,
57    compression_method: CompressionMethod,
58}
59
60struct CurrentEntry {
61    name: String,
62    local_header_offset: u64,
63    encoder: Box<dyn CompressorWrite>,
64    compression_method: u16,
65}
66
67trait CompressorWrite: Write {
68    fn finish_compression(self: Box<Self>) -> Result<CrcCountingWriter>;
69    fn get_crc_writer_mut(&mut self) -> &mut CrcCountingWriter;
70}
71
72struct DeflateCompressor {
73    encoder: DeflateEncoder<CrcCountingWriter>,
74}
75
76impl Write for DeflateCompressor {
77    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
78        self.encoder.write(buf)
79    }
80
81    fn flush(&mut self) -> std::io::Result<()> {
82        self.encoder.flush()
83    }
84}
85
86impl CompressorWrite for DeflateCompressor {
87    fn finish_compression(self: Box<Self>) -> Result<CrcCountingWriter> {
88        Ok(self.encoder.finish()?)
89    }
90
91    fn get_crc_writer_mut(&mut self) -> &mut CrcCountingWriter {
92        self.encoder.get_mut()
93    }
94}
95
96#[cfg(feature = "zstd-support")]
97struct ZstdCompressor {
98    encoder: zstd::Encoder<'static, CrcCountingWriter>,
99}
100
101#[cfg(feature = "zstd-support")]
102impl Write for ZstdCompressor {
103    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
104        self.encoder.write(buf)
105    }
106
107    fn flush(&mut self) -> std::io::Result<()> {
108        self.encoder.flush()
109    }
110}
111
112#[cfg(feature = "zstd-support")]
113impl CompressorWrite for ZstdCompressor {
114    fn finish_compression(self: Box<Self>) -> Result<CrcCountingWriter> {
115        Ok(self.encoder.finish()?)
116    }
117
118    fn get_crc_writer_mut(&mut self) -> &mut CrcCountingWriter {
119        self.encoder.get_mut()
120    }
121}
122
123/// Writer that counts bytes and computes CRC32 while writing to output
124struct CrcCountingWriter {
125    output: File,
126    crc: Crc32,
127    uncompressed_count: u64,
128    compressed_count: u64,
129}
130
131impl CrcCountingWriter {
132    fn new(output: File) -> Self {
133        Self {
134            output,
135            crc: Crc32::new(),
136            uncompressed_count: 0,
137            compressed_count: 0,
138        }
139    }
140}
141
142impl Write for CrcCountingWriter {
143    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
144        // This is the compressed data being written
145        let n = self.output.write(buf)?;
146        self.compressed_count += n as u64;
147        Ok(n)
148    }
149
150    fn flush(&mut self) -> std::io::Result<()> {
151        self.output.flush()
152    }
153}
154
155impl StreamingZipWriter {
156    /// Create a new ZIP writer with default compression level (6) using DEFLATE
157    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
158        Self::with_compression(path, 6)
159    }
160
161    /// Create a new ZIP writer with custom compression level (0-9) using DEFLATE
162    pub fn with_compression<P: AsRef<Path>>(path: P, compression_level: u32) -> Result<Self> {
163        Self::with_method(path, CompressionMethod::Deflate, compression_level)
164    }
165
166    /// Create a new ZIP writer with specified compression method and level
167    ///
168    /// # Arguments
169    /// * `path` - Path to the output ZIP file
170    /// * `method` - Compression method to use (Deflate, Zstd, or Stored)
171    /// * `compression_level` - Compression level (0-9 for DEFLATE, 1-21 for Zstd)
172    pub fn with_method<P: AsRef<Path>>(
173        path: P,
174        method: CompressionMethod,
175        compression_level: u32,
176    ) -> Result<Self> {
177        let output = File::create(path)?;
178        Ok(Self {
179            output,
180            entries: Vec::new(),
181            current_entry: None,
182            compression_level,
183            compression_method: method,
184        })
185    }
186
187    /// Create a new ZIP writer with Zstd compression (requires zstd-support feature)
188    #[cfg(feature = "zstd-support")]
189    pub fn with_zstd<P: AsRef<Path>>(path: P, compression_level: i32) -> Result<Self> {
190        let output = File::create(path)?;
191        Ok(Self {
192            output,
193            entries: Vec::new(),
194            current_entry: None,
195            compression_level: compression_level as u32,
196            compression_method: CompressionMethod::Zstd,
197        })
198    }
199
200    /// Start a new entry (file) in the ZIP
201    pub fn start_entry(&mut self, name: &str) -> Result<()> {
202        // Finish previous entry if any
203        self.finish_current_entry()?;
204
205        let local_header_offset = self.output.stream_position()?;
206        let compression_method = self.compression_method.to_zip_method();
207
208        // Write local file header with data descriptor flag (bit 3)
209        self.output.write_all(&[0x50, 0x4b, 0x03, 0x04])?; // signature
210        self.output.write_all(&[20, 0])?; // version needed
211        self.output.write_all(&[8, 0])?; // general purpose bit flag (bit 3 set)
212        self.output.write_all(&compression_method.to_le_bytes())?; // compression method
213        self.output.write_all(&[0, 0, 0, 0])?; // mod time/date
214        self.output.write_all(&0u32.to_le_bytes())?; // crc32 placeholder
215        self.output.write_all(&0u32.to_le_bytes())?; // compressed size placeholder
216        self.output.write_all(&0u32.to_le_bytes())?; // uncompressed size placeholder
217        self.output.write_all(&(name.len() as u16).to_le_bytes())?;
218        self.output.write_all(&0u16.to_le_bytes())?; // extra len
219        self.output.write_all(name.as_bytes())?;
220
221        // Create encoder for this entry based on compression method
222        let counting_writer = CrcCountingWriter::new(self.output.try_clone()?);
223        let encoder: Box<dyn CompressorWrite> = match self.compression_method {
224            CompressionMethod::Deflate => Box::new(DeflateCompressor {
225                encoder: DeflateEncoder::new(
226                    counting_writer,
227                    Compression::new(self.compression_level),
228                ),
229            }),
230            #[cfg(feature = "zstd-support")]
231            CompressionMethod::Zstd => {
232                let mut encoder =
233                    zstd::Encoder::new(counting_writer, self.compression_level as i32)?;
234                encoder.include_checksum(false)?; // ZIP uses CRC32, not zstd checksum
235                Box::new(ZstdCompressor { encoder })
236            }
237            CompressionMethod::Stored => {
238                // For stored, we don't compress
239                return Err(SZipError::InvalidFormat(
240                    "Stored method not yet implemented".to_string(),
241                ));
242            }
243        };
244
245        self.current_entry = Some(CurrentEntry {
246            name: name.to_string(),
247            local_header_offset,
248            encoder,
249            compression_method,
250        });
251
252        Ok(())
253    }
254
255    /// Write uncompressed data to current entry (will be compressed on-the-fly)
256    pub fn write_data(&mut self, data: &[u8]) -> Result<()> {
257        if let Some(ref mut entry) = self.current_entry {
258            // Update CRC with uncompressed data
259            let crc_writer = entry.encoder.get_crc_writer_mut();
260            crc_writer.crc.update(data);
261            crc_writer.uncompressed_count += data.len() as u64;
262
263            // Write to encoder (compresses and writes to output)
264            entry.encoder.write_all(data)?;
265            Ok(())
266        } else {
267            Err(SZipError::InvalidFormat("No entry started".to_string()))
268        }
269    }
270
271    /// Finish current entry and write data descriptor
272    fn finish_current_entry(&mut self) -> Result<()> {
273        if let Some(entry) = self.current_entry.take() {
274            // Finish compression
275            let counting_writer = entry.encoder.finish_compression()?;
276
277            let crc = counting_writer.crc.finalize();
278            let compressed_size = counting_writer.compressed_count;
279            let uncompressed_size = counting_writer.uncompressed_count;
280
281            // Write data descriptor
282            // signature
283            self.output.write_all(&[0x50, 0x4b, 0x07, 0x08])?;
284            self.output.write_all(&crc.to_le_bytes())?;
285            // If sizes exceed 32-bit, write 64-bit sizes (ZIP64 data descriptor)
286            if compressed_size > u32::MAX as u64 || uncompressed_size > u32::MAX as u64 {
287                self.output.write_all(&compressed_size.to_le_bytes())?;
288                self.output.write_all(&uncompressed_size.to_le_bytes())?;
289            } else {
290                self.output
291                    .write_all(&(compressed_size as u32).to_le_bytes())?;
292                self.output
293                    .write_all(&(uncompressed_size as u32).to_le_bytes())?;
294            }
295
296            // Save entry info for central directory
297            self.entries.push(ZipEntry {
298                name: entry.name,
299                local_header_offset: entry.local_header_offset,
300                crc32: crc,
301                compressed_size,
302                uncompressed_size,
303                compression_method: entry.compression_method,
304            });
305        }
306        Ok(())
307    }
308
309    /// Finish ZIP file (write central directory and close)
310    pub fn finish(mut self) -> Result<()> {
311        // Finish last entry
312        self.finish_current_entry()?;
313
314        let central_dir_offset = self.output.stream_position()?;
315
316        // Write central directory
317        for entry in &self.entries {
318            self.output.write_all(&[0x50, 0x4b, 0x01, 0x02])?; // central dir sig
319            self.output.write_all(&[20, 0])?; // version made by
320            self.output.write_all(&[20, 0])?; // version needed
321            self.output.write_all(&[8, 0])?; // general purpose bit flag (bit 3 set)
322            self.output
323                .write_all(&entry.compression_method.to_le_bytes())?; // compression method
324            self.output.write_all(&[0, 0, 0, 0])?; // mod time/date
325            self.output.write_all(&entry.crc32.to_le_bytes())?;
326
327            // Write sizes (32-bit placeholders or actual values)
328            if entry.compressed_size > u32::MAX as u64 {
329                self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
330            } else {
331                self.output
332                    .write_all(&(entry.compressed_size as u32).to_le_bytes())?;
333            }
334
335            if entry.uncompressed_size > u32::MAX as u64 {
336                self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
337            } else {
338                self.output
339                    .write_all(&(entry.uncompressed_size as u32).to_le_bytes())?;
340            }
341
342            self.output
343                .write_all(&(entry.name.len() as u16).to_le_bytes())?;
344
345            // Prepare ZIP64 extra field if needed
346            let mut extra_field: Vec<u8> = Vec::new();
347            if entry.uncompressed_size > u32::MAX as u64
348                || entry.compressed_size > u32::MAX as u64
349                || entry.local_header_offset > u32::MAX as u64
350            {
351                // ZIP64 extra header ID 0x0001
352                extra_field.extend_from_slice(&0x0001u16.to_le_bytes());
353                // data size: we'll include uncompressed (8) if needed, compressed (8) if needed, and offset (8) if needed
354                let mut data: Vec<u8> = Vec::new();
355                if entry.uncompressed_size > u32::MAX as u64 {
356                    data.extend_from_slice(&entry.uncompressed_size.to_le_bytes());
357                }
358                if entry.compressed_size > u32::MAX as u64 {
359                    data.extend_from_slice(&entry.compressed_size.to_le_bytes());
360                }
361                if entry.local_header_offset > u32::MAX as u64 {
362                    data.extend_from_slice(&entry.local_header_offset.to_le_bytes());
363                }
364                extra_field.extend_from_slice(&(data.len() as u16).to_le_bytes());
365                extra_field.extend_from_slice(&data);
366            }
367
368            self.output
369                .write_all(&(extra_field.len() as u16).to_le_bytes())?; // extra len
370            self.output.write_all(&0u16.to_le_bytes())?; // file comment len
371            self.output.write_all(&0u16.to_le_bytes())?; // disk number start
372            self.output.write_all(&0u16.to_le_bytes())?; // internal attrs
373            self.output.write_all(&0u32.to_le_bytes())?; // external attrs
374
375            // local header offset (32-bit or 0xFFFFFFFF)
376            if entry.local_header_offset > u32::MAX as u64 {
377                self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
378            } else {
379                self.output
380                    .write_all(&(entry.local_header_offset as u32).to_le_bytes())?;
381            }
382
383            self.output.write_all(entry.name.as_bytes())?;
384            if !extra_field.is_empty() {
385                self.output.write_all(&extra_field)?;
386            }
387        }
388
389        let central_dir_size = self.output.stream_position()? - central_dir_offset;
390
391        // Determine if we need ZIP64 EOCD
392        let need_zip64 = self.entries.len() > u16::MAX as usize
393            || central_dir_size > u32::MAX as u64
394            || central_dir_offset > u32::MAX as u64;
395
396        if need_zip64 {
397            // Write ZIP64 End of Central Directory Record
398            // signature
399            self.output.write_all(&[0x50, 0x4b, 0x06, 0x06])?; // 0x06064b50
400                                                               // size of zip64 eocd record (size of remaining fields)
401                                                               // We'll write fixed-size fields: version made by(2)+version needed(2)+disk numbers(4+4)+entries on disk(8)+total entries(8)+cd size(8)+cd offset(8)
402            let zip64_eocd_size: u64 = 44;
403            self.output.write_all(&zip64_eocd_size.to_le_bytes())?;
404            // version made by, version needed
405            self.output.write_all(&[20, 0])?;
406            self.output.write_all(&[20, 0])?;
407            // disk number, disk where central dir starts
408            self.output.write_all(&0u32.to_le_bytes())?;
409            self.output.write_all(&0u32.to_le_bytes())?;
410            // entries on this disk (8)
411            self.output
412                .write_all(&(self.entries.len() as u64).to_le_bytes())?;
413            // total entries (8)
414            self.output
415                .write_all(&(self.entries.len() as u64).to_le_bytes())?;
416            // central directory size (8)
417            self.output.write_all(&central_dir_size.to_le_bytes())?;
418            // central directory offset (8)
419            self.output.write_all(&central_dir_offset.to_le_bytes())?;
420
421            // Write ZIP64 EOCD locator
422            // signature
423            self.output.write_all(&[0x50, 0x4b, 0x06, 0x07])?; // 0x07064b50
424                                                               // disk with ZIP64 EOCD (4)
425            self.output.write_all(&0u32.to_le_bytes())?;
426            // relative offset of ZIP64 EOCD (8)
427            let zip64_eocd_pos = central_dir_offset + central_dir_size; // directly after central dir
428            self.output.write_all(&zip64_eocd_pos.to_le_bytes())?;
429            // total number of disks
430            self.output.write_all(&0u32.to_le_bytes())?;
431        }
432
433        // Write end of central directory (classic)
434        self.output.write_all(&[0x50, 0x4b, 0x05, 0x06])?;
435        self.output.write_all(&0u16.to_le_bytes())?; // disk number
436        self.output.write_all(&0u16.to_le_bytes())?; // disk with central dir
437
438        // number of entries (16-bit or 0xFFFF if ZIP64 used)
439        if self.entries.len() > u16::MAX as usize {
440            self.output.write_all(&0xFFFFu16.to_le_bytes())?;
441            self.output.write_all(&0xFFFFu16.to_le_bytes())?;
442        } else {
443            self.output
444                .write_all(&(self.entries.len() as u16).to_le_bytes())?;
445            self.output
446                .write_all(&(self.entries.len() as u16).to_le_bytes())?;
447        }
448
449        // central dir size and offset (32-bit or 0xFFFFFFFF)
450        if central_dir_size > u32::MAX as u64 {
451            self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
452        } else {
453            self.output
454                .write_all(&(central_dir_size as u32).to_le_bytes())?;
455        }
456
457        if central_dir_offset > u32::MAX as u64 {
458            self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
459        } else {
460            self.output
461                .write_all(&(central_dir_offset as u32).to_le_bytes())?;
462        }
463
464        self.output.write_all(&0u16.to_le_bytes())?; // comment len
465
466        self.output.flush()?;
467        Ok(())
468    }
469}