tower_http/on_early_drop/
service.rs1use crate::on_early_drop::body::OnEarlyDropBody;
4use crate::on_early_drop::future::OnEarlyDropFuture;
5use crate::on_early_drop::traits::{OnBodyDrop, OnFutureDrop};
6use http::{Request, Response};
7use std::task::{Context, Poll};
8use tower_service::Service;
9
10pub struct OnEarlyDropService<S, OFD, OBD> {
16 pub(crate) inner: S,
17 pub(crate) on_future_drop: OFD,
18 pub(crate) on_body_drop: OBD,
19}
20
21impl<S, OFD, OBD> std::fmt::Debug for OnEarlyDropService<S, OFD, OBD>
22where
23 S: std::fmt::Debug,
24{
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("OnEarlyDropService")
27 .field("inner", &self.inner)
28 .field("on_future_drop", &format_args!(".."))
29 .field("on_body_drop", &format_args!(".."))
30 .finish()
31 }
32}
33
34impl<S, OFD, OBD> Clone for OnEarlyDropService<S, OFD, OBD>
35where
36 S: Clone,
37 OFD: Clone,
38 OBD: Clone,
39{
40 fn clone(&self) -> Self {
41 Self {
42 inner: self.inner.clone(),
43 on_future_drop: self.on_future_drop.clone(),
44 on_body_drop: self.on_body_drop.clone(),
45 }
46 }
47}
48
49impl<S, OFD, OBD> OnEarlyDropService<S, OFD, OBD> {
50 pub fn new(inner: S, on_future_drop: OFD, on_body_drop: OBD) -> Self {
53 Self {
54 inner,
55 on_future_drop,
56 on_body_drop,
57 }
58 }
59
60 define_inner_service_accessors!();
61}
62
63impl<S, OFD, OBD, ReqB, ResB> Service<Request<ReqB>> for OnEarlyDropService<S, OFD, OBD>
64where
65 S: Service<Request<ReqB>, Response = Response<ResB>>,
66 OFD: OnFutureDrop<ReqB>,
67 OBD: OnBodyDrop<ReqB> + Clone,
68 ResB: http_body::Body,
69{
70 type Response = Response<OnEarlyDropBody<ResB, OBD::Callback>>;
71 type Error = S::Error;
72 type Future = OnEarlyDropFuture<S::Future, OBD, ReqB, OFD::Callback, OBD::Callback>;
73
74 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75 self.inner.poll_ready(cx)
76 }
77
78 fn call(&mut self, req: Request<ReqB>) -> Self::Future {
79 let future_callback = self.on_future_drop.make(&req);
80 let intermediate = self.on_body_drop.make_at_call(&req);
81 let inner = self.inner.call(req);
82 OnEarlyDropFuture::new(
83 inner,
84 future_callback,
85 self.on_body_drop.clone(),
86 intermediate,
87 )
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use crate::on_early_drop::{OnBodyDropFn, OnEarlyDropLayer};
95 use bytes::Bytes;
96 use http::{Request, Response, StatusCode};
97 use http_body_util::{BodyExt, Full};
98 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
99 use std::sync::Arc;
100 use std::time::Duration;
101 use tokio::time::{sleep, timeout};
102 use tower::{service_fn, Layer, ServiceExt};
103
104 fn ok_service() -> impl Service<
105 Request<()>,
106 Response = Response<Full<Bytes>>,
107 Error = std::convert::Infallible,
108 Future = impl std::future::Future<
109 Output = Result<Response<Full<Bytes>>, std::convert::Infallible>,
110 > + Send,
111 > + Clone {
112 service_fn(|_req: Request<()>| async move {
113 Ok::<_, std::convert::Infallible>(
114 Response::builder()
115 .status(StatusCode::OK)
116 .body(Full::new(Bytes::from_static(b"hello")))
117 .unwrap(),
118 )
119 })
120 }
121
122 fn request() -> Request<()> {
123 Request::builder().uri("http://example/").body(()).unwrap()
124 }
125
126 #[tokio::test]
127 async fn forwards_response() {
128 let layer = OnEarlyDropLayer::builder();
129 let service = layer.layer(ok_service());
130 let response = service.oneshot(request()).await.unwrap();
131 assert_eq!(response.status(), StatusCode::OK);
132 let body = response.into_body().collect().await.unwrap().to_bytes();
133 assert_eq!(body, "hello");
134 }
135
136 #[tokio::test]
137 async fn future_drop_fires_callback() {
138 let fired = Arc::new(AtomicUsize::new(0));
139 let fired_clone = fired.clone();
140
141 let slow_service = service_fn(|_req: Request<()>| async move {
142 sleep(Duration::from_secs(60)).await;
143 Ok::<_, std::convert::Infallible>(
144 Response::builder()
145 .status(StatusCode::OK)
146 .body(Full::new(Bytes::new()))
147 .unwrap(),
148 )
149 });
150
151 let layer = OnEarlyDropLayer::builder().on_future_drop(move |_req: &Request<()>| {
152 let fired = fired_clone.clone();
153 move || {
154 fired.fetch_add(1, Ordering::Relaxed);
155 }
156 });
157 let service = layer.layer(slow_service);
158 let _ = timeout(Duration::from_millis(50), service.oneshot(request())).await;
159
160 sleep(Duration::from_millis(10)).await;
161 assert_eq!(fired.load(Ordering::Relaxed), 1);
162 }
163
164 #[tokio::test]
165 async fn future_drop_suppressed_on_completion() {
166 let fired = Arc::new(AtomicUsize::new(0));
167 let fired_clone = fired.clone();
168
169 let layer = OnEarlyDropLayer::builder().on_future_drop(move |_req: &Request<()>| {
170 let fired = fired_clone.clone();
171 move || {
172 fired.fetch_add(1, Ordering::Relaxed);
173 }
174 });
175 let service = layer.layer(ok_service());
176 let _ = service.oneshot(request()).await.unwrap();
177
178 assert_eq!(fired.load(Ordering::Relaxed), 0);
179 }
180
181 #[tokio::test]
182 async fn body_drop_fires_callback_with_status() {
183 let observed_status = Arc::new(std::sync::Mutex::new(None));
184 let observed_clone = observed_status.clone();
185
186 struct PendingBody;
188 impl http_body::Body for PendingBody {
189 type Data = Bytes;
190 type Error = std::convert::Infallible;
191 fn poll_frame(
192 self: std::pin::Pin<&mut Self>,
193 _cx: &mut std::task::Context<'_>,
194 ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>>
195 {
196 std::task::Poll::Pending
197 }
198 fn is_end_stream(&self) -> bool {
199 false
200 }
201 }
202
203 let pending_service = service_fn(|_req: Request<()>| async move {
204 Ok::<_, std::convert::Infallible>(
205 Response::builder()
206 .status(StatusCode::CREATED)
207 .body(PendingBody)
208 .unwrap(),
209 )
210 });
211
212 let layer = OnEarlyDropLayer::builder().on_body_drop(OnBodyDropFn::new(
213 move |_req: &Request<()>| {
214 let observed = observed_clone.clone();
215 move |parts: &http::response::Parts| {
216 let status = parts.status;
217 move || {
218 *observed.lock().unwrap() = Some(status);
219 }
220 }
221 },
222 ));
223 let service = layer.layer(pending_service);
224 let response = service.oneshot(request()).await.unwrap();
225 assert_eq!(response.status(), StatusCode::CREATED);
226 drop(response);
227
228 assert_eq!(
229 *observed_status.lock().unwrap(),
230 Some(StatusCode::CREATED),
231 "body-drop callback should observe the response status",
232 );
233 }
234
235 #[tokio::test]
236 async fn body_drop_suppressed_when_body_consumed() {
237 let fired = Arc::new(AtomicBool::new(false));
238 let fired_clone = fired.clone();
239
240 let layer = OnEarlyDropLayer::builder().on_body_drop(OnBodyDropFn::new(
241 move |_req: &Request<()>| {
242 let fired = fired_clone.clone();
243 move |_parts: &http::response::Parts| {
244 let fired = fired.clone();
245 move || {
246 fired.store(true, Ordering::Relaxed);
247 }
248 }
249 },
250 ));
251 let service = layer.layer(ok_service());
252 let response = service.oneshot(request()).await.unwrap();
253 let _body = response.into_body().collect().await.unwrap();
254
255 assert!(!fired.load(Ordering::Relaxed));
256 }
257
258 #[tokio::test]
259 async fn inner_error_does_not_fire() {
260 let fired = Arc::new(AtomicBool::new(false));
261 let fired_clone = fired.clone();
262
263 let err_service = service_fn(|_req: Request<()>| async move {
264 Err::<Response<Full<Bytes>>, _>(std::io::Error::other("boom"))
265 });
266
267 let layer = OnEarlyDropLayer::builder().on_future_drop(move |_req: &Request<()>| {
268 let fired = fired_clone.clone();
269 move || {
270 fired.store(true, Ordering::Relaxed);
271 }
272 });
273 let service = layer.layer(err_service);
274 let _ = service.oneshot(request()).await;
275
276 assert!(!fired.load(Ordering::Relaxed));
277 }
278
279 #[tokio::test]
280 async fn body_error_frame_does_not_fire() {
281 let fired = Arc::new(AtomicBool::new(false));
282 let fired_clone = fired.clone();
283
284 struct ErrBody {
286 yielded: bool,
287 }
288 impl http_body::Body for ErrBody {
289 type Data = Bytes;
290 type Error = std::io::Error;
291 fn poll_frame(
292 mut self: std::pin::Pin<&mut Self>,
293 _cx: &mut std::task::Context<'_>,
294 ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>>
295 {
296 if self.yielded {
297 std::task::Poll::Ready(None)
298 } else {
299 self.yielded = true;
300 std::task::Poll::Ready(Some(Err(std::io::Error::other("frame err"))))
301 }
302 }
303 fn is_end_stream(&self) -> bool {
304 false
305 }
306 }
307
308 let err_body_service = service_fn(|_req: Request<()>| async move {
309 Ok::<_, std::convert::Infallible>(
310 Response::builder()
311 .status(StatusCode::OK)
312 .body(ErrBody { yielded: false })
313 .unwrap(),
314 )
315 });
316
317 let layer = OnEarlyDropLayer::builder().on_body_drop(OnBodyDropFn::new(
318 move |_req: &Request<()>| {
319 let fired = fired_clone.clone();
320 move |_parts: &http::response::Parts| {
321 let fired = fired.clone();
322 move || {
323 fired.store(true, Ordering::Relaxed);
324 }
325 }
326 },
327 ));
328 let service = layer.layer(err_body_service);
329 let response = service.oneshot(request()).await.unwrap();
330 let mut body = response.into_body();
332 use http_body::Body as _;
333 let frame = std::future::poll_fn(|cx| std::pin::Pin::new(&mut body).poll_frame(cx)).await;
334 assert!(matches!(frame, Some(Err(_))));
335 drop(body);
336
337 assert!(
338 !fired.load(Ordering::Relaxed),
339 "body-level error must not be reported as a body drop",
340 );
341 }
342
343 #[allow(dead_code)]
346 fn static_property_hooks_without_debug() {
347 fn hook_without_debug<F>(f: F) -> F {
348 f
349 }
350 let _layer = OnEarlyDropLayer::builder()
351 .on_future_drop(hook_without_debug(|_req: &Request<()>| || {}))
352 .on_body_drop(OnBodyDropFn::new(hook_without_debug(
353 |_req: &Request<()>| |_parts: &http::response::Parts| || {},
354 )));
355 }
356
357 #[allow(dead_code)]
360 fn static_property_service_is_send_sync() {
361 fn assert_send<T: Send>(_: &T) {}
362 fn assert_sync<T: Sync>(_: &T) {}
363 fn assert_clone<T: Clone>(_: &T) {}
364
365 let layer = OnEarlyDropLayer::builder();
366 let service = layer.layer(ok_service());
367 assert_send(&service);
368 assert_sync(&service);
369 assert_clone(&service);
370 }
371
372 #[tokio::test]
373 async fn body_drop_suppressed_when_is_end_stream_at_construction() {
374 let fired = Arc::new(AtomicBool::new(false));
375 let fired_clone = fired.clone();
376
377 let empty_service = service_fn(|_req: Request<()>| async move {
380 Ok::<_, std::convert::Infallible>(
381 Response::builder()
382 .status(StatusCode::NO_CONTENT)
383 .body(http_body_util::Empty::<Bytes>::new())
384 .unwrap(),
385 )
386 });
387
388 let layer = OnEarlyDropLayer::builder().on_body_drop(OnBodyDropFn::new(
389 move |_req: &Request<()>| {
390 let fired = fired_clone.clone();
391 move |_parts: &http::response::Parts| {
392 let fired = fired.clone();
393 move || {
394 fired.store(true, Ordering::Relaxed);
395 }
396 }
397 },
398 ));
399 let service = layer.layer(empty_service);
400 let response = service.oneshot(request()).await.unwrap();
401 drop(response);
403
404 assert!(
405 !fired.load(Ordering::Relaxed),
406 "body already at end-of-stream at construction must not fire the callback",
407 );
408 }
409
410 #[tokio::test]
411 async fn body_drop_does_not_fire_on_inner_error() {
412 let fired = Arc::new(AtomicBool::new(false));
413 let fired_clone = fired.clone();
414
415 let err_service = service_fn(|_req: Request<()>| async move {
416 Err::<Response<Full<Bytes>>, _>(std::io::Error::other("boom"))
417 });
418
419 let layer = OnEarlyDropLayer::builder().on_body_drop(OnBodyDropFn::new(
420 move |_req: &Request<()>| {
421 let fired = fired_clone.clone();
422 move |_parts: &http::response::Parts| {
423 let fired = fired.clone();
424 move || {
425 fired.store(true, Ordering::Relaxed);
426 }
427 }
428 },
429 ));
430 let service = layer.layer(err_service);
431 let _ = service.oneshot(request()).await;
432
433 assert!(!fired.load(Ordering::Relaxed));
434 }
435
436 #[tokio::test]
437 async fn noop_slots_do_not_fire() {
438 let layer = OnEarlyDropLayer::builder();
442 let service = layer.layer(ok_service());
443 let response = service.oneshot(request()).await.unwrap();
444 drop(response);
446 }
449}