tower_async_http/trace/
body.rs

1use super::{OnBodyChunk, OnEos, OnFailure};
2use crate::classify::ClassifyEos;
3use futures_core::ready;
4use http_body::{Body, Frame};
5use pin_project_lite::pin_project;
6use std::{
7    fmt,
8    pin::Pin,
9    task::{Context, Poll},
10    time::Instant,
11};
12use tracing::Span;
13
14pin_project! {
15    /// Response body for [`Trace`].
16    ///
17    /// [`Trace`]: super::Trace
18    pub struct ResponseBody<B, C, OnBodyChunk, OnEos, OnFailure> {
19        #[pin]
20        pub(crate) inner: B,
21        pub(crate) classify_eos: Option<C>,
22        pub(crate) on_eos: Option<(OnEos, Instant)>,
23        pub(crate) on_body_chunk: OnBodyChunk,
24        pub(crate) on_failure: Option<OnFailure>,
25        pub(crate) start: Instant,
26        pub(crate) span: Span,
27    }
28}
29
30impl<B, C, OnBodyChunkT, OnEosT, OnFailureT> Body
31    for ResponseBody<B, C, OnBodyChunkT, OnEosT, OnFailureT>
32where
33    B: Body,
34    B::Error: fmt::Display,
35    C: ClassifyEos,
36    OnEosT: OnEos,
37    OnBodyChunkT: OnBodyChunk<B::Data>,
38    OnFailureT: OnFailure<C::FailureClass>,
39{
40    type Data = B::Data;
41    type Error = B::Error;
42
43    fn poll_frame(
44        self: Pin<&mut Self>,
45        cx: &mut Context<'_>,
46    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
47        let this = self.project();
48        let _guard = this.span.enter();
49        let result = ready!(this.inner.poll_frame(cx));
50
51        let latency = this.start.elapsed();
52        *this.start = Instant::now();
53
54        match result {
55            Some(Ok(frame)) => {
56                let frame = match frame.into_data() {
57                    Ok(chunk) => {
58                        this.on_body_chunk.on_body_chunk(&chunk, latency, this.span);
59                        Frame::data(chunk)
60                    }
61                    Err(frame) => frame,
62                };
63
64                let frame = match frame.into_trailers() {
65                    Ok(trailers) => {
66                        if let Some((on_eos, stream_start)) = this.on_eos.take() {
67                            on_eos.on_eos(Some(&trailers), stream_start.elapsed(), this.span);
68                        }
69                        Frame::trailers(trailers)
70                    }
71                    Err(frame) => frame,
72                };
73
74                Poll::Ready(Some(Ok(frame)))
75            }
76            Some(Err(err)) => {
77                if let Some((classify_eos, on_failure)) =
78                    this.classify_eos.take().zip(this.on_failure.take())
79                {
80                    let failure_class = classify_eos.classify_error(&err);
81                    on_failure.on_failure(failure_class, latency, this.span);
82                }
83
84                Poll::Ready(Some(Err(err)))
85            }
86            None => {
87                if let Some((on_eos, stream_start)) = this.on_eos.take() {
88                    on_eos.on_eos(None, stream_start.elapsed(), this.span);
89                }
90
91                Poll::Ready(None)
92            }
93        }
94    }
95
96    fn is_end_stream(&self) -> bool {
97        self.inner.is_end_stream()
98    }
99
100    fn size_hint(&self) -> http_body::SizeHint {
101        self.inner.size_hint()
102    }
103}