wasmrs_wasmtime/
engine_provider.rs

1use std::sync::Arc;
2
3use bytes::{BufMut, BytesMut};
4use tokio::sync::Mutex;
5use wasi_common::WasiCtx;
6use wasmrs::{Frame, OperationList, WasmSocket};
7use wasmrs_host::{EngineProvider, GuestExports, HostServer, ProviderCallContext, SharedContext};
8use wasmtime::{AsContextMut, Engine, Linker, Memory, Module, Store, TypedFunc};
9
10use super::Result;
11use crate::errors::Error;
12use crate::memory::write_bytes_to_memory;
13use crate::store::{new_store, ProviderStore};
14use crate::wasmrs_wasmtime::{self};
15
16/// A wasmRS engine provider that encapsulates the Wasmtime WebAssembly runtime
17#[allow(missing_debug_implementations)]
18pub struct WasmtimeEngineProvider {
19  module: Module,
20  engine: Engine,
21  linker: Linker<ProviderStore<HostServer>>,
22  wasi_ctx: Option<WasiCtx>,
23  pub(crate) epoch_deadlines: Option<EpochDeadlines>,
24}
25
26#[derive(Clone, Copy, Debug)]
27pub(crate) struct EpochDeadlines {
28  /// Deadline for wasmRS initialization code. Expressed in number of epoch ticks
29  #[allow(dead_code)]
30  pub(crate) wasmrs_init: u64,
31
32  /// Deadline for user-defined wasmRS function computation. Expressed in number of epoch ticks
33  #[allow(dead_code)]
34  pub(crate) wasmrs_func: u64,
35}
36
37impl WasmtimeEngineProvider {
38  /// Creates a new instance of a [WasmtimeEngineProvider] from a separately created [wasmtime::Engine].
39  pub(crate) fn new_with_engine(module: Module, engine: Engine, wasi_ctx: Option<WasiCtx>) -> Result<Self> {
40    let mut linker: Linker<ProviderStore<HostServer>> = Linker::new(&engine);
41
42    if wasi_ctx.is_some() {
43      wasmtime_wasi::add_to_linker(&mut linker, |s| s.wasi_ctx.as_mut().unwrap()).unwrap();
44    }
45
46    Ok(WasmtimeEngineProvider {
47      module,
48      engine,
49      wasi_ctx,
50      linker,
51      epoch_deadlines: None,
52    })
53  }
54}
55
56#[async_trait::async_trait]
57impl EngineProvider for WasmtimeEngineProvider {
58  async fn new_context(
59    &self,
60    socket: Arc<WasmSocket<HostServer>>,
61  ) -> std::result::Result<SharedContext, wasmrs_host::errors::Error> {
62    let store = new_store(self.wasi_ctx.clone(), socket, &self.engine)
63      .map_err(|e| wasmrs_host::errors::Error::NewContext(e.to_string()))?;
64
65    let context = SharedContext::new(
66      WasmtimeCallContext::new(self.linker.clone(), &self.module, store)
67        .await
68        .map_err(|e| wasmrs_host::errors::Error::InitFailed(Box::new(e)))?,
69    );
70
71    Ok(context)
72  }
73}
74
75#[derive(PartialEq, Debug)]
76enum Version {
77  V0,
78  V1,
79}
80
81struct Imports {
82  start: Option<TypedFunc<(), ()>>,
83  guest_init: TypedFunc<(u32, u32, u32), ()>,
84  op_list: Option<TypedFunc<(), ()>>,
85  guest_send: TypedFunc<i32, ()>,
86  version: Version,
87}
88
89struct WasmtimeCallContext {
90  memory: Memory,
91  store: Mutex<Store<ProviderStore<HostServer>>>,
92  imports: Imports,
93  op_list: parking_lot::Mutex<OperationList>,
94}
95
96impl WasmtimeCallContext {
97  pub(crate) async fn new(
98    mut linker: Linker<ProviderStore<HostServer>>,
99    module: &Module,
100    mut store: Store<ProviderStore<HostServer>>,
101  ) -> Result<Self> {
102    wasmrs_wasmtime::add_to_linker(&mut linker)?;
103    let instance = linker
104      .instantiate_async(&mut store, module)
105      .await
106      .map_err(Error::Linker)?;
107
108    let guest_send = instance
109      .get_typed_func::<i32, ()>(&mut store, GuestExports::Send.as_ref())
110      .map_err(|_| crate::errors::Error::GuestSend)?;
111    let memory = instance.get_memory(&mut store, "memory").unwrap();
112
113    let version = instance
114      .get_typed_func::<(), ()>(&mut store, GuestExports::Version1.as_ref())
115      .map_or(Version::V0, |_| Version::V1);
116
117    let imports = Imports {
118      version,
119      start: instance.get_typed_func(&mut store, GuestExports::Start.as_ref()).ok(),
120      guest_init: instance
121        .get_typed_func(&mut store, GuestExports::Init.as_ref())
122        .map_err(|_e| Error::GuestInit)?,
123      op_list: instance
124        .get_typed_func::<(), ()>(&mut store, GuestExports::OpListRequest.as_ref())
125        .ok(),
126      guest_send,
127    };
128
129    Ok(Self {
130      memory,
131      store: Mutex::new(store),
132      imports,
133      op_list: parking_lot::Mutex::new(OperationList::default()),
134    })
135  }
136}
137
138#[async_trait::async_trait]
139impl wasmrs::ModuleHost for WasmtimeCallContext {
140  /// Request-Response interaction model of RSocket.
141  async fn write_frame(&self, mut req: Frame) -> std::result::Result<(), wasmrs::Error> {
142    let bytes = if self.imports.version == Version::V0 {
143      req.make_v0_metadata();
144      req.encode()
145    } else {
146      req.encode()
147    };
148    trace!(?bytes, "writing frame");
149
150    let buffer_len_bytes = wasmrs::util::to_u24_bytes(bytes.len() as u32);
151    let mut buffer = BytesMut::with_capacity(buffer_len_bytes.len() + bytes.len());
152    buffer.put(buffer_len_bytes);
153    buffer.put(bytes);
154
155    let mut store = self.store.lock().await;
156
157    let start = store.data().guest_buffer.get_start();
158    let len = store.data().guest_buffer.get_size();
159
160    let written = write_bytes_to_memory(store.as_context_mut(), self.memory, &buffer, start, len);
161
162    self
163      .imports
164      .guest_send
165      .call_async(store.as_context_mut(), written as i32)
166      .await
167      .map_err(|e| wasmrs::Error::GuestCall(e.to_string()))?;
168
169    Ok(())
170  }
171
172  async fn on_error(&self, stream_id: u32) -> std::result::Result<(), wasmrs::Error> {
173    let mut lock = self.store.lock().await;
174    let data = lock.data_mut();
175    if let Err(e) = data.socket.process_once(Frame::new_cancel(stream_id)) {
176      error!("error processing cancel for stream id {}, {}", stream_id, e);
177    };
178    Ok(())
179  }
180
181  fn get_import(&self, namespace: &str, operation: &str) -> Option<u32> {
182    self.op_list.lock().get_import(namespace, operation)
183  }
184
185  fn get_export(&self, namespace: &str, operation: &str) -> Option<u32> {
186    self.op_list.lock().get_export(namespace, operation)
187  }
188
189  fn get_operation_list(&self) -> OperationList {
190    self.op_list.lock().clone()
191  }
192}
193
194#[async_trait::async_trait]
195impl ProviderCallContext for WasmtimeCallContext {
196  async fn init(
197    &self,
198    host_buffer_size: u32,
199    guest_buffer_size: u32,
200  ) -> std::result::Result<(), wasmrs_host::errors::Error> {
201    let mut store = self.store.lock().await;
202
203    if let Some(start) = &self.imports.start {
204      start
205        .call_async(store.as_context_mut(), ())
206        .await
207        .map_err(|e| wasmrs_host::errors::Error::InitFailed(e.into()))?;
208    }
209
210    self
211      .imports
212      .guest_init
213      .call_async(store.as_context_mut(), (host_buffer_size, guest_buffer_size, 128))
214      .await
215      .map_err(|e| wasmrs_host::errors::Error::InitFailed(e.into()))?;
216
217    store.data().guest_buffer.update_size(guest_buffer_size);
218    store.data().host_buffer.update_size(host_buffer_size);
219
220    if let Some(oplist) = self.imports.op_list {
221      trace!("calling operation list");
222      oplist
223        .call_async(store.as_context_mut(), ())
224        .await
225        .map_err(|e| wasmrs_host::errors::Error::OpList(e.to_string()))?;
226
227      *self.op_list.lock() = store.data().op_list.clone();
228    }
229
230    Ok(())
231  }
232}