1use core::fmt::{Debug, Display};
2use core::marker::PhantomData;
3
4use std::collections::{hash_map, HashMap};
5use std::sync::Arc;
6
7use anyhow::bail;
8use futures::{Stream, StreamExt as _};
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
10use tokio::sync::{mpsc, Mutex};
11use tokio_stream::wrappers::ReceiverStream;
12use tracing::{instrument, trace};
13use wasm_tokio::AsyncReadCore as _;
14
15use crate::frame::conn::Accept;
16use crate::frame::{Conn, ConnHandler, Incoming, Outgoing};
17use crate::Serve;
18
19pub struct Server<C, I, O, H = ()> {
21 handlers: Mutex<HashMap<String, HashMap<String, mpsc::Sender<(C, I, O)>>>>,
22 conn_handler: PhantomData<H>,
23}
24
25impl<C, I, O, H> Server<C, I, O, H> {
26 pub fn new() -> Self {
28 Self {
29 handlers: Mutex::default(),
30 conn_handler: PhantomData,
31 }
32 }
33}
34
35impl<C, I, O> Default for Server<C, I, O> {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41pub enum AcceptError<C, I, O> {
43 IO(std::io::Error),
45 UnsupportedVersion(u8),
47 UnhandledFunction {
49 instance: String,
51 name: String,
53 },
54 Send(mpsc::error::SendError<(C, I, O)>),
56}
57
58impl<C, I, O> Debug for AcceptError<C, I, O> {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 match self {
61 AcceptError::IO(err) => Debug::fmt(err, f),
62 AcceptError::UnsupportedVersion(v) => write!(f, "unsupported version byte: {v}"),
63 AcceptError::UnhandledFunction { instance, name } => {
64 write!(f, "`{instance}#{name}` does not have a handler registered")
65 }
66 AcceptError::Send(err) => Debug::fmt(err, f),
67 }
68 }
69}
70
71impl<C, I, O> Display for AcceptError<C, I, O> {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 AcceptError::IO(err) => Display::fmt(err, f),
75 AcceptError::UnsupportedVersion(v) => write!(f, "unsupported version byte: {v}"),
76 AcceptError::UnhandledFunction { instance, name } => {
77 write!(f, "`{instance}#{name}` does not have a handler registered")
78 }
79 AcceptError::Send(err) => Display::fmt(err, f),
80 }
81 }
82}
83
84impl<C, I, O> std::error::Error for AcceptError<C, I, O> {}
85
86impl<C, I, O, H> Server<C, I, O, H>
87where
88 I: AsyncRead + Unpin,
89 H: ConnHandler<I, O>,
90{
91 #[instrument(level = "trace", skip_all, ret(level = "trace"))]
97 pub async fn accept(
98 &self,
99 listener: impl Accept<Context = C, Incoming = I, Outgoing = O>,
100 ) -> Result<(), AcceptError<C, I, O>> {
101 let (cx, tx, mut rx) = listener.accept().await.map_err(AcceptError::IO)?;
102 let mut instance = String::default();
103 let mut name = String::default();
104 match rx.read_u8().await.map_err(AcceptError::IO)? {
105 0x00 => {
106 rx.read_core_name(&mut instance)
107 .await
108 .map_err(AcceptError::IO)?;
109 rx.read_core_name(&mut name)
110 .await
111 .map_err(AcceptError::IO)?;
112 }
113 v => return Err(AcceptError::UnsupportedVersion(v)),
114 }
115 let h = self.handlers.lock().await;
116 let h = h
117 .get(&instance)
118 .and_then(|h| h.get(&name))
119 .ok_or_else(|| AcceptError::UnhandledFunction { instance, name })?;
120 h.send((cx, rx, tx)).await.map_err(AcceptError::Send)?;
121 Ok(())
122 }
123}
124
125#[instrument(level = "trace", skip(srv, paths))]
126async fn serve<C, I, O, H>(
127 srv: &Server<C, I, O, H>,
128 instance: &str,
129 func: &str,
130 paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
131) -> anyhow::Result<impl Stream<Item = anyhow::Result<(C, Outgoing, Incoming)>> + 'static>
132where
133 C: Send + Sync + 'static,
134 I: AsyncRead + Send + Sync + Unpin + 'static,
135 O: AsyncWrite + Send + Sync + Unpin + 'static,
136 H: ConnHandler<I, O>,
137{
138 let (tx, rx) = mpsc::channel(1024);
139 let mut handlers = srv.handlers.lock().await;
140 match handlers
141 .entry(instance.to_string())
142 .or_default()
143 .entry(func.to_string())
144 {
145 hash_map::Entry::Occupied(_) => {
146 bail!("handler for `{instance}#{func}` already exists")
147 }
148 hash_map::Entry::Vacant(entry) => {
149 entry.insert(tx);
150 }
151 }
152 let paths = paths.into();
153 Ok(ReceiverStream::new(rx).map(move |(cx, rx, tx)| {
154 trace!("received invocation");
155 let Conn { tx, rx } = Conn::new::<H, _, _, _>(rx, tx, paths.iter());
156 Ok((cx, tx, rx))
157 }))
158}
159
160impl<C, I, O, H> Serve for Server<C, I, O, H>
161where
162 C: Send + Sync + 'static,
163 I: AsyncRead + Send + Sync + Unpin + 'static,
164 O: AsyncWrite + Send + Sync + Unpin + 'static,
165 H: ConnHandler<I, O> + Send + Sync,
166{
167 type Context = C;
168 type Outgoing = Outgoing;
169 type Incoming = Incoming;
170
171 async fn serve(
172 &self,
173 instance: &str,
174 func: &str,
175 paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
176 ) -> anyhow::Result<
177 impl Stream<Item = anyhow::Result<(Self::Context, Self::Outgoing, Self::Incoming)>> + 'static,
178 > {
179 serve(self, instance, func, paths).await
180 }
181}
182
183impl<C, I, O, H> Serve for &Server<C, I, O, H>
184where
185 C: Send + Sync + 'static,
186 I: AsyncRead + Send + Sync + Unpin + 'static,
187 O: AsyncWrite + Send + Sync + Unpin + 'static,
188 H: ConnHandler<I, O> + Send + Sync,
189{
190 type Context = C;
191 type Outgoing = Outgoing;
192 type Incoming = Incoming;
193
194 async fn serve(
195 &self,
196 instance: &str,
197 func: &str,
198 paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
199 ) -> anyhow::Result<
200 impl Stream<Item = anyhow::Result<(Self::Context, Self::Outgoing, Self::Incoming)>> + 'static,
201 > {
202 serve(self, instance, func, paths).await
203 }
204}