zstd_framed/
async_writer.rs

1use crate::encoder::{ZstdFramedEncoder, ZstdFramedEncoderSeekTableConfig};
2
3pin_project_lite::pin_project! {
4    /// A writer that writes a compressed zstd stream to the underlying writer. Works as either a `tokio` or `futures` writer if
5    /// their respective features are enabled.
6    ///
7    /// The underlying writer `W` should implement the following traits:
8    ///
9    /// - `tokio`
10    ///   - [`tokio::io::AsyncWrite`] (required for [`tokio::io::AsyncWrite`] impl)
11    /// - `futures`
12    ///   - [`futures::AsyncWrite`] (required for [`futures::AsyncWrite`] impl)
13    ///
14    /// For sync I/O support, see [`crate::ZstdWriter`].
15    ///
16    /// ## Construction
17    ///
18    /// Create a builder using [`AsyncZstdWriter::builder`]. See [`ZstdWriterBuilder`]
19    /// for builder options. Call [`ZstdWriterBuilder::build`] to build the
20    /// [`AsyncZstdWriter`] instance.
21    ///
22    /// ```
23    /// # #[cfg(feature = "tokio")]
24    /// # #[tokio::main]
25    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
26    /// # use tokio::io::AsyncWriteExt as _;
27    /// # let compressed_file = vec![];
28    /// // Tokio example
29    /// let mut writer = zstd_framed::AsyncZstdWriter::builder(compressed_file)
30    ///     .with_compression_level(3) // Set custom compression level
31    ///     .with_seek_table(1024 * 1024) // Write zstd seekable format table
32    ///     .build()?;
33    ///
34    /// // ...
35    ///
36    /// writer.shutdown().await?; // Shut down the writer
37    /// # Ok(())
38    /// # }
39    /// # #[cfg(not(feature = "tokio"))]
40    /// # fn main() { }
41    /// ```
42    ///
43    /// ```
44    /// # #[cfg(feature = "futures")]
45    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
46    /// # use futures::io::AsyncWriteExt as _;
47    /// # futures::executor::block_on(async {
48    /// # let compressed_file = vec![];
49    /// // futures example
50    /// let mut writer = zstd_framed::AsyncZstdWriter::builder(compressed_file)
51    ///     .with_compression_level(3) // Set custom compression level
52    ///     .with_seek_table(1024 * 1024) // Write zstd seekable format table
53    ///     .build()?;
54    ///
55    /// // ...
56    ///
57    /// writer.close().await?; // Close the writer
58    /// # Ok(())
59    /// # })
60    /// # }
61    /// # #[cfg(not(feature = "futures"))]
62    /// # fn main() { }
63    /// ```
64    ///
65    /// ## Writing multiple frames
66    ///
67    /// To allow for efficient seeking (e.g. when using [`ZstdReaderBuilder::with_seek_table`](crate::async_reader::ZstdReaderBuilder::with_seek_table)),
68    /// you can write multiple zstd frames to the underlying writer. If the
69    /// [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table) option is
70    /// given during construction, multiple frames will be created automatically
71    /// to fit within the given `max_frame_size`.
72    ///
73    /// Alternatively, you can use [`AsyncZstdWriter::finish_frame()`] to explicitly
74    /// split the underlying stream into multiple frames. [`.finish_frame()`](ZstdWriter::finish_frame)
75    /// can be used even when not using the [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table)
76    /// option (but note the seek table will only be written when using
77    /// [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table)).
78    ///
79    /// ## Clean shutdown
80    ///
81    /// To ensure the writer shuts down cleanly (including flushing any in-memory
82    /// buffers and writing the seek table if enabled with [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table)),
83    /// make sure to call the Tokio [`.shutdown()`](tokio::io::AsyncWriteExt::shutdown)
84    /// method or the or futures [`.close()`](`futures::io::AsyncWriteExt::close`) method!
85    pub struct AsyncZstdWriter<'dict, W> {
86        #[pin]
87        writer: W,
88        encoder: ZstdFramedEncoder<'dict>,
89        buffer: crate::buffer::FixedBuffer<Vec<u8>> ,
90    }
91}
92
93impl<W> AsyncZstdWriter<'_, W> {
94    pub fn builder(writer: W) -> ZstdWriterBuilder<W> {
95        ZstdWriterBuilder::new(writer)
96    }
97
98    pub fn finish_frame(&mut self) -> std::io::Result<()> {
99        self.encoder.finish_frame(&mut self.buffer)?;
100
101        Ok(())
102    }
103
104    /// Write all uncommitted buffered data to the underlying writer. After
105    /// returning `Ok(_)``, `self.buffer` will be empty.
106    #[cfg(feature = "tokio")]
107    fn flush_uncommitted_tokio(
108        self: std::pin::Pin<&mut Self>,
109        cx: &mut std::task::Context<'_>,
110    ) -> std::task::Poll<Result<(), std::io::Error>>
111    where
112        W: tokio::io::AsyncWrite,
113    {
114        use crate::buffer::Buffer as _;
115
116        let mut this = self.project();
117
118        loop {
119            // Get the uncommitted data to write
120            let uncommitted = this.buffer.uncommitted();
121            if uncommitted.is_empty() {
122                // If there's no uncommitted data, we're done
123                return std::task::Poll::Ready(Ok(()));
124            }
125
126            // Write the data to the underlying writer, and record it
127            // as committed
128            let committed = ready!(this.writer.as_mut().poll_write(cx, uncommitted))?;
129            this.buffer.commit(committed);
130
131            if committed == 0 {
132                // The underlying reader didn't accept any more of our data
133
134                return std::task::Poll::Ready(Err(std::io::Error::new(
135                    std::io::ErrorKind::WriteZero,
136                    "failed to write buffered data",
137                )));
138            }
139        }
140    }
141
142    /// Write all uncommitted buffered data to the underlying writer. After
143    /// returning `Ok(_)`, `self.buffer` will be empty.
144    #[cfg(feature = "futures")]
145    fn flush_uncommitted_futures(
146        self: std::pin::Pin<&mut Self>,
147        cx: &mut std::task::Context<'_>,
148    ) -> std::task::Poll<Result<(), std::io::Error>>
149    where
150        W: futures::AsyncWrite,
151    {
152        use crate::buffer::Buffer as _;
153
154        let mut this = self.project();
155
156        loop {
157            // Get the uncommitted data to write
158            let uncommitted = this.buffer.uncommitted();
159            if uncommitted.is_empty() {
160                // If there's no uncommitted data, we're done
161                return std::task::Poll::Ready(Ok(()));
162            }
163
164            // Write the data to the underlying writer, and record it
165            // as committed
166            let committed = ready!(this.writer.as_mut().poll_write(cx, uncommitted))?;
167            this.buffer.commit(committed);
168
169            if committed == 0 {
170                // The underlying reader didn't accept any more of our data
171
172                return std::task::Poll::Ready(Err(std::io::Error::new(
173                    std::io::ErrorKind::WriteZero,
174                    "failed to write buffered data",
175                )));
176            }
177        }
178    }
179}
180
181#[cfg(feature = "tokio")]
182impl<W> tokio::io::AsyncWrite for AsyncZstdWriter<'_, W>
183where
184    W: tokio::io::AsyncWrite + Unpin,
185{
186    fn poll_write(
187        mut self: std::pin::Pin<&mut Self>,
188        cx: &mut std::task::Context<'_>,
189        data: &[u8],
190    ) -> std::task::Poll<std::io::Result<usize>> {
191        loop {
192            // Write all buffered data
193            ready!(self.as_mut().flush_uncommitted_tokio(cx))?;
194
195            let this = self.as_mut().project();
196
197            // Encode the newly-written data
198            let outcome = this.encoder.encode(data, this.buffer)?;
199
200            match outcome {
201                crate::ZstdOutcome::HasMore { .. } => {
202                    // The encoder has more to do before data can be encoded
203                }
204                crate::ZstdOutcome::Complete(consumed) => {
205                    // We've now encoded some data to the buffer, so we're done
206                    return std::task::Poll::Ready(Ok(consumed));
207                }
208            }
209        }
210    }
211
212    fn poll_flush(
213        mut self: std::pin::Pin<&mut Self>,
214        cx: &mut std::task::Context<'_>,
215    ) -> std::task::Poll<std::io::Result<()>> {
216        loop {
217            // Write all buffered data
218            ready!(self.as_mut().flush_uncommitted_tokio(cx))?;
219
220            let this = self.as_mut().project();
221
222            // Flush any data from the encoder to the interal buffer
223            let outcome = this.encoder.flush(this.buffer)?;
224
225            match outcome {
226                crate::ZstdOutcome::HasMore { .. } => {
227                    // zstd still has more data to flush, so loop again
228                }
229                crate::ZstdOutcome::Complete(_) => {
230                    // No more data from the encoder
231                    break;
232                }
233            }
234        }
235
236        // Write any newly buffered data from the encoder
237        ready!(self.as_mut().flush_uncommitted_tokio(cx))?;
238
239        // Flush the underlying writer
240        let this = self.project();
241        this.writer.poll_flush(cx)
242    }
243
244    fn poll_shutdown(
245        mut self: std::pin::Pin<&mut Self>,
246        cx: &mut std::task::Context<'_>,
247    ) -> std::task::Poll<std::io::Result<()>> {
248        loop {
249            // Flush any uncommitted data
250            ready!(self.as_mut().flush_uncommitted_tokio(cx))?;
251
252            let this = self.as_mut().project();
253
254            // Shut down the encoder
255            let outcome = this.encoder.shutdown(this.buffer)?;
256
257            match outcome {
258                crate::ZstdOutcome::HasMore { .. } => {
259                    // Encoder still has more to write, so keep looping
260                }
261                crate::ZstdOutcome::Complete(_) => {
262                    // Encoder has nothing else to do, so we're done
263                    break;
264                }
265            }
266        }
267
268        // Flush any final data from the encoder
269        ready!(self.as_mut().flush_uncommitted_tokio(cx))?;
270
271        // Shut down the underlying writer
272        let this = self.project();
273        this.writer.poll_shutdown(cx)
274    }
275}
276
277#[cfg(feature = "futures")]
278impl<W> futures::AsyncWrite for AsyncZstdWriter<'_, W>
279where
280    W: futures::AsyncWrite + Unpin,
281{
282    fn poll_write(
283        mut self: std::pin::Pin<&mut Self>,
284        cx: &mut std::task::Context<'_>,
285        data: &[u8],
286    ) -> std::task::Poll<std::io::Result<usize>> {
287        loop {
288            // Write all buffered data
289            ready!(self.as_mut().flush_uncommitted_futures(cx))?;
290
291            let this = self.as_mut().project();
292
293            // Encode the newly-written data
294            let outcome = this.encoder.encode(data, this.buffer)?;
295
296            match outcome {
297                crate::ZstdOutcome::HasMore { .. } => {
298                    // The encoder has more to do before data can be encoded
299                }
300                crate::ZstdOutcome::Complete(consumed) => {
301                    // We've now encoded some data to the buffer, so we're done
302                    return std::task::Poll::Ready(Ok(consumed));
303                }
304            }
305        }
306    }
307
308    fn poll_flush(
309        mut self: std::pin::Pin<&mut Self>,
310        cx: &mut std::task::Context<'_>,
311    ) -> std::task::Poll<std::io::Result<()>> {
312        loop {
313            // Write all buffered data
314            ready!(self.as_mut().flush_uncommitted_futures(cx))?;
315
316            let this = self.as_mut().project();
317
318            // Flush any data from the encoder to the interal buffer
319            let outcome = this.encoder.flush(this.buffer)?;
320
321            match outcome {
322                crate::ZstdOutcome::HasMore { .. } => {
323                    // zstd still has more data to flush, so loop again
324                }
325                crate::ZstdOutcome::Complete(_) => {
326                    // No more data from the encoder
327                    break;
328                }
329            }
330        }
331
332        // Write any newly buffered data from the encoder
333        ready!(self.as_mut().flush_uncommitted_futures(cx))?;
334
335        // Flush the underlying writer
336        let this = self.project();
337        this.writer.poll_flush(cx)
338    }
339
340    fn poll_close(
341        mut self: std::pin::Pin<&mut Self>,
342        cx: &mut std::task::Context<'_>,
343    ) -> std::task::Poll<std::io::Result<()>> {
344        loop {
345            // Flush any uncommitted data
346            ready!(self.as_mut().flush_uncommitted_futures(cx))?;
347
348            let this = self.as_mut().project();
349
350            // Shut down the encoder
351            let outcome = this.encoder.shutdown(this.buffer)?;
352
353            match outcome {
354                crate::ZstdOutcome::HasMore { .. } => {
355                    // Encoder still has more to write, so keep looping
356                }
357                crate::ZstdOutcome::Complete(_) => {
358                    // Encoder has nothing else to do, so we're done
359                    break;
360                }
361            }
362        }
363
364        // Flush any final data from the encoder
365        ready!(self.as_mut().flush_uncommitted_futures(cx))?;
366
367        // Close the underlying writer
368        let this = self.project();
369        this.writer.poll_close(cx)
370    }
371}
372
373/// A builder that builds an [`AsyncZstdWriter`] from the provided writer.
374pub struct ZstdWriterBuilder<W> {
375    writer: W,
376    compression_level: i32,
377    seek_table_config: Option<ZstdFramedEncoderSeekTableConfig>,
378}
379
380impl<W> ZstdWriterBuilder<W> {
381    fn new(writer: W) -> Self {
382        Self {
383            writer,
384            compression_level: 0,
385            seek_table_config: None,
386        }
387    }
388
389    /// Set the zstd compression level.
390    pub fn with_compression_level(mut self, level: i32) -> Self {
391        self.compression_level = level;
392        self
393    }
394
395    /// Write the stream using the [zstd seekable format].
396    ///
397    /// Once the current zstd frame reaches a decompressed size of
398    /// `max_frame_size`, a new frame will automatically be started. When
399    /// the writer is cleanly shut down, a final frame containing a seek
400    /// table will be written to the end of the writer. This seek table can
401    /// be used to efficiently seek through the file, such as by using
402    /// [crate::table::read_seek_table] (or async equivalent) along with
403    /// [`ZstdReaderBuilder::with_seek_table`](crate::async_reader::ZstdReaderBuilder::with_seek_table).
404    ///
405    /// [zstd seekable format]: https://github.com/facebook/zstd/tree/51eb7daf39c8e8a7c338ba214a9d4e2a6a086826/contrib/seekable_format
406    pub fn with_seek_table(mut self, max_frame_size: u32) -> Self {
407        assert!(max_frame_size > 0, "max frame size must be greater than 0");
408
409        self.seek_table_config = Some(ZstdFramedEncoderSeekTableConfig { max_frame_size });
410        self
411    }
412
413    /// Build the writer.
414    pub fn build(self) -> std::io::Result<AsyncZstdWriter<'static, W>> {
415        let zstd_encoder = zstd::stream::raw::Encoder::new(self.compression_level)?;
416        let buffer = crate::buffer::FixedBuffer::new(vec![0; zstd::zstd_safe::CCtx::out_size()]);
417        let encoder = ZstdFramedEncoder::new(zstd_encoder, self.seek_table_config);
418
419        Ok(AsyncZstdWriter {
420            writer: self.writer,
421            encoder,
422            buffer,
423        })
424    }
425}