Skip to main content

tower_http/on_early_drop/
early_drops_as_failures.rs

1//! Adapter that bridges early-drop events to [`trace::OnFailure`].
2//!
3//! [`trace::OnFailure`]: crate::trace::OnFailure
4
5use crate::on_early_drop::failure::{BodyDropped, DroppedFailure, FutureDropped};
6use crate::on_early_drop::traits::{OnBodyDrop, OnDropCallback, OnFutureDrop};
7use crate::trace::OnFailure;
8use http::{response, Request, StatusCode};
9use std::time::Instant;
10use tracing::Span;
11
12/// Bridges early-drop events to [`trace::OnFailure`](crate::trace::OnFailure).
13///
14/// Each event is reported by invoking the wrapped hook with a
15/// [`DroppedFailure`]: `Future` for future drops, `Body` for body drops
16/// (carrying the emitted response status).
17///
18/// Latency is computed from the moment the hook is produced (either at
19/// `Service::call` or at response-ready time). The captured span is
20/// [`Span::current()`] at that same moment. To report events against the
21/// request span used by [`TraceLayer`](crate::trace::TraceLayer), place
22/// [`OnEarlyDropLayer`] inside `TraceLayer`.
23///
24/// See the [module docs](super) for the example.
25///
26/// [`OnEarlyDropLayer`]: super::OnEarlyDropLayer
27#[derive(Debug, Clone, Copy)]
28pub struct EarlyDropsAsFailures<F> {
29    on_failure: F,
30}
31
32impl<F> EarlyDropsAsFailures<F> {
33    /// Wrap an [`OnFailure`] implementation.
34    pub fn new(on_failure: F) -> Self {
35        Self { on_failure }
36    }
37}
38
39/// Future-drop callback produced by [`EarlyDropsAsFailures`].
40pub struct FutureDropFailureCallback<F> {
41    start: Instant,
42    on_failure: F,
43    span: Span,
44}
45
46impl<F> OnDropCallback for FutureDropFailureCallback<F>
47where
48    F: OnFailure<DroppedFailure> + Send + 'static,
49{
50    fn on_drop(mut self) {
51        let latency = self.start.elapsed();
52        let _entered = self.span.enter();
53        self.on_failure
54            .on_failure(DroppedFailure::Future(FutureDropped), latency, &self.span);
55    }
56}
57
58/// Intermediate produced by [`OnBodyDrop::make_at_call`] for
59/// [`EarlyDropsAsFailures`], carrying state forward to
60/// [`OnBodyDrop::make_at_response`].
61pub struct PreResponseBodyDropCallback<F> {
62    start: Instant,
63    on_failure: F,
64    span: Span,
65}
66
67/// Body-drop callback produced by [`EarlyDropsAsFailures`].
68pub struct BodyDropFailureCallback<F> {
69    start: Instant,
70    on_failure: F,
71    span: Span,
72    status: StatusCode,
73}
74
75impl<F> OnDropCallback for BodyDropFailureCallback<F>
76where
77    F: OnFailure<DroppedFailure> + Send + 'static,
78{
79    fn on_drop(mut self) {
80        let latency = self.start.elapsed();
81        let _entered = self.span.enter();
82        self.on_failure.on_failure(
83            DroppedFailure::Body(BodyDropped {
84                status: self.status,
85            }),
86            latency,
87            &self.span,
88        );
89    }
90}
91
92impl<F, ReqB> OnFutureDrop<ReqB> for EarlyDropsAsFailures<F>
93where
94    F: OnFailure<DroppedFailure> + Clone + Send + 'static,
95{
96    type Callback = FutureDropFailureCallback<F>;
97
98    fn make(&mut self, _request: &Request<ReqB>) -> Self::Callback {
99        FutureDropFailureCallback {
100            start: Instant::now(),
101            on_failure: self.on_failure.clone(),
102            span: Span::current(),
103        }
104    }
105}
106
107impl<F, ReqB> OnBodyDrop<ReqB> for EarlyDropsAsFailures<F>
108where
109    F: OnFailure<DroppedFailure> + Clone + Send + 'static,
110{
111    type Intermediate = PreResponseBodyDropCallback<F>;
112    type Callback = BodyDropFailureCallback<F>;
113
114    fn make_at_call(&mut self, _request: &Request<ReqB>) -> Self::Intermediate {
115        PreResponseBodyDropCallback {
116            start: Instant::now(),
117            on_failure: self.on_failure.clone(),
118            span: Span::current(),
119        }
120    }
121
122    fn make_at_response(
123        &mut self,
124        intermediate: Self::Intermediate,
125        response_parts: &response::Parts,
126    ) -> Self::Callback {
127        BodyDropFailureCallback {
128            start: intermediate.start,
129            on_failure: intermediate.on_failure,
130            span: intermediate.span,
131            status: response_parts.status,
132        }
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::on_early_drop::OnEarlyDropLayer;
140    use bytes::Bytes;
141    use http::{Request, Response, StatusCode};
142    use http_body_util::{BodyExt, Full};
143    use std::sync::{Arc, Mutex};
144    use std::time::Duration;
145    use tokio::time::{sleep, timeout};
146    use tower::{service_fn, Layer, ServiceExt};
147    use tracing::Span;
148
149    #[derive(Clone, Default)]
150    struct RecordingOnFailure {
151        events: Arc<Mutex<Vec<DroppedFailure>>>,
152    }
153
154    impl OnFailure<DroppedFailure> for RecordingOnFailure {
155        fn on_failure(&mut self, class: DroppedFailure, _latency: Duration, _span: &Span) {
156            self.events.lock().unwrap().push(class);
157        }
158    }
159
160    #[tokio::test]
161    async fn future_drop_reports_future_failure() {
162        let recorder = RecordingOnFailure::default();
163        let events = recorder.events.clone();
164
165        let slow_service = service_fn(|_req: Request<()>| async move {
166            sleep(Duration::from_secs(60)).await;
167            Ok::<_, std::convert::Infallible>(
168                Response::builder()
169                    .status(StatusCode::OK)
170                    .body(Full::new(Bytes::new()))
171                    .unwrap(),
172            )
173        });
174
175        let layer = OnEarlyDropLayer::new(EarlyDropsAsFailures::new(recorder));
176        let service = layer.layer(slow_service);
177        let _ = timeout(
178            Duration::from_millis(50),
179            service.oneshot(Request::builder().uri("/").body(()).unwrap()),
180        )
181        .await;
182
183        sleep(Duration::from_millis(10)).await;
184        let captured = events.lock().unwrap();
185        assert_eq!(captured.len(), 1);
186        assert!(matches!(captured[0], DroppedFailure::Future(_)));
187    }
188
189    #[tokio::test]
190    async fn body_drop_reports_body_failure_with_status() {
191        let recorder = RecordingOnFailure::default();
192        let events = recorder.events.clone();
193
194        struct PendingBody;
195        impl http_body::Body for PendingBody {
196            type Data = Bytes;
197            type Error = std::convert::Infallible;
198            fn poll_frame(
199                self: std::pin::Pin<&mut Self>,
200                _cx: &mut std::task::Context<'_>,
201            ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>>
202            {
203                std::task::Poll::Pending
204            }
205            fn is_end_stream(&self) -> bool {
206                false
207            }
208        }
209
210        let service = service_fn(|_req: Request<()>| async move {
211            Ok::<_, std::convert::Infallible>(
212                Response::builder()
213                    .status(StatusCode::CREATED)
214                    .body(PendingBody)
215                    .unwrap(),
216            )
217        });
218
219        let layer = OnEarlyDropLayer::new(EarlyDropsAsFailures::new(recorder));
220        let service = layer.layer(service);
221        let response = service
222            .oneshot(Request::builder().uri("/").body(()).unwrap())
223            .await
224            .unwrap();
225        drop(response);
226
227        let captured = events.lock().unwrap();
228        assert_eq!(captured.len(), 1);
229        match &captured[0] {
230            DroppedFailure::Body(body) => assert_eq!(body.status, StatusCode::CREATED),
231            other => panic!("expected Body failure, got {:?}", other),
232        }
233    }
234
235    #[tokio::test]
236    async fn completion_suppresses_both() {
237        let recorder = RecordingOnFailure::default();
238        let events = recorder.events.clone();
239
240        let ok_service = service_fn(|_req: Request<()>| async move {
241            Ok::<_, std::convert::Infallible>(
242                Response::builder()
243                    .status(StatusCode::OK)
244                    .body(Full::new(Bytes::from_static(b"hi")))
245                    .unwrap(),
246            )
247        });
248
249        let layer = OnEarlyDropLayer::new(EarlyDropsAsFailures::new(recorder));
250        let service = layer.layer(ok_service);
251        let response = service
252            .oneshot(Request::builder().uri("/").body(()).unwrap())
253            .await
254            .unwrap();
255        let _body = response.into_body().collect().await.unwrap();
256
257        assert!(events.lock().unwrap().is_empty());
258    }
259}