scrappy_utils/
framed.rs

1//! Framed dispatcher service and related utilities
2#![allow(type_alias_bounds)]
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::{fmt, mem};
6
7use scrappy_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed};
8use scrappy_service::{IntoService, Service};
9use futures::{Future, FutureExt, Stream};
10use log::debug;
11
12use crate::mpsc;
13
14type Request<U> = <U as Decoder>::Item;
15type Response<U> = <U as Encoder>::Item;
16
17/// Framed transport errors
18pub enum DispatcherError<E, U: Encoder + Decoder> {
19    Service(E),
20    Encoder(<U as Encoder>::Error),
21    Decoder(<U as Decoder>::Error),
22}
23
24impl<E, U: Encoder + Decoder> From<E> for DispatcherError<E, U> {
25    fn from(err: E) -> Self {
26        DispatcherError::Service(err)
27    }
28}
29
30impl<E, U: Encoder + Decoder> fmt::Debug for DispatcherError<E, U>
31where
32    E: fmt::Debug,
33    <U as Encoder>::Error: fmt::Debug,
34    <U as Decoder>::Error: fmt::Debug,
35{
36    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
37        match *self {
38            DispatcherError::Service(ref e) => write!(fmt, "DispatcherError::Service({:?})", e),
39            DispatcherError::Encoder(ref e) => write!(fmt, "DispatcherError::Encoder({:?})", e),
40            DispatcherError::Decoder(ref e) => write!(fmt, "DispatcherError::Decoder({:?})", e),
41        }
42    }
43}
44
45impl<E, U: Encoder + Decoder> fmt::Display for DispatcherError<E, U>
46where
47    E: fmt::Display,
48    <U as Encoder>::Error: fmt::Debug,
49    <U as Decoder>::Error: fmt::Debug,
50{
51    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
52        match *self {
53            DispatcherError::Service(ref e) => write!(fmt, "{}", e),
54            DispatcherError::Encoder(ref e) => write!(fmt, "{:?}", e),
55            DispatcherError::Decoder(ref e) => write!(fmt, "{:?}", e),
56        }
57    }
58}
59
60pub enum Message<T> {
61    Item(T),
62    Close,
63}
64
65/// FramedTransport - is a future that reads frames from Framed object
66/// and pass then to the service.
67#[pin_project::pin_project]
68pub struct Dispatcher<S, T, U>
69where
70    S: Service<Request = Request<U>, Response = Response<U>>,
71    S::Error: 'static,
72    S::Future: 'static,
73    T: AsyncRead + AsyncWrite,
74    U: Encoder + Decoder,
75    <U as Encoder>::Item: 'static,
76    <U as Encoder>::Error: std::fmt::Debug,
77{
78    service: S,
79    state: State<S, U>,
80    framed: Framed<T, U>,
81    rx: mpsc::Receiver<Result<Message<<U as Encoder>::Item>, S::Error>>,
82    tx: mpsc::Sender<Result<Message<<U as Encoder>::Item>, S::Error>>,
83}
84
85enum State<S: Service, U: Encoder + Decoder> {
86    Processing,
87    Error(DispatcherError<S::Error, U>),
88    FramedError(DispatcherError<S::Error, U>),
89    FlushAndStop,
90    Stopping,
91}
92
93impl<S: Service, U: Encoder + Decoder> State<S, U> {
94    fn take_error(&mut self) -> DispatcherError<S::Error, U> {
95        match mem::replace(self, State::Processing) {
96            State::Error(err) => err,
97            _ => panic!(),
98        }
99    }
100
101    fn take_framed_error(&mut self) -> DispatcherError<S::Error, U> {
102        match mem::replace(self, State::Processing) {
103            State::FramedError(err) => err,
104            _ => panic!(),
105        }
106    }
107}
108
109impl<S, T, U> Dispatcher<S, T, U>
110where
111    S: Service<Request = Request<U>, Response = Response<U>>,
112    S::Error: 'static,
113    S::Future: 'static,
114    T: AsyncRead + AsyncWrite,
115    U: Decoder + Encoder,
116    <U as Encoder>::Item: 'static,
117    <U as Encoder>::Error: std::fmt::Debug,
118{
119    pub fn new<F: IntoService<S>>(framed: Framed<T, U>, service: F) -> Self {
120        let (tx, rx) = mpsc::channel();
121        Dispatcher {
122            framed,
123            rx,
124            tx,
125            service: service.into_service(),
126            state: State::Processing,
127        }
128    }
129
130    /// Construct new `Dispatcher` instance with customer `mpsc::Receiver`
131    pub fn with_rx<F: IntoService<S>>(
132        framed: Framed<T, U>,
133        service: F,
134        rx: mpsc::Receiver<Result<Message<<U as Encoder>::Item>, S::Error>>,
135    ) -> Self {
136        let tx = rx.sender();
137        Dispatcher {
138            framed,
139            rx,
140            tx,
141            service: service.into_service(),
142            state: State::Processing,
143        }
144    }
145
146    /// Get sink
147    pub fn get_sink(&self) -> mpsc::Sender<Result<Message<<U as Encoder>::Item>, S::Error>> {
148        self.tx.clone()
149    }
150
151    /// Get reference to a service wrapped by `Dispatcher` instance.
152    pub fn get_ref(&self) -> &S {
153        &self.service
154    }
155
156    /// Get mutable reference to a service wrapped by `Dispatcher` instance.
157    pub fn get_mut(&mut self) -> &mut S {
158        &mut self.service
159    }
160
161    /// Get reference to a framed instance wrapped by `Dispatcher`
162    /// instance.
163    pub fn get_framed(&self) -> &Framed<T, U> {
164        &self.framed
165    }
166
167    /// Get mutable reference to a framed instance wrapped by `Dispatcher` instance.
168    pub fn get_framed_mut(&mut self) -> &mut Framed<T, U> {
169        &mut self.framed
170    }
171
172    fn poll_read(&mut self, cx: &mut Context<'_>) -> bool
173    where
174        S: Service<Request = Request<U>, Response = Response<U>>,
175        S::Error: 'static,
176        S::Future: 'static,
177        T: AsyncRead + AsyncWrite,
178        U: Decoder + Encoder,
179        <U as Encoder>::Item: 'static,
180        <U as Encoder>::Error: std::fmt::Debug,
181    {
182        loop {
183            match self.service.poll_ready(cx) {
184                Poll::Ready(Ok(_)) => {
185                    let item = match self.framed.next_item(cx) {
186                        Poll::Ready(Some(Ok(el))) => el,
187                        Poll::Ready(Some(Err(err))) => {
188                            self.state = State::FramedError(DispatcherError::Decoder(err));
189                            return true;
190                        }
191                        Poll::Pending => return false,
192                        Poll::Ready(None) => {
193                            self.state = State::Stopping;
194                            return true;
195                        }
196                    };
197
198                    let tx = self.tx.clone();
199                    scrappy_rt::spawn(self.service.call(item).map(move |item| {
200                        let _ = tx.send(item.map(Message::Item));
201                    }));
202                }
203                Poll::Pending => return false,
204                Poll::Ready(Err(err)) => {
205                    self.state = State::Error(DispatcherError::Service(err));
206                    return true;
207                }
208            }
209        }
210    }
211
212    /// write to framed object
213    fn poll_write(&mut self, cx: &mut Context<'_>) -> bool
214    where
215        S: Service<Request = Request<U>, Response = Response<U>>,
216        S::Error: 'static,
217        S::Future: 'static,
218        T: AsyncRead + AsyncWrite,
219        U: Decoder + Encoder,
220        <U as Encoder>::Item: 'static,
221        <U as Encoder>::Error: std::fmt::Debug,
222    {
223        loop {
224            while !self.framed.is_write_buf_full() {
225                match Pin::new(&mut self.rx).poll_next(cx) {
226                    Poll::Ready(Some(Ok(Message::Item(msg)))) => {
227                        if let Err(err) = self.framed.write(msg) {
228                            self.state = State::FramedError(DispatcherError::Encoder(err));
229                            return true;
230                        }
231                    }
232                    Poll::Ready(Some(Ok(Message::Close))) => {
233                        self.state = State::FlushAndStop;
234                        return true;
235                    }
236                    Poll::Ready(Some(Err(err))) => {
237                        self.state = State::Error(DispatcherError::Service(err));
238                        return true;
239                    }
240                    Poll::Ready(None) | Poll::Pending => break,
241                }
242            }
243
244            if !self.framed.is_write_buf_empty() {
245                match self.framed.flush(cx) {
246                    Poll::Pending => break,
247                    Poll::Ready(Ok(_)) => (),
248                    Poll::Ready(Err(err)) => {
249                        debug!("Error sending data: {:?}", err);
250                        self.state = State::FramedError(DispatcherError::Encoder(err));
251                        return true;
252                    }
253                }
254            } else {
255                break;
256            }
257        }
258
259        false
260    }
261}
262
263impl<S, T, U> Future for Dispatcher<S, T, U>
264where
265    S: Service<Request = Request<U>, Response = Response<U>>,
266    S::Error: 'static,
267    S::Future: 'static,
268    T: AsyncRead + AsyncWrite,
269    U: Decoder + Encoder,
270    <U as Encoder>::Item: 'static,
271    <U as Encoder>::Error: std::fmt::Debug,
272    <U as Decoder>::Error: std::fmt::Debug,
273{
274    type Output = Result<(), DispatcherError<S::Error, U>>;
275
276    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
277        loop {
278            let this = self.as_mut().project();
279
280            return match this.state {
281                State::Processing => {
282                    if self.poll_read(cx) || self.poll_write(cx) {
283                        continue;
284                    } else {
285                        Poll::Pending
286                    }
287                }
288                State::Error(_) => {
289                    // flush write buffer
290                    if !self.framed.is_write_buf_empty() {
291                        if let Poll::Pending = self.framed.flush(cx) {
292                            return Poll::Pending;
293                        }
294                    }
295                    Poll::Ready(Err(self.state.take_error()))
296                }
297                State::FlushAndStop => {
298                    if !this.framed.is_write_buf_empty() {
299                        match this.framed.flush(cx) {
300                            Poll::Ready(Err(err)) => {
301                                debug!("Error sending data: {:?}", err);
302                                Poll::Ready(Ok(()))
303                            }
304                            Poll::Pending => Poll::Pending,
305                            Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
306                        }
307                    } else {
308                        Poll::Ready(Ok(()))
309                    }
310                }
311                State::FramedError(_) => Poll::Ready(Err(this.state.take_framed_error())),
312                State::Stopping => Poll::Ready(Ok(())),
313            };
314        }
315    }
316}