zstd_framed/async_writer.rs
1use crate::encoder::{ZstdFramedEncoder, ZstdFramedEncoderSeekTableConfig};
2
3pin_project_lite::pin_project! {
4 /// A writer that writes a compressed zstd stream to the underlying writer. Works as either a `tokio` or `futures` writer if
5 /// their respective features are enabled.
6 ///
7 /// The underlying writer `W` should implement the following traits:
8 ///
9 /// - `tokio`
10 /// - [`tokio::io::AsyncWrite`] (required for [`tokio::io::AsyncWrite`] impl)
11 /// - `futures`
12 /// - [`futures::AsyncWrite`] (required for [`futures::AsyncWrite`] impl)
13 ///
14 /// For sync I/O support, see [`crate::ZstdWriter`].
15 ///
16 /// ## Construction
17 ///
18 /// Create a builder using [`AsyncZstdWriter::builder`]. See [`ZstdWriterBuilder`]
19 /// for builder options. Call [`ZstdWriterBuilder::build`] to build the
20 /// [`AsyncZstdWriter`] instance.
21 ///
22 /// ```
23 /// # #[cfg(feature = "tokio")]
24 /// # #[tokio::main]
25 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
26 /// # use tokio::io::AsyncWriteExt as _;
27 /// # let compressed_file = vec![];
28 /// // Tokio example
29 /// let mut writer = zstd_framed::AsyncZstdWriter::builder(compressed_file)
30 /// .with_compression_level(3) // Set custom compression level
31 /// .with_seek_table(1024 * 1024) // Write zstd seekable format table
32 /// .build()?;
33 ///
34 /// // ...
35 ///
36 /// writer.shutdown().await?; // Shut down the writer
37 /// # Ok(())
38 /// # }
39 /// # #[cfg(not(feature = "tokio"))]
40 /// # fn main() { }
41 /// ```
42 ///
43 /// ```
44 /// # #[cfg(feature = "futures")]
45 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
46 /// # use futures::io::AsyncWriteExt as _;
47 /// # futures::executor::block_on(async {
48 /// # let compressed_file = vec![];
49 /// // futures example
50 /// let mut writer = zstd_framed::AsyncZstdWriter::builder(compressed_file)
51 /// .with_compression_level(3) // Set custom compression level
52 /// .with_seek_table(1024 * 1024) // Write zstd seekable format table
53 /// .build()?;
54 ///
55 /// // ...
56 ///
57 /// writer.close().await?; // Close the writer
58 /// # Ok(())
59 /// # })
60 /// # }
61 /// # #[cfg(not(feature = "futures"))]
62 /// # fn main() { }
63 /// ```
64 ///
65 /// ## Writing multiple frames
66 ///
67 /// To allow for efficient seeking (e.g. when using [`ZstdReaderBuilder::with_seek_table`](crate::async_reader::ZstdReaderBuilder::with_seek_table)),
68 /// you can write multiple zstd frames to the underlying writer. If the
69 /// [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table) option is
70 /// given during construction, multiple frames will be created automatically
71 /// to fit within the given `max_frame_size`.
72 ///
73 /// Alternatively, you can use [`AsyncZstdWriter::finish_frame()`] to explicitly
74 /// split the underlying stream into multiple frames. [`.finish_frame()`](ZstdWriter::finish_frame)
75 /// can be used even when not using the [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table)
76 /// option (but note the seek table will only be written when using
77 /// [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table)).
78 ///
79 /// ## Clean shutdown
80 ///
81 /// To ensure the writer shuts down cleanly (including flushing any in-memory
82 /// buffers and writing the seek table if enabled with [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table)),
83 /// make sure to call the Tokio [`.shutdown()`](tokio::io::AsyncWriteExt::shutdown)
84 /// method or the or futures [`.close()`](`futures::io::AsyncWriteExt::close`) method!
85 pub struct AsyncZstdWriter<'dict, W> {
86 #[pin]
87 writer: W,
88 encoder: ZstdFramedEncoder<'dict>,
89 buffer: crate::buffer::FixedBuffer<Vec<u8>> ,
90 }
91}
92
93impl<W> AsyncZstdWriter<'_, W> {
94 pub fn builder(writer: W) -> ZstdWriterBuilder<W> {
95 ZstdWriterBuilder::new(writer)
96 }
97
98 pub fn finish_frame(&mut self) -> std::io::Result<()> {
99 self.encoder.finish_frame(&mut self.buffer)?;
100
101 Ok(())
102 }
103
104 /// Write all uncommitted buffered data to the underlying writer. After
105 /// returning `Ok(_)``, `self.buffer` will be empty.
106 #[cfg(feature = "tokio")]
107 fn flush_uncommitted_tokio(
108 self: std::pin::Pin<&mut Self>,
109 cx: &mut std::task::Context<'_>,
110 ) -> std::task::Poll<Result<(), std::io::Error>>
111 where
112 W: tokio::io::AsyncWrite,
113 {
114 use crate::buffer::Buffer as _;
115
116 let mut this = self.project();
117
118 loop {
119 // Get the uncommitted data to write
120 let uncommitted = this.buffer.uncommitted();
121 if uncommitted.is_empty() {
122 // If there's no uncommitted data, we're done
123 return std::task::Poll::Ready(Ok(()));
124 }
125
126 // Write the data to the underlying writer, and record it
127 // as committed
128 let committed = ready!(this.writer.as_mut().poll_write(cx, uncommitted))?;
129 this.buffer.commit(committed);
130
131 if committed == 0 {
132 // The underlying reader didn't accept any more of our data
133
134 return std::task::Poll::Ready(Err(std::io::Error::new(
135 std::io::ErrorKind::WriteZero,
136 "failed to write buffered data",
137 )));
138 }
139 }
140 }
141
142 /// Write all uncommitted buffered data to the underlying writer. After
143 /// returning `Ok(_)`, `self.buffer` will be empty.
144 #[cfg(feature = "futures")]
145 fn flush_uncommitted_futures(
146 self: std::pin::Pin<&mut Self>,
147 cx: &mut std::task::Context<'_>,
148 ) -> std::task::Poll<Result<(), std::io::Error>>
149 where
150 W: futures::AsyncWrite,
151 {
152 use crate::buffer::Buffer as _;
153
154 let mut this = self.project();
155
156 loop {
157 // Get the uncommitted data to write
158 let uncommitted = this.buffer.uncommitted();
159 if uncommitted.is_empty() {
160 // If there's no uncommitted data, we're done
161 return std::task::Poll::Ready(Ok(()));
162 }
163
164 // Write the data to the underlying writer, and record it
165 // as committed
166 let committed = ready!(this.writer.as_mut().poll_write(cx, uncommitted))?;
167 this.buffer.commit(committed);
168
169 if committed == 0 {
170 // The underlying reader didn't accept any more of our data
171
172 return std::task::Poll::Ready(Err(std::io::Error::new(
173 std::io::ErrorKind::WriteZero,
174 "failed to write buffered data",
175 )));
176 }
177 }
178 }
179}
180
181#[cfg(feature = "tokio")]
182impl<W> tokio::io::AsyncWrite for AsyncZstdWriter<'_, W>
183where
184 W: tokio::io::AsyncWrite + Unpin,
185{
186 fn poll_write(
187 mut self: std::pin::Pin<&mut Self>,
188 cx: &mut std::task::Context<'_>,
189 data: &[u8],
190 ) -> std::task::Poll<std::io::Result<usize>> {
191 loop {
192 // Write all buffered data
193 ready!(self.as_mut().flush_uncommitted_tokio(cx))?;
194
195 let this = self.as_mut().project();
196
197 // Encode the newly-written data
198 let outcome = this.encoder.encode(data, this.buffer)?;
199
200 match outcome {
201 crate::ZstdOutcome::HasMore { .. } => {
202 // The encoder has more to do before data can be encoded
203 }
204 crate::ZstdOutcome::Complete(consumed) => {
205 // We've now encoded some data to the buffer, so we're done
206 return std::task::Poll::Ready(Ok(consumed));
207 }
208 }
209 }
210 }
211
212 fn poll_flush(
213 mut self: std::pin::Pin<&mut Self>,
214 cx: &mut std::task::Context<'_>,
215 ) -> std::task::Poll<std::io::Result<()>> {
216 loop {
217 // Write all buffered data
218 ready!(self.as_mut().flush_uncommitted_tokio(cx))?;
219
220 let this = self.as_mut().project();
221
222 // Flush any data from the encoder to the interal buffer
223 let outcome = this.encoder.flush(this.buffer)?;
224
225 match outcome {
226 crate::ZstdOutcome::HasMore { .. } => {
227 // zstd still has more data to flush, so loop again
228 }
229 crate::ZstdOutcome::Complete(_) => {
230 // No more data from the encoder
231 break;
232 }
233 }
234 }
235
236 // Write any newly buffered data from the encoder
237 ready!(self.as_mut().flush_uncommitted_tokio(cx))?;
238
239 // Flush the underlying writer
240 let this = self.project();
241 this.writer.poll_flush(cx)
242 }
243
244 fn poll_shutdown(
245 mut self: std::pin::Pin<&mut Self>,
246 cx: &mut std::task::Context<'_>,
247 ) -> std::task::Poll<std::io::Result<()>> {
248 loop {
249 // Flush any uncommitted data
250 ready!(self.as_mut().flush_uncommitted_tokio(cx))?;
251
252 let this = self.as_mut().project();
253
254 // Shut down the encoder
255 let outcome = this.encoder.shutdown(this.buffer)?;
256
257 match outcome {
258 crate::ZstdOutcome::HasMore { .. } => {
259 // Encoder still has more to write, so keep looping
260 }
261 crate::ZstdOutcome::Complete(_) => {
262 // Encoder has nothing else to do, so we're done
263 break;
264 }
265 }
266 }
267
268 // Flush any final data from the encoder
269 ready!(self.as_mut().flush_uncommitted_tokio(cx))?;
270
271 // Shut down the underlying writer
272 let this = self.project();
273 this.writer.poll_shutdown(cx)
274 }
275}
276
277#[cfg(feature = "futures")]
278impl<W> futures::AsyncWrite for AsyncZstdWriter<'_, W>
279where
280 W: futures::AsyncWrite + Unpin,
281{
282 fn poll_write(
283 mut self: std::pin::Pin<&mut Self>,
284 cx: &mut std::task::Context<'_>,
285 data: &[u8],
286 ) -> std::task::Poll<std::io::Result<usize>> {
287 loop {
288 // Write all buffered data
289 ready!(self.as_mut().flush_uncommitted_futures(cx))?;
290
291 let this = self.as_mut().project();
292
293 // Encode the newly-written data
294 let outcome = this.encoder.encode(data, this.buffer)?;
295
296 match outcome {
297 crate::ZstdOutcome::HasMore { .. } => {
298 // The encoder has more to do before data can be encoded
299 }
300 crate::ZstdOutcome::Complete(consumed) => {
301 // We've now encoded some data to the buffer, so we're done
302 return std::task::Poll::Ready(Ok(consumed));
303 }
304 }
305 }
306 }
307
308 fn poll_flush(
309 mut self: std::pin::Pin<&mut Self>,
310 cx: &mut std::task::Context<'_>,
311 ) -> std::task::Poll<std::io::Result<()>> {
312 loop {
313 // Write all buffered data
314 ready!(self.as_mut().flush_uncommitted_futures(cx))?;
315
316 let this = self.as_mut().project();
317
318 // Flush any data from the encoder to the interal buffer
319 let outcome = this.encoder.flush(this.buffer)?;
320
321 match outcome {
322 crate::ZstdOutcome::HasMore { .. } => {
323 // zstd still has more data to flush, so loop again
324 }
325 crate::ZstdOutcome::Complete(_) => {
326 // No more data from the encoder
327 break;
328 }
329 }
330 }
331
332 // Write any newly buffered data from the encoder
333 ready!(self.as_mut().flush_uncommitted_futures(cx))?;
334
335 // Flush the underlying writer
336 let this = self.project();
337 this.writer.poll_flush(cx)
338 }
339
340 fn poll_close(
341 mut self: std::pin::Pin<&mut Self>,
342 cx: &mut std::task::Context<'_>,
343 ) -> std::task::Poll<std::io::Result<()>> {
344 loop {
345 // Flush any uncommitted data
346 ready!(self.as_mut().flush_uncommitted_futures(cx))?;
347
348 let this = self.as_mut().project();
349
350 // Shut down the encoder
351 let outcome = this.encoder.shutdown(this.buffer)?;
352
353 match outcome {
354 crate::ZstdOutcome::HasMore { .. } => {
355 // Encoder still has more to write, so keep looping
356 }
357 crate::ZstdOutcome::Complete(_) => {
358 // Encoder has nothing else to do, so we're done
359 break;
360 }
361 }
362 }
363
364 // Flush any final data from the encoder
365 ready!(self.as_mut().flush_uncommitted_futures(cx))?;
366
367 // Close the underlying writer
368 let this = self.project();
369 this.writer.poll_close(cx)
370 }
371}
372
373/// A builder that builds an [`AsyncZstdWriter`] from the provided writer.
374pub struct ZstdWriterBuilder<W> {
375 writer: W,
376 compression_level: i32,
377 seek_table_config: Option<ZstdFramedEncoderSeekTableConfig>,
378}
379
380impl<W> ZstdWriterBuilder<W> {
381 fn new(writer: W) -> Self {
382 Self {
383 writer,
384 compression_level: 0,
385 seek_table_config: None,
386 }
387 }
388
389 /// Set the zstd compression level.
390 pub fn with_compression_level(mut self, level: i32) -> Self {
391 self.compression_level = level;
392 self
393 }
394
395 /// Write the stream using the [zstd seekable format].
396 ///
397 /// Once the current zstd frame reaches a decompressed size of
398 /// `max_frame_size`, a new frame will automatically be started. When
399 /// the writer is cleanly shut down, a final frame containing a seek
400 /// table will be written to the end of the writer. This seek table can
401 /// be used to efficiently seek through the file, such as by using
402 /// [crate::table::read_seek_table] (or async equivalent) along with
403 /// [`ZstdReaderBuilder::with_seek_table`](crate::async_reader::ZstdReaderBuilder::with_seek_table).
404 ///
405 /// [zstd seekable format]: https://github.com/facebook/zstd/tree/51eb7daf39c8e8a7c338ba214a9d4e2a6a086826/contrib/seekable_format
406 pub fn with_seek_table(mut self, max_frame_size: u32) -> Self {
407 assert!(max_frame_size > 0, "max frame size must be greater than 0");
408
409 self.seek_table_config = Some(ZstdFramedEncoderSeekTableConfig { max_frame_size });
410 self
411 }
412
413 /// Build the writer.
414 pub fn build(self) -> std::io::Result<AsyncZstdWriter<'static, W>> {
415 let zstd_encoder = zstd::stream::raw::Encoder::new(self.compression_level)?;
416 let buffer = crate::buffer::FixedBuffer::new(vec![0; zstd::zstd_safe::CCtx::out_size()]);
417 let encoder = ZstdFramedEncoder::new(zstd_encoder, self.seek_table_config);
418
419 Ok(AsyncZstdWriter {
420 writer: self.writer,
421 encoder,
422 buffer,
423 })
424 }
425}