wasmrs_host/
host.rs

1use std::sync::Arc;
2
3use futures_core::Stream;
4use futures_util::{FutureExt, StreamExt};
5use parking_lot::Mutex;
6use wasmrs::{
7  BoxFlux, BoxMono, Frame, Handlers, IncomingMono, IncomingStream, Metadata, OperationHandler, OutgoingMono,
8  OutgoingStream, Payload, RSocket, RawPayload, WasmSocket,
9};
10use wasmrs_frames::PayloadError;
11use wasmrs_runtime::{spawn, ConditionallySend, UnboundedReceiver};
12use wasmrs_rx::*;
13
14use crate::context::{EngineProvider, SharedContext};
15
16type Result<T> = std::result::Result<T, crate::errors::Error>;
17
18#[must_use]
19#[allow(missing_debug_implementations)]
20/// A wasmRS native Host.
21pub struct Host {
22  engine: Box<dyn EngineProvider + Send + Sync>,
23  mtu: usize,
24  handlers: Arc<Mutex<Handlers>>,
25}
26
27impl Host {
28  /// Create a new [Host] with an [EngineProvider] implementation.
29  pub async fn new<E: EngineProvider + Send + Sync + 'static>(engine: E) -> Result<Self> {
30    let host = Host {
31      engine: Box::new(engine),
32      mtu: 256,
33      handlers: Default::default(),
34    };
35
36    Ok(host)
37  }
38
39  /// Create a new [CallContext], a way to bucket calls together with the same memory and configuration.
40  pub async fn new_context(&self, host_buffer_size: u32, guest_buffer_size: u32) -> Result<CallContext> {
41    let mut socket = WasmSocket::new(
42      HostServer {
43        handlers: self.handlers.clone(),
44      },
45      wasmrs::SocketSide::Host,
46    );
47    let rx = socket.take_rx().unwrap();
48    let socket = Arc::new(socket);
49
50    let context = self.engine.new_context(socket.clone()).await?;
51
52    context.init(host_buffer_size, guest_buffer_size).await?;
53
54    CallContext::new(self.mtu, socket, context, rx)
55  }
56
57  /// Register a Request/Response style handler on the host.
58  pub fn register_request_response(
59    &self,
60    ns: impl AsRef<str>,
61    op: impl AsRef<str>,
62    handler: OperationHandler<IncomingMono, OutgoingMono>,
63  ) -> usize {
64    self.handlers.lock().register_request_response(ns, op, handler)
65  }
66
67  /// Register a Request/Response style handler on the host.
68  pub fn register_request_stream(
69    &self,
70    ns: impl AsRef<str>,
71    op: impl AsRef<str>,
72    handler: OperationHandler<IncomingMono, OutgoingStream>,
73  ) -> usize {
74    self.handlers.lock().register_request_stream(ns, op, handler)
75  }
76
77  /// Register a Request/Response style handler on the host.
78  pub fn register_request_channel(
79    &self,
80    ns: impl AsRef<str>,
81    op: impl AsRef<str>,
82    handler: OperationHandler<IncomingStream, OutgoingStream>,
83  ) -> usize {
84    self.handlers.lock().register_request_channel(ns, op, handler)
85  }
86
87  /// Register a Request/Response style handler on the host.
88  pub fn register_fire_and_forget(
89    &self,
90    ns: impl AsRef<str>,
91    op: impl AsRef<str>,
92    handler: OperationHandler<IncomingMono, ()>,
93  ) -> usize {
94    self.handlers.lock().register_fire_and_forget(ns, op, handler)
95  }
96}
97
98fn spawn_writer(mut rx: UnboundedReceiver<Frame>, context: SharedContext) -> tokio::task::JoinHandle<()> {
99  spawn("host:spawn_writer", async move {
100    while let Some(frame) = rx.recv().await {
101      let _ = context.write_frame(frame).await;
102    }
103  })
104}
105
106#[allow(missing_debug_implementations)]
107#[derive(Clone)]
108/// A wasmRS native Host.
109pub struct HostServer {
110  handlers: Arc<Mutex<Handlers>>,
111}
112
113fn parse_payload(req: RawPayload) -> Payload {
114  if let Some(mut md_bytes) = req.metadata {
115    let md = Metadata::decode(&mut md_bytes).unwrap();
116    Payload::new(md, req.data.unwrap())
117  } else {
118    panic!("No metadata found in payload.");
119  }
120}
121
122impl RSocket for HostServer {
123  fn fire_and_forget(&self, req: RawPayload) -> BoxMono<(), PayloadError> {
124    let payload = parse_payload(req);
125    let handler = self
126      .handlers
127      .lock()
128      .get_fnf_handler(payload.metadata.index.unwrap())
129      .unwrap();
130    handler(futures_util::future::ready(Ok(payload)).boxed()).unwrap();
131    futures_util::future::ready(Ok(())).boxed()
132  }
133
134  fn request_response(&self, req: RawPayload) -> BoxMono<RawPayload, PayloadError> {
135    let payload = parse_payload(req);
136    let handler = self
137      .handlers
138      .lock()
139      .get_request_response_handler(payload.metadata.index.unwrap())
140      .unwrap();
141
142    handler(futures_util::future::ready(Ok(payload)).boxed()).unwrap()
143  }
144
145  fn request_stream(&self, req: RawPayload) -> BoxFlux<RawPayload, PayloadError> {
146    let payload = parse_payload(req);
147    let handler = self
148      .handlers
149      .lock()
150      .get_request_stream_handler(payload.metadata.index.unwrap())
151      .unwrap();
152    handler(futures_util::future::ready(Ok(payload)).boxed()).unwrap()
153  }
154
155  fn request_channel<
156    T: Stream<Item = std::result::Result<RawPayload, PayloadError>> + ConditionallySend + Unpin + 'static,
157  >(
158    &self,
159    mut reqs: T,
160  ) -> BoxFlux<RawPayload, PayloadError> {
161    let (out_tx, out_rx) = FluxChannel::<RawPayload, PayloadError>::new_parts();
162    let handlers = self.handlers.clone();
163    tokio::spawn(async move {
164      let (inner_tx, inner_rx) = FluxChannel::new_parts();
165      let first = match reqs.next().await {
166        None => {
167          let _ = out_tx.send_result(Err(PayloadError::application_error("No first payload.", None)));
168          return;
169        }
170        Some(Err(e)) => {
171          let _ = out_tx.send_result(Err(e));
172          return;
173        }
174        Some(Ok(p)) => p,
175      };
176
177      let payload = parse_payload(first);
178      let handler = handlers
179        .lock()
180        .get_request_channel_handler(payload.metadata.index.unwrap())
181        .unwrap();
182      let _ = inner_tx.send(payload);
183      let mut out = handler(inner_rx.boxed()).unwrap();
184      tokio::spawn(async move {
185        while let Some(p) = out.next().await {
186          let _ = out_tx.send_result(p);
187        }
188        out_tx.complete();
189      });
190      tokio::spawn(async move {
191        while let Some(p) = reqs.next().await {
192          let _ = inner_tx.send_result(p.map(parse_payload));
193        }
194        inner_tx.complete();
195      });
196    });
197    out_rx.boxed()
198  }
199}
200
201/// A [CallContext] is a way to bucket calls together with the same memory and configuration.
202pub struct CallContext {
203  socket: Arc<WasmSocket<HostServer>>,
204  context: SharedContext,
205  writer: tokio::task::JoinHandle<()>,
206}
207
208impl std::fmt::Debug for CallContext {
209  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210    f.debug_struct("WasmRsCallContext")
211      .field("state", &self.socket)
212      .finish()
213  }
214}
215
216impl CallContext {
217  fn new(
218    _mtu: usize,
219    socket: Arc<WasmSocket<HostServer>>,
220    context: SharedContext,
221    rx: UnboundedReceiver<Frame>,
222  ) -> Result<Self> {
223    let writer = spawn_writer(rx, context.clone());
224
225    Ok(Self {
226      socket,
227      context,
228      writer,
229    })
230  }
231
232  /// Get the import id for a given namespace and operation.
233  pub fn get_import(&self, namespace: &str, operation: &str) -> Option<u32> {
234    self.context.get_import(namespace, operation)
235  }
236
237  /// Get the export id for a given namespace and operation.
238  pub fn get_export(&self, namespace: &str, operation: &str) -> Option<u32> {
239    self.context.get_export(namespace, operation)
240  }
241
242  /// Get a list of the exports for this context.
243  #[must_use]
244  pub fn get_exports(&self) -> Vec<String> {
245    self.context.get_operation_list().get_exports()
246  }
247
248  /// A utility function to dump the operation list.
249  pub fn dump_operations(&self) {
250    println!("{:#?}", self.context.get_operation_list());
251  }
252
253  /// Query if the frame writer is still running.
254  pub fn is_alive(&self) -> bool {
255    !self.writer.is_finished()
256  }
257}
258
259impl RSocket for CallContext {
260  fn fire_and_forget(&self, payload: RawPayload) -> BoxMono<(), PayloadError> {
261    self.socket.fire_and_forget(payload)
262  }
263
264  fn request_response(&self, payload: RawPayload) -> BoxMono<RawPayload, PayloadError> {
265    self.socket.request_response(payload)
266  }
267
268  fn request_stream(&self, payload: RawPayload) -> BoxFlux<RawPayload, PayloadError> {
269    self.socket.request_stream(payload)
270  }
271
272  fn request_channel<
273    T: Stream<Item = std::result::Result<RawPayload, PayloadError>> + ConditionallySend + Unpin + 'static,
274  >(
275    &self,
276    stream: T,
277  ) -> BoxFlux<RawPayload, PayloadError> {
278    self.socket.request_channel(stream)
279  }
280}