1#![allow(type_alias_bounds)]
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::{fmt, mem};
6
7use scrappy_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed};
8use scrappy_service::{IntoService, Service};
9use futures::{Future, FutureExt, Stream};
10use log::debug;
11
12use crate::mpsc;
13
14type Request<U> = <U as Decoder>::Item;
15type Response<U> = <U as Encoder>::Item;
16
17pub enum DispatcherError<E, U: Encoder + Decoder> {
19 Service(E),
20 Encoder(<U as Encoder>::Error),
21 Decoder(<U as Decoder>::Error),
22}
23
24impl<E, U: Encoder + Decoder> From<E> for DispatcherError<E, U> {
25 fn from(err: E) -> Self {
26 DispatcherError::Service(err)
27 }
28}
29
30impl<E, U: Encoder + Decoder> fmt::Debug for DispatcherError<E, U>
31where
32 E: fmt::Debug,
33 <U as Encoder>::Error: fmt::Debug,
34 <U as Decoder>::Error: fmt::Debug,
35{
36 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
37 match *self {
38 DispatcherError::Service(ref e) => write!(fmt, "DispatcherError::Service({:?})", e),
39 DispatcherError::Encoder(ref e) => write!(fmt, "DispatcherError::Encoder({:?})", e),
40 DispatcherError::Decoder(ref e) => write!(fmt, "DispatcherError::Decoder({:?})", e),
41 }
42 }
43}
44
45impl<E, U: Encoder + Decoder> fmt::Display for DispatcherError<E, U>
46where
47 E: fmt::Display,
48 <U as Encoder>::Error: fmt::Debug,
49 <U as Decoder>::Error: fmt::Debug,
50{
51 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
52 match *self {
53 DispatcherError::Service(ref e) => write!(fmt, "{}", e),
54 DispatcherError::Encoder(ref e) => write!(fmt, "{:?}", e),
55 DispatcherError::Decoder(ref e) => write!(fmt, "{:?}", e),
56 }
57 }
58}
59
60pub enum Message<T> {
61 Item(T),
62 Close,
63}
64
65#[pin_project::pin_project]
68pub struct Dispatcher<S, T, U>
69where
70 S: Service<Request = Request<U>, Response = Response<U>>,
71 S::Error: 'static,
72 S::Future: 'static,
73 T: AsyncRead + AsyncWrite,
74 U: Encoder + Decoder,
75 <U as Encoder>::Item: 'static,
76 <U as Encoder>::Error: std::fmt::Debug,
77{
78 service: S,
79 state: State<S, U>,
80 framed: Framed<T, U>,
81 rx: mpsc::Receiver<Result<Message<<U as Encoder>::Item>, S::Error>>,
82 tx: mpsc::Sender<Result<Message<<U as Encoder>::Item>, S::Error>>,
83}
84
85enum State<S: Service, U: Encoder + Decoder> {
86 Processing,
87 Error(DispatcherError<S::Error, U>),
88 FramedError(DispatcherError<S::Error, U>),
89 FlushAndStop,
90 Stopping,
91}
92
93impl<S: Service, U: Encoder + Decoder> State<S, U> {
94 fn take_error(&mut self) -> DispatcherError<S::Error, U> {
95 match mem::replace(self, State::Processing) {
96 State::Error(err) => err,
97 _ => panic!(),
98 }
99 }
100
101 fn take_framed_error(&mut self) -> DispatcherError<S::Error, U> {
102 match mem::replace(self, State::Processing) {
103 State::FramedError(err) => err,
104 _ => panic!(),
105 }
106 }
107}
108
109impl<S, T, U> Dispatcher<S, T, U>
110where
111 S: Service<Request = Request<U>, Response = Response<U>>,
112 S::Error: 'static,
113 S::Future: 'static,
114 T: AsyncRead + AsyncWrite,
115 U: Decoder + Encoder,
116 <U as Encoder>::Item: 'static,
117 <U as Encoder>::Error: std::fmt::Debug,
118{
119 pub fn new<F: IntoService<S>>(framed: Framed<T, U>, service: F) -> Self {
120 let (tx, rx) = mpsc::channel();
121 Dispatcher {
122 framed,
123 rx,
124 tx,
125 service: service.into_service(),
126 state: State::Processing,
127 }
128 }
129
130 pub fn with_rx<F: IntoService<S>>(
132 framed: Framed<T, U>,
133 service: F,
134 rx: mpsc::Receiver<Result<Message<<U as Encoder>::Item>, S::Error>>,
135 ) -> Self {
136 let tx = rx.sender();
137 Dispatcher {
138 framed,
139 rx,
140 tx,
141 service: service.into_service(),
142 state: State::Processing,
143 }
144 }
145
146 pub fn get_sink(&self) -> mpsc::Sender<Result<Message<<U as Encoder>::Item>, S::Error>> {
148 self.tx.clone()
149 }
150
151 pub fn get_ref(&self) -> &S {
153 &self.service
154 }
155
156 pub fn get_mut(&mut self) -> &mut S {
158 &mut self.service
159 }
160
161 pub fn get_framed(&self) -> &Framed<T, U> {
164 &self.framed
165 }
166
167 pub fn get_framed_mut(&mut self) -> &mut Framed<T, U> {
169 &mut self.framed
170 }
171
172 fn poll_read(&mut self, cx: &mut Context<'_>) -> bool
173 where
174 S: Service<Request = Request<U>, Response = Response<U>>,
175 S::Error: 'static,
176 S::Future: 'static,
177 T: AsyncRead + AsyncWrite,
178 U: Decoder + Encoder,
179 <U as Encoder>::Item: 'static,
180 <U as Encoder>::Error: std::fmt::Debug,
181 {
182 loop {
183 match self.service.poll_ready(cx) {
184 Poll::Ready(Ok(_)) => {
185 let item = match self.framed.next_item(cx) {
186 Poll::Ready(Some(Ok(el))) => el,
187 Poll::Ready(Some(Err(err))) => {
188 self.state = State::FramedError(DispatcherError::Decoder(err));
189 return true;
190 }
191 Poll::Pending => return false,
192 Poll::Ready(None) => {
193 self.state = State::Stopping;
194 return true;
195 }
196 };
197
198 let tx = self.tx.clone();
199 scrappy_rt::spawn(self.service.call(item).map(move |item| {
200 let _ = tx.send(item.map(Message::Item));
201 }));
202 }
203 Poll::Pending => return false,
204 Poll::Ready(Err(err)) => {
205 self.state = State::Error(DispatcherError::Service(err));
206 return true;
207 }
208 }
209 }
210 }
211
212 fn poll_write(&mut self, cx: &mut Context<'_>) -> bool
214 where
215 S: Service<Request = Request<U>, Response = Response<U>>,
216 S::Error: 'static,
217 S::Future: 'static,
218 T: AsyncRead + AsyncWrite,
219 U: Decoder + Encoder,
220 <U as Encoder>::Item: 'static,
221 <U as Encoder>::Error: std::fmt::Debug,
222 {
223 loop {
224 while !self.framed.is_write_buf_full() {
225 match Pin::new(&mut self.rx).poll_next(cx) {
226 Poll::Ready(Some(Ok(Message::Item(msg)))) => {
227 if let Err(err) = self.framed.write(msg) {
228 self.state = State::FramedError(DispatcherError::Encoder(err));
229 return true;
230 }
231 }
232 Poll::Ready(Some(Ok(Message::Close))) => {
233 self.state = State::FlushAndStop;
234 return true;
235 }
236 Poll::Ready(Some(Err(err))) => {
237 self.state = State::Error(DispatcherError::Service(err));
238 return true;
239 }
240 Poll::Ready(None) | Poll::Pending => break,
241 }
242 }
243
244 if !self.framed.is_write_buf_empty() {
245 match self.framed.flush(cx) {
246 Poll::Pending => break,
247 Poll::Ready(Ok(_)) => (),
248 Poll::Ready(Err(err)) => {
249 debug!("Error sending data: {:?}", err);
250 self.state = State::FramedError(DispatcherError::Encoder(err));
251 return true;
252 }
253 }
254 } else {
255 break;
256 }
257 }
258
259 false
260 }
261}
262
263impl<S, T, U> Future for Dispatcher<S, T, U>
264where
265 S: Service<Request = Request<U>, Response = Response<U>>,
266 S::Error: 'static,
267 S::Future: 'static,
268 T: AsyncRead + AsyncWrite,
269 U: Decoder + Encoder,
270 <U as Encoder>::Item: 'static,
271 <U as Encoder>::Error: std::fmt::Debug,
272 <U as Decoder>::Error: std::fmt::Debug,
273{
274 type Output = Result<(), DispatcherError<S::Error, U>>;
275
276 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
277 loop {
278 let this = self.as_mut().project();
279
280 return match this.state {
281 State::Processing => {
282 if self.poll_read(cx) || self.poll_write(cx) {
283 continue;
284 } else {
285 Poll::Pending
286 }
287 }
288 State::Error(_) => {
289 if !self.framed.is_write_buf_empty() {
291 if let Poll::Pending = self.framed.flush(cx) {
292 return Poll::Pending;
293 }
294 }
295 Poll::Ready(Err(self.state.take_error()))
296 }
297 State::FlushAndStop => {
298 if !this.framed.is_write_buf_empty() {
299 match this.framed.flush(cx) {
300 Poll::Ready(Err(err)) => {
301 debug!("Error sending data: {:?}", err);
302 Poll::Ready(Ok(()))
303 }
304 Poll::Pending => Poll::Pending,
305 Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
306 }
307 } else {
308 Poll::Ready(Ok(()))
309 }
310 }
311 State::FramedError(_) => Poll::Ready(Err(this.state.take_framed_error())),
312 State::Stopping => Poll::Ready(Ok(())),
313 };
314 }
315 }
316}