zstd_framed/writer.rs
1use crate::{
2 buffer::Buffer as _,
3 encoder::{ZstdFramedEncoder, ZstdFramedEncoderSeekTableConfig},
4 ZstdOutcome,
5};
6
7/// A writer that writes a compressed zstd stream to the underlying writer.
8///
9/// The underlying writer `W` must implement the following traits:
10///
11/// - [`std::io::Write`]
12///
13/// For async support, see [`crate::AsyncZstdWriter`].
14///
15/// ## Construction
16///
17/// Create a builder using [`ZstdWriter::builder`]. See [`ZstdWriterBuilder`]
18/// for builder options. Call [`ZstdWriterBuilder::build`] to build the
19/// [`ZstdWriter`] instance.
20///
21/// ```
22/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
23/// # let compressed_file = vec![];
24/// let mut writer = zstd_framed::ZstdWriter::builder(compressed_file)
25/// .with_compression_level(3) // Set custom compression level
26/// .with_seek_table(1024 * 1024) // Write zstd seekable format table
27/// .build()?;
28///
29/// // ...
30///
31/// writer.shutdown()?; // Optional, will shut down automatically on drop
32/// # Ok(())
33/// # }
34/// ```
35///
36/// ## Writing multiple frames
37///
38/// To allow for efficient seeking (e.g. when using [`ZstdReaderBuilder::with_seek_table`](crate::reader::ZstdReaderBuilder::with_seek_table)),
39/// you can write multiple zstd frames to the underlying writer. If the
40/// [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table) option is
41/// given during construction, multiple frames will be created automatically
42/// to fit within the given `max_frame_size`.
43///
44/// Alternatively, you can use [`ZstdWriter::finish_frame()`] to explicitly
45/// split the underlying stream into multiple frames. [`.finish_frame()`](ZstdWriter::finish_frame)
46/// can be used even when not using the [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table)
47/// option (but note the seek table will only be written when using
48/// [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table)).
49///
50/// ## Clean shutdown
51///
52/// To ensure the writer shuts down cleanly (including flushing any in-memory
53/// buffers and writing the seek table if enabled with [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table)),
54/// you can explicitly call the [`ZstdWriter::shutdown`] method. This
55/// method will also be called automatically on drop, but errors will
56/// be ignored.
57pub struct ZstdWriter<'dict, W>
58where
59 W: std::io::Write,
60{
61 writer: W,
62 encoder: ZstdFramedEncoder<'dict>,
63 buffer: crate::buffer::FixedBuffer<Vec<u8>>,
64}
65
66impl<W> ZstdWriter<'_, W>
67where
68 W: std::io::Write,
69{
70 /// Create a new zstd writer that writes a compressed zstd stream
71 /// to the underlying writer.
72 pub fn builder(writer: W) -> ZstdWriterBuilder<W> {
73 ZstdWriterBuilder::new(writer)
74 }
75
76 /// Explicitly finish the current zstd frame. If more data is written,
77 /// a new frame will be started.
78 ///
79 /// When using [`ZstdWriterBuilder::with_seek_table`], the just-finished
80 /// frame will be reflected in the resulting seek table.
81 pub fn finish_frame(&mut self) -> std::io::Result<()> {
82 self.encoder.finish_frame(&mut self.buffer)?;
83
84 Ok(())
85 }
86
87 /// Cleanly shut down the zstd stream. This will flush internal buffers,
88 /// finish writing any partially-written frames, and write the
89 /// seek table when using [`ZstdWriterBuilder::with_seek_table`].
90 ///
91 /// This method will be called automatically on drop, although
92 /// any errors will be ignored.
93 pub fn shutdown(&mut self) -> std::io::Result<()> {
94 loop {
95 // Flush any uncommitted data
96 self.flush_uncommitted()?;
97
98 // Shut down the encoder
99 let outcome = self.encoder.shutdown(&mut self.buffer)?;
100
101 match outcome {
102 ZstdOutcome::HasMore { .. } => {
103 // Encoder still has more to write, so keep looping
104 }
105 ZstdOutcome::Complete(_) => {
106 // Encoder has nothing else to do, so we're done
107 break;
108 }
109 }
110 }
111
112 // Flush any final data from the encoder
113 self.flush_uncommitted()?;
114
115 // Flush the underlying writer for good measure
116 self.writer.flush()?;
117
118 Ok(())
119 }
120
121 /// Write all uncommitted buffered data to the underlying writer. After
122 /// returning `Ok(_)`, `self.buffer` will be empty.
123 fn flush_uncommitted(&mut self) -> std::io::Result<()> {
124 loop {
125 // Get the uncommitted data to write
126 let uncommitted = self.buffer.uncommitted();
127 if uncommitted.is_empty() {
128 // If there's no uncommitted data, we're done
129 return Ok(());
130 }
131
132 // Write the data to the underlying writer, and record it
133 // as committed
134 let committed = self.writer.write(uncommitted)?;
135 self.buffer.commit(committed);
136
137 if committed == 0 {
138 // The underlying reader didn't accept any more of our data
139
140 return Err(std::io::Error::new(
141 std::io::ErrorKind::WriteZero,
142 "failed to write buffered data",
143 ));
144 }
145 }
146 }
147}
148
149impl<W> std::io::Write for ZstdWriter<'_, W>
150where
151 W: std::io::Write,
152{
153 fn write(&mut self, data: &[u8]) -> Result<usize, std::io::Error> {
154 loop {
155 // Write all buffered data
156 self.flush_uncommitted()?;
157
158 // Encode the newly-written data
159 let outcome = self.encoder.encode(data, &mut self.buffer)?;
160
161 match outcome {
162 ZstdOutcome::HasMore { .. } => {
163 // The encoder has more to do before data can be encoded
164 }
165 ZstdOutcome::Complete(consumed) => {
166 // We've now encoded some data to the buffer, so we're done
167 return Ok(consumed);
168 }
169 }
170 }
171 }
172
173 fn flush(&mut self) -> std::io::Result<()> {
174 loop {
175 // Write all buffered data
176 self.flush_uncommitted()?;
177
178 // Flush any data from the encoder to the interal buffer
179 let outcome = self.encoder.flush(&mut self.buffer)?;
180
181 match outcome {
182 ZstdOutcome::HasMore { .. } => {
183 // zstd still has more data to flush, so loop again
184 }
185 ZstdOutcome::Complete(_) => {
186 // No more data from the encoder
187 break;
188 }
189 }
190 }
191
192 // Write any newly buffered data from the encoder
193 self.flush_uncommitted()?;
194
195 // Flush the underlying writer
196 self.writer.flush()
197 }
198}
199
200impl<W> Drop for ZstdWriter<'_, W>
201where
202 W: std::io::Write,
203{
204 fn drop(&mut self) {
205 // Try to shut down the writer
206 let _ = self.shutdown();
207 }
208}
209
210/// A builder that builds a [`ZstdWriter`] from the provided writer.
211pub struct ZstdWriterBuilder<W> {
212 writer: W,
213 compression_level: i32,
214 seek_table_config: Option<ZstdFramedEncoderSeekTableConfig>,
215}
216
217impl<W> ZstdWriterBuilder<W> {
218 fn new(writer: W) -> Self {
219 Self {
220 writer,
221 compression_level: 0,
222 seek_table_config: None,
223 }
224 }
225
226 /// Set the zstd compression level.
227 pub fn with_compression_level(mut self, level: i32) -> Self {
228 self.compression_level = level;
229 self
230 }
231
232 /// Write the stream using the [zstd seekable format].
233 ///
234 /// Once the current zstd frame reaches a decompressed size of
235 /// `max_frame_size`, a new frame will automatically be started. When
236 /// the writer is [shut down](ZstdWriter::shutdown), a final frame
237 /// containing a seek table will be written to the end of the writer.
238 /// This seek table can be used to efficiently seek through the file, such
239 /// as by using [crate::table::read_seek_table] along with
240 /// [`ZstdReaderBuilder::with_seek_table`](crate::reader::ZstdReaderBuilder::with_seek_table).
241 ///
242 /// [zstd seekable format]: https://github.com/facebook/zstd/tree/51eb7daf39c8e8a7c338ba214a9d4e2a6a086826/contrib/seekable_format
243 pub fn with_seek_table(mut self, max_frame_size: u32) -> Self {
244 assert!(max_frame_size > 0, "max frame size must be greater than 0");
245
246 self.seek_table_config = Some(ZstdFramedEncoderSeekTableConfig { max_frame_size });
247 self
248 }
249
250 /// Build the writer.
251 pub fn build(self) -> std::io::Result<ZstdWriter<'static, W>>
252 where
253 W: std::io::Write,
254 {
255 let zstd_encoder = zstd::stream::raw::Encoder::new(self.compression_level)?;
256 let buffer = crate::buffer::FixedBuffer::new(vec![0; zstd::zstd_safe::CCtx::out_size()]);
257 let encoder = ZstdFramedEncoder::new(zstd_encoder, self.seek_table_config);
258
259 Ok(ZstdWriter {
260 writer: self.writer,
261 encoder,
262 buffer,
263 })
264 }
265}