Skip to main content

tower_scope_spawn/
service.rs

1//!
2//! A Scoped Service
3//!
4
5use std::future::Future;
6use std::pin::Pin;
7use std::task::Context;
8use std::task::Poll;
9
10use pin_project::pin_project;
11use pin_project::pinned_drop;
12use tower::Service;
13
14use scope_spawn::scope::Scope;
15
16/// Request wrapper
17#[derive(Debug)]
18pub struct WithScope<Req> {
19    /// The original Request wrapped in this scope
20    pub request: Req,
21    /// The Scope of the wrapped request
22    pub scope: Scope,
23}
24
25/// A spawn scope service.
26#[derive(Clone, Debug)]
27pub struct ScopeSpawnService<S> {
28    inner: S,
29}
30
31impl<S, Req> Service<Req> for ScopeSpawnService<S>
32where
33    S: Service<WithScope<Req>>, // Inner service expects WithScope<Req>
34    Req: Send + 'static,
35    S::Error: 'static, // Ensure S::Error is static
36{
37    type Response = S::Response;
38    type Error = S::Error;
39    type Future = ScopeFuture<S::Future>;
40
41    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
42        self.inner.poll_ready(cx)
43    }
44
45    fn call(&mut self, req: Req) -> Self::Future {
46        let scope = Scope::new();
47        // The scope clone is passed to the inner request AND kept by ScopeFuture
48        let inner_req_with_scope = WithScope {
49            request: req,
50            scope: scope.clone(),
51        };
52        let inner_future = self.inner.call(inner_req_with_scope);
53        ScopeFuture::new(inner_future, scope) // ScopeFuture retains its own clone for cancellation
54    }
55}
56
57impl<S> ScopeSpawnService<S> {
58    /// Create a new ScopeSpawnService
59    pub fn new(inner: S) -> Self {
60        Self { inner }
61    }
62}
63
64/// A ScopeFuture. Useful for integrating Scope with [tower](https://docs.rs/tower/latest/tower/), [axum](https://docs.rs/axum/latest/axum), etc..
65#[pin_project(PinnedDrop)]
66#[derive(Clone, Debug)]
67pub struct ScopeFuture<F> {
68    #[pin]
69    inner: F,
70    scope: Scope,
71}
72
73impl<F> ScopeFuture<F> {
74    /// Create a new ScopeFuture
75    pub fn new(inner: F, scope: Scope) -> Self {
76        Self { inner, scope }
77    }
78
79    /// Borrow the scope
80    pub fn scope(&self) -> &Scope {
81        &self.scope
82    }
83}
84
85impl<F: Future> Future for ScopeFuture<F> {
86    type Output = F::Output;
87    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
88        self.project().inner.poll(cx)
89    }
90}
91
92#[pinned_drop]
93impl<F> PinnedDrop for ScopeFuture<F> {
94    fn drop(self: Pin<&mut Self>) {
95        self.project().scope.cancel();
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use std::time::Duration;
102
103    use axum::http::Request;
104    use bytes::Bytes;
105    use http_body_util::Empty;
106    use tokio::sync::oneshot::channel;
107    use tokio::time::sleep;
108
109    use super::*;
110
111    type TestReq = Request<Empty<Bytes>>;
112    type TestRes = ();
113
114    #[tokio::test]
115    async fn test_cancellation_on_drop() {
116        // Setup the mock service, which now expects WithScope<TestReq>
117        let (mut mock_service, mut mock_handle) = tower_test::mock::spawn_with(
118            |svc: tower_test::mock::Mock<WithScope<TestReq>, TestRes>| ScopeSpawnService::new(svc),
119        );
120
121        // We only expect one call
122        mock_handle.allow(1);
123
124        // Send a request and get the ScopeFuture
125        let req = Request::new(Empty::new()); // Original request type
126        tokio_test::assert_ready_ok!(mock_service.poll_ready());
127        let fut = mock_service.call(req);
128
129        // Mock service receives the request as WithScope<TestReq>
130        let (with_scope_req, _send_response) = mock_handle.next_request().await.unwrap();
131        let _inner_req = with_scope_req.request; // The original request
132        let _inner_service_scope = with_scope_req.scope; // The scope passed to the inner service
133
134        // The scope from ScopeFuture, which is responsible for cancellation upon fut drop
135        let scope_from_fut = fut.scope();
136
137        // Spawn a "background task" in the scope that lasts forever
138        let (tx, rx) = channel::<()>();
139        scope_from_fut.spawn(async move {
140            let _guard = tx; // Drops when this task is cancelled
141            sleep(Duration::from_millis(200)).await;
142        });
143
144        // Simulate a client timeout/disconnect by dropping the response future
145        drop(fut);
146
147        // Verify the background task was actually killed
148        // The receiver will get an error when the sender is dropped.
149        tokio::select! {
150            resp = rx => assert!(resp.is_err()),
151            _ = sleep(Duration::from_millis(100)) => {
152                panic!("Task should have been cancelled!");
153            }
154        }
155    }
156
157    #[tokio::test]
158    #[should_panic]
159    async fn test_no_cancellation_on_no_drop() {
160        // Setup the mock service, which now expects WithScope<TestReq>
161        let (mut mock_service, mut mock_handle) = tower_test::mock::spawn_with(
162            |svc: tower_test::mock::Mock<WithScope<TestReq>, TestRes>| ScopeSpawnService::new(svc),
163        );
164
165        // We only expect one call
166        mock_handle.allow(1);
167
168        // Send a request and get the ScopeFuture
169        let req = Request::new(Empty::new()); // Original request type
170        tokio_test::assert_ready_ok!(mock_service.poll_ready());
171        let fut = mock_service.call(req);
172
173        // Mock service receives the request as WithScope<TestReq>
174        let (with_scope_req, _send_response) = mock_handle.next_request().await.unwrap();
175        let _inner_req = with_scope_req.request; // The original request
176        let _inner_service_scope = with_scope_req.scope; // The scope passed to the inner service
177
178        // The scope from ScopeFuture, which is responsible for cancellation upon fut drop
179        let scope_from_fut = fut.scope();
180
181        // Spawn a "background task" in the scope that lasts forever
182        let (tx, rx) = channel::<()>();
183        scope_from_fut.spawn(async move {
184            sleep(Duration::from_millis(200)).await;
185            // We won't get here because our tokio::select is too impatient
186            let _ = tx.send(());
187        });
188
189        // Don't simulate a client timeout/disconnect by dropping the response future
190        // (i.e., don't drop 'fut')
191
192        // Verify the background task was not killed
193        // The receiver will get an error when the sender is dropped.
194        tokio::select! {
195            resp = rx => assert!(resp.is_err()),
196            _ = sleep(Duration::from_millis(100)) => {
197                panic!("Task was not cancelled!");
198            }
199        }
200    }
201
202    #[tokio::test]
203    async fn test_cancellation_on_completion() {
204        let (mut mock_service, mut mock_handle) = tower_test::mock::spawn_with(
205            |svc: tower_test::mock::Mock<WithScope<TestReq>, TestRes>| ScopeSpawnService::new(svc),
206        );
207
208        mock_handle.allow(1);
209
210        let req = Request::new(Empty::new());
211        tokio_test::assert_ready_ok!(mock_service.poll_ready());
212        let fut = mock_service.call(req);
213
214        let (with_scope_req, send_response) = mock_handle.next_request().await.unwrap();
215        let scope_from_service = with_scope_req.scope;
216
217        let (tx, rx) = channel::<()>();
218        scope_from_service.spawn(async move {
219            let _guard = tx;
220            sleep(Duration::from_millis(200)).await;
221        });
222
223        // Complete the request
224        send_response.send_response(());
225
226        // The future should resolve
227        let _res = fut.await;
228
229        // After the future is dropped (which it is here), the task should be cancelled
230        tokio::select! {
231            resp = rx => assert!(resp.is_err()),
232            _ = sleep(Duration::from_millis(100)) => {
233                panic!("Task should have been cancelled after completion!");
234            }
235        }
236    }
237}