tower_http/on_early_drop/
early_drops_as_failures.rs1use crate::on_early_drop::failure::{BodyDropped, DroppedFailure, FutureDropped};
6use crate::on_early_drop::traits::{OnBodyDrop, OnDropCallback, OnFutureDrop};
7use crate::trace::OnFailure;
8use http::{response, Request, StatusCode};
9use std::time::Instant;
10use tracing::Span;
11
12#[derive(Debug, Clone, Copy)]
28pub struct EarlyDropsAsFailures<F> {
29 on_failure: F,
30}
31
32impl<F> EarlyDropsAsFailures<F> {
33 pub fn new(on_failure: F) -> Self {
35 Self { on_failure }
36 }
37}
38
39pub struct FutureDropFailureCallback<F> {
41 start: Instant,
42 on_failure: F,
43 span: Span,
44}
45
46impl<F> OnDropCallback for FutureDropFailureCallback<F>
47where
48 F: OnFailure<DroppedFailure> + Send + 'static,
49{
50 fn on_drop(mut self) {
51 let latency = self.start.elapsed();
52 let _entered = self.span.enter();
53 self.on_failure
54 .on_failure(DroppedFailure::Future(FutureDropped), latency, &self.span);
55 }
56}
57
58pub struct PreResponseBodyDropCallback<F> {
62 start: Instant,
63 on_failure: F,
64 span: Span,
65}
66
67pub struct BodyDropFailureCallback<F> {
69 start: Instant,
70 on_failure: F,
71 span: Span,
72 status: StatusCode,
73}
74
75impl<F> OnDropCallback for BodyDropFailureCallback<F>
76where
77 F: OnFailure<DroppedFailure> + Send + 'static,
78{
79 fn on_drop(mut self) {
80 let latency = self.start.elapsed();
81 let _entered = self.span.enter();
82 self.on_failure.on_failure(
83 DroppedFailure::Body(BodyDropped {
84 status: self.status,
85 }),
86 latency,
87 &self.span,
88 );
89 }
90}
91
92impl<F, ReqB> OnFutureDrop<ReqB> for EarlyDropsAsFailures<F>
93where
94 F: OnFailure<DroppedFailure> + Clone + Send + 'static,
95{
96 type Callback = FutureDropFailureCallback<F>;
97
98 fn make(&mut self, _request: &Request<ReqB>) -> Self::Callback {
99 FutureDropFailureCallback {
100 start: Instant::now(),
101 on_failure: self.on_failure.clone(),
102 span: Span::current(),
103 }
104 }
105}
106
107impl<F, ReqB> OnBodyDrop<ReqB> for EarlyDropsAsFailures<F>
108where
109 F: OnFailure<DroppedFailure> + Clone + Send + 'static,
110{
111 type Intermediate = PreResponseBodyDropCallback<F>;
112 type Callback = BodyDropFailureCallback<F>;
113
114 fn make_at_call(&mut self, _request: &Request<ReqB>) -> Self::Intermediate {
115 PreResponseBodyDropCallback {
116 start: Instant::now(),
117 on_failure: self.on_failure.clone(),
118 span: Span::current(),
119 }
120 }
121
122 fn make_at_response(
123 &mut self,
124 intermediate: Self::Intermediate,
125 response_parts: &response::Parts,
126 ) -> Self::Callback {
127 BodyDropFailureCallback {
128 start: intermediate.start,
129 on_failure: intermediate.on_failure,
130 span: intermediate.span,
131 status: response_parts.status,
132 }
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::on_early_drop::OnEarlyDropLayer;
140 use bytes::Bytes;
141 use http::{Request, Response, StatusCode};
142 use http_body_util::{BodyExt, Full};
143 use std::sync::{Arc, Mutex};
144 use std::time::Duration;
145 use tokio::time::{sleep, timeout};
146 use tower::{service_fn, Layer, ServiceExt};
147 use tracing::Span;
148
149 #[derive(Clone, Default)]
150 struct RecordingOnFailure {
151 events: Arc<Mutex<Vec<DroppedFailure>>>,
152 }
153
154 impl OnFailure<DroppedFailure> for RecordingOnFailure {
155 fn on_failure(&mut self, class: DroppedFailure, _latency: Duration, _span: &Span) {
156 self.events.lock().unwrap().push(class);
157 }
158 }
159
160 #[tokio::test]
161 async fn future_drop_reports_future_failure() {
162 let recorder = RecordingOnFailure::default();
163 let events = recorder.events.clone();
164
165 let slow_service = service_fn(|_req: Request<()>| async move {
166 sleep(Duration::from_secs(60)).await;
167 Ok::<_, std::convert::Infallible>(
168 Response::builder()
169 .status(StatusCode::OK)
170 .body(Full::new(Bytes::new()))
171 .unwrap(),
172 )
173 });
174
175 let layer = OnEarlyDropLayer::new(EarlyDropsAsFailures::new(recorder));
176 let service = layer.layer(slow_service);
177 let _ = timeout(
178 Duration::from_millis(50),
179 service.oneshot(Request::builder().uri("/").body(()).unwrap()),
180 )
181 .await;
182
183 sleep(Duration::from_millis(10)).await;
184 let captured = events.lock().unwrap();
185 assert_eq!(captured.len(), 1);
186 assert!(matches!(captured[0], DroppedFailure::Future(_)));
187 }
188
189 #[tokio::test]
190 async fn body_drop_reports_body_failure_with_status() {
191 let recorder = RecordingOnFailure::default();
192 let events = recorder.events.clone();
193
194 struct PendingBody;
195 impl http_body::Body for PendingBody {
196 type Data = Bytes;
197 type Error = std::convert::Infallible;
198 fn poll_frame(
199 self: std::pin::Pin<&mut Self>,
200 _cx: &mut std::task::Context<'_>,
201 ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>>
202 {
203 std::task::Poll::Pending
204 }
205 fn is_end_stream(&self) -> bool {
206 false
207 }
208 }
209
210 let service = service_fn(|_req: Request<()>| async move {
211 Ok::<_, std::convert::Infallible>(
212 Response::builder()
213 .status(StatusCode::CREATED)
214 .body(PendingBody)
215 .unwrap(),
216 )
217 });
218
219 let layer = OnEarlyDropLayer::new(EarlyDropsAsFailures::new(recorder));
220 let service = layer.layer(service);
221 let response = service
222 .oneshot(Request::builder().uri("/").body(()).unwrap())
223 .await
224 .unwrap();
225 drop(response);
226
227 let captured = events.lock().unwrap();
228 assert_eq!(captured.len(), 1);
229 match &captured[0] {
230 DroppedFailure::Body(body) => assert_eq!(body.status, StatusCode::CREATED),
231 other => panic!("expected Body failure, got {:?}", other),
232 }
233 }
234
235 #[tokio::test]
236 async fn completion_suppresses_both() {
237 let recorder = RecordingOnFailure::default();
238 let events = recorder.events.clone();
239
240 let ok_service = service_fn(|_req: Request<()>| async move {
241 Ok::<_, std::convert::Infallible>(
242 Response::builder()
243 .status(StatusCode::OK)
244 .body(Full::new(Bytes::from_static(b"hi")))
245 .unwrap(),
246 )
247 });
248
249 let layer = OnEarlyDropLayer::new(EarlyDropsAsFailures::new(recorder));
250 let service = layer.layer(ok_service);
251 let response = service
252 .oneshot(Request::builder().uri("/").body(()).unwrap())
253 .await
254 .unwrap();
255 let _body = response.into_body().collect().await.unwrap();
256
257 assert!(events.lock().unwrap().is_empty());
258 }
259}