#[doc(hidden)]
pub mod re_export {
pub extern crate serde;
pub use super::*;
pub use crate::error::Error;
}
use std::{collections::HashMap, marker::PhantomData, pin::Pin, sync::Arc};
use futures::{
channel::{mpsc, oneshot},
future::{select, Either},
Future, Sink, SinkExt, Stream, StreamExt,
};
use crate::error::Error;
pub trait Rpc {
type Request;
type Response;
}
pub trait RpcServerStub<R: Rpc, I: RpcFrame, O: RpcFrame> {
fn make_response(
self: Arc<Self>,
req: I,
rsp_handler: ResponseHandler<O>,
) -> Pin<Box<dyn Future<Output = ()> + Send>>;
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct RequestId(pub u64);
impl std::fmt::Display for RequestId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[{:016X}]", self.0)
}
}
#[derive(Debug)]
pub struct ResponseHandler<F: RpcFrame>(mpsc::Sender<F>);
impl<F: RpcFrame> ResponseHandler<F> {
pub async fn response_with(mut self, rsp: F) {
self.0.send(rsp).await.expect("driver closed unexpectedly")
}
}
pub trait RpcFrame: Send + 'static {
type Data;
fn new(id: RequestId, data: Self::Data) -> Self;
fn get_id(&self) -> RequestId;
fn get_data(self) -> Self::Data;
}
impl<T: Send + 'static> RpcFrame for (RequestId, T) {
type Data = T;
fn new(id: RequestId, data: Self::Data) -> Self {
(id, data)
}
fn get_id(&self) -> RequestId {
self.0
}
fn get_data(self) -> Self::Data {
self.1
}
}
pub async fn serve<R, S, I, O, T, U>(
stub: impl Into<Arc<S>>,
mut recv: T,
mut send: U,
) -> Result<(), Error>
where
R: Rpc,
S: RpcServerStub<R, I, O>,
I: RpcFrame,
O: RpcFrame,
T: Stream<Item = I> + Unpin,
U: Sink<O, Error = Error> + Unpin,
{
let stub: Arc<S> = stub.into();
let (tx, mut rx) = mpsc::channel::<O>(128);
let mut fut = select(recv.next(), rx.next());
loop {
match fut.await {
Either::Left((Some(req), r)) => {
let stub = stub.clone();
tokio::spawn(stub.make_response(req, ResponseHandler(tx.clone())));
fut = select(recv.next(), r);
}
Either::Right((Some(rsp), r)) => {
send.send(rsp).await?;
fut = select(r, rx.next());
}
_ => {
break Ok(());
}
}
}
}
#[derive(Debug)]
pub struct RpcClient<'a, I: RpcFrame, O: RpcFrame>(
mpsc::Sender<(oneshot::Sender<Result<I, Error>>, O)>,
PhantomData<&'a ()>,
);
impl<I: RpcFrame, O: RpcFrame> RpcClient<'static, I, O> {
pub fn new<
T: Stream<Item = I> + Unpin + Send + 'static,
U: Sink<O, Error = Error> + Unpin + Send + 'static,
>(
recv: T,
send: U,
) -> Self {
let (d, r) = Self::new_with_driver(recv, send);
tokio::spawn(d);
r
}
}
impl<'a, I: RpcFrame, O: RpcFrame> RpcClient<'a, I, O> {
pub fn new_with_driver<T, U>(recv: T, send: U) -> (impl Future<Output = ()> + 'a, Self)
where
T: Stream<Item = I> + Unpin + 'a,
U: Sink<O, Error = Error> + Unpin + 'a,
{
async fn driver<'a, I, O, T, U>(
mut rx: mpsc::Receiver<(oneshot::Sender<Result<I, Error>>, O)>,
mut recv: T,
mut send: U,
) where
I: RpcFrame,
O: RpcFrame,
T: Stream<Item = I> + Unpin + 'a,
U: Sink<O, Error = Error> + Unpin + 'a,
{
let mut fut = select(rx.next(), recv.next());
let mut req_map = HashMap::with_capacity(128);
loop {
match fut.await {
Either::Left((Some((callback, req)), r)) => {
let id = req.get_id();
if let Err(e) = send.send(req).await {
callback
.send(Err(e))
.unwrap_or_else(|_| panic!("client closed unexpectedly"));
} else {
if req_map.insert(id, callback).is_some() {
panic!("request id is not unique")
}
}
fut = select(rx.next(), r);
}
Either::Right((Some(rsp), r)) => {
let id = rsp.get_id();
if let Some(callback) = req_map.remove(&id) {
callback
.send(Ok(rsp))
.unwrap_or_else(|_| panic!("client closed unexpectedly"));
} else {
warn!("Server responeded for nonexist request: {}", id);
}
fut = select(r, recv.next());
}
_ => {
break;
}
}
}
}
let (tx, rx) = mpsc::channel::<(oneshot::Sender<Result<I, Error>>, O)>(128);
(driver(rx, recv, send), Self(tx, PhantomData))
}
pub async fn make_request(&mut self, req: O) -> Result<I, Error> {
let (tx, rx) = oneshot::channel();
self.0
.send((tx, req))
.await
.expect("driver closed unexpectedly");
rx.await.expect("driver closed unexpectedly")
}
}
impl<'a, I: RpcFrame, O: RpcFrame> Clone for RpcClient<'a, I, O> {
#[inline]
fn clone(&self) -> Self {
Self(self.0.clone(), PhantomData)
}
}