scrappy_utils/
order.rs

1use std::collections::VecDeque;
2use std::convert::Infallible;
3use std::fmt;
4use std::future::Future;
5use std::marker::PhantomData;
6use std::pin::Pin;
7use std::rc::Rc;
8use std::task::{Context, Poll};
9
10use scrappy_service::{IntoService, Service, Transform};
11use futures::future::{ok, Ready};
12
13use crate::oneshot;
14use crate::task::LocalWaker;
15
16struct Record<I, E> {
17    rx: oneshot::Receiver<Result<I, E>>,
18    tx: oneshot::Sender<Result<I, E>>,
19}
20
21/// Timeout error
22pub enum InOrderError<E> {
23    /// Service error
24    Service(E),
25    /// Service call dropped
26    Disconnected,
27}
28
29impl<E> From<E> for InOrderError<E> {
30    fn from(err: E) -> Self {
31        InOrderError::Service(err)
32    }
33}
34
35impl<E: fmt::Debug> fmt::Debug for InOrderError<E> {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        match self {
38            InOrderError::Service(e) => write!(f, "InOrderError::Service({:?})", e),
39            InOrderError::Disconnected => write!(f, "InOrderError::Disconnected"),
40        }
41    }
42}
43
44impl<E: fmt::Display> fmt::Display for InOrderError<E> {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        match self {
47            InOrderError::Service(e) => e.fmt(f),
48            InOrderError::Disconnected => write!(f, "InOrder service disconnected"),
49        }
50    }
51}
52
53/// InOrder - The service will yield responses as they become available,
54/// in the order that their originating requests were submitted to the service.
55pub struct InOrder<S> {
56    _t: PhantomData<S>,
57}
58
59impl<S> InOrder<S>
60where
61    S: Service,
62    S::Response: 'static,
63    S::Future: 'static,
64    S::Error: 'static,
65{
66    pub fn new() -> Self {
67        Self { _t: PhantomData }
68    }
69
70    pub fn service(service: S) -> InOrderService<S> {
71        InOrderService::new(service)
72    }
73}
74
75impl<S> Default for InOrder<S>
76where
77    S: Service,
78    S::Response: 'static,
79    S::Future: 'static,
80    S::Error: 'static,
81{
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87impl<S> Transform<S> for InOrder<S>
88where
89    S: Service,
90    S::Response: 'static,
91    S::Future: 'static,
92    S::Error: 'static,
93{
94    type Request = S::Request;
95    type Response = S::Response;
96    type Error = InOrderError<S::Error>;
97    type InitError = Infallible;
98    type Transform = InOrderService<S>;
99    type Future = Ready<Result<Self::Transform, Self::InitError>>;
100
101    fn new_transform(&self, service: S) -> Self::Future {
102        ok(InOrderService::new(service))
103    }
104}
105
106pub struct InOrderService<S: Service> {
107    service: S,
108    waker: Rc<LocalWaker>,
109    acks: VecDeque<Record<S::Response, S::Error>>,
110}
111
112impl<S> InOrderService<S>
113where
114    S: Service,
115    S::Response: 'static,
116    S::Future: 'static,
117    S::Error: 'static,
118{
119    pub fn new<U>(service: U) -> Self
120    where
121        U: IntoService<S>,
122    {
123        Self {
124            service: service.into_service(),
125            acks: VecDeque::new(),
126            waker: Rc::new(LocalWaker::new()),
127        }
128    }
129}
130
131impl<S> Service for InOrderService<S>
132where
133    S: Service,
134    S::Response: 'static,
135    S::Future: 'static,
136    S::Error: 'static,
137{
138    type Request = S::Request;
139    type Response = S::Response;
140    type Error = InOrderError<S::Error>;
141    type Future = InOrderServiceResponse<S>;
142
143    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
144        // poll_ready could be called from different task
145        self.waker.register(cx.waker());
146
147        // check acks
148        while !self.acks.is_empty() {
149            let rec = self.acks.front_mut().unwrap();
150            match Pin::new(&mut rec.rx).poll(cx) {
151                Poll::Ready(Ok(res)) => {
152                    let rec = self.acks.pop_front().unwrap();
153                    let _ = rec.tx.send(res);
154                }
155                Poll::Pending => break,
156                Poll::Ready(Err(oneshot::Canceled)) => {
157                    return Poll::Ready(Err(InOrderError::Disconnected))
158                }
159            }
160        }
161
162        // check nested service
163        if let Poll::Pending = self.service.poll_ready(cx).map_err(InOrderError::Service)? {
164            Poll::Pending
165        } else {
166            Poll::Ready(Ok(()))
167        }
168    }
169
170    fn call(&mut self, request: S::Request) -> Self::Future {
171        let (tx1, rx1) = oneshot::channel();
172        let (tx2, rx2) = oneshot::channel();
173        self.acks.push_back(Record { rx: rx1, tx: tx2 });
174
175        let waker = self.waker.clone();
176        let fut = self.service.call(request);
177        scrappy_rt::spawn(async move {
178            let res = fut.await;
179            waker.wake();
180            let _ = tx1.send(res);
181        });
182
183        InOrderServiceResponse { rx: rx2 }
184    }
185}
186
187#[doc(hidden)]
188pub struct InOrderServiceResponse<S: Service> {
189    rx: oneshot::Receiver<Result<S::Response, S::Error>>,
190}
191
192impl<S: Service> Future for InOrderServiceResponse<S> {
193    type Output = Result<S::Response, InOrderError<S::Error>>;
194
195    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
196        match Pin::new(&mut self.rx).poll(cx) {
197            Poll::Pending => Poll::Pending,
198            Poll::Ready(Ok(Ok(res))) => Poll::Ready(Ok(res)),
199            Poll::Ready(Ok(Err(e))) => Poll::Ready(Err(e.into())),
200            Poll::Ready(Err(_)) => Poll::Ready(Err(InOrderError::Disconnected)),
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207
208    use std::task::{Context, Poll};
209    use std::time::Duration;
210
211    use super::*;
212    use scrappy_service::Service;
213    use futures::channel::oneshot;
214    use futures::future::{lazy, poll_fn, FutureExt, LocalBoxFuture};
215
216    struct Srv;
217
218    impl Service for Srv {
219        type Request = oneshot::Receiver<usize>;
220        type Response = usize;
221        type Error = ();
222        type Future = LocalBoxFuture<'static, Result<usize, ()>>;
223
224        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
225            Poll::Ready(Ok(()))
226        }
227
228        fn call(&mut self, req: oneshot::Receiver<usize>) -> Self::Future {
229            req.map(|res| res.map_err(|_| ())).boxed_local()
230        }
231    }
232
233    #[scrappy_rt::test]
234    async fn test_inorder() {
235        let (tx1, rx1) = oneshot::channel();
236        let (tx2, rx2) = oneshot::channel();
237        let (tx3, rx3) = oneshot::channel();
238        let (tx_stop, rx_stop) = oneshot::channel();
239
240        let h = std::thread::spawn(move || {
241            let rx1 = rx1;
242            let rx2 = rx2;
243            let rx3 = rx3;
244            let tx_stop = tx_stop;
245            let _ = scrappy_rt::System::new("test").block_on(async {
246                let mut srv = InOrderService::new(Srv);
247
248                let _ = lazy(|cx| srv.poll_ready(cx)).await;
249                let res1 = srv.call(rx1);
250                let res2 = srv.call(rx2);
251                let res3 = srv.call(rx3);
252
253                scrappy_rt::spawn(async move {
254                    let _ = poll_fn(|cx| {
255                        let _ = srv.poll_ready(cx);
256                        Poll::<()>::Pending
257                    })
258                    .await;
259                });
260
261                assert_eq!(res1.await.unwrap(), 1);
262                assert_eq!(res2.await.unwrap(), 2);
263                assert_eq!(res3.await.unwrap(), 3);
264
265                let _ = tx_stop.send(());
266                scrappy_rt::System::current().stop();
267            });
268        });
269
270        let _ = tx3.send(3);
271        std::thread::sleep(Duration::from_millis(50));
272        let _ = tx2.send(2);
273        let _ = tx1.send(1);
274
275        let _ = rx_stop.await;
276        let _ = h.join();
277    }
278}