tower_scope_spawn/
service.rs1use std::future::Future;
6use std::pin::Pin;
7use std::task::Context;
8use std::task::Poll;
9
10use pin_project::pin_project;
11use pin_project::pinned_drop;
12use tower::Service;
13
14use scope_spawn::scope::Scope;
15
16#[derive(Debug)]
18pub struct WithScope<Req> {
19 pub request: Req,
21 pub scope: Scope,
23}
24
25#[derive(Clone, Debug)]
27pub struct ScopeSpawnService<S> {
28 inner: S,
29}
30
31impl<S, Req> Service<Req> for ScopeSpawnService<S>
32where
33 S: Service<WithScope<Req>>, Req: Send + 'static,
35 S::Error: 'static, {
37 type Response = S::Response;
38 type Error = S::Error;
39 type Future = ScopeFuture<S::Future>;
40
41 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
42 self.inner.poll_ready(cx)
43 }
44
45 fn call(&mut self, req: Req) -> Self::Future {
46 let scope = Scope::new();
47 let inner_req_with_scope = WithScope {
49 request: req,
50 scope: scope.clone(),
51 };
52 let inner_future = self.inner.call(inner_req_with_scope);
53 ScopeFuture::new(inner_future, scope) }
55}
56
57impl<S> ScopeSpawnService<S> {
58 pub fn new(inner: S) -> Self {
60 Self { inner }
61 }
62}
63
64#[pin_project(PinnedDrop)]
66#[derive(Clone, Debug)]
67pub struct ScopeFuture<F> {
68 #[pin]
69 inner: F,
70 scope: Scope,
71}
72
73impl<F> ScopeFuture<F> {
74 pub fn new(inner: F, scope: Scope) -> Self {
76 Self { inner, scope }
77 }
78
79 pub fn scope(&self) -> &Scope {
81 &self.scope
82 }
83}
84
85impl<F: Future> Future for ScopeFuture<F> {
86 type Output = F::Output;
87 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
88 self.project().inner.poll(cx)
89 }
90}
91
92#[pinned_drop]
93impl<F> PinnedDrop for ScopeFuture<F> {
94 fn drop(self: Pin<&mut Self>) {
95 self.project().scope.cancel();
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use std::time::Duration;
102
103 use axum::http::Request;
104 use bytes::Bytes;
105 use http_body_util::Empty;
106 use tokio::sync::oneshot::channel;
107 use tokio::time::sleep;
108
109 use super::*;
110
111 type TestReq = Request<Empty<Bytes>>;
112 type TestRes = ();
113
114 #[tokio::test]
115 async fn test_cancellation_on_drop() {
116 let (mut mock_service, mut mock_handle) = tower_test::mock::spawn_with(
118 |svc: tower_test::mock::Mock<WithScope<TestReq>, TestRes>| ScopeSpawnService::new(svc),
119 );
120
121 mock_handle.allow(1);
123
124 let req = Request::new(Empty::new()); tokio_test::assert_ready_ok!(mock_service.poll_ready());
127 let fut = mock_service.call(req);
128
129 let (with_scope_req, _send_response) = mock_handle.next_request().await.unwrap();
131 let _inner_req = with_scope_req.request; let _inner_service_scope = with_scope_req.scope; let scope_from_fut = fut.scope();
136
137 let (tx, rx) = channel::<()>();
139 scope_from_fut.spawn(async move {
140 let _guard = tx; sleep(Duration::from_millis(200)).await;
142 });
143
144 drop(fut);
146
147 tokio::select! {
150 resp = rx => assert!(resp.is_err()),
151 _ = sleep(Duration::from_millis(100)) => {
152 panic!("Task should have been cancelled!");
153 }
154 }
155 }
156
157 #[tokio::test]
158 #[should_panic]
159 async fn test_no_cancellation_on_no_drop() {
160 let (mut mock_service, mut mock_handle) = tower_test::mock::spawn_with(
162 |svc: tower_test::mock::Mock<WithScope<TestReq>, TestRes>| ScopeSpawnService::new(svc),
163 );
164
165 mock_handle.allow(1);
167
168 let req = Request::new(Empty::new()); tokio_test::assert_ready_ok!(mock_service.poll_ready());
171 let fut = mock_service.call(req);
172
173 let (with_scope_req, _send_response) = mock_handle.next_request().await.unwrap();
175 let _inner_req = with_scope_req.request; let _inner_service_scope = with_scope_req.scope; let scope_from_fut = fut.scope();
180
181 let (tx, rx) = channel::<()>();
183 scope_from_fut.spawn(async move {
184 sleep(Duration::from_millis(200)).await;
185 let _ = tx.send(());
187 });
188
189 tokio::select! {
195 resp = rx => assert!(resp.is_err()),
196 _ = sleep(Duration::from_millis(100)) => {
197 panic!("Task was not cancelled!");
198 }
199 }
200 }
201
202 #[tokio::test]
203 async fn test_cancellation_on_completion() {
204 let (mut mock_service, mut mock_handle) = tower_test::mock::spawn_with(
205 |svc: tower_test::mock::Mock<WithScope<TestReq>, TestRes>| ScopeSpawnService::new(svc),
206 );
207
208 mock_handle.allow(1);
209
210 let req = Request::new(Empty::new());
211 tokio_test::assert_ready_ok!(mock_service.poll_ready());
212 let fut = mock_service.call(req);
213
214 let (with_scope_req, send_response) = mock_handle.next_request().await.unwrap();
215 let scope_from_service = with_scope_req.scope;
216
217 let (tx, rx) = channel::<()>();
218 scope_from_service.spawn(async move {
219 let _guard = tx;
220 sleep(Duration::from_millis(200)).await;
221 });
222
223 send_response.send_response(());
225
226 let _res = fut.await;
228
229 tokio::select! {
231 resp = rx => assert!(resp.is_err()),
232 _ = sleep(Duration::from_millis(100)) => {
233 panic!("Task should have been cancelled after completion!");
234 }
235 }
236 }
237}