Skip to main content

tower_scope_spawn/
layer.rs

1//!
2//! A Scoped Layer
3//!
4
5use tower::Layer;
6
7use crate::service::ScopeSpawnService;
8
9/// Applies Spawn Scope to requests.
10#[derive(Copy, Clone, Debug, Default)]
11pub struct ScopeSpawnLayer {}
12
13impl ScopeSpawnLayer {
14    /// Create a ScopeSpawnLayer
15    pub fn new() -> Self {
16        ScopeSpawnLayer {}
17    }
18}
19
20impl<S> Layer<S> for ScopeSpawnLayer {
21    type Service = ScopeSpawnService<S>;
22
23    fn layer(&self, service: S) -> Self::Service {
24        ScopeSpawnService::new(service)
25    }
26}
27
28#[cfg(test)]
29mod tests {
30    use axum::http::Request;
31    use bytes::Bytes;
32    use http_body_util::Empty;
33    use tower_test::mock;
34
35    use super::*;
36    use crate::service::ScopeFuture;
37
38    // Not Need with Layer tests, but left as documentation
39    // type TestReq = Request<Empty<Bytes>>;
40    type TestRes = ();
41
42    #[tokio::test]
43    async fn test_cancellation_on_drop() {
44        // Setup the mock service, which now expects WithScope<TestReq>
45        let (mut mock_service, mut mock_handle) = mock::spawn_layer(ScopeSpawnLayer::new());
46
47        // We only expect one call
48        mock_handle.allow(1);
49
50        // Send a request and get the ScopeFuture
51        let req = Request::new(Empty::<Bytes>::new()); // Original request type
52        tokio_test::assert_ready_ok!(mock_service.poll_ready());
53        let fut: ScopeFuture<mock::future::ResponseFuture<TestRes>> = mock_service.call(req);
54
55        // Mock service receives the request as WithScope<TestReq>
56        let (with_scope_req, _send_response) = mock_handle.next_request().await.unwrap();
57        let _inner_req = with_scope_req.request; // The original request
58        let _inner_service_scope = with_scope_req.scope; // The scope passed to the inner service
59
60        // The scope from ScopeFuture, which is responsible for cancellation upon fut drop
61        let scope_from_fut = fut.scope();
62
63        // Spawn a "background task" in the scope that lasts forever
64        let (tx, rx) = tokio::sync::oneshot::channel::<()>();
65        scope_from_fut.spawn(async move {
66            let _guard = tx; // Drops when this task is cancelled
67            tokio::time::sleep(std::time::Duration::from_millis(200)).await;
68        });
69
70        // Simulate a client timeout/disconnect by dropping the response future
71        drop(fut);
72
73        // Verify the background task was actually killed
74        // The receiver will get an error when the sender is dropped.
75        tokio::select! {
76            resp = rx => assert!(resp.is_err()),
77            _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
78                panic!("Task should have been cancelled!");
79            }
80        }
81    }
82
83    #[tokio::test]
84    #[should_panic]
85    async fn test_no_cancellation_on_no_drop() {
86        // Setup the mock service, which now expects WithScope<TestReq>
87        let (mut mock_service, mut mock_handle) = mock::spawn_layer(ScopeSpawnLayer::new());
88
89        // We only expect one call
90        mock_handle.allow(1);
91
92        // Send a request and get the ScopeFuture
93        let req = Request::new(Empty::<Bytes>::new()); // Original request type
94        tokio_test::assert_ready_ok!(mock_service.poll_ready());
95        let fut: ScopeFuture<mock::future::ResponseFuture<TestRes>> = mock_service.call(req);
96
97        // Mock service receives the request as WithScope<TestReq>
98        let (with_scope_req, _send_response) = mock_handle.next_request().await.unwrap();
99        let _inner_req = with_scope_req.request; // The original request
100        let _inner_service_scope = with_scope_req.scope; // The scope passed to the inner service
101
102        // The scope from ScopeFuture, which is responsible for cancellation upon fut drop
103        let scope_from_fut = fut.scope();
104
105        // Spawn a "background task" in the scope that lasts forever
106        let (tx, rx) = tokio::sync::oneshot::channel::<()>();
107        scope_from_fut.spawn(async move {
108            tokio::time::sleep(std::time::Duration::from_millis(200)).await;
109            // We won't get here because our tokio::select is too impatient
110            let _ = tx.send(());
111        });
112
113        // Don't simulate a client timeout/disconnect by dropping the response future
114        // (i.e., don't drop 'fut')
115
116        // Verify the background task was not killed
117        // The receiver will get an error when the sender is dropped.
118        tokio::select! {
119            resp = rx => assert!(resp.is_err()),
120            _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
121                panic!("Task was not cancelled!");
122            }
123        }
124    }
125}