Skip to main content

wreq_proto/body/
incoming.rs

1use std::{
2    fmt,
3    future::Future,
4    pin::Pin,
5    task::{ready, Context, Poll},
6};
7
8use bytes::Bytes;
9use http::HeaderMap;
10use http_body::{Body, Frame, SizeHint};
11use tokio::sync::{mpsc, oneshot};
12use tokio_util::sync::PollSender;
13
14use super::{watch, DecodedLength};
15use crate::{proto::http2::ping, Error, Result};
16
17/// A stream of [`Bytes`], used when receiving bodies from the network.
18///
19/// Note that Users should not instantiate this struct directly. When working with the client,
20/// [`Incoming`] is returned to you in responses.
21#[must_use = "streams do nothing unless polled"]
22pub struct Incoming {
23    kind: Kind,
24}
25
26enum Kind {
27    H1 {
28        want_tx: watch::Sender,
29        data_rx: mpsc::Receiver<Result<Bytes, Error>>,
30        trailers_rx: oneshot::Receiver<HeaderMap>,
31        content_length: DecodedLength,
32        data_done: bool,
33    },
34    H2 {
35        ping: ping::Recorder,
36        recv: http2::RecvStream,
37        content_length: DecodedLength,
38        data_done: bool,
39    },
40    Empty,
41}
42
43/// A sender half created through [`Body::channel()`].
44///
45/// Useful when wanting to stream chunks from another thread.
46///
47/// ## Body Closing
48///
49/// Note that the request body will always be closed normally when the sender is dropped (meaning
50/// that the empty terminating chunk will be sent to the remote). If you desire to close the
51/// connection with an incomplete response (e.g. in the case of an error during asynchronous
52/// processing), call the [`Sender::abort()`] method to abort the body in an abnormal fashion.
53///
54/// [`Body::channel()`]: struct.Body.html#method.channel
55/// [`Sender::abort()`]: struct.Sender.html#method.abort
56#[must_use = "Sender does nothing unless sent on"]
57pub(crate) struct Sender {
58    want_rx: watch::Receiver,
59    data_tx: PollSender<Result<Bytes, Error>>,
60    trailers_tx: Option<oneshot::Sender<HeaderMap>>,
61}
62
63// ===== impl Incoming =====
64
65impl Incoming {
66    #[inline]
67    pub(crate) fn empty() -> Incoming {
68        Incoming { kind: Kind::Empty }
69    }
70
71    pub(crate) fn h1(content_length: DecodedLength, wanter: bool) -> (Sender, Incoming) {
72        let (data_tx, data_rx) = mpsc::channel(2);
73        let (trailers_tx, trailers_rx) = oneshot::channel();
74        // If wanter is true, `Sender::poll_ready()` won't becoming ready
75        // until the `Body` has been polled for data once.
76        let (want_tx, want_rx) = watch::channel(wanter);
77
78        (
79            Sender {
80                want_rx,
81                data_tx: PollSender::new(data_tx),
82                trailers_tx: Some(trailers_tx),
83            },
84            Incoming {
85                kind: Kind::H1 {
86                    want_tx,
87                    data_rx,
88                    trailers_rx,
89                    content_length,
90                    data_done: false,
91                },
92            },
93        )
94    }
95
96    pub(crate) fn h2(
97        recv: http2::RecvStream,
98        mut content_length: DecodedLength,
99        ping: ping::Recorder,
100    ) -> Self {
101        // If the stream is already EOS, then the "unknown length" is clearly
102        // actually ZERO.
103        if !content_length.is_exact() && recv.is_end_stream() {
104            content_length = DecodedLength::ZERO;
105        }
106
107        Incoming {
108            kind: Kind::H2 {
109                ping,
110                recv,
111                content_length,
112                data_done: false,
113            },
114        }
115    }
116}
117
118impl Body for Incoming {
119    type Data = Bytes;
120    type Error = Error;
121
122    fn poll_frame(
123        mut self: Pin<&mut Self>,
124        cx: &mut Context<'_>,
125    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
126        match self.kind {
127            Kind::H1 {
128                ref want_tx,
129                ref mut data_rx,
130                ref mut trailers_rx,
131                ref mut content_length,
132                ref mut data_done,
133            } => {
134                want_tx.ready();
135
136                if !*data_done {
137                    match ready!(data_rx.poll_recv(cx)) {
138                        Some(Ok(chunk)) => {
139                            content_length.sub_if(chunk.len() as u64);
140                            return Poll::Ready(Some(Ok(Frame::data(chunk))));
141                        }
142                        Some(Err(err)) => return Poll::Ready(Some(Err(err))),
143                        None => {
144                            // fall through to trailers
145                            *data_done = true;
146                        }
147                    }
148                }
149
150                // check trailers after data is terminated
151                if !trailers_rx.is_terminated() {
152                    if let Ok(trailers) = ready!(Pin::new(trailers_rx).poll(cx)) {
153                        return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
154                    }
155                }
156
157                Poll::Ready(None)
158            }
159            Kind::H2 {
160                ref ping,
161                ref mut recv,
162                ref mut content_length,
163                ref mut data_done,
164            } => {
165                if !*data_done {
166                    match ready!(recv.poll_data(cx)) {
167                        Some(Ok(bytes)) => {
168                            let _ = recv.flow_control().release_capacity(bytes.len());
169                            content_length.sub_if(bytes.len() as u64);
170                            ping.record_data(bytes.len());
171                            return Poll::Ready(Some(Ok(Frame::data(bytes))));
172                        }
173                        Some(Err(e)) => {
174                            if let Some(http2::Reason::NO_ERROR) = e.reason() {
175                                // As mentioned in RFC 7540 Section 8.1, a RST_STREAM with NO_ERROR
176                                // indicates an early response, and should cause the body reading
177                                // to stop, but not fail it:
178                                return Poll::Ready(None);
179                            } else {
180                                return Poll::Ready(Some(Err(Error::new_body(e))));
181                            }
182                        }
183                        None => {
184                            // fall through to trailers
185                            *data_done = true;
186                        }
187                    }
188                }
189
190                // after data, check trailers
191                match ready!(recv.poll_trailers(cx)) {
192                    Ok(t) => {
193                        ping.record_non_data();
194                        Poll::Ready(Ok(t.map(Frame::trailers)).transpose())
195                    }
196                    Err(e) => {
197                        if let Some(http2::Reason::NO_ERROR) = e.reason() {
198                            // Same as above, a RST_STREAM with NO_ERROR indicates an early
199                            // response, and should cause reading the trailers to stop, but
200                            // not fail it:
201                            Poll::Ready(None)
202                        } else {
203                            Poll::Ready(Some(Err(Error::new_h2(e))))
204                        }
205                    }
206                }
207            }
208            Kind::Empty => Poll::Ready(None),
209        }
210    }
211
212    #[inline]
213    fn is_end_stream(&self) -> bool {
214        match self.kind {
215            Kind::H1 { content_length, .. } => content_length == DecodedLength::ZERO,
216            Kind::H2 { recv: ref h2, .. } => h2.is_end_stream(),
217            Kind::Empty => true,
218        }
219    }
220
221    #[inline]
222    fn size_hint(&self) -> SizeHint {
223        match self.kind {
224            Kind::H1 { content_length, .. } | Kind::H2 { content_length, .. } => content_length
225                .into_opt()
226                .map_or_else(SizeHint::default, SizeHint::with_exact),
227            Kind::Empty => SizeHint::with_exact(0),
228        }
229    }
230}
231
232impl fmt::Debug for Incoming {
233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234        let mut builder = f.debug_tuple(stringify!(Incoming));
235        match self.kind {
236            Kind::Empty => builder.field(&stringify!(Empty)),
237            _ => builder.field(&stringify!(Streaming)),
238        };
239        builder.finish()
240    }
241}
242
243// ===== impl Sender =====
244
245impl Sender {
246    /// Check to see if this `Sender` can send more data.
247    #[inline]
248    pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
249        // Check if the receiver end has tried polling for the body yet
250        ready!(self.want_rx.poll_ready(cx)?);
251        self.data_tx
252            .poll_reserve(cx)
253            .map_err(|_| Error::new_closed())
254    }
255
256    /// Send data on this channel.
257    ///
258    /// # Errors
259    ///
260    /// Returns `Err(Bytes)` if the channel could not (currently) accept
261    /// another `Bytes`.
262    ///
263    /// # Panics
264    ///
265    /// If `poll_ready` was not successfully called prior to calling `send_data`, then this method
266    /// will panic.
267    #[inline]
268    pub(crate) fn send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> {
269        self.data_tx.send_item(Ok(chunk)).map_err(|err| {
270            err.into_inner()
271                .expect("value returned")
272                .expect("just sent Ok")
273        })
274    }
275
276    /// Send trailers on this channel.
277    ///
278    /// # Errors
279    ///
280    /// Returns `Err(HeaderMap)` if the channel could not (currently) accept
281    /// another `HeaderMap`.
282    #[inline]
283    pub(crate) fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Option<HeaderMap>> {
284        self.trailers_tx
285            .take()
286            .ok_or(None)?
287            .send(trailers)
288            .map_err(Some)
289    }
290
291    /// Send an error on this channel, which will cause the body stream to end with an error.
292    #[inline]
293    pub(crate) fn send_error(&mut self, err: Error) {
294        self.data_tx
295            .get_ref()
296            .map(|sender| sender.try_send(Err(err)));
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use std::{mem, task::Poll};
303
304    use http_body_util::BodyExt;
305
306    use super::{Body, DecodedLength, Error, Incoming, Result, Sender, SizeHint};
307
308    impl Incoming {
309        /// Create a `Body` stream with an associated sender half.
310        ///
311        /// Useful when wanting to stream chunks from another thread.
312        pub(crate) fn channel() -> (Sender, Incoming) {
313            Self::h1(DecodedLength::CHUNKED, /* wanter = */ false)
314        }
315    }
316
317    impl Sender {
318        async fn ready(&mut self) -> Result<()> {
319            std::future::poll_fn(|cx| self.poll_ready(cx)).await
320        }
321
322        fn abort(mut self) {
323            self.send_error(Error::new_body_write_aborted());
324        }
325    }
326
327    #[test]
328    fn test_size_of() {
329        // These are mostly to help catch *accidentally* increasing
330        // the size by too much.
331
332        let body_size = mem::size_of::<Incoming>();
333        let body_expected_size = mem::size_of::<u64>() * 6;
334        assert!(
335            body_size <= body_expected_size,
336            "Body size = {body_size} <= {body_expected_size}",
337        );
338
339        //assert_eq!(body_size, mem::size_of::<Option<Incoming>>(), "Option<Incoming>");
340
341        assert_eq!(
342            mem::size_of::<Sender>(),
343            mem::size_of::<usize>() * 8,
344            "Sender"
345        );
346
347        assert_eq!(
348            mem::size_of::<Sender>(),
349            mem::size_of::<Option<Sender>>(),
350            "Option<Sender>"
351        );
352    }
353
354    #[test]
355    fn size_hint() {
356        fn eq(body: Incoming, b: SizeHint, note: &str) {
357            let a = body.size_hint();
358            assert_eq!(a.lower(), b.lower(), "lower for {note:?}");
359            assert_eq!(a.upper(), b.upper(), "upper for {note:?}");
360        }
361
362        eq(Incoming::empty(), SizeHint::with_exact(0), "empty");
363
364        eq(Incoming::channel().1, SizeHint::new(), "channel");
365
366        eq(
367            Incoming::h1(DecodedLength::new(4), /* wanter = */ false).1,
368            SizeHint::with_exact(4),
369            "channel with length",
370        );
371    }
372
373    #[tokio::test]
374    async fn channel_abort() {
375        let (tx, mut rx) = Incoming::channel();
376
377        tx.abort();
378
379        let err = rx.frame().await.unwrap().unwrap_err();
380        assert!(err.is_body_write_aborted(), "{err:?}");
381    }
382
383    #[tokio::test]
384    async fn channel_abort_when_buffer_is_full() {
385        let (mut tx, mut rx) = Incoming::channel();
386
387        tx.ready().await.expect("ready");
388        tx.send_data("chunk 1".into()).expect("send 1");
389        // buffer is full, but can still send abort
390        tx.abort();
391
392        let chunk1 = rx
393            .frame()
394            .await
395            .expect("item 1")
396            .expect("chunk 1")
397            .into_data()
398            .unwrap();
399        assert_eq!(chunk1, "chunk 1");
400
401        let err = rx.frame().await.unwrap().unwrap_err();
402        assert!(err.is_body_write_aborted(), "{err:?}");
403    }
404
405    #[tokio::test]
406    async fn channel_buffers_two() {
407        let (mut tx, _rx) = Incoming::channel();
408
409        tx.ready().await.expect("ready");
410        tx.send_data("chunk 1".into()).expect("send 1");
411        tx.ready().await.expect("ready");
412        tx.send_data("chunk 2".into()).expect("send 2");
413
414        // buffer is now full, poll_ready should not be ready
415        let res = tokio::time::timeout(
416            std::time::Duration::from_millis(100),
417            std::future::poll_fn(|cx| tx.poll_ready(cx)),
418        )
419        .await;
420
421        assert!(res.is_err(), "poll_ready unexpectedly became ready");
422    }
423
424    #[tokio::test]
425    async fn channel_empty() {
426        let (_, mut rx) = Incoming::channel();
427        assert!(rx.frame().await.is_none());
428    }
429
430    #[test]
431    fn channel_ready() {
432        let (mut tx, _rx) = Incoming::h1(DecodedLength::CHUNKED, /* wanter = */ false);
433
434        let mut tx_ready = tokio_test::task::spawn(tx.ready());
435
436        assert!(tx_ready.poll().is_ready(), "tx is ready immediately");
437    }
438
439    #[test]
440    fn channel_wanter() {
441        let (mut tx, mut rx) = Incoming::h1(DecodedLength::CHUNKED, /* wanter = */ true);
442
443        let mut tx_ready = tokio_test::task::spawn(tx.ready());
444        let mut rx_data = tokio_test::task::spawn(rx.frame());
445
446        assert!(
447            tx_ready.poll().is_pending(),
448            "tx isn't ready before rx has been polled"
449        );
450
451        assert!(rx_data.poll().is_pending(), "poll rx.data");
452        assert!(tx_ready.is_woken(), "rx poll wakes tx");
453
454        assert!(
455            tx_ready.poll().is_ready(),
456            "tx is ready after rx has been polled"
457        );
458    }
459
460    #[test]
461    fn channel_notices_closure() {
462        let (mut tx, rx) = Incoming::h1(DecodedLength::CHUNKED, /* wanter = */ true);
463
464        let mut tx_ready = tokio_test::task::spawn(tx.ready());
465
466        assert!(
467            tx_ready.poll().is_pending(),
468            "tx isn't ready before rx has been polled"
469        );
470
471        drop(rx);
472        assert!(tx_ready.is_woken(), "dropping rx wakes tx");
473
474        match tx_ready.poll() {
475            Poll::Ready(Err(ref e)) if e.is_closed() => (),
476            unexpected => panic!("tx poll ready unexpected: {unexpected:?}"),
477        }
478    }
479}