1use crate::error::{Result, SZipError};
11use crc32fast::Hasher as Crc32;
12use flate2::write::DeflateEncoder;
13use flate2::Compression;
14use std::fs::File;
15use std::io::{Seek, Write};
16use std::path::Path;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum CompressionMethod {
21 Stored,
23 Deflate,
25 #[cfg(feature = "zstd-support")]
27 Zstd,
28}
29
30impl CompressionMethod {
31 fn to_zip_method(self) -> u16 {
32 match self {
33 CompressionMethod::Stored => 0,
34 CompressionMethod::Deflate => 8,
35 #[cfg(feature = "zstd-support")]
36 CompressionMethod::Zstd => 93,
37 }
38 }
39}
40
41struct ZipEntry {
43 name: String,
44 local_header_offset: u64,
45 crc32: u32,
46 compressed_size: u64,
47 uncompressed_size: u64,
48 compression_method: u16,
49}
50
51pub struct StreamingZipWriter {
53 output: File,
54 entries: Vec<ZipEntry>,
55 current_entry: Option<CurrentEntry>,
56 compression_level: u32,
57 compression_method: CompressionMethod,
58}
59
60struct CurrentEntry {
61 name: String,
62 local_header_offset: u64,
63 encoder: Box<dyn CompressorWrite>,
64 compression_method: u16,
65}
66
67trait CompressorWrite: Write {
68 fn finish_compression(self: Box<Self>) -> Result<CrcCountingWriter>;
69 fn get_crc_writer_mut(&mut self) -> &mut CrcCountingWriter;
70}
71
72struct DeflateCompressor {
73 encoder: DeflateEncoder<CrcCountingWriter>,
74}
75
76impl Write for DeflateCompressor {
77 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
78 self.encoder.write(buf)
79 }
80
81 fn flush(&mut self) -> std::io::Result<()> {
82 self.encoder.flush()
83 }
84}
85
86impl CompressorWrite for DeflateCompressor {
87 fn finish_compression(self: Box<Self>) -> Result<CrcCountingWriter> {
88 Ok(self.encoder.finish()?)
89 }
90
91 fn get_crc_writer_mut(&mut self) -> &mut CrcCountingWriter {
92 self.encoder.get_mut()
93 }
94}
95
96#[cfg(feature = "zstd-support")]
97struct ZstdCompressor {
98 encoder: zstd::Encoder<'static, CrcCountingWriter>,
99}
100
101#[cfg(feature = "zstd-support")]
102impl Write for ZstdCompressor {
103 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
104 self.encoder.write(buf)
105 }
106
107 fn flush(&mut self) -> std::io::Result<()> {
108 self.encoder.flush()
109 }
110}
111
112#[cfg(feature = "zstd-support")]
113impl CompressorWrite for ZstdCompressor {
114 fn finish_compression(self: Box<Self>) -> Result<CrcCountingWriter> {
115 Ok(self.encoder.finish()?)
116 }
117
118 fn get_crc_writer_mut(&mut self) -> &mut CrcCountingWriter {
119 self.encoder.get_mut()
120 }
121}
122
123struct CrcCountingWriter {
125 output: File,
126 crc: Crc32,
127 uncompressed_count: u64,
128 compressed_count: u64,
129}
130
131impl CrcCountingWriter {
132 fn new(output: File) -> Self {
133 Self {
134 output,
135 crc: Crc32::new(),
136 uncompressed_count: 0,
137 compressed_count: 0,
138 }
139 }
140}
141
142impl Write for CrcCountingWriter {
143 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
144 let n = self.output.write(buf)?;
146 self.compressed_count += n as u64;
147 Ok(n)
148 }
149
150 fn flush(&mut self) -> std::io::Result<()> {
151 self.output.flush()
152 }
153}
154
155impl StreamingZipWriter {
156 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
158 Self::with_compression(path, 6)
159 }
160
161 pub fn with_compression<P: AsRef<Path>>(path: P, compression_level: u32) -> Result<Self> {
163 Self::with_method(path, CompressionMethod::Deflate, compression_level)
164 }
165
166 pub fn with_method<P: AsRef<Path>>(
173 path: P,
174 method: CompressionMethod,
175 compression_level: u32,
176 ) -> Result<Self> {
177 let output = File::create(path)?;
178 Ok(Self {
179 output,
180 entries: Vec::new(),
181 current_entry: None,
182 compression_level,
183 compression_method: method,
184 })
185 }
186
187 #[cfg(feature = "zstd-support")]
189 pub fn with_zstd<P: AsRef<Path>>(path: P, compression_level: i32) -> Result<Self> {
190 let output = File::create(path)?;
191 Ok(Self {
192 output,
193 entries: Vec::new(),
194 current_entry: None,
195 compression_level: compression_level as u32,
196 compression_method: CompressionMethod::Zstd,
197 })
198 }
199
200 pub fn start_entry(&mut self, name: &str) -> Result<()> {
202 self.finish_current_entry()?;
204
205 let local_header_offset = self.output.stream_position()?;
206 let compression_method = self.compression_method.to_zip_method();
207
208 self.output.write_all(&[0x50, 0x4b, 0x03, 0x04])?; self.output.write_all(&[20, 0])?; self.output.write_all(&[8, 0])?; self.output.write_all(&compression_method.to_le_bytes())?; self.output.write_all(&[0, 0, 0, 0])?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; self.output.write_all(&(name.len() as u16).to_le_bytes())?;
218 self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(name.as_bytes())?;
220
221 let counting_writer = CrcCountingWriter::new(self.output.try_clone()?);
223 let encoder: Box<dyn CompressorWrite> = match self.compression_method {
224 CompressionMethod::Deflate => Box::new(DeflateCompressor {
225 encoder: DeflateEncoder::new(
226 counting_writer,
227 Compression::new(self.compression_level),
228 ),
229 }),
230 #[cfg(feature = "zstd-support")]
231 CompressionMethod::Zstd => {
232 let mut encoder =
233 zstd::Encoder::new(counting_writer, self.compression_level as i32)?;
234 encoder.include_checksum(false)?; Box::new(ZstdCompressor { encoder })
236 }
237 CompressionMethod::Stored => {
238 return Err(SZipError::InvalidFormat(
240 "Stored method not yet implemented".to_string(),
241 ));
242 }
243 };
244
245 self.current_entry = Some(CurrentEntry {
246 name: name.to_string(),
247 local_header_offset,
248 encoder,
249 compression_method,
250 });
251
252 Ok(())
253 }
254
255 pub fn write_data(&mut self, data: &[u8]) -> Result<()> {
257 if let Some(ref mut entry) = self.current_entry {
258 let crc_writer = entry.encoder.get_crc_writer_mut();
260 crc_writer.crc.update(data);
261 crc_writer.uncompressed_count += data.len() as u64;
262
263 entry.encoder.write_all(data)?;
265 Ok(())
266 } else {
267 Err(SZipError::InvalidFormat("No entry started".to_string()))
268 }
269 }
270
271 fn finish_current_entry(&mut self) -> Result<()> {
273 if let Some(entry) = self.current_entry.take() {
274 let counting_writer = entry.encoder.finish_compression()?;
276
277 let crc = counting_writer.crc.finalize();
278 let compressed_size = counting_writer.compressed_count;
279 let uncompressed_size = counting_writer.uncompressed_count;
280
281 self.output.write_all(&[0x50, 0x4b, 0x07, 0x08])?;
284 self.output.write_all(&crc.to_le_bytes())?;
285 if compressed_size > u32::MAX as u64 || uncompressed_size > u32::MAX as u64 {
287 self.output.write_all(&compressed_size.to_le_bytes())?;
288 self.output.write_all(&uncompressed_size.to_le_bytes())?;
289 } else {
290 self.output
291 .write_all(&(compressed_size as u32).to_le_bytes())?;
292 self.output
293 .write_all(&(uncompressed_size as u32).to_le_bytes())?;
294 }
295
296 self.entries.push(ZipEntry {
298 name: entry.name,
299 local_header_offset: entry.local_header_offset,
300 crc32: crc,
301 compressed_size,
302 uncompressed_size,
303 compression_method: entry.compression_method,
304 });
305 }
306 Ok(())
307 }
308
309 pub fn finish(mut self) -> Result<()> {
311 self.finish_current_entry()?;
313
314 let central_dir_offset = self.output.stream_position()?;
315
316 for entry in &self.entries {
318 self.output.write_all(&[0x50, 0x4b, 0x01, 0x02])?; self.output.write_all(&[20, 0])?; self.output.write_all(&[20, 0])?; self.output.write_all(&[8, 0])?; self.output
323 .write_all(&entry.compression_method.to_le_bytes())?; self.output.write_all(&[0, 0, 0, 0])?; self.output.write_all(&entry.crc32.to_le_bytes())?;
326
327 if entry.compressed_size > u32::MAX as u64 {
329 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
330 } else {
331 self.output
332 .write_all(&(entry.compressed_size as u32).to_le_bytes())?;
333 }
334
335 if entry.uncompressed_size > u32::MAX as u64 {
336 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
337 } else {
338 self.output
339 .write_all(&(entry.uncompressed_size as u32).to_le_bytes())?;
340 }
341
342 self.output
343 .write_all(&(entry.name.len() as u16).to_le_bytes())?;
344
345 let mut extra_field: Vec<u8> = Vec::new();
347 if entry.uncompressed_size > u32::MAX as u64
348 || entry.compressed_size > u32::MAX as u64
349 || entry.local_header_offset > u32::MAX as u64
350 {
351 extra_field.extend_from_slice(&0x0001u16.to_le_bytes());
353 let mut data: Vec<u8> = Vec::new();
355 if entry.uncompressed_size > u32::MAX as u64 {
356 data.extend_from_slice(&entry.uncompressed_size.to_le_bytes());
357 }
358 if entry.compressed_size > u32::MAX as u64 {
359 data.extend_from_slice(&entry.compressed_size.to_le_bytes());
360 }
361 if entry.local_header_offset > u32::MAX as u64 {
362 data.extend_from_slice(&entry.local_header_offset.to_le_bytes());
363 }
364 extra_field.extend_from_slice(&(data.len() as u16).to_le_bytes());
365 extra_field.extend_from_slice(&data);
366 }
367
368 self.output
369 .write_all(&(extra_field.len() as u16).to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u32.to_le_bytes())?; if entry.local_header_offset > u32::MAX as u64 {
377 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
378 } else {
379 self.output
380 .write_all(&(entry.local_header_offset as u32).to_le_bytes())?;
381 }
382
383 self.output.write_all(entry.name.as_bytes())?;
384 if !extra_field.is_empty() {
385 self.output.write_all(&extra_field)?;
386 }
387 }
388
389 let central_dir_size = self.output.stream_position()? - central_dir_offset;
390
391 let need_zip64 = self.entries.len() > u16::MAX as usize
393 || central_dir_size > u32::MAX as u64
394 || central_dir_offset > u32::MAX as u64;
395
396 if need_zip64 {
397 self.output.write_all(&[0x50, 0x4b, 0x06, 0x06])?; let zip64_eocd_size: u64 = 44;
403 self.output.write_all(&zip64_eocd_size.to_le_bytes())?;
404 self.output.write_all(&[20, 0])?;
406 self.output.write_all(&[20, 0])?;
407 self.output.write_all(&0u32.to_le_bytes())?;
409 self.output.write_all(&0u32.to_le_bytes())?;
410 self.output
412 .write_all(&(self.entries.len() as u64).to_le_bytes())?;
413 self.output
415 .write_all(&(self.entries.len() as u64).to_le_bytes())?;
416 self.output.write_all(¢ral_dir_size.to_le_bytes())?;
418 self.output.write_all(¢ral_dir_offset.to_le_bytes())?;
420
421 self.output.write_all(&[0x50, 0x4b, 0x06, 0x07])?; self.output.write_all(&0u32.to_le_bytes())?;
426 let zip64_eocd_pos = central_dir_offset + central_dir_size; self.output.write_all(&zip64_eocd_pos.to_le_bytes())?;
429 self.output.write_all(&0u32.to_le_bytes())?;
431 }
432
433 self.output.write_all(&[0x50, 0x4b, 0x05, 0x06])?;
435 self.output.write_all(&0u16.to_le_bytes())?; self.output.write_all(&0u16.to_le_bytes())?; if self.entries.len() > u16::MAX as usize {
440 self.output.write_all(&0xFFFFu16.to_le_bytes())?;
441 self.output.write_all(&0xFFFFu16.to_le_bytes())?;
442 } else {
443 self.output
444 .write_all(&(self.entries.len() as u16).to_le_bytes())?;
445 self.output
446 .write_all(&(self.entries.len() as u16).to_le_bytes())?;
447 }
448
449 if central_dir_size > u32::MAX as u64 {
451 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
452 } else {
453 self.output
454 .write_all(&(central_dir_size as u32).to_le_bytes())?;
455 }
456
457 if central_dir_offset > u32::MAX as u64 {
458 self.output.write_all(&0xFFFFFFFFu32.to_le_bytes())?;
459 } else {
460 self.output
461 .write_all(&(central_dir_offset as u32).to_le_bytes())?;
462 }
463
464 self.output.write_all(&0u16.to_le_bytes())?; self.output.flush()?;
467 Ok(())
468 }
469}