tiny_rpc/
rpc.rs

1#[doc(hidden)]
2pub mod re_export {
3    pub extern crate serde;
4
5    pub use std::{
6        clone::Clone,
7        convert::Into,
8        marker::{PhantomData, Send, Sync},
9        stringify,
10        sync::Arc,
11        unreachable,
12    };
13
14    pub use futures::{
15        future::{BoxFuture, FutureExt},
16        stream::BoxStream,
17    };
18    pub use serde_derive::{Deserialize, Serialize};
19    pub use tracing::Instrument;
20
21    pub use crate::{
22        error::{Error, ProtocolError, Result},
23        io::{IdGenerator, RpcFrame, Transport},
24        rpc::{Client, ClientDriverHandle, Server},
25    };
26
27    #[derive(Serialize, Deserialize)]
28    pub enum Never {}
29}
30
31use std::collections::HashMap;
32
33use futures::{
34    channel::{mpsc, oneshot},
35    future::{select, BoxFuture, Either},
36    stream::BoxStream,
37    FutureExt, SinkExt, StreamExt,
38};
39use tracing::Instrument;
40
41use crate::{
42    error::{Error, ProtocolError, Result},
43    io::{Id, RpcFrame, Transport},
44};
45
46pub trait Server: Clone + Send + Sync + 'static {
47    fn make_response(self, req: RpcFrame) -> BoxFuture<'static, Result<RpcFrame>>;
48    fn serve(self, transport: Transport) -> BoxStream<'static, BoxFuture<'static, ()>> {
49        trace!("server start");
50
51        let (mut recv, mut send) = transport.split();
52        let (spawner_tx, spawner_rx) = mpsc::unbounded::<BoxFuture<'static, ()>>();
53        let spawner = spawner_tx.clone();
54        let serve_fut = async move {
55            let (tx, mut rx) = mpsc::unbounded::<RpcFrame>();
56            let mut fut = select(recv.next(), rx.next());
57            loop {
58                match fut.await {
59                    Either::Left((Some(req_frame), r)) => {
60                        let id = req_frame
61                            .as_ref()
62                            .map(|f| f.id().unwrap_or(Id::NULL))
63                            .unwrap_or(Id::NULL);
64                        trace!("accept request {}", id);
65
66                        let span = info_span!("server", %id);
67                        let tx = tx.clone();
68                        let this = self.clone();
69                        let rsp_fut = async move {
70                            let req_frame = req_frame?;
71                            let rsp_frame = this.make_response(req_frame).await?; // TODO send server fail
72                            tx.unbounded_send(rsp_frame).map_err(|_| Error::Driver)?;
73                            Ok::<_, Error>(())
74                        }
75                        .map(|r: Result<()>| log_error("responser", r))
76                        .instrument(span)
77                        .boxed();
78                        spawner
79                            .unbounded_send(rsp_fut)
80                            .map_err(|_| Error::Spawner)?;
81                        fut = select(recv.next(), r);
82                    }
83                    Either::Right((Some(rsp_frame), r)) => {
84                        let id = rsp_frame.id().expect("misformed outgoing frame");
85                        trace!("finish request {}", id);
86
87                        send.send(rsp_frame).await?;
88                        fut = select(r, rx.next());
89                    }
90                    _ => {
91                        // None is returned from client or remote. Stop driver.
92                        trace!("server stop");
93                        break Ok(());
94                    }
95                }
96            }
97        };
98        spawner_tx
99            .unbounded_send(Box::pin(
100                serve_fut
101                    .map(|r: Result<()>| log_error("server driver", r))
102                    .boxed(),
103            ))
104            .expect("infallible: unbounded mpsc");
105        Box::pin(spawner_rx)
106    }
107}
108
109#[derive(Clone)]
110pub struct ClientDriverHandle {
111    sender: mpsc::UnboundedSender<(RpcFrame, oneshot::Sender<Result<RpcFrame>>)>,
112}
113
114pub trait Client: Sized {
115    fn from_handle(handle: ClientDriverHandle) -> Self;
116    fn handle(&self) -> &ClientDriverHandle;
117    fn new(transport: Transport) -> (Self, BoxFuture<'static, ()>) {
118        let (mut recv, mut send) = transport.split();
119        let (tx, mut rx) = mpsc::unbounded::<(RpcFrame, oneshot::Sender<Result<RpcFrame>>)>();
120        let dispatcher_fut = async move {
121            trace!("dispatcher start");
122
123            let mut fut = select(recv.next(), rx.next());
124            let mut req_map = HashMap::<Id, oneshot::Sender<Result<RpcFrame>>>::new();
125            loop {
126                match fut.await {
127                    Either::Left((Some(rsp_frame), r)) => {
128                        // recv response from server
129
130                        let rsp_frame = rsp_frame?;
131                        let id = rsp_frame.id()?;
132                        trace!("finish request {}", id);
133
134                        if let Some(handler) = req_map.remove(&id) {
135                            if handler.send(Ok(rsp_frame)).is_err() {
136                                debug!("respond to canceled request: {}", id);
137                            }
138                        } else {
139                            break Err(ProtocolError::MartianResponse.into());
140                        }
141                        fut = select(recv.next(), r);
142                    }
143                    Either::Right((Some((req_frame, rsp_handler)), r)) => {
144                        let id = req_frame.id().expect("misformed outgoing frame");
145                        trace!("begin request {}", id);
146
147                        if req_map.insert(id, rsp_handler).is_some() {
148                            panic!("id duplication: {}", id);
149                        }
150                        send.send(req_frame).await?;
151                        fut = select(r, rx.next());
152                    }
153                    _ => {
154                        // None is returned from client or remote. Stop driver.
155                        trace!("dispatcher stop");
156                        break Ok(());
157                    }
158                }
159            }
160        };
161        let handle = ClientDriverHandle { sender: tx };
162        let client = Self::from_handle(handle);
163        (
164            client,
165            dispatcher_fut
166                .map(|r: Result<()>| log_error("client driver", r))
167                .boxed(),
168        )
169    }
170
171    fn make_request(&self, req: RpcFrame) -> BoxFuture<'static, Result<RpcFrame>> {
172        let sender = self.handle().sender.clone();
173        let fut = async move {
174            let (handler_tx, handler_rx) = oneshot::channel();
175            sender
176                .unbounded_send((req, handler_tx))
177                .map_err(|_| Error::Driver)?;
178            handler_rx.await.map_err(|_| Error::Driver)?
179        };
180        fut.boxed()
181    }
182}
183
184fn log_error(context: &'static str, r: Result<()>) {
185    if let Err(e) = r {
186        error!("{}: {}", context, e);
187    }
188}