Skip to main content

tower_http/on_early_drop/
service.rs

1//! Service implementation for the on-early-drop middleware.
2
3use crate::on_early_drop::body::OnEarlyDropBody;
4use crate::on_early_drop::future::OnEarlyDropFuture;
5use crate::on_early_drop::traits::{OnBodyDrop, OnFutureDrop};
6use http::{Request, Response};
7use std::task::{Context, Poll};
8use tower_service::Service;
9
10/// [`Service`] produced by [`OnEarlyDropLayer`].
11///
12/// See the [module docs](super) for details and examples.
13///
14/// [`OnEarlyDropLayer`]: super::OnEarlyDropLayer
15pub struct OnEarlyDropService<S, OFD, OBD> {
16    pub(crate) inner: S,
17    pub(crate) on_future_drop: OFD,
18    pub(crate) on_body_drop: OBD,
19}
20
21impl<S, OFD, OBD> std::fmt::Debug for OnEarlyDropService<S, OFD, OBD>
22where
23    S: std::fmt::Debug,
24{
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("OnEarlyDropService")
27            .field("inner", &self.inner)
28            .field("on_future_drop", &format_args!(".."))
29            .field("on_body_drop", &format_args!(".."))
30            .finish()
31    }
32}
33
34impl<S, OFD, OBD> Clone for OnEarlyDropService<S, OFD, OBD>
35where
36    S: Clone,
37    OFD: Clone,
38    OBD: Clone,
39{
40    fn clone(&self) -> Self {
41        Self {
42            inner: self.inner.clone(),
43            on_future_drop: self.on_future_drop.clone(),
44            on_body_drop: self.on_body_drop.clone(),
45        }
46    }
47}
48
49impl<S, OFD, OBD> OnEarlyDropService<S, OFD, OBD> {
50    /// Construct a new service directly. Most uses go through
51    /// [`OnEarlyDropLayer`](super::OnEarlyDropLayer).
52    pub fn new(inner: S, on_future_drop: OFD, on_body_drop: OBD) -> Self {
53        Self {
54            inner,
55            on_future_drop,
56            on_body_drop,
57        }
58    }
59
60    define_inner_service_accessors!();
61}
62
63impl<S, OFD, OBD, ReqB, ResB> Service<Request<ReqB>> for OnEarlyDropService<S, OFD, OBD>
64where
65    S: Service<Request<ReqB>, Response = Response<ResB>>,
66    OFD: OnFutureDrop<ReqB>,
67    OBD: OnBodyDrop<ReqB> + Clone,
68    ResB: http_body::Body,
69{
70    type Response = Response<OnEarlyDropBody<ResB, OBD::Callback>>;
71    type Error = S::Error;
72    type Future = OnEarlyDropFuture<S::Future, OBD, ReqB, OFD::Callback, OBD::Callback>;
73
74    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75        self.inner.poll_ready(cx)
76    }
77
78    fn call(&mut self, req: Request<ReqB>) -> Self::Future {
79        let future_callback = self.on_future_drop.make(&req);
80        let intermediate = self.on_body_drop.make_at_call(&req);
81        let inner = self.inner.call(req);
82        OnEarlyDropFuture::new(
83            inner,
84            future_callback,
85            self.on_body_drop.clone(),
86            intermediate,
87        )
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use crate::on_early_drop::{OnBodyDropFn, OnEarlyDropLayer};
95    use bytes::Bytes;
96    use http::{Request, Response, StatusCode};
97    use http_body_util::{BodyExt, Full};
98    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
99    use std::sync::Arc;
100    use std::time::Duration;
101    use tokio::time::{sleep, timeout};
102    use tower::{service_fn, Layer, ServiceExt};
103
104    fn ok_service() -> impl Service<
105        Request<()>,
106        Response = Response<Full<Bytes>>,
107        Error = std::convert::Infallible,
108        Future = impl std::future::Future<
109            Output = Result<Response<Full<Bytes>>, std::convert::Infallible>,
110        > + Send,
111    > + Clone {
112        service_fn(|_req: Request<()>| async move {
113            Ok::<_, std::convert::Infallible>(
114                Response::builder()
115                    .status(StatusCode::OK)
116                    .body(Full::new(Bytes::from_static(b"hello")))
117                    .unwrap(),
118            )
119        })
120    }
121
122    fn request() -> Request<()> {
123        Request::builder().uri("http://example/").body(()).unwrap()
124    }
125
126    #[tokio::test]
127    async fn forwards_response() {
128        let layer = OnEarlyDropLayer::builder();
129        let service = layer.layer(ok_service());
130        let response = service.oneshot(request()).await.unwrap();
131        assert_eq!(response.status(), StatusCode::OK);
132        let body = response.into_body().collect().await.unwrap().to_bytes();
133        assert_eq!(body, "hello");
134    }
135
136    #[tokio::test]
137    async fn future_drop_fires_callback() {
138        let fired = Arc::new(AtomicUsize::new(0));
139        let fired_clone = fired.clone();
140
141        let slow_service = service_fn(|_req: Request<()>| async move {
142            sleep(Duration::from_secs(60)).await;
143            Ok::<_, std::convert::Infallible>(
144                Response::builder()
145                    .status(StatusCode::OK)
146                    .body(Full::new(Bytes::new()))
147                    .unwrap(),
148            )
149        });
150
151        let layer = OnEarlyDropLayer::builder().on_future_drop(move |_req: &Request<()>| {
152            let fired = fired_clone.clone();
153            move || {
154                fired.fetch_add(1, Ordering::Relaxed);
155            }
156        });
157        let service = layer.layer(slow_service);
158        let _ = timeout(Duration::from_millis(50), service.oneshot(request())).await;
159
160        sleep(Duration::from_millis(10)).await;
161        assert_eq!(fired.load(Ordering::Relaxed), 1);
162    }
163
164    #[tokio::test]
165    async fn future_drop_suppressed_on_completion() {
166        let fired = Arc::new(AtomicUsize::new(0));
167        let fired_clone = fired.clone();
168
169        let layer = OnEarlyDropLayer::builder().on_future_drop(move |_req: &Request<()>| {
170            let fired = fired_clone.clone();
171            move || {
172                fired.fetch_add(1, Ordering::Relaxed);
173            }
174        });
175        let service = layer.layer(ok_service());
176        let _ = service.oneshot(request()).await.unwrap();
177
178        assert_eq!(fired.load(Ordering::Relaxed), 0);
179    }
180
181    #[tokio::test]
182    async fn body_drop_fires_callback_with_status() {
183        let observed_status = Arc::new(std::sync::Mutex::new(None));
184        let observed_clone = observed_status.clone();
185
186        // Body that never reaches end-of-stream.
187        struct PendingBody;
188        impl http_body::Body for PendingBody {
189            type Data = Bytes;
190            type Error = std::convert::Infallible;
191            fn poll_frame(
192                self: std::pin::Pin<&mut Self>,
193                _cx: &mut std::task::Context<'_>,
194            ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>>
195            {
196                std::task::Poll::Pending
197            }
198            fn is_end_stream(&self) -> bool {
199                false
200            }
201        }
202
203        let pending_service = service_fn(|_req: Request<()>| async move {
204            Ok::<_, std::convert::Infallible>(
205                Response::builder()
206                    .status(StatusCode::CREATED)
207                    .body(PendingBody)
208                    .unwrap(),
209            )
210        });
211
212        let layer = OnEarlyDropLayer::builder().on_body_drop(OnBodyDropFn::new(
213            move |_req: &Request<()>| {
214                let observed = observed_clone.clone();
215                move |parts: &http::response::Parts| {
216                    let status = parts.status;
217                    move || {
218                        *observed.lock().unwrap() = Some(status);
219                    }
220                }
221            },
222        ));
223        let service = layer.layer(pending_service);
224        let response = service.oneshot(request()).await.unwrap();
225        assert_eq!(response.status(), StatusCode::CREATED);
226        drop(response);
227
228        assert_eq!(
229            *observed_status.lock().unwrap(),
230            Some(StatusCode::CREATED),
231            "body-drop callback should observe the response status",
232        );
233    }
234
235    #[tokio::test]
236    async fn body_drop_suppressed_when_body_consumed() {
237        let fired = Arc::new(AtomicBool::new(false));
238        let fired_clone = fired.clone();
239
240        let layer = OnEarlyDropLayer::builder().on_body_drop(OnBodyDropFn::new(
241            move |_req: &Request<()>| {
242                let fired = fired_clone.clone();
243                move |_parts: &http::response::Parts| {
244                    let fired = fired.clone();
245                    move || {
246                        fired.store(true, Ordering::Relaxed);
247                    }
248                }
249            },
250        ));
251        let service = layer.layer(ok_service());
252        let response = service.oneshot(request()).await.unwrap();
253        let _body = response.into_body().collect().await.unwrap();
254
255        assert!(!fired.load(Ordering::Relaxed));
256    }
257
258    #[tokio::test]
259    async fn inner_error_does_not_fire() {
260        let fired = Arc::new(AtomicBool::new(false));
261        let fired_clone = fired.clone();
262
263        let err_service = service_fn(|_req: Request<()>| async move {
264            Err::<Response<Full<Bytes>>, _>(std::io::Error::other("boom"))
265        });
266
267        let layer = OnEarlyDropLayer::builder().on_future_drop(move |_req: &Request<()>| {
268            let fired = fired_clone.clone();
269            move || {
270                fired.store(true, Ordering::Relaxed);
271            }
272        });
273        let service = layer.layer(err_service);
274        let _ = service.oneshot(request()).await;
275
276        assert!(!fired.load(Ordering::Relaxed));
277    }
278
279    #[tokio::test]
280    async fn body_error_frame_does_not_fire() {
281        let fired = Arc::new(AtomicBool::new(false));
282        let fired_clone = fired.clone();
283
284        // Body that returns Err once, then is dropped.
285        struct ErrBody {
286            yielded: bool,
287        }
288        impl http_body::Body for ErrBody {
289            type Data = Bytes;
290            type Error = std::io::Error;
291            fn poll_frame(
292                mut self: std::pin::Pin<&mut Self>,
293                _cx: &mut std::task::Context<'_>,
294            ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>>
295            {
296                if self.yielded {
297                    std::task::Poll::Ready(None)
298                } else {
299                    self.yielded = true;
300                    std::task::Poll::Ready(Some(Err(std::io::Error::other("frame err"))))
301                }
302            }
303            fn is_end_stream(&self) -> bool {
304                false
305            }
306        }
307
308        let err_body_service = service_fn(|_req: Request<()>| async move {
309            Ok::<_, std::convert::Infallible>(
310                Response::builder()
311                    .status(StatusCode::OK)
312                    .body(ErrBody { yielded: false })
313                    .unwrap(),
314            )
315        });
316
317        let layer = OnEarlyDropLayer::builder().on_body_drop(OnBodyDropFn::new(
318            move |_req: &Request<()>| {
319                let fired = fired_clone.clone();
320                move |_parts: &http::response::Parts| {
321                    let fired = fired.clone();
322                    move || {
323                        fired.store(true, Ordering::Relaxed);
324                    }
325                }
326            },
327        ));
328        let service = layer.layer(err_body_service);
329        let response = service.oneshot(request()).await.unwrap();
330        // Poll the body until it surfaces the Err frame, then drop.
331        let mut body = response.into_body();
332        use http_body::Body as _;
333        let frame = std::future::poll_fn(|cx| std::pin::Pin::new(&mut body).poll_frame(cx)).await;
334        assert!(matches!(frame, Some(Err(_))));
335        drop(body);
336
337        assert!(
338            !fired.load(Ordering::Relaxed),
339            "body-level error must not be reported as a body drop",
340        );
341    }
342
343    // The service's trait bounds must not require hook types to be `Debug`;
344    // non-Debug closures must produce a service that still compiles.
345    #[allow(dead_code)]
346    fn static_property_hooks_without_debug() {
347        fn hook_without_debug<F>(f: F) -> F {
348            f
349        }
350        let _layer = OnEarlyDropLayer::builder()
351            .on_future_drop(hook_without_debug(|_req: &Request<()>| || {}))
352            .on_body_drop(OnBodyDropFn::new(hook_without_debug(
353                |_req: &Request<()>| |_parts: &http::response::Parts| || {},
354            )));
355    }
356
357    // The service must be Send + Sync + Clone whenever the underlying
358    // hooks and inner service are.
359    #[allow(dead_code)]
360    fn static_property_service_is_send_sync() {
361        fn assert_send<T: Send>(_: &T) {}
362        fn assert_sync<T: Sync>(_: &T) {}
363        fn assert_clone<T: Clone>(_: &T) {}
364
365        let layer = OnEarlyDropLayer::builder();
366        let service = layer.layer(ok_service());
367        assert_send(&service);
368        assert_sync(&service);
369        assert_clone(&service);
370    }
371
372    #[tokio::test]
373    async fn body_drop_suppressed_when_is_end_stream_at_construction() {
374        let fired = Arc::new(AtomicBool::new(false));
375        let fired_clone = fired.clone();
376
377        // Body already at end-of-stream at construction (HEAD response,
378        // 204 No Content, etc).
379        let empty_service = service_fn(|_req: Request<()>| async move {
380            Ok::<_, std::convert::Infallible>(
381                Response::builder()
382                    .status(StatusCode::NO_CONTENT)
383                    .body(http_body_util::Empty::<Bytes>::new())
384                    .unwrap(),
385            )
386        });
387
388        let layer = OnEarlyDropLayer::builder().on_body_drop(OnBodyDropFn::new(
389            move |_req: &Request<()>| {
390                let fired = fired_clone.clone();
391                move |_parts: &http::response::Parts| {
392                    let fired = fired.clone();
393                    move || {
394                        fired.store(true, Ordering::Relaxed);
395                    }
396                }
397            },
398        ));
399        let service = layer.layer(empty_service);
400        let response = service.oneshot(request()).await.unwrap();
401        // Drop immediately without polling the body.
402        drop(response);
403
404        assert!(
405            !fired.load(Ordering::Relaxed),
406            "body already at end-of-stream at construction must not fire the callback",
407        );
408    }
409
410    #[tokio::test]
411    async fn body_drop_does_not_fire_on_inner_error() {
412        let fired = Arc::new(AtomicBool::new(false));
413        let fired_clone = fired.clone();
414
415        let err_service = service_fn(|_req: Request<()>| async move {
416            Err::<Response<Full<Bytes>>, _>(std::io::Error::other("boom"))
417        });
418
419        let layer = OnEarlyDropLayer::builder().on_body_drop(OnBodyDropFn::new(
420            move |_req: &Request<()>| {
421                let fired = fired_clone.clone();
422                move |_parts: &http::response::Parts| {
423                    let fired = fired.clone();
424                    move || {
425                        fired.store(true, Ordering::Relaxed);
426                    }
427                }
428            },
429        ));
430        let service = layer.layer(err_service);
431        let _ = service.oneshot(request()).await;
432
433        assert!(!fired.load(Ordering::Relaxed));
434    }
435
436    #[tokio::test]
437    async fn noop_slots_do_not_fire() {
438        // Builder with default () slots: no hook is installed. Even on a
439        // dropped pending future and a dropped incomplete body, nothing
440        // should be observable.
441        let layer = OnEarlyDropLayer::builder();
442        let service = layer.layer(ok_service());
443        let response = service.oneshot(request()).await.unwrap();
444        // Dropping without consuming the body.
445        drop(response);
446        // Nothing to assert; reaching here without panic confirms the
447        // no-op slots do not panic or invoke any user code.
448    }
449}