1use std::{
2 fmt,
3 future::Future,
4 pin::Pin,
5 task::{ready, Context, Poll},
6};
7
8use bytes::Bytes;
9use futures_channel::{mpsc, oneshot};
10use futures_util::{stream::FusedStream, Stream};
11use http::HeaderMap;
12use http_body::{Body, Frame, SizeHint};
13
14use super::{watch, DecodedLength};
15use crate::{proto::http2::ping, Error, Result};
16
17#[must_use = "streams do nothing unless polled"]
22pub struct Incoming {
23 kind: Kind,
24}
25
26enum Kind {
27 H1 {
28 want_tx: watch::Sender,
29 data_rx: mpsc::Receiver<Result<Bytes, Error>>,
30 trailers_rx: oneshot::Receiver<HeaderMap>,
31 content_length: DecodedLength,
32 },
33 H2 {
34 ping: ping::Recorder,
35 recv: http2::RecvStream,
36 content_length: DecodedLength,
37 data_done: bool,
38 },
39 Empty,
40}
41
42#[must_use = "Sender does nothing unless sent on"]
56pub(crate) struct Sender {
57 want_rx: watch::Receiver,
58 data_tx: mpsc::Sender<Result<Bytes, Error>>,
59 trailers_tx: Option<oneshot::Sender<HeaderMap>>,
60}
61
62impl Incoming {
65 #[inline]
66 pub(crate) fn empty() -> Incoming {
67 Incoming { kind: Kind::Empty }
68 }
69
70 pub(crate) fn h1(content_length: DecodedLength, wanter: bool) -> (Sender, Incoming) {
71 let (data_tx, data_rx) = mpsc::channel(0);
72 let (trailers_tx, trailers_rx) = oneshot::channel();
73 let (want_tx, want_rx) = watch::channel(wanter);
76
77 (
78 Sender {
79 want_rx,
80 data_tx,
81 trailers_tx: Some(trailers_tx),
82 },
83 Incoming {
84 kind: Kind::H1 {
85 want_tx,
86 data_rx,
87 trailers_rx,
88 content_length,
89 },
90 },
91 )
92 }
93
94 pub(crate) fn h2(
95 recv: http2::RecvStream,
96 mut content_length: DecodedLength,
97 ping: ping::Recorder,
98 ) -> Self {
99 if !content_length.is_exact() && recv.is_end_stream() {
102 content_length = DecodedLength::ZERO;
103 }
104
105 Incoming {
106 kind: Kind::H2 {
107 ping,
108 recv,
109 content_length,
110 data_done: false,
111 },
112 }
113 }
114}
115
116impl Body for Incoming {
117 type Data = Bytes;
118 type Error = Error;
119
120 fn poll_frame(
121 mut self: Pin<&mut Self>,
122 cx: &mut Context<'_>,
123 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
124 match self.kind {
125 Kind::H1 {
126 ref want_tx,
127 ref mut data_rx,
128 ref mut trailers_rx,
129 ref mut content_length,
130 } => {
131 want_tx.ready();
132
133 if !data_rx.is_terminated() {
134 if let Some(chunk) = ready!(Pin::new(data_rx).poll_next(cx)?) {
135 content_length.sub_if(chunk.len() as u64);
136 return Poll::Ready(Some(Ok(Frame::data(chunk))));
137 }
138 }
139
140 match ready!(Pin::new(trailers_rx).poll(cx)) {
142 Ok(t) => Poll::Ready(Some(Ok(Frame::trailers(t)))),
143 Err(_) => Poll::Ready(None),
144 }
145 }
146 Kind::H2 {
147 ref ping,
148 ref mut recv,
149 ref mut content_length,
150 ref mut data_done,
151 } => {
152 if !*data_done {
153 match ready!(recv.poll_data(cx)) {
154 Some(Ok(bytes)) => {
155 let _ = recv.flow_control().release_capacity(bytes.len());
156 content_length.sub_if(bytes.len() as u64);
157 ping.record_data(bytes.len());
158 return Poll::Ready(Some(Ok(Frame::data(bytes))));
159 }
160 Some(Err(e)) => {
161 if let Some(http2::Reason::NO_ERROR) = e.reason() {
162 return Poll::Ready(None);
166 } else {
167 return Poll::Ready(Some(Err(Error::new_body(e))));
168 }
169 }
170 None => {
171 *data_done = true;
173 }
174 }
175 }
176
177 match ready!(recv.poll_trailers(cx)) {
179 Ok(t) => {
180 ping.record_non_data();
181 Poll::Ready(Ok(t.map(Frame::trailers)).transpose())
182 }
183 Err(e) => {
184 if let Some(http2::Reason::NO_ERROR) = e.reason() {
185 Poll::Ready(None)
189 } else {
190 Poll::Ready(Some(Err(Error::new_h2(e))))
191 }
192 }
193 }
194 }
195 Kind::Empty => Poll::Ready(None),
196 }
197 }
198
199 #[inline]
200 fn is_end_stream(&self) -> bool {
201 match self.kind {
202 Kind::H1 { content_length, .. } => content_length == DecodedLength::ZERO,
203 Kind::H2 { recv: ref h2, .. } => h2.is_end_stream(),
204 Kind::Empty => true,
205 }
206 }
207
208 #[inline]
209 fn size_hint(&self) -> SizeHint {
210 match self.kind {
211 Kind::H1 { content_length, .. } | Kind::H2 { content_length, .. } => content_length
212 .into_opt()
213 .map_or_else(SizeHint::default, SizeHint::with_exact),
214 Kind::Empty => SizeHint::with_exact(0),
215 }
216 }
217}
218
219impl fmt::Debug for Incoming {
220 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221 let mut builder = f.debug_tuple(stringify!(Incoming));
222 match self.kind {
223 Kind::Empty => builder.field(&stringify!(Empty)),
224 _ => builder.field(&stringify!(Streaming)),
225 };
226 builder.finish()
227 }
228}
229
230impl Sender {
233 #[inline]
235 pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
236 ready!(self.want_rx.poll_ready(cx)?);
238 self.data_tx.poll_ready(cx).map_err(|_| Error::new_closed())
239 }
240
241 #[inline]
248 pub(crate) fn send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> {
249 self.data_tx
250 .try_send(Ok(chunk))
251 .map_err(|err| err.into_inner().expect("just sent Ok"))
252 }
253
254 #[inline]
261 pub(crate) fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Option<HeaderMap>> {
262 self.trailers_tx
263 .take()
264 .ok_or(None)?
265 .send(trailers)
266 .map_err(Some)
267 }
268
269 #[inline]
271 pub(crate) fn send_error(&mut self, err: Error) {
272 let _ = self.data_tx.clone().try_send(Err(err));
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use std::{mem, task::Poll};
280
281 use http_body_util::BodyExt;
282
283 use super::{Body, DecodedLength, Error, Incoming, Result, Sender, SizeHint};
284
285 impl Incoming {
286 pub(crate) fn channel() -> (Sender, Incoming) {
290 Self::h1(DecodedLength::CHUNKED, false)
291 }
292 }
293
294 impl Sender {
295 async fn ready(&mut self) -> Result<()> {
296 std::future::poll_fn(|cx| self.poll_ready(cx)).await
297 }
298
299 pub(crate) fn abort(mut self) {
300 self.send_error(Error::new_body_write_aborted());
301 }
302 }
303
304 #[test]
305 fn test_size_of() {
306 let body_size = mem::size_of::<Incoming>();
310 let body_expected_size = mem::size_of::<u64>() * 5;
311 assert!(
312 body_size <= body_expected_size,
313 "Body size = {body_size} <= {body_expected_size}",
314 );
315
316 assert_eq!(
319 mem::size_of::<Sender>(),
320 mem::size_of::<usize>() * 5,
321 "Sender"
322 );
323
324 assert_eq!(
325 mem::size_of::<Sender>(),
326 mem::size_of::<Option<Sender>>(),
327 "Option<Sender>"
328 );
329 }
330
331 #[test]
332 fn size_hint() {
333 fn eq(body: Incoming, b: SizeHint, note: &str) {
334 let a = body.size_hint();
335 assert_eq!(a.lower(), b.lower(), "lower for {note:?}");
336 assert_eq!(a.upper(), b.upper(), "upper for {note:?}");
337 }
338
339 eq(Incoming::empty(), SizeHint::with_exact(0), "empty");
340
341 eq(Incoming::channel().1, SizeHint::new(), "channel");
342
343 eq(
344 Incoming::h1(DecodedLength::new(4), false).1,
345 SizeHint::with_exact(4),
346 "channel with length",
347 );
348 }
349
350 #[tokio::test]
351 async fn channel_abort() {
352 let (tx, mut rx) = Incoming::channel();
353
354 tx.abort();
355
356 let err = rx.frame().await.unwrap().unwrap_err();
357 assert!(err.is_body_write_aborted(), "{err:?}");
358 }
359
360 #[tokio::test]
361 async fn channel_abort_when_buffer_is_full() {
362 let (mut tx, mut rx) = Incoming::channel();
363
364 tx.send_data("chunk 1".into()).expect("send 1");
365 tx.abort();
367
368 let chunk1 = rx
369 .frame()
370 .await
371 .expect("item 1")
372 .expect("chunk 1")
373 .into_data()
374 .unwrap();
375 assert_eq!(chunk1, "chunk 1");
376
377 let err = rx.frame().await.unwrap().unwrap_err();
378 assert!(err.is_body_write_aborted(), "{err:?}");
379 }
380
381 #[test]
382 fn channel_buffers_one() {
383 let (mut tx, _rx) = Incoming::channel();
384
385 tx.send_data("chunk 1".into()).expect("send 1");
386
387 let chunk2 = tx.send_data("chunk 2".into()).expect_err("send 2");
389 assert_eq!(chunk2, "chunk 2");
390 }
391
392 #[tokio::test]
393 async fn channel_empty() {
394 let (_, mut rx) = Incoming::channel();
395 assert!(rx.frame().await.is_none());
396 }
397
398 #[test]
399 fn channel_ready() {
400 let (mut tx, _rx) = Incoming::h1(DecodedLength::CHUNKED, false);
401
402 let mut tx_ready = tokio_test::task::spawn(tx.ready());
403
404 assert!(tx_ready.poll().is_ready(), "tx is ready immediately");
405 }
406
407 #[test]
408 fn channel_wanter() {
409 let (mut tx, mut rx) = Incoming::h1(DecodedLength::CHUNKED, true);
410
411 let mut tx_ready = tokio_test::task::spawn(tx.ready());
412 let mut rx_data = tokio_test::task::spawn(rx.frame());
413
414 assert!(
415 tx_ready.poll().is_pending(),
416 "tx isn't ready before rx has been polled"
417 );
418
419 assert!(rx_data.poll().is_pending(), "poll rx.data");
420 assert!(tx_ready.is_woken(), "rx poll wakes tx");
421
422 assert!(
423 tx_ready.poll().is_ready(),
424 "tx is ready after rx has been polled"
425 );
426 }
427
428 #[test]
429 fn channel_notices_closure() {
430 let (mut tx, rx) = Incoming::h1(DecodedLength::CHUNKED, true);
431
432 let mut tx_ready = tokio_test::task::spawn(tx.ready());
433
434 assert!(
435 tx_ready.poll().is_pending(),
436 "tx isn't ready before rx has been polled"
437 );
438
439 drop(rx);
440 assert!(tx_ready.is_woken(), "dropping rx wakes tx");
441
442 match tx_ready.poll() {
443 Poll::Ready(Err(ref e)) if e.is_closed() => (),
444 unexpected => panic!("tx poll ready unexpected: {unexpected:?}"),
445 }
446 }
447}