1use std::{
2 fmt,
3 future::Future,
4 pin::Pin,
5 task::{ready, Context, Poll},
6};
7
8use bytes::Bytes;
9use http::HeaderMap;
10use http_body::{Body, Frame, SizeHint};
11use tokio::sync::{mpsc, oneshot};
12use tokio_util::sync::PollSender;
13
14use super::{watch, DecodedLength};
15use crate::{proto::http2::ping, Error, Result};
16
17#[must_use = "streams do nothing unless polled"]
22pub struct Incoming {
23 kind: Kind,
24}
25
26enum Kind {
27 H1 {
28 want_tx: watch::Sender,
29 data_rx: mpsc::Receiver<Result<Bytes, Error>>,
30 trailers_rx: oneshot::Receiver<HeaderMap>,
31 content_length: DecodedLength,
32 data_done: bool,
33 },
34 H2 {
35 ping: ping::Recorder,
36 recv: http2::RecvStream,
37 content_length: DecodedLength,
38 data_done: bool,
39 },
40 Empty,
41}
42
43#[must_use = "Sender does nothing unless sent on"]
57pub(crate) struct Sender {
58 want_rx: watch::Receiver,
59 data_tx: PollSender<Result<Bytes, Error>>,
60 trailers_tx: Option<oneshot::Sender<HeaderMap>>,
61}
62
63impl Incoming {
66 #[inline]
67 pub(crate) fn empty() -> Incoming {
68 Incoming { kind: Kind::Empty }
69 }
70
71 pub(crate) fn h1(content_length: DecodedLength, wanter: bool) -> (Sender, Incoming) {
72 let (data_tx, data_rx) = mpsc::channel(2);
73 let (trailers_tx, trailers_rx) = oneshot::channel();
74 let (want_tx, want_rx) = watch::channel(wanter);
77
78 (
79 Sender {
80 want_rx,
81 data_tx: PollSender::new(data_tx),
82 trailers_tx: Some(trailers_tx),
83 },
84 Incoming {
85 kind: Kind::H1 {
86 want_tx,
87 data_rx,
88 trailers_rx,
89 content_length,
90 data_done: false,
91 },
92 },
93 )
94 }
95
96 pub(crate) fn h2(
97 recv: http2::RecvStream,
98 mut content_length: DecodedLength,
99 ping: ping::Recorder,
100 ) -> Self {
101 if !content_length.is_exact() && recv.is_end_stream() {
104 content_length = DecodedLength::ZERO;
105 }
106
107 Incoming {
108 kind: Kind::H2 {
109 ping,
110 recv,
111 content_length,
112 data_done: false,
113 },
114 }
115 }
116}
117
118impl Body for Incoming {
119 type Data = Bytes;
120 type Error = Error;
121
122 fn poll_frame(
123 mut self: Pin<&mut Self>,
124 cx: &mut Context<'_>,
125 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
126 match self.kind {
127 Kind::H1 {
128 ref want_tx,
129 ref mut data_rx,
130 ref mut trailers_rx,
131 ref mut content_length,
132 ref mut data_done,
133 } => {
134 want_tx.ready();
135
136 if !*data_done {
137 match ready!(data_rx.poll_recv(cx)) {
138 Some(Ok(chunk)) => {
139 content_length.sub_if(chunk.len() as u64);
140 return Poll::Ready(Some(Ok(Frame::data(chunk))));
141 }
142 Some(Err(err)) => return Poll::Ready(Some(Err(err))),
143 None => {
144 *data_done = true;
146 }
147 }
148 }
149
150 if !trailers_rx.is_terminated() {
152 if let Ok(trailers) = ready!(Pin::new(trailers_rx).poll(cx)) {
153 return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
154 }
155 }
156
157 Poll::Ready(None)
158 }
159 Kind::H2 {
160 ref ping,
161 ref mut recv,
162 ref mut content_length,
163 ref mut data_done,
164 } => {
165 if !*data_done {
166 match ready!(recv.poll_data(cx)) {
167 Some(Ok(bytes)) => {
168 let _ = recv.flow_control().release_capacity(bytes.len());
169 content_length.sub_if(bytes.len() as u64);
170 ping.record_data(bytes.len());
171 return Poll::Ready(Some(Ok(Frame::data(bytes))));
172 }
173 Some(Err(e)) => {
174 if let Some(http2::Reason::NO_ERROR) = e.reason() {
175 return Poll::Ready(None);
179 } else {
180 return Poll::Ready(Some(Err(Error::new_body(e))));
181 }
182 }
183 None => {
184 *data_done = true;
186 }
187 }
188 }
189
190 match ready!(recv.poll_trailers(cx)) {
192 Ok(t) => {
193 ping.record_non_data();
194 Poll::Ready(Ok(t.map(Frame::trailers)).transpose())
195 }
196 Err(e) => {
197 if let Some(http2::Reason::NO_ERROR) = e.reason() {
198 Poll::Ready(None)
202 } else {
203 Poll::Ready(Some(Err(Error::new_h2(e))))
204 }
205 }
206 }
207 }
208 Kind::Empty => Poll::Ready(None),
209 }
210 }
211
212 #[inline]
213 fn is_end_stream(&self) -> bool {
214 match self.kind {
215 Kind::H1 { content_length, .. } => content_length == DecodedLength::ZERO,
216 Kind::H2 { recv: ref h2, .. } => h2.is_end_stream(),
217 Kind::Empty => true,
218 }
219 }
220
221 #[inline]
222 fn size_hint(&self) -> SizeHint {
223 match self.kind {
224 Kind::H1 { content_length, .. } | Kind::H2 { content_length, .. } => content_length
225 .into_opt()
226 .map_or_else(SizeHint::default, SizeHint::with_exact),
227 Kind::Empty => SizeHint::with_exact(0),
228 }
229 }
230}
231
232impl fmt::Debug for Incoming {
233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234 let mut builder = f.debug_tuple(stringify!(Incoming));
235 match self.kind {
236 Kind::Empty => builder.field(&stringify!(Empty)),
237 _ => builder.field(&stringify!(Streaming)),
238 };
239 builder.finish()
240 }
241}
242
243impl Sender {
246 #[inline]
248 pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
249 ready!(self.want_rx.poll_ready(cx)?);
251 self.data_tx
252 .poll_reserve(cx)
253 .map_err(|_| Error::new_closed())
254 }
255
256 #[inline]
268 pub(crate) fn send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> {
269 self.data_tx.send_item(Ok(chunk)).map_err(|err| {
270 err.into_inner()
271 .expect("value returned")
272 .expect("just sent Ok")
273 })
274 }
275
276 #[inline]
283 pub(crate) fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Option<HeaderMap>> {
284 self.trailers_tx
285 .take()
286 .ok_or(None)?
287 .send(trailers)
288 .map_err(Some)
289 }
290
291 #[inline]
293 pub(crate) fn send_error(&mut self, err: Error) {
294 self.data_tx
295 .get_ref()
296 .map(|sender| sender.try_send(Err(err)));
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use std::{mem, task::Poll};
303
304 use http_body_util::BodyExt;
305
306 use super::{Body, DecodedLength, Error, Incoming, Result, Sender, SizeHint};
307
308 impl Incoming {
309 pub(crate) fn channel() -> (Sender, Incoming) {
313 Self::h1(DecodedLength::CHUNKED, false)
314 }
315 }
316
317 impl Sender {
318 async fn ready(&mut self) -> Result<()> {
319 std::future::poll_fn(|cx| self.poll_ready(cx)).await
320 }
321
322 fn abort(mut self) {
323 self.send_error(Error::new_body_write_aborted());
324 }
325 }
326
327 #[test]
328 fn test_size_of() {
329 let body_size = mem::size_of::<Incoming>();
333 let body_expected_size = mem::size_of::<u64>() * 6;
334 assert!(
335 body_size <= body_expected_size,
336 "Body size = {body_size} <= {body_expected_size}",
337 );
338
339 assert_eq!(
342 mem::size_of::<Sender>(),
343 mem::size_of::<usize>() * 8,
344 "Sender"
345 );
346
347 assert_eq!(
348 mem::size_of::<Sender>(),
349 mem::size_of::<Option<Sender>>(),
350 "Option<Sender>"
351 );
352 }
353
354 #[test]
355 fn size_hint() {
356 fn eq(body: Incoming, b: SizeHint, note: &str) {
357 let a = body.size_hint();
358 assert_eq!(a.lower(), b.lower(), "lower for {note:?}");
359 assert_eq!(a.upper(), b.upper(), "upper for {note:?}");
360 }
361
362 eq(Incoming::empty(), SizeHint::with_exact(0), "empty");
363
364 eq(Incoming::channel().1, SizeHint::new(), "channel");
365
366 eq(
367 Incoming::h1(DecodedLength::new(4), false).1,
368 SizeHint::with_exact(4),
369 "channel with length",
370 );
371 }
372
373 #[tokio::test]
374 async fn channel_abort() {
375 let (tx, mut rx) = Incoming::channel();
376
377 tx.abort();
378
379 let err = rx.frame().await.unwrap().unwrap_err();
380 assert!(err.is_body_write_aborted(), "{err:?}");
381 }
382
383 #[tokio::test]
384 async fn channel_abort_when_buffer_is_full() {
385 let (mut tx, mut rx) = Incoming::channel();
386
387 tx.ready().await.expect("ready");
388 tx.send_data("chunk 1".into()).expect("send 1");
389 tx.abort();
391
392 let chunk1 = rx
393 .frame()
394 .await
395 .expect("item 1")
396 .expect("chunk 1")
397 .into_data()
398 .unwrap();
399 assert_eq!(chunk1, "chunk 1");
400
401 let err = rx.frame().await.unwrap().unwrap_err();
402 assert!(err.is_body_write_aborted(), "{err:?}");
403 }
404
405 #[tokio::test]
406 async fn channel_buffers_two() {
407 let (mut tx, _rx) = Incoming::channel();
408
409 tx.ready().await.expect("ready");
410 tx.send_data("chunk 1".into()).expect("send 1");
411 tx.ready().await.expect("ready");
412 tx.send_data("chunk 2".into()).expect("send 2");
413
414 let res = tokio::time::timeout(
416 std::time::Duration::from_millis(100),
417 std::future::poll_fn(|cx| tx.poll_ready(cx)),
418 )
419 .await;
420
421 assert!(res.is_err(), "poll_ready unexpectedly became ready");
422 }
423
424 #[tokio::test]
425 async fn channel_empty() {
426 let (_, mut rx) = Incoming::channel();
427 assert!(rx.frame().await.is_none());
428 }
429
430 #[test]
431 fn channel_ready() {
432 let (mut tx, _rx) = Incoming::h1(DecodedLength::CHUNKED, false);
433
434 let mut tx_ready = tokio_test::task::spawn(tx.ready());
435
436 assert!(tx_ready.poll().is_ready(), "tx is ready immediately");
437 }
438
439 #[test]
440 fn channel_wanter() {
441 let (mut tx, mut rx) = Incoming::h1(DecodedLength::CHUNKED, true);
442
443 let mut tx_ready = tokio_test::task::spawn(tx.ready());
444 let mut rx_data = tokio_test::task::spawn(rx.frame());
445
446 assert!(
447 tx_ready.poll().is_pending(),
448 "tx isn't ready before rx has been polled"
449 );
450
451 assert!(rx_data.poll().is_pending(), "poll rx.data");
452 assert!(tx_ready.is_woken(), "rx poll wakes tx");
453
454 assert!(
455 tx_ready.poll().is_ready(),
456 "tx is ready after rx has been polled"
457 );
458 }
459
460 #[test]
461 fn channel_notices_closure() {
462 let (mut tx, rx) = Incoming::h1(DecodedLength::CHUNKED, true);
463
464 let mut tx_ready = tokio_test::task::spawn(tx.ready());
465
466 assert!(
467 tx_ready.poll().is_pending(),
468 "tx isn't ready before rx has been polled"
469 );
470
471 drop(rx);
472 assert!(tx_ready.is_woken(), "dropping rx wakes tx");
473
474 match tx_ready.poll() {
475 Poll::Ready(Err(ref e)) if e.is_closed() => (),
476 unexpected => panic!("tx poll ready unexpected: {unexpected:?}"),
477 }
478 }
479}