1use std::sync::Arc;
2
3use bytes::{BufMut, BytesMut};
4use tokio::sync::Mutex;
5use wasi_common::WasiCtx;
6use wasmrs::{Frame, OperationList, WasmSocket};
7use wasmrs_host::{EngineProvider, GuestExports, HostServer, ProviderCallContext, SharedContext};
8use wasmtime::{AsContextMut, Engine, Linker, Memory, Module, Store, TypedFunc};
9
10use super::Result;
11use crate::errors::Error;
12use crate::memory::write_bytes_to_memory;
13use crate::store::{new_store, ProviderStore};
14use crate::wasmrs_wasmtime::{self};
15
16#[allow(missing_debug_implementations)]
18pub struct WasmtimeEngineProvider {
19 module: Module,
20 engine: Engine,
21 linker: Linker<ProviderStore<HostServer>>,
22 wasi_ctx: Option<WasiCtx>,
23 pub(crate) epoch_deadlines: Option<EpochDeadlines>,
24}
25
26#[derive(Clone, Copy, Debug)]
27pub(crate) struct EpochDeadlines {
28 #[allow(dead_code)]
30 pub(crate) wasmrs_init: u64,
31
32 #[allow(dead_code)]
34 pub(crate) wasmrs_func: u64,
35}
36
37impl WasmtimeEngineProvider {
38 pub(crate) fn new_with_engine(module: Module, engine: Engine, wasi_ctx: Option<WasiCtx>) -> Result<Self> {
40 let mut linker: Linker<ProviderStore<HostServer>> = Linker::new(&engine);
41
42 if wasi_ctx.is_some() {
43 wasmtime_wasi::add_to_linker(&mut linker, |s| s.wasi_ctx.as_mut().unwrap()).unwrap();
44 }
45
46 Ok(WasmtimeEngineProvider {
47 module,
48 engine,
49 wasi_ctx,
50 linker,
51 epoch_deadlines: None,
52 })
53 }
54}
55
56#[async_trait::async_trait]
57impl EngineProvider for WasmtimeEngineProvider {
58 async fn new_context(
59 &self,
60 socket: Arc<WasmSocket<HostServer>>,
61 ) -> std::result::Result<SharedContext, wasmrs_host::errors::Error> {
62 let store = new_store(self.wasi_ctx.clone(), socket, &self.engine)
63 .map_err(|e| wasmrs_host::errors::Error::NewContext(e.to_string()))?;
64
65 let context = SharedContext::new(
66 WasmtimeCallContext::new(self.linker.clone(), &self.module, store)
67 .await
68 .map_err(|e| wasmrs_host::errors::Error::InitFailed(Box::new(e)))?,
69 );
70
71 Ok(context)
72 }
73}
74
75#[derive(PartialEq, Debug)]
76enum Version {
77 V0,
78 V1,
79}
80
81struct Imports {
82 start: Option<TypedFunc<(), ()>>,
83 guest_init: TypedFunc<(u32, u32, u32), ()>,
84 op_list: Option<TypedFunc<(), ()>>,
85 guest_send: TypedFunc<i32, ()>,
86 version: Version,
87}
88
89struct WasmtimeCallContext {
90 memory: Memory,
91 store: Mutex<Store<ProviderStore<HostServer>>>,
92 imports: Imports,
93 op_list: parking_lot::Mutex<OperationList>,
94}
95
96impl WasmtimeCallContext {
97 pub(crate) async fn new(
98 mut linker: Linker<ProviderStore<HostServer>>,
99 module: &Module,
100 mut store: Store<ProviderStore<HostServer>>,
101 ) -> Result<Self> {
102 wasmrs_wasmtime::add_to_linker(&mut linker)?;
103 let instance = linker
104 .instantiate_async(&mut store, module)
105 .await
106 .map_err(Error::Linker)?;
107
108 let guest_send = instance
109 .get_typed_func::<i32, ()>(&mut store, GuestExports::Send.as_ref())
110 .map_err(|_| crate::errors::Error::GuestSend)?;
111 let memory = instance.get_memory(&mut store, "memory").unwrap();
112
113 let version = instance
114 .get_typed_func::<(), ()>(&mut store, GuestExports::Version1.as_ref())
115 .map_or(Version::V0, |_| Version::V1);
116
117 let imports = Imports {
118 version,
119 start: instance.get_typed_func(&mut store, GuestExports::Start.as_ref()).ok(),
120 guest_init: instance
121 .get_typed_func(&mut store, GuestExports::Init.as_ref())
122 .map_err(|_e| Error::GuestInit)?,
123 op_list: instance
124 .get_typed_func::<(), ()>(&mut store, GuestExports::OpListRequest.as_ref())
125 .ok(),
126 guest_send,
127 };
128
129 Ok(Self {
130 memory,
131 store: Mutex::new(store),
132 imports,
133 op_list: parking_lot::Mutex::new(OperationList::default()),
134 })
135 }
136}
137
138#[async_trait::async_trait]
139impl wasmrs::ModuleHost for WasmtimeCallContext {
140 async fn write_frame(&self, mut req: Frame) -> std::result::Result<(), wasmrs::Error> {
142 let bytes = if self.imports.version == Version::V0 {
143 req.make_v0_metadata();
144 req.encode()
145 } else {
146 req.encode()
147 };
148 trace!(?bytes, "writing frame");
149
150 let buffer_len_bytes = wasmrs::util::to_u24_bytes(bytes.len() as u32);
151 let mut buffer = BytesMut::with_capacity(buffer_len_bytes.len() + bytes.len());
152 buffer.put(buffer_len_bytes);
153 buffer.put(bytes);
154
155 let mut store = self.store.lock().await;
156
157 let start = store.data().guest_buffer.get_start();
158 let len = store.data().guest_buffer.get_size();
159
160 let written = write_bytes_to_memory(store.as_context_mut(), self.memory, &buffer, start, len);
161
162 self
163 .imports
164 .guest_send
165 .call_async(store.as_context_mut(), written as i32)
166 .await
167 .map_err(|e| wasmrs::Error::GuestCall(e.to_string()))?;
168
169 Ok(())
170 }
171
172 async fn on_error(&self, stream_id: u32) -> std::result::Result<(), wasmrs::Error> {
173 let mut lock = self.store.lock().await;
174 let data = lock.data_mut();
175 if let Err(e) = data.socket.process_once(Frame::new_cancel(stream_id)) {
176 error!("error processing cancel for stream id {}, {}", stream_id, e);
177 };
178 Ok(())
179 }
180
181 fn get_import(&self, namespace: &str, operation: &str) -> Option<u32> {
182 self.op_list.lock().get_import(namespace, operation)
183 }
184
185 fn get_export(&self, namespace: &str, operation: &str) -> Option<u32> {
186 self.op_list.lock().get_export(namespace, operation)
187 }
188
189 fn get_operation_list(&self) -> OperationList {
190 self.op_list.lock().clone()
191 }
192}
193
194#[async_trait::async_trait]
195impl ProviderCallContext for WasmtimeCallContext {
196 async fn init(
197 &self,
198 host_buffer_size: u32,
199 guest_buffer_size: u32,
200 ) -> std::result::Result<(), wasmrs_host::errors::Error> {
201 let mut store = self.store.lock().await;
202
203 if let Some(start) = &self.imports.start {
204 start
205 .call_async(store.as_context_mut(), ())
206 .await
207 .map_err(|e| wasmrs_host::errors::Error::InitFailed(e.into()))?;
208 }
209
210 self
211 .imports
212 .guest_init
213 .call_async(store.as_context_mut(), (host_buffer_size, guest_buffer_size, 128))
214 .await
215 .map_err(|e| wasmrs_host::errors::Error::InitFailed(e.into()))?;
216
217 store.data().guest_buffer.update_size(guest_buffer_size);
218 store.data().host_buffer.update_size(host_buffer_size);
219
220 if let Some(oplist) = self.imports.op_list {
221 trace!("calling operation list");
222 oplist
223 .call_async(store.as_context_mut(), ())
224 .await
225 .map_err(|e| wasmrs_host::errors::Error::OpList(e.to_string()))?;
226
227 *self.op_list.lock() = store.data().op_list.clone();
228 }
229
230 Ok(())
231 }
232}