1use std::{
4 collections::{HashMap, HashSet},
5 sync::{Arc, RwLock},
6};
7
8use selium_abi::EntrypointInvocation;
9use selium_abi::{
10 self, AbiParam, AbiScalarType, AbiScalarValue, AbiSignature, AbiValue, CallPlan, CallPlanError,
11 hostcalls,
12};
13use selium_kernel::{
14 KernelError,
15 drivers::{Capability, module_store::ModuleStoreError, process::EntrypointInvocationExt},
16 futures::FutureSharedState,
17 guest_async::GuestAsync,
18 guest_data::{GuestError, GuestInt, GuestUint, write_poll_result},
19 mailbox,
20 operation::LinkableOperation,
21 registry::{InstanceRegistry, ProcessIdentity, Registry, ResourceId},
22};
23use thiserror::Error;
24use tracing::{debug, warn};
25use wasmtime::{Caller, Config, Engine, Func, Linker, Memory, Module, Store, Val, ValType};
26
27mod driver;
28pub use driver::WasmtimeDriver;
29
30pub struct WasmRuntime {
31 engine: Engine,
32 available_caps: RwLock<HashMap<Capability, Vec<Arc<dyn LinkableOperation>>>>,
33 guest_async: Arc<GuestAsync>,
34}
35
36const PREALLOC_PAGES: u64 = 256;
37
38#[derive(Error, Debug)]
39pub enum Error {
40 #[error("The requested capability ({0}) is not part of this kernel")]
41 CapabilityUnavailable(Capability),
42 #[error("Selium kernel error: {0}")]
43 Kernel(#[from] KernelError),
44 #[error("Module store error: {0}")]
45 ModuleStore(#[from] ModuleStoreError),
46 #[error("Wasmtime error: {0}")]
47 Wasmtime(#[from] wasmtime::Error),
48 #[error("The lock guarding the Capability registry has been poisoned")]
49 CapabilityRegistryPoisoned,
50}
51
52impl From<CallPlanError> for Error {
53 fn from(value: CallPlanError) -> Self {
54 Self::Kernel(KernelError::Driver(value.to_string()))
55 }
56}
57
58impl WasmRuntime {
59 pub fn new(
60 available_caps: HashMap<Capability, Vec<Arc<dyn LinkableOperation>>>,
61 guest_async: Arc<GuestAsync>,
62 ) -> Result<Self, Error> {
63 let mut config = Config::new();
64 config.memory_may_move(false);
65
66 Ok(Self {
67 engine: Engine::new(&config)?,
68 available_caps: RwLock::new(available_caps),
69 guest_async,
70 })
71 }
72
73 pub fn extend_capability(
74 &self,
75 capability: Capability,
76 operations: impl IntoIterator<Item = Arc<dyn LinkableOperation>>,
77 ) -> Result<(), Error> {
78 let mut map = self
79 .available_caps
80 .write()
81 .map_err(|_| Error::CapabilityRegistryPoisoned)?;
82 let entry = map.entry(capability).or_default();
83 entry.extend(operations);
84 Ok(())
85 }
86
87 pub async fn run(
88 &self,
89 registry: &Arc<Registry>,
90 process_id: ResourceId,
91 module: Module,
92 name: &str,
93 capabilities: &[Capability],
94 entrypoint: EntrypointInvocation,
95 ) -> Result<(), Error> {
96 let mut linker = Linker::new(&self.engine);
97 let operations_to_link = {
98 let map = self
99 .available_caps
100 .read()
101 .map_err(|_| Error::CapabilityRegistryPoisoned)?;
102 let mut ops = Vec::new();
103 let requested: HashSet<Capability> = capabilities.iter().copied().collect();
104 for capability in &requested {
105 let operations = map
106 .get(capability)
107 .ok_or(Error::CapabilityUnavailable(*capability))?;
108
109 if operations.is_empty() {
110 return Err(Error::CapabilityUnavailable(*capability));
111 }
112
113 ops.extend(operations.iter().cloned());
114 }
115 ops.extend(stub_operations_for_missing(&requested));
116 ops
117 };
118
119 for op in operations_to_link {
120 op.link(&mut linker)?;
121 }
122
123 self.guest_async.link(&mut linker)?;
124
125 let instance_registry = registry.instance().map_err(KernelError::from)?;
126 let mut store = Store::new(&self.engine, instance_registry);
127 store
128 .data_mut()
129 .set_process_id(process_id)
130 .map_err(KernelError::from)?;
131 let identity = ProcessIdentity::new(process_id);
132 store
133 .data_mut()
134 .insert_extension(identity)
135 .map_err(KernelError::from)?;
136 let instance = linker.instantiate_async(&mut store, &module).await?;
141
142 let memory = instance.get_memory(&mut store, "memory").ok_or_else(|| {
144 Error::Kernel(KernelError::Driver("guest memory missing".to_string()))
145 })?;
146 preallocate_memory(&memory, &mut store);
147 let mb = unsafe { mailbox::create_guest_mailbox(&memory, &mut store) };
148 store
149 .data_mut()
150 .load_mailbox(mb)
151 .map_err(KernelError::from)?;
152
153 let signature = entrypoint.signature().clone();
154 let call_values = {
155 let registry = store.data_mut();
156 entrypoint.materialise_values(registry)?
157 };
158 let plan = CallPlan::new(&signature, &call_values)?;
159 materialise_plan(&memory, &mut store, &plan)?;
160
161 let func = instance.get_func(&mut store, name).ok_or_else(|| {
162 Error::Wasmtime(wasmtime::Error::msg(format!(
163 "entrypoint `{name}` not found"
164 )))
165 })?;
166 let func_ty = func.ty(&store);
167 let param_types: Vec<ValType> = func_ty.params().collect();
168 let result_types: Vec<ValType> = func_ty.results().collect();
169 let expected_params = flatten_signature_types(signature.params());
170 let expected_results = flatten_signature_types(signature.results());
171
172 let params_match = param_types.len() == expected_params.len()
173 && param_types
174 .iter()
175 .zip(expected_params.iter())
176 .all(|(actual, expected)| valtype_eq(actual, expected));
177
178 if !params_match {
179 return Err(Error::Kernel(KernelError::Driver(format!(
180 "entrypoint `{name}` expects params {:?}, got {:?}",
181 expected_params, param_types
182 ))));
183 }
184
185 let results_match = result_types.len() == expected_results.len()
186 && result_types
187 .iter()
188 .zip(expected_results.iter())
189 .all(|(actual, expected)| valtype_eq(actual, expected));
190
191 if !results_match {
192 return Err(Error::Kernel(KernelError::Driver(format!(
193 "entrypoint expects results {:?}, got {:?}",
194 expected_results, result_types
195 ))));
196 }
197
198 let params = prepare_params(¶m_types, plan.params())
199 .map_err(|err| Error::Kernel(KernelError::Driver(err)))?;
200 let result_template = prepare_results(&result_types)
201 .map_err(|err| Error::Kernel(KernelError::Driver(err)))?;
202 let signature_clone = signature.clone();
203 let (start_tx, start_rx) = tokio::sync::oneshot::channel();
204 let handle = tokio::spawn(async move {
205 if start_rx.await.is_err() {
208 return Err(wasmtime::Error::msg("process start cancelled"));
209 }
210 invoke_entrypoint(
211 func,
212 store,
213 memory,
214 params,
215 result_template,
216 signature_clone,
217 )
218 .await
219 });
220
221 registry
222 .initialise(process_id, handle)
223 .map_err(|err| Error::Kernel(KernelError::from(err)))?;
224
225 start_tx.send(()).map_err(|_| {
227 Error::Kernel(KernelError::Driver("process start cancelled".to_string()))
228 })?;
229
230 Ok(())
231 }
232}
233
234fn materialise_plan(
235 memory: &Memory,
236 store: &mut Store<InstanceRegistry>,
237 plan: &CallPlan,
238) -> Result<(), Error> {
239 for write in plan.memory_writes() {
240 if write.bytes.is_empty() {
241 continue;
242 }
243
244 let start = usize::try_from(write.offset)
245 .map_err(|err| Error::Kernel(KernelError::IntConvert(err)))?;
246 let end = start
247 .checked_add(write.bytes.len())
248 .ok_or_else(|| Error::Kernel(KernelError::MemoryCapacity))?;
249 let data = memory
250 .data_mut(&mut *store)
251 .get_mut(start..end)
252 .ok_or(Error::Kernel(KernelError::MemoryCapacity))?;
253 data.copy_from_slice(&write.bytes);
254 }
255
256 Ok(())
257}
258
259fn preallocate_memory(memory: &Memory, store: &mut Store<InstanceRegistry>) {
260 let mut current = memory.size(&mut *store);
261 if current < PREALLOC_PAGES {
262 let delta = PREALLOC_PAGES - current;
263 if let Err(err) = memory.grow(&mut *store, delta) {
264 warn!("failed to preallocate guest memory to {PREALLOC_PAGES} pages: {err:?}");
265 }
266 current = memory.size(&mut *store);
267 }
268 let bytes = memory.data_size(&*store);
269 debug!(pages = current, bytes, "prepared guest linear memory");
270}
271
272fn prepare_params(param_types: &[ValType], scalars: &[AbiScalarValue]) -> Result<Vec<Val>, String> {
273 if param_types.len() != scalars.len() {
274 return Err(format!(
275 "entrypoint expects {} params, got {}",
276 param_types.len(),
277 scalars.len()
278 ));
279 }
280
281 scalars
282 .iter()
283 .zip(param_types.iter())
284 .map(|(scalar, ty)| scalar_to_val(scalar, ty))
285 .collect()
286}
287
288fn prepare_results(result_types: &[ValType]) -> Result<Vec<Val>, String> {
289 Ok(result_types
290 .iter()
291 .map(|ty| default_val(ty.clone()))
292 .collect())
293}
294
295fn stub_operations_for_missing(requested: &HashSet<Capability>) -> Vec<Arc<dyn LinkableOperation>> {
296 let hostcalls_by_capability = hostcalls::by_capability();
297
298 selium_abi::Capability::ALL
299 .iter()
300 .copied()
301 .filter(|capability| !requested.contains(capability))
302 .flat_map(|capability| {
303 hostcalls_by_capability
304 .get(&capability)
305 .into_iter()
306 .flatten()
307 .map(move |meta| {
308 StubOperation::new(meta.name, capability) as Arc<dyn LinkableOperation>
309 })
310 })
311 .collect()
312}
313
314struct StubOperation {
315 module: &'static str,
316 capability: Capability,
317}
318
319impl StubOperation {
320 fn new(module: &'static str, capability: Capability) -> Arc<Self> {
321 Arc::new(Self { module, capability })
322 }
323
324 fn create_stub_future(
325 mut caller: Caller<'_, InstanceRegistry>,
326 module: &'static str,
327 capability: Capability,
328 ) -> Result<GuestUint, KernelError> {
329 debug!(%module, ?capability, "invoking stub capability binding");
330
331 let state = FutureSharedState::new();
332 state.resolve(Err(GuestError::PermissionDenied));
333 let handle = caller.data_mut().insert_future(state)?;
334
335 GuestUint::try_from(handle).map_err(KernelError::IntConvert)
336 }
337
338 fn poll_stub_future(
339 mut caller: Caller<'_, InstanceRegistry>,
340 state_id: GuestUint,
341 _task_id: GuestUint,
342 result_ptr: GuestInt,
343 result_capacity: GuestUint,
344 module: &'static str,
345 capability: Capability,
346 ) -> Result<GuestUint, KernelError> {
347 debug!(%module, ?capability, "polling stub capability binding");
348
349 let state_id = usize::try_from(state_id).map_err(KernelError::IntConvert)?;
350 let result = match caller.data_mut().remove_future(state_id) {
351 Some(state) => state
352 .take_result()
353 .unwrap_or(Err(GuestError::PermissionDenied)),
354 None => Err(GuestError::NotFound),
355 };
356
357 write_poll_result(&mut caller, result_ptr, result_capacity, result)
358 }
359
360 fn drop_stub_future(
361 mut caller: Caller<'_, InstanceRegistry>,
362 state_id: GuestUint,
363 result_ptr: GuestInt,
364 result_capacity: GuestUint,
365 module: &'static str,
366 capability: Capability,
367 ) -> Result<GuestUint, KernelError> {
368 debug!(%module, ?capability, "dropping stub capability binding");
369
370 let state_id = usize::try_from(state_id).map_err(KernelError::IntConvert)?;
371 let result = if let Some(state) = caller.data_mut().remove_future(state_id) {
372 state.abandon();
373 Ok(Vec::new())
374 } else {
375 Err(GuestError::NotFound)
376 };
377
378 write_poll_result(&mut caller, result_ptr, result_capacity, result)
379 }
380}
381
382impl LinkableOperation for StubOperation {
383 fn link(&self, linker: &mut Linker<InstanceRegistry>) -> Result<(), KernelError> {
384 let module = self.module;
385 let capability = self.capability;
386 linker.func_wrap(
387 module,
388 "create",
389 move |caller: Caller<'_, InstanceRegistry>,
390 _args_ptr: GuestInt,
391 _args_len: GuestUint| {
392 StubOperation::create_stub_future(caller, module, capability).map_err(Into::into)
393 },
394 )?;
395
396 let module = self.module;
397 let capability = self.capability;
398 linker.func_wrap(
399 module,
400 "poll",
401 move |caller: Caller<'_, InstanceRegistry>,
402 state_id: GuestUint,
403 task_id: GuestUint,
404 result_ptr: GuestInt,
405 result_capacity: GuestUint| {
406 StubOperation::poll_stub_future(
407 caller,
408 state_id,
409 task_id,
410 result_ptr,
411 result_capacity,
412 module,
413 capability,
414 )
415 .map_err(Into::into)
416 },
417 )?;
418
419 let module = self.module;
420 let capability = self.capability;
421 linker.func_wrap(
422 module,
423 "drop",
424 move |caller: Caller<'_, InstanceRegistry>,
425 state_id: GuestUint,
426 result_ptr: GuestInt,
427 result_capacity: GuestUint| {
428 StubOperation::drop_stub_future(
429 caller,
430 state_id,
431 result_ptr,
432 result_capacity,
433 module,
434 capability,
435 )
436 .map_err(Into::into)
437 },
438 )?;
439
440 Ok(())
441 }
442}
443
444async fn invoke_entrypoint(
445 func: Func,
446 mut store: Store<InstanceRegistry>,
447 memory: Memory,
448 params: Vec<Val>,
449 mut results: Vec<Val>,
450 signature: AbiSignature,
451) -> Result<Vec<AbiValue>, wasmtime::Error> {
452 func.call_async(&mut store, ¶ms, &mut results).await?;
453 decode_results(&memory, &store, &results, &signature)
454}
455
456fn decode_results(
457 memory: &Memory,
458 store: &Store<InstanceRegistry>,
459 raw: &[Val],
460 signature: &AbiSignature,
461) -> Result<Vec<AbiValue>, wasmtime::Error> {
462 let mut iter = raw.iter();
463 let mut values = Vec::new();
464
465 for param in signature.results() {
466 match param {
467 AbiParam::Scalar(kind) => {
468 let scalar = decode_scalar(&mut iter, *kind)?;
469 values.push(AbiValue::Scalar(scalar));
470 }
471 AbiParam::Buffer => {
472 let ptr_val = iter
473 .next()
474 .ok_or_else(|| wasmtime::Error::msg("missing buffer pointer"))?;
475 let len_val = iter
476 .next()
477 .ok_or_else(|| wasmtime::Error::msg("missing buffer length"))?;
478 let ptr = match ptr_val {
479 Val::I32(v) if *v >= 0 => *v as usize,
480 _ => return Err(wasmtime::Error::msg("buffer pointer must be i32")),
481 };
482 let len = match len_val {
483 Val::I32(v) if *v >= 0 => *v as usize,
484 _ => return Err(wasmtime::Error::msg("buffer length must be i32")),
485 };
486
487 if len == 0 {
488 values.push(AbiValue::Buffer(Vec::new()));
489 continue;
490 }
491
492 let data = memory
493 .data(store)
494 .get(ptr..ptr + len)
495 .ok_or_else(|| wasmtime::Error::msg("buffer result out of bounds"))?;
496 values.push(AbiValue::Buffer(data.to_vec()));
497 }
498 }
499 }
500
501 if iter.next().is_some() {
502 return Err(wasmtime::Error::msg("extra values returned by entrypoint"));
503 }
504
505 Ok(values)
506}
507
508fn scalar_to_val(value: &AbiScalarValue, ty: &ValType) -> Result<Val, String> {
509 match (value, ty) {
510 (AbiScalarValue::I32(v), ValType::I32) => Ok(Val::I32(*v)),
511 (AbiScalarValue::U32(v), ValType::I32) => {
512 let bits = i32::from_ne_bytes(v.to_ne_bytes());
513 Ok(Val::I32(bits))
514 }
515 (AbiScalarValue::I16(v), ValType::I32) => Ok(Val::I32(i32::from(*v))),
516 (AbiScalarValue::U16(v), ValType::I32) => Ok(Val::I32(i32::from(*v))),
517 (AbiScalarValue::I8(v), ValType::I32) => Ok(Val::I32(i32::from(*v))),
518 (AbiScalarValue::U8(v), ValType::I32) => Ok(Val::I32(i32::from(*v))),
519 (AbiScalarValue::I64(v), ValType::I64) => Ok(Val::I64(*v)),
520 (AbiScalarValue::F32(v), ValType::F32) => Ok(Val::F32(v.to_bits())),
521 (AbiScalarValue::F64(v), ValType::F64) => Ok(Val::F64(v.to_bits())),
522 _ => Err(format!(
523 "type mismatch: value {:?} cannot be passed as {:?}",
524 value, ty
525 )),
526 }
527}
528
529fn decode_scalar(
530 iter: &mut std::slice::Iter<Val>,
531 expected: AbiScalarType,
532) -> Result<AbiScalarValue, wasmtime::Error> {
533 match expected {
534 AbiScalarType::I8 => {
535 let raw = take_i32(iter, "missing i8 result")?;
536 i8::try_from(raw)
537 .map(AbiScalarValue::I8)
538 .map_err(|_| wasmtime::Error::msg("i8 result out of range"))
539 }
540 AbiScalarType::U8 => {
541 let raw = take_u32(iter, "missing u8 result")?;
542 u8::try_from(raw)
543 .map(AbiScalarValue::U8)
544 .map_err(|_| wasmtime::Error::msg("u8 result out of range"))
545 }
546 AbiScalarType::I16 => {
547 let raw = take_i32(iter, "missing i16 result")?;
548 i16::try_from(raw)
549 .map(AbiScalarValue::I16)
550 .map_err(|_| wasmtime::Error::msg("i16 result out of range"))
551 }
552 AbiScalarType::U16 => {
553 let raw = take_u32(iter, "missing u16 result")?;
554 u16::try_from(raw)
555 .map(AbiScalarValue::U16)
556 .map_err(|_| wasmtime::Error::msg("u16 result out of range"))
557 }
558 AbiScalarType::I32 => {
559 let raw = take_i32(iter, "missing i32 result")?;
560 Ok(AbiScalarValue::I32(raw))
561 }
562 AbiScalarType::U32 => {
563 let raw = take_u32(iter, "missing u32 result")?;
564 Ok(AbiScalarValue::U32(raw))
565 }
566 AbiScalarType::I64 => {
567 let lo = take_u32(iter, "missing low i64 result")?;
568 let hi = take_u32(iter, "missing high i64 result")?;
569 let combined = (u64::from(hi) << 32) | u64::from(lo);
570 Ok(AbiScalarValue::I64(i64::from_le_bytes(
571 combined.to_le_bytes(),
572 )))
573 }
574 AbiScalarType::U64 => {
575 let lo = take_u32(iter, "missing low u64 result")?;
576 let hi = take_u32(iter, "missing high u64 result")?;
577 let combined = (u64::from(hi) << 32) | u64::from(lo);
578 Ok(AbiScalarValue::U64(combined))
579 }
580 AbiScalarType::F32 => {
581 let val = iter
582 .next()
583 .ok_or_else(|| wasmtime::Error::msg("missing f32 result"))?;
584 match val {
585 Val::F32(bits) => Ok(AbiScalarValue::F32(f32::from_bits(*bits))),
586 _ => Err(wasmtime::Error::msg("f32 result must be f32")),
587 }
588 }
589 AbiScalarType::F64 => {
590 let val = iter
591 .next()
592 .ok_or_else(|| wasmtime::Error::msg("missing f64 result"))?;
593 match val {
594 Val::F64(bits) => Ok(AbiScalarValue::F64(f64::from_bits(*bits))),
595 _ => Err(wasmtime::Error::msg("f64 result must be f64")),
596 }
597 }
598 }
599}
600
601fn default_val(ty: ValType) -> Val {
602 match ty {
603 ValType::I32 => Val::I32(0),
604 ValType::I64 => Val::I64(0),
605 ValType::F32 => Val::F32(0u32),
606 ValType::F64 => Val::F64(0u64),
607 other => panic!("unsupported Wasm value type in entrypoint: {other:?}"),
608 }
609}
610
611fn flatten_signature_types(spec: &[AbiParam]) -> Vec<ValType> {
612 let mut types = Vec::new();
613 for param in spec {
614 match param {
615 AbiParam::Scalar(kind) => push_scalar_types(*kind, &mut types),
616 AbiParam::Buffer => {
617 types.push(ValType::I32);
618 types.push(ValType::I32);
619 }
620 }
621 }
622 types
623}
624
625fn push_scalar_types(kind: AbiScalarType, types: &mut Vec<ValType>) {
626 match kind {
627 AbiScalarType::F32 => types.push(ValType::F32),
628 AbiScalarType::F64 => types.push(ValType::F64),
629 AbiScalarType::I64 | AbiScalarType::U64 => {
630 types.push(ValType::I32);
631 types.push(ValType::I32);
632 }
633 AbiScalarType::I8
634 | AbiScalarType::U8
635 | AbiScalarType::I16
636 | AbiScalarType::U16
637 | AbiScalarType::I32
638 | AbiScalarType::U32 => types.push(ValType::I32),
639 }
640}
641
642fn valtype_eq(a: &ValType, b: &ValType) -> bool {
643 matches!(
644 (a, b),
645 (ValType::I32, ValType::I32)
646 | (ValType::I64, ValType::I64)
647 | (ValType::F32, ValType::F32)
648 | (ValType::F64, ValType::F64)
649 )
650}
651
652fn take_i32(iter: &mut std::slice::Iter<Val>, msg: &str) -> Result<i32, wasmtime::Error> {
653 let Some(val) = iter.next() else {
654 return Err(wasmtime::Error::msg(msg.to_owned()));
655 };
656
657 match val {
658 Val::I32(v) => Ok(*v),
659 _ => Err(wasmtime::Error::msg(msg.to_owned())),
660 }
661}
662
663fn take_u32(iter: &mut std::slice::Iter<Val>, msg: &str) -> Result<u32, wasmtime::Error> {
664 let raw = take_i32(iter, msg)?;
665 Ok(u32::from_ne_bytes(raw.to_ne_bytes()))
666}