wrpc_transport/frame/conn/
server.rs

1use core::fmt::{Debug, Display};
2use core::marker::PhantomData;
3
4use std::collections::{hash_map, HashMap};
5use std::sync::Arc;
6
7use anyhow::bail;
8use futures::{Stream, StreamExt as _};
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
10use tokio::sync::{mpsc, Mutex};
11use tokio_stream::wrappers::ReceiverStream;
12use tracing::{instrument, trace};
13use wasm_tokio::AsyncReadCore as _;
14
15use crate::frame::conn::Accept;
16use crate::frame::{Conn, ConnHandler, Incoming, Outgoing};
17use crate::Serve;
18
19/// wRPC server for framed transports
20pub struct Server<C, I, O, H = ()> {
21    handlers: Mutex<HashMap<String, HashMap<String, mpsc::Sender<(C, I, O)>>>>,
22    conn_handler: PhantomData<H>,
23}
24
25impl<C, I, O, H> Server<C, I, O, H> {
26    /// Constructs a new [Server]
27    pub fn new() -> Self {
28        Self {
29            handlers: Mutex::default(),
30            conn_handler: PhantomData,
31        }
32    }
33}
34
35impl<C, I, O> Default for Server<C, I, O> {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41/// Error returned by [`Server::accept`]
42pub enum AcceptError<C, I, O> {
43    /// I/O error
44    IO(std::io::Error),
45    /// Protocol version is not supported
46    UnsupportedVersion(u8),
47    /// Function was not handled
48    UnhandledFunction {
49        /// Instance
50        instance: String,
51        /// Function name
52        name: String,
53    },
54    /// Message sending failed
55    Send(mpsc::error::SendError<(C, I, O)>),
56}
57
58impl<C, I, O> Debug for AcceptError<C, I, O> {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        match self {
61            AcceptError::IO(err) => Debug::fmt(err, f),
62            AcceptError::UnsupportedVersion(v) => write!(f, "unsupported version byte: {v}"),
63            AcceptError::UnhandledFunction { instance, name } => {
64                write!(f, "`{instance}#{name}` does not have a handler registered")
65            }
66            AcceptError::Send(err) => Debug::fmt(err, f),
67        }
68    }
69}
70
71impl<C, I, O> Display for AcceptError<C, I, O> {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        match self {
74            AcceptError::IO(err) => Display::fmt(err, f),
75            AcceptError::UnsupportedVersion(v) => write!(f, "unsupported version byte: {v}"),
76            AcceptError::UnhandledFunction { instance, name } => {
77                write!(f, "`{instance}#{name}` does not have a handler registered")
78            }
79            AcceptError::Send(err) => Display::fmt(err, f),
80        }
81    }
82}
83
84impl<C, I, O> std::error::Error for AcceptError<C, I, O> {}
85
86impl<C, I, O, H> Server<C, I, O, H>
87where
88    I: AsyncRead + Unpin,
89    H: ConnHandler<I, O>,
90{
91    /// Accept a connection on an [Accept].
92    ///
93    /// # Errors
94    ///
95    /// Returns an error if accepting the connection has failed
96    #[instrument(level = "trace", skip_all, ret(level = "trace"))]
97    pub async fn accept(
98        &self,
99        listener: impl Accept<Context = C, Incoming = I, Outgoing = O>,
100    ) -> Result<(), AcceptError<C, I, O>> {
101        let (cx, tx, mut rx) = listener.accept().await.map_err(AcceptError::IO)?;
102        let mut instance = String::default();
103        let mut name = String::default();
104        match rx.read_u8().await.map_err(AcceptError::IO)? {
105            0x00 => {
106                rx.read_core_name(&mut instance)
107                    .await
108                    .map_err(AcceptError::IO)?;
109                rx.read_core_name(&mut name)
110                    .await
111                    .map_err(AcceptError::IO)?;
112            }
113            v => return Err(AcceptError::UnsupportedVersion(v)),
114        }
115        let h = self.handlers.lock().await;
116        let h = h
117            .get(&instance)
118            .and_then(|h| h.get(&name))
119            .ok_or_else(|| AcceptError::UnhandledFunction { instance, name })?;
120        h.send((cx, rx, tx)).await.map_err(AcceptError::Send)?;
121        Ok(())
122    }
123}
124
125#[instrument(level = "trace", skip(srv, paths))]
126async fn serve<C, I, O, H>(
127    srv: &Server<C, I, O, H>,
128    instance: &str,
129    func: &str,
130    paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
131) -> anyhow::Result<impl Stream<Item = anyhow::Result<(C, Outgoing, Incoming)>> + 'static>
132where
133    C: Send + Sync + 'static,
134    I: AsyncRead + Send + Sync + Unpin + 'static,
135    O: AsyncWrite + Send + Sync + Unpin + 'static,
136    H: ConnHandler<I, O>,
137{
138    let (tx, rx) = mpsc::channel(1024);
139    let mut handlers = srv.handlers.lock().await;
140    match handlers
141        .entry(instance.to_string())
142        .or_default()
143        .entry(func.to_string())
144    {
145        hash_map::Entry::Occupied(_) => {
146            bail!("handler for `{instance}#{func}` already exists")
147        }
148        hash_map::Entry::Vacant(entry) => {
149            entry.insert(tx);
150        }
151    }
152    let paths = paths.into();
153    Ok(ReceiverStream::new(rx).map(move |(cx, rx, tx)| {
154        trace!("received invocation");
155        let Conn { tx, rx } = Conn::new::<H, _, _, _>(rx, tx, paths.iter());
156        Ok((cx, tx, rx))
157    }))
158}
159
160impl<C, I, O, H> Serve for Server<C, I, O, H>
161where
162    C: Send + Sync + 'static,
163    I: AsyncRead + Send + Sync + Unpin + 'static,
164    O: AsyncWrite + Send + Sync + Unpin + 'static,
165    H: ConnHandler<I, O> + Send + Sync,
166{
167    type Context = C;
168    type Outgoing = Outgoing;
169    type Incoming = Incoming;
170
171    async fn serve(
172        &self,
173        instance: &str,
174        func: &str,
175        paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
176    ) -> anyhow::Result<
177        impl Stream<Item = anyhow::Result<(Self::Context, Self::Outgoing, Self::Incoming)>> + 'static,
178    > {
179        serve(self, instance, func, paths).await
180    }
181}
182
183impl<C, I, O, H> Serve for &Server<C, I, O, H>
184where
185    C: Send + Sync + 'static,
186    I: AsyncRead + Send + Sync + Unpin + 'static,
187    O: AsyncWrite + Send + Sync + Unpin + 'static,
188    H: ConnHandler<I, O> + Send + Sync,
189{
190    type Context = C;
191    type Outgoing = Outgoing;
192    type Incoming = Incoming;
193
194    async fn serve(
195        &self,
196        instance: &str,
197        func: &str,
198        paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
199    ) -> anyhow::Result<
200        impl Stream<Item = anyhow::Result<(Self::Context, Self::Outgoing, Self::Incoming)>> + 'static,
201    > {
202        serve(self, instance, func, paths).await
203    }
204}