tari_service_framework/tower/
service_ext.rs

1// Copyright 2019 The Tari Project
2//
3// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
4// following conditions are met:
5//
6// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
7// disclaimer.
8//
9// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the
10// following disclaimer in the documentation and/or other materials provided with the distribution.
11//
12// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
13// products derived from this software without specific prior written permission.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
16// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
18// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
19// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
20// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
21// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
22
23use std::{pin::Pin, task::Poll};
24
25use futures::{ready, task::Context, Future, FutureExt};
26use tower_service::Service;
27
28impl<T: ?Sized, TRequest> ServiceExt<TRequest> for T where T: Service<TRequest> {}
29
30pub trait ServiceExt<TRequest>: Service<TRequest> {
31    /// The service combinator combines calling `poll_ready` and `call` into a single call.
32    /// It returns a [ServiceCallReady](./struct.ServiceCallReady.html) future that
33    /// calls `poll_ready` on the given service, once the service is ready to
34    /// receive a request, `call` is called and the resulting future is polled.
35    fn call_ready(&mut self, req: TRequest) -> ServiceCallReady<'_, Self, TRequest>
36    where Self::Future: Unpin {
37        ServiceCallReady::new(self, req)
38    }
39}
40
41#[must_use = "futures do nothing unless you `.await` or poll them"]
42pub struct ServiceCallReady<'a, S, TRequest>
43where S: Service<TRequest> + ?Sized
44{
45    service: &'a mut S,
46    request: Option<TRequest>,
47    pending: Option<S::Future>,
48}
49
50impl<S: ?Sized + Service<TRequest> + Unpin, TRequest> Unpin for ServiceCallReady<'_, S, TRequest> {}
51
52impl<'a, S, TRequest> ServiceCallReady<'a, S, TRequest>
53where
54    S: Service<TRequest> + ?Sized,
55    S::Future: Unpin,
56{
57    fn new(service: &'a mut S, request: TRequest) -> Self {
58        Self {
59            service,
60            request: Some(request),
61            pending: None,
62        }
63    }
64}
65
66impl<S, TRequest> Future for ServiceCallReady<'_, S, TRequest>
67where
68    S: Service<TRequest> + ?Sized + Unpin,
69    S::Future: Unpin,
70{
71    type Output = Result<S::Response, S::Error>;
72
73    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
74        let this = &mut *self;
75        loop {
76            match this.pending {
77                Some(ref mut fut) => return fut.poll_unpin(cx),
78                None => {
79                    // Poll the service to check if it's ready. If so, make the call
80                    ready!(this.service.poll_ready(cx))?;
81                    let req = this.request.take().expect("the request cannot be made twice");
82                    this.pending = Some(this.service.call(req));
83                },
84            }
85        }
86    }
87}
88
89#[cfg(test)]
90mod test {
91    use std::sync::{
92        atomic::{AtomicBool, Ordering},
93        Arc,
94    };
95
96    use futures::future;
97    use futures_test::task::panic_context;
98    use tower::service_fn;
99
100    use super::*;
101
102    #[test]
103    fn service_ready() {
104        let mut double_service = service_fn(|req: u32| future::ok::<_, ()>(req + req));
105
106        let mut cx = panic_context();
107
108        match ServiceCallReady::new(&mut double_service, 157).poll_unpin(&mut cx) {
109            Poll::Ready(Ok(v)) => assert_eq!(v, 314),
110            _ => panic!("Expected future to be ready"),
111        }
112    }
113
114    #[test]
115    fn service_ready_later() {
116        struct ReadyLater {
117            call_count: u32,
118            flag: Arc<AtomicBool>,
119        }
120
121        impl Service<u32> for ReadyLater {
122            type Error = ();
123            type Future = future::Ready<Result<Self::Response, Self::Error>>;
124            type Response = u32;
125
126            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
127                if self.flag.load(Ordering::SeqCst) {
128                    Ok(()).into()
129                } else {
130                    Poll::Pending
131                }
132            }
133
134            fn call(&mut self, req: u32) -> Self::Future {
135                self.call_count += 1;
136                future::ok(req + req)
137            }
138        }
139
140        let mut cx = panic_context();
141        let ready_flag = Arc::new(AtomicBool::new(false));
142        let mut service = ReadyLater {
143            flag: ready_flag.clone(),
144            call_count: 0,
145        };
146
147        let mut fut = ServiceCallReady::new(&mut service, 157);
148
149        match fut.poll_unpin(&mut cx) {
150            Poll::Pending => {},
151            _ => panic!("Expected future to be pending"),
152        }
153
154        ready_flag.store(true, Ordering::SeqCst);
155
156        match fut.poll_unpin(&mut cx) {
157            Poll::Ready(Ok(v)) => assert_eq!(v, 314),
158            _ => panic!("Expected future to be ready"),
159        }
160
161        assert_eq!(service.call_count, 1);
162    }
163}