Skip to main content

tower_http/trace/
body.rs

1use super::{DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, OnBodyChunk, OnEos, OnFailure};
2use crate::classify::ClassifyEos;
3use http_body::{Body, Frame};
4use pin_project_lite::pin_project;
5use std::{
6    fmt,
7    pin::Pin,
8    task::{ready, Context, Poll},
9    time::Instant,
10};
11use tracing::Span;
12
13pin_project! {
14    /// Response body for [`Trace`].
15    ///
16    /// [`Trace`]: super::Trace
17    pub struct ResponseBody<B, C, OnBodyChunk = DefaultOnBodyChunk, OnEos = DefaultOnEos, OnFailure = DefaultOnFailure> {
18        #[pin]
19        pub(crate) inner: B,
20        pub(crate) classify_eos: Option<C>,
21        pub(crate) on_eos: Option<(OnEos, Instant)>,
22        pub(crate) on_body_chunk: OnBodyChunk,
23        pub(crate) on_failure: Option<OnFailure>,
24        pub(crate) start: Instant,
25        pub(crate) span: Span,
26    }
27}
28
29impl<B, C, OnBodyChunkT, OnEosT, OnFailureT> Body
30    for ResponseBody<B, C, OnBodyChunkT, OnEosT, OnFailureT>
31where
32    B: Body,
33    B::Error: fmt::Display + 'static,
34    C: ClassifyEos,
35    OnEosT: OnEos,
36    OnBodyChunkT: OnBodyChunk<B::Data>,
37    OnFailureT: OnFailure<C::FailureClass>,
38{
39    type Data = B::Data;
40    type Error = B::Error;
41
42    fn poll_frame(
43        self: Pin<&mut Self>,
44        cx: &mut Context<'_>,
45    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
46        let this = self.project();
47        let _guard = this.span.enter();
48        let result = ready!(this.inner.poll_frame(cx));
49
50        let latency = this.start.elapsed();
51        *this.start = Instant::now();
52
53        match result {
54            Some(Ok(frame)) => {
55                let frame = match frame.into_data() {
56                    Ok(chunk) => {
57                        this.on_body_chunk.on_body_chunk(&chunk, latency, this.span);
58                        Frame::data(chunk)
59                    }
60                    Err(frame) => frame,
61                };
62
63                let frame = match frame.into_trailers() {
64                    Ok(trailers) => {
65                        if let Some((classify_eos, mut on_failure)) =
66                            this.classify_eos.take().zip(this.on_failure.take())
67                        {
68                            if let Err(failure_class) = classify_eos.classify_eos(Some(&trailers)) {
69                                on_failure.on_failure(failure_class, latency, this.span);
70                            }
71                        }
72                        if let Some((on_eos, stream_start)) = this.on_eos.take() {
73                            on_eos.on_eos(Some(&trailers), stream_start.elapsed(), this.span);
74                        }
75                        Frame::trailers(trailers)
76                    }
77                    Err(frame) => frame,
78                };
79
80                Poll::Ready(Some(Ok(frame)))
81            }
82            Some(Err(err)) => {
83                if let Some((classify_eos, mut on_failure)) =
84                    this.classify_eos.take().zip(this.on_failure.take())
85                {
86                    let failure_class = classify_eos.classify_error(&err);
87                    on_failure.on_failure(failure_class, latency, this.span);
88                }
89
90                Poll::Ready(Some(Err(err)))
91            }
92            None => {
93                if let Some((classify_eos, mut on_failure)) =
94                    this.classify_eos.take().zip(this.on_failure.take())
95                {
96                    if let Err(failure_class) = classify_eos.classify_eos(None) {
97                        on_failure.on_failure(failure_class, latency, this.span);
98                    }
99                }
100                if let Some((on_eos, stream_start)) = this.on_eos.take() {
101                    on_eos.on_eos(None, stream_start.elapsed(), this.span);
102                }
103
104                Poll::Ready(None)
105            }
106        }
107    }
108
109    fn is_end_stream(&self) -> bool {
110        self.inner.is_end_stream()
111    }
112
113    fn size_hint(&self) -> http_body::SizeHint {
114        self.inner.size_hint()
115    }
116}