1use std::collections::VecDeque;
2use std::convert::Infallible;
3use std::fmt;
4use std::future::Future;
5use std::marker::PhantomData;
6use std::pin::Pin;
7use std::rc::Rc;
8use std::task::{Context, Poll};
9
10use scrappy_service::{IntoService, Service, Transform};
11use futures::future::{ok, Ready};
12
13use crate::oneshot;
14use crate::task::LocalWaker;
15
16struct Record<I, E> {
17 rx: oneshot::Receiver<Result<I, E>>,
18 tx: oneshot::Sender<Result<I, E>>,
19}
20
21pub enum InOrderError<E> {
23 Service(E),
25 Disconnected,
27}
28
29impl<E> From<E> for InOrderError<E> {
30 fn from(err: E) -> Self {
31 InOrderError::Service(err)
32 }
33}
34
35impl<E: fmt::Debug> fmt::Debug for InOrderError<E> {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 match self {
38 InOrderError::Service(e) => write!(f, "InOrderError::Service({:?})", e),
39 InOrderError::Disconnected => write!(f, "InOrderError::Disconnected"),
40 }
41 }
42}
43
44impl<E: fmt::Display> fmt::Display for InOrderError<E> {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 match self {
47 InOrderError::Service(e) => e.fmt(f),
48 InOrderError::Disconnected => write!(f, "InOrder service disconnected"),
49 }
50 }
51}
52
53pub struct InOrder<S> {
56 _t: PhantomData<S>,
57}
58
59impl<S> InOrder<S>
60where
61 S: Service,
62 S::Response: 'static,
63 S::Future: 'static,
64 S::Error: 'static,
65{
66 pub fn new() -> Self {
67 Self { _t: PhantomData }
68 }
69
70 pub fn service(service: S) -> InOrderService<S> {
71 InOrderService::new(service)
72 }
73}
74
75impl<S> Default for InOrder<S>
76where
77 S: Service,
78 S::Response: 'static,
79 S::Future: 'static,
80 S::Error: 'static,
81{
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl<S> Transform<S> for InOrder<S>
88where
89 S: Service,
90 S::Response: 'static,
91 S::Future: 'static,
92 S::Error: 'static,
93{
94 type Request = S::Request;
95 type Response = S::Response;
96 type Error = InOrderError<S::Error>;
97 type InitError = Infallible;
98 type Transform = InOrderService<S>;
99 type Future = Ready<Result<Self::Transform, Self::InitError>>;
100
101 fn new_transform(&self, service: S) -> Self::Future {
102 ok(InOrderService::new(service))
103 }
104}
105
106pub struct InOrderService<S: Service> {
107 service: S,
108 waker: Rc<LocalWaker>,
109 acks: VecDeque<Record<S::Response, S::Error>>,
110}
111
112impl<S> InOrderService<S>
113where
114 S: Service,
115 S::Response: 'static,
116 S::Future: 'static,
117 S::Error: 'static,
118{
119 pub fn new<U>(service: U) -> Self
120 where
121 U: IntoService<S>,
122 {
123 Self {
124 service: service.into_service(),
125 acks: VecDeque::new(),
126 waker: Rc::new(LocalWaker::new()),
127 }
128 }
129}
130
131impl<S> Service for InOrderService<S>
132where
133 S: Service,
134 S::Response: 'static,
135 S::Future: 'static,
136 S::Error: 'static,
137{
138 type Request = S::Request;
139 type Response = S::Response;
140 type Error = InOrderError<S::Error>;
141 type Future = InOrderServiceResponse<S>;
142
143 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
144 self.waker.register(cx.waker());
146
147 while !self.acks.is_empty() {
149 let rec = self.acks.front_mut().unwrap();
150 match Pin::new(&mut rec.rx).poll(cx) {
151 Poll::Ready(Ok(res)) => {
152 let rec = self.acks.pop_front().unwrap();
153 let _ = rec.tx.send(res);
154 }
155 Poll::Pending => break,
156 Poll::Ready(Err(oneshot::Canceled)) => {
157 return Poll::Ready(Err(InOrderError::Disconnected))
158 }
159 }
160 }
161
162 if let Poll::Pending = self.service.poll_ready(cx).map_err(InOrderError::Service)? {
164 Poll::Pending
165 } else {
166 Poll::Ready(Ok(()))
167 }
168 }
169
170 fn call(&mut self, request: S::Request) -> Self::Future {
171 let (tx1, rx1) = oneshot::channel();
172 let (tx2, rx2) = oneshot::channel();
173 self.acks.push_back(Record { rx: rx1, tx: tx2 });
174
175 let waker = self.waker.clone();
176 let fut = self.service.call(request);
177 scrappy_rt::spawn(async move {
178 let res = fut.await;
179 waker.wake();
180 let _ = tx1.send(res);
181 });
182
183 InOrderServiceResponse { rx: rx2 }
184 }
185}
186
187#[doc(hidden)]
188pub struct InOrderServiceResponse<S: Service> {
189 rx: oneshot::Receiver<Result<S::Response, S::Error>>,
190}
191
192impl<S: Service> Future for InOrderServiceResponse<S> {
193 type Output = Result<S::Response, InOrderError<S::Error>>;
194
195 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
196 match Pin::new(&mut self.rx).poll(cx) {
197 Poll::Pending => Poll::Pending,
198 Poll::Ready(Ok(Ok(res))) => Poll::Ready(Ok(res)),
199 Poll::Ready(Ok(Err(e))) => Poll::Ready(Err(e.into())),
200 Poll::Ready(Err(_)) => Poll::Ready(Err(InOrderError::Disconnected)),
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207
208 use std::task::{Context, Poll};
209 use std::time::Duration;
210
211 use super::*;
212 use scrappy_service::Service;
213 use futures::channel::oneshot;
214 use futures::future::{lazy, poll_fn, FutureExt, LocalBoxFuture};
215
216 struct Srv;
217
218 impl Service for Srv {
219 type Request = oneshot::Receiver<usize>;
220 type Response = usize;
221 type Error = ();
222 type Future = LocalBoxFuture<'static, Result<usize, ()>>;
223
224 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
225 Poll::Ready(Ok(()))
226 }
227
228 fn call(&mut self, req: oneshot::Receiver<usize>) -> Self::Future {
229 req.map(|res| res.map_err(|_| ())).boxed_local()
230 }
231 }
232
233 #[scrappy_rt::test]
234 async fn test_inorder() {
235 let (tx1, rx1) = oneshot::channel();
236 let (tx2, rx2) = oneshot::channel();
237 let (tx3, rx3) = oneshot::channel();
238 let (tx_stop, rx_stop) = oneshot::channel();
239
240 let h = std::thread::spawn(move || {
241 let rx1 = rx1;
242 let rx2 = rx2;
243 let rx3 = rx3;
244 let tx_stop = tx_stop;
245 let _ = scrappy_rt::System::new("test").block_on(async {
246 let mut srv = InOrderService::new(Srv);
247
248 let _ = lazy(|cx| srv.poll_ready(cx)).await;
249 let res1 = srv.call(rx1);
250 let res2 = srv.call(rx2);
251 let res3 = srv.call(rx3);
252
253 scrappy_rt::spawn(async move {
254 let _ = poll_fn(|cx| {
255 let _ = srv.poll_ready(cx);
256 Poll::<()>::Pending
257 })
258 .await;
259 });
260
261 assert_eq!(res1.await.unwrap(), 1);
262 assert_eq!(res2.await.unwrap(), 2);
263 assert_eq!(res3.await.unwrap(), 3);
264
265 let _ = tx_stop.send(());
266 scrappy_rt::System::current().stop();
267 });
268 });
269
270 let _ = tx3.send(3);
271 std::thread::sleep(Duration::from_millis(50));
272 let _ = tx2.send(2);
273 let _ = tx1.send(1);
274
275 let _ = rx_stop.await;
276 let _ = h.join();
277 }
278}