tokio_into_sink/
lib.rs

1//! Use an [`AsyncWrite`] as a [`Sink`]`<Item: AsRef<[u8]>`.
2//!
3//! This adapter produces a sink that will write each value passed to it into the underlying writer.
4//! Note that this function consumes the given writer, returning a wrapped version.
5//!
6//! ```
7//! use tokio_into_sink::IntoSinkExt as _;
8//! use futures::{stream, StreamExt as _};
9//! use std::io;
10//!
11//! # tokio::runtime::Builder::new_current_thread().build().unwrap().block_on(async {
12//! let stream = stream::iter(["hello", "world"]).map(io::Result::Ok);
13//! let write = tokio::fs::File::create("/dev/null").await.unwrap();
14//! let sink = write.into_sink();
15//! stream.forward(sink).await.unwrap();
16//! # } ) // block_on
17//! ```
18//!
19//! Ported from [`futures::io::AsyncWriteExt::into_sink`](https://docs.rs/futures/0.3.28/futures/io/trait.AsyncWriteExt.html#method.into_sink).
20
21use std::{
22    io,
23    pin::Pin,
24    task::{ready, Context, Poll},
25};
26
27use futures_sink::Sink;
28use pin_project_lite::pin_project;
29use tokio::io::AsyncWrite;
30
31pub trait IntoSinkExt: AsyncWrite {
32    /// See the [module documentation](mod@self).
33    fn into_sink<Item>(self) -> IntoSink<Self, Item>
34    where
35        Self: Sized;
36}
37
38impl<W> IntoSinkExt for W
39where
40    W: AsyncWrite,
41{
42    fn into_sink<Item>(self) -> IntoSink<Self, Item>
43    where
44        Self: Sized,
45    {
46        IntoSink {
47            writer: self,
48            buffer: None,
49        }
50    }
51}
52
53#[derive(Debug)]
54struct Cursor<T> {
55    offset: usize,
56    inner: T,
57}
58
59pin_project! {
60    /// See the [module documentation](mod@self).
61    #[derive(Debug)]
62    pub struct IntoSink<W, Item> {
63        #[pin]
64        writer: W,
65        buffer: Option<Cursor<Item>>,
66    }
67}
68
69impl<W, Item> IntoSink<W, Item>
70where
71    W: AsyncWrite,
72    Item: AsRef<[u8]>,
73{
74    /// If we have an outstanding block in `buffer` attempt to push it into the writer, does _not_
75    /// flush the writer after it succeeds in pushing the block into it.
76    fn poll_flush_buffer(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
77        let mut this = self.project();
78        if let Some(cursor) = this.buffer {
79            loop {
80                let bytes = cursor.inner.as_ref();
81                let written = ready!(this.writer.as_mut().poll_write(cx, &bytes[cursor.offset..]))?;
82                cursor.offset += written;
83                if cursor.offset == bytes.len() {
84                    break;
85                }
86            }
87        }
88        *this.buffer = None;
89        Poll::Ready(Ok(()))
90    }
91}
92
93impl<W, Item> Sink<Item> for IntoSink<W, Item>
94where
95    W: AsyncWrite,
96    Item: AsRef<[u8]>,
97{
98    type Error = io::Error;
99
100    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
101        ready!(self.poll_flush_buffer(cx))?;
102        Poll::Ready(Ok(()))
103    }
104
105    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
106        debug_assert!(self.buffer.is_none());
107        *self.project().buffer = Some(Cursor {
108            offset: 0,
109            inner: item,
110        });
111        Ok(())
112    }
113
114    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
115        ready!(self.as_mut().poll_flush_buffer(cx))?;
116        ready!(self.project().writer.poll_flush(cx))?;
117        Poll::Ready(Ok(()))
118    }
119
120    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121        ready!(self.as_mut().poll_flush_buffer(cx))?;
122        ready!(self.project().writer.poll_shutdown(cx))?;
123        Poll::Ready(Ok(()))
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    use futures::{executor::block_on, stream, StreamExt as _};
132    use std::io;
133
134    #[test]
135    fn readme() {
136        assert!(
137            std::process::Command::new("cargo")
138                .args(["rdme", "--check"])
139                .output()
140                .expect("couldn't run `cargo rdme`")
141                .status
142                .success(),
143            "README.md is out of date - bless the new version by running `cargo rdme`"
144        )
145    }
146
147    #[test]
148    fn test() {
149        block_on(async {
150            let stream = stream::iter(["hello", "world"]).map(io::Result::Ok);
151            let mut v = vec![];
152            let sink = (&mut v).into_sink();
153            stream.forward(sink).await.unwrap();
154            assert_eq!(v, b"helloworld");
155        })
156    }
157}