stream_body/
body.rs

1use crate::data::StreamData;
2use crate::state::State;
3use async_pipe::{self, PipeReader, PipeWriter};
4use bytes::Bytes;
5use http::{HeaderMap, HeaderValue};
6use http_body::{Body, SizeHint};
7use pin_project_lite::pin_project;
8use std::borrow::Cow;
9use std::marker::Unpin;
10use std::mem::MaybeUninit;
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use std::task::{Context, Poll};
14use tokio::io::{self, AsyncRead};
15
16const DEFAULT_BUF_SIZE: usize = 8 * 1024;
17
18/// An [HttpBody](https://docs.rs/hyper/0.13.4/hyper/body/trait.HttpBody.html) implementation which handles data streaming in an efficient way.
19///
20/// It is similar to [Body](https://docs.rs/hyper/0.13.4/hyper/body/struct.Body.html).
21pub struct StreamBody {
22    inner: Inner,
23}
24
25enum Inner {
26    Once(OnceInner),
27    Channel(ChannelInner),
28}
29
30struct OnceInner {
31    data: Option<Bytes>,
32    reached_eof: bool,
33    state: Arc<Mutex<State>>,
34}
35
36pin_project! {
37    struct ChannelInner {
38        #[pin]
39        reader: PipeReader,
40        buf: Box<[u8]>,
41        len: usize,
42        reached_eof: bool,
43        state: Arc<Mutex<State>>,
44    }
45}
46
47impl StreamBody {
48    /// Creates an empty body.
49    pub fn empty() -> StreamBody {
50        StreamBody {
51            inner: Inner::Once(OnceInner {
52                data: None,
53                reached_eof: true,
54                state: Arc::new(Mutex::new(State {
55                    is_current_stream_data_consumed: true,
56                    waker: None,
57                })),
58            }),
59        }
60    }
61
62    /// Creates a body stream with an associated writer half.
63    ///
64    /// Useful when wanting to stream chunks from another thread.
65    pub fn channel() -> (PipeWriter, StreamBody) {
66        StreamBody::channel_with_capacity(DEFAULT_BUF_SIZE)
67    }
68
69    /// Creates a body stream with an associated writer half having a specific size of internal buffer.
70    ///
71    /// Useful when wanting to stream chunks from another thread.
72    pub fn channel_with_capacity(capacity: usize) -> (PipeWriter, StreamBody) {
73        let (w, r) = async_pipe::pipe();
74
75        let mut buffer = Vec::with_capacity(capacity);
76        unsafe {
77            buffer.set_len(capacity);
78
79            let b = &mut *(&mut buffer[..] as *mut [u8] as *mut [MaybeUninit<u8>]);
80            r.prepare_uninitialized_buffer(b);
81        }
82
83        let body = StreamBody {
84            inner: Inner::Channel(ChannelInner {
85                reader: r,
86                buf: buffer.into_boxed_slice(),
87                len: 0,
88                reached_eof: false,
89                state: Arc::new(Mutex::new(State {
90                    is_current_stream_data_consumed: true,
91                    waker: None,
92                })),
93            }),
94        };
95
96        (w, body)
97    }
98
99    /// A helper method to convert an [AsyncRead](https://docs.rs/tokio/0.2.16/tokio/io/trait.AsyncRead.html) to a `StreamBody`. If there is any error
100    /// thrown during the reading/writing, it will be logged via [log::error!](https://docs.rs/log/0.4.10/log/macro.error.html).
101    pub fn from_reader<R: AsyncRead + Unpin + Send + 'static>(mut r: R) -> StreamBody {
102        let (mut w, body) = StreamBody::channel();
103
104        tokio::spawn(async move {
105            if let Err(err) = io::copy(&mut r, &mut w).await {
106                log::error!(
107                    "{}: StreamBody: Something went wrong while piping the provided reader to the body: {}",
108                    env!("CARGO_PKG_NAME"),
109                    err
110                )
111            }
112        });
113
114        body
115    }
116}
117
118impl Body for StreamBody {
119    type Data = StreamData;
120    type Error = io::Error;
121
122    fn poll_data(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Result<Self::Data, Self::Error>>> {
123        match self.inner {
124            Inner::Once(ref mut inner) => {
125                let mut state;
126                match inner.state.lock() {
127                    Ok(s) => state = s,
128                    Err(err) => {
129                        return Poll::Ready(Some(Err(io::Error::new(
130                            io::ErrorKind::Other,
131                            format!(
132                                "{}: StreamBody [Once Data]: Failed to lock the stream state on poll data: {}",
133                                env!("CARGO_PKG_NAME"),
134                                err
135                            ),
136                        ))));
137                    }
138                }
139
140                if !state.is_current_stream_data_consumed {
141                    state.waker = Some(cx.waker().clone());
142                    return Poll::Pending;
143                }
144
145                if inner.reached_eof {
146                    return Poll::Ready(None);
147                }
148
149                if let Some(ref bytes) = inner.data {
150                    state.is_current_stream_data_consumed = false;
151                    inner.reached_eof = true;
152
153                    let data = StreamData::new(&bytes[..], Arc::clone(&inner.state));
154
155                    return Poll::Ready(Some(Ok(data)));
156                }
157
158                return Poll::Ready(None);
159            }
160            Inner::Channel(ref mut inner) => {
161                let mut inner_me = Pin::new(inner).project();
162
163                let mut state;
164                match inner_me.state.lock() {
165                    Ok(s) => state = s,
166                    Err(err) => {
167                        return Poll::Ready(Some(Err(io::Error::new(
168                            io::ErrorKind::Other,
169                            format!(
170                                "{}: StreamBody [Channel Data]: Failed to lock the stream state on poll data: {}",
171                                env!("CARGO_PKG_NAME"),
172                                err
173                            ),
174                        ))));
175                    }
176                }
177
178                if !state.is_current_stream_data_consumed {
179                    state.waker = Some(cx.waker().clone());
180                    return Poll::Pending;
181                }
182
183                if *inner_me.reached_eof {
184                    return Poll::Ready(None);
185                }
186
187                let buf: &mut Box<[u8]> = &mut inner_me.buf;
188                let poll_status = inner_me.reader.poll_read(cx, &mut buf[..]);
189
190                match poll_status {
191                    Poll::Pending => Poll::Pending,
192                    Poll::Ready(result) => match result {
193                        Ok(read_count) if read_count > 0 => {
194                            state.is_current_stream_data_consumed = false;
195
196                            let data = StreamData::new(&buf[..read_count], Arc::clone(&inner_me.state));
197                            Poll::Ready(Some(Ok(data)))
198                        }
199                        Ok(_) => {
200                            *inner_me.reached_eof = true;
201                            Poll::Ready(None)
202                        }
203                        Err(err) => Poll::Ready(Some(Err(err))),
204                    },
205                }
206            }
207        }
208    }
209
210    fn poll_trailers(
211        self: Pin<&mut Self>,
212        _cx: &mut Context,
213    ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
214        Poll::Ready(Ok(None))
215    }
216
217    fn is_end_stream(&self) -> bool {
218        match self.inner {
219            Inner::Once(ref inner) => inner.reached_eof,
220            Inner::Channel(ref inner) => inner.reached_eof,
221        }
222    }
223
224    fn size_hint(&self) -> SizeHint {
225        match self.inner {
226            Inner::Once(ref inner) => match inner.data {
227                Some(ref data) => SizeHint::with_exact(data.len() as u64),
228                None => SizeHint::with_exact(0),
229            },
230            Inner::Channel(_) => SizeHint::default(),
231        }
232    }
233}
234
235impl From<Bytes> for StreamBody {
236    #[inline]
237    fn from(chunk: Bytes) -> StreamBody {
238        if chunk.is_empty() {
239            StreamBody::empty()
240        } else {
241            StreamBody {
242                inner: Inner::Once(OnceInner {
243                    data: Some(chunk),
244                    reached_eof: false,
245                    state: Arc::new(Mutex::new(State {
246                        is_current_stream_data_consumed: true,
247                        waker: None,
248                    })),
249                }),
250            }
251        }
252    }
253}
254
255impl From<Vec<u8>> for StreamBody {
256    #[inline]
257    fn from(vec: Vec<u8>) -> StreamBody {
258        StreamBody::from(Bytes::from(vec))
259    }
260}
261
262impl From<&'static [u8]> for StreamBody {
263    #[inline]
264    fn from(slice: &'static [u8]) -> StreamBody {
265        StreamBody::from(Bytes::from(slice))
266    }
267}
268
269impl From<Cow<'static, [u8]>> for StreamBody {
270    #[inline]
271    fn from(cow: Cow<'static, [u8]>) -> StreamBody {
272        match cow {
273            Cow::Borrowed(b) => StreamBody::from(b),
274            Cow::Owned(o) => StreamBody::from(o),
275        }
276    }
277}
278
279impl From<String> for StreamBody {
280    #[inline]
281    fn from(s: String) -> StreamBody {
282        StreamBody::from(Bytes::from(s.into_bytes()))
283    }
284}
285
286impl From<&'static str> for StreamBody {
287    #[inline]
288    fn from(slice: &'static str) -> StreamBody {
289        StreamBody::from(Bytes::from(slice.as_bytes()))
290    }
291}
292
293impl From<Cow<'static, str>> for StreamBody {
294    #[inline]
295    fn from(cow: Cow<'static, str>) -> StreamBody {
296        match cow {
297            Cow::Borrowed(b) => StreamBody::from(b),
298            Cow::Owned(o) => StreamBody::from(o),
299        }
300    }
301}