Skip to main content

selium_wasmtime/
lib.rs

1//! Wasmtime subsystem integration for Selium runtime.
2
3use std::{
4    collections::{HashMap, HashSet},
5    sync::{Arc, RwLock},
6};
7
8use selium_abi::EntrypointInvocation;
9use selium_abi::{
10    self, AbiParam, AbiScalarType, AbiScalarValue, AbiSignature, AbiValue, CallPlan, CallPlanError,
11    hostcalls,
12};
13use selium_kernel::{
14    KernelError,
15    drivers::{Capability, module_store::ModuleStoreError, process::EntrypointInvocationExt},
16    futures::FutureSharedState,
17    guest_async::GuestAsync,
18    guest_data::{GuestError, GuestInt, GuestUint, write_poll_result},
19    mailbox,
20    operation::LinkableOperation,
21    registry::{InstanceRegistry, ProcessIdentity, Registry, ResourceId},
22};
23use thiserror::Error;
24use tracing::{debug, warn};
25use wasmtime::{Caller, Config, Engine, Func, Linker, Memory, Module, Store, Val, ValType};
26
27mod driver;
28pub use driver::WasmtimeDriver;
29
30pub struct WasmRuntime {
31    engine: Engine,
32    available_caps: RwLock<HashMap<Capability, Vec<Arc<dyn LinkableOperation>>>>,
33    guest_async: Arc<GuestAsync>,
34}
35
36const PREALLOC_PAGES: u64 = 256;
37
38#[derive(Error, Debug)]
39pub enum Error {
40    #[error("The requested capability ({0}) is not part of this kernel")]
41    CapabilityUnavailable(Capability),
42    #[error("Selium kernel error: {0}")]
43    Kernel(#[from] KernelError),
44    #[error("Module store error: {0}")]
45    ModuleStore(#[from] ModuleStoreError),
46    #[error("Wasmtime error: {0}")]
47    Wasmtime(#[from] wasmtime::Error),
48    #[error("The lock guarding the Capability registry has been poisoned")]
49    CapabilityRegistryPoisoned,
50}
51
52impl From<CallPlanError> for Error {
53    fn from(value: CallPlanError) -> Self {
54        Self::Kernel(KernelError::Driver(value.to_string()))
55    }
56}
57
58impl WasmRuntime {
59    pub fn new(
60        available_caps: HashMap<Capability, Vec<Arc<dyn LinkableOperation>>>,
61        guest_async: Arc<GuestAsync>,
62    ) -> Result<Self, Error> {
63        let mut config = Config::new();
64        config.memory_may_move(false);
65
66        Ok(Self {
67            engine: Engine::new(&config)?,
68            available_caps: RwLock::new(available_caps),
69            guest_async,
70        })
71    }
72
73    pub fn extend_capability(
74        &self,
75        capability: Capability,
76        operations: impl IntoIterator<Item = Arc<dyn LinkableOperation>>,
77    ) -> Result<(), Error> {
78        let mut map = self
79            .available_caps
80            .write()
81            .map_err(|_| Error::CapabilityRegistryPoisoned)?;
82        let entry = map.entry(capability).or_default();
83        entry.extend(operations);
84        Ok(())
85    }
86
87    pub async fn run(
88        &self,
89        registry: &Arc<Registry>,
90        process_id: ResourceId,
91        module: Module,
92        name: &str,
93        capabilities: &[Capability],
94        entrypoint: EntrypointInvocation,
95    ) -> Result<(), Error> {
96        let mut linker = Linker::new(&self.engine);
97        let operations_to_link = {
98            let map = self
99                .available_caps
100                .read()
101                .map_err(|_| Error::CapabilityRegistryPoisoned)?;
102            let mut ops = Vec::new();
103            let requested: HashSet<Capability> = capabilities.iter().copied().collect();
104            for capability in &requested {
105                let operations = map
106                    .get(capability)
107                    .ok_or(Error::CapabilityUnavailable(*capability))?;
108
109                if operations.is_empty() {
110                    return Err(Error::CapabilityUnavailable(*capability));
111                }
112
113                ops.extend(operations.iter().cloned());
114            }
115            ops.extend(stub_operations_for_missing(&requested));
116            ops
117        };
118
119        for op in operations_to_link {
120            op.link(&mut linker)?;
121        }
122
123        self.guest_async.link(&mut linker)?;
124
125        let instance_registry = registry.instance().map_err(KernelError::from)?;
126        let mut store = Store::new(&self.engine, instance_registry);
127        store
128            .data_mut()
129            .set_process_id(process_id)
130            .map_err(KernelError::from)?;
131        let identity = ProcessIdentity::new(process_id);
132        store
133            .data_mut()
134            .insert_extension(identity)
135            .map_err(KernelError::from)?;
136        // Limit linear memory growth to keep the mailbox pointers stable across the
137        // instance lifetime. We preallocate and then lock the limit to the current
138        // size so guest-initiated growth fails fast instead of moving the base
139        // address out from under host-side wakers.
140        let instance = linker.instantiate_async(&mut store, &module).await?;
141
142        // Initialise waker mailbox
143        let memory = instance.get_memory(&mut store, "memory").ok_or_else(|| {
144            Error::Kernel(KernelError::Driver("guest memory missing".to_string()))
145        })?;
146        preallocate_memory(&memory, &mut store);
147        let mb = unsafe { mailbox::create_guest_mailbox(&memory, &mut store) };
148        store
149            .data_mut()
150            .load_mailbox(mb)
151            .map_err(KernelError::from)?;
152
153        let signature = entrypoint.signature().clone();
154        let call_values = {
155            let registry = store.data_mut();
156            entrypoint.materialise_values(registry)?
157        };
158        let plan = CallPlan::new(&signature, &call_values)?;
159        materialise_plan(&memory, &mut store, &plan)?;
160
161        let func = instance.get_func(&mut store, name).ok_or_else(|| {
162            Error::Wasmtime(wasmtime::Error::msg(format!(
163                "entrypoint `{name}` not found"
164            )))
165        })?;
166        let func_ty = func.ty(&store);
167        let param_types: Vec<ValType> = func_ty.params().collect();
168        let result_types: Vec<ValType> = func_ty.results().collect();
169        let expected_params = flatten_signature_types(signature.params());
170        let expected_results = flatten_signature_types(signature.results());
171
172        let params_match = param_types.len() == expected_params.len()
173            && param_types
174                .iter()
175                .zip(expected_params.iter())
176                .all(|(actual, expected)| valtype_eq(actual, expected));
177
178        if !params_match {
179            return Err(Error::Kernel(KernelError::Driver(format!(
180                "entrypoint `{name}` expects params {:?}, got {:?}",
181                expected_params, param_types
182            ))));
183        }
184
185        let results_match = result_types.len() == expected_results.len()
186            && result_types
187                .iter()
188                .zip(expected_results.iter())
189                .all(|(actual, expected)| valtype_eq(actual, expected));
190
191        if !results_match {
192            return Err(Error::Kernel(KernelError::Driver(format!(
193                "entrypoint expects results {:?}, got {:?}",
194                expected_results, result_types
195            ))));
196        }
197
198        let params = prepare_params(&param_types, plan.params())
199            .map_err(|err| Error::Kernel(KernelError::Driver(err)))?;
200        let result_template = prepare_results(&result_types)
201            .map_err(|err| Error::Kernel(KernelError::Driver(err)))?;
202        let signature_clone = signature.clone();
203        let (start_tx, start_rx) = tokio::sync::oneshot::channel();
204        let handle = tokio::spawn(async move {
205            // Wait for registration before invoking entrypoint. This prevents races between
206            // guests registering resources and the process_id being set on the registry.
207            if start_rx.await.is_err() {
208                return Err(wasmtime::Error::msg("process start cancelled"));
209            }
210            invoke_entrypoint(
211                func,
212                store,
213                memory,
214                params,
215                result_template,
216                signature_clone,
217            )
218            .await
219        });
220
221        registry
222            .initialise(process_id, handle)
223            .map_err(|err| Error::Kernel(KernelError::from(err)))?;
224
225        // Trigger entrypoint exec
226        start_tx.send(()).map_err(|_| {
227            Error::Kernel(KernelError::Driver("process start cancelled".to_string()))
228        })?;
229
230        Ok(())
231    }
232}
233
234fn materialise_plan(
235    memory: &Memory,
236    store: &mut Store<InstanceRegistry>,
237    plan: &CallPlan,
238) -> Result<(), Error> {
239    for write in plan.memory_writes() {
240        if write.bytes.is_empty() {
241            continue;
242        }
243
244        let start = usize::try_from(write.offset)
245            .map_err(|err| Error::Kernel(KernelError::IntConvert(err)))?;
246        let end = start
247            .checked_add(write.bytes.len())
248            .ok_or_else(|| Error::Kernel(KernelError::MemoryCapacity))?;
249        let data = memory
250            .data_mut(&mut *store)
251            .get_mut(start..end)
252            .ok_or(Error::Kernel(KernelError::MemoryCapacity))?;
253        data.copy_from_slice(&write.bytes);
254    }
255
256    Ok(())
257}
258
259fn preallocate_memory(memory: &Memory, store: &mut Store<InstanceRegistry>) {
260    let mut current = memory.size(&mut *store);
261    if current < PREALLOC_PAGES {
262        let delta = PREALLOC_PAGES - current;
263        if let Err(err) = memory.grow(&mut *store, delta) {
264            warn!("failed to preallocate guest memory to {PREALLOC_PAGES} pages: {err:?}");
265        }
266        current = memory.size(&mut *store);
267    }
268    let bytes = memory.data_size(&*store);
269    debug!(pages = current, bytes, "prepared guest linear memory");
270}
271
272fn prepare_params(param_types: &[ValType], scalars: &[AbiScalarValue]) -> Result<Vec<Val>, String> {
273    if param_types.len() != scalars.len() {
274        return Err(format!(
275            "entrypoint expects {} params, got {}",
276            param_types.len(),
277            scalars.len()
278        ));
279    }
280
281    scalars
282        .iter()
283        .zip(param_types.iter())
284        .map(|(scalar, ty)| scalar_to_val(scalar, ty))
285        .collect()
286}
287
288fn prepare_results(result_types: &[ValType]) -> Result<Vec<Val>, String> {
289    Ok(result_types
290        .iter()
291        .map(|ty| default_val(ty.clone()))
292        .collect())
293}
294
295fn stub_operations_for_missing(requested: &HashSet<Capability>) -> Vec<Arc<dyn LinkableOperation>> {
296    let hostcalls_by_capability = hostcalls::by_capability();
297
298    selium_abi::Capability::ALL
299        .iter()
300        .copied()
301        .filter(|capability| !requested.contains(capability))
302        .flat_map(|capability| {
303            hostcalls_by_capability
304                .get(&capability)
305                .into_iter()
306                .flatten()
307                .map(move |meta| {
308                    StubOperation::new(meta.name, capability) as Arc<dyn LinkableOperation>
309                })
310        })
311        .collect()
312}
313
314struct StubOperation {
315    module: &'static str,
316    capability: Capability,
317}
318
319impl StubOperation {
320    fn new(module: &'static str, capability: Capability) -> Arc<Self> {
321        Arc::new(Self { module, capability })
322    }
323
324    fn create_stub_future(
325        mut caller: Caller<'_, InstanceRegistry>,
326        module: &'static str,
327        capability: Capability,
328    ) -> Result<GuestUint, KernelError> {
329        debug!(%module, ?capability, "invoking stub capability binding");
330
331        let state = FutureSharedState::new();
332        state.resolve(Err(GuestError::PermissionDenied));
333        let handle = caller.data_mut().insert_future(state)?;
334
335        GuestUint::try_from(handle).map_err(KernelError::IntConvert)
336    }
337
338    fn poll_stub_future(
339        mut caller: Caller<'_, InstanceRegistry>,
340        state_id: GuestUint,
341        _task_id: GuestUint,
342        result_ptr: GuestInt,
343        result_capacity: GuestUint,
344        module: &'static str,
345        capability: Capability,
346    ) -> Result<GuestUint, KernelError> {
347        debug!(%module, ?capability, "polling stub capability binding");
348
349        let state_id = usize::try_from(state_id).map_err(KernelError::IntConvert)?;
350        let result = match caller.data_mut().remove_future(state_id) {
351            Some(state) => state
352                .take_result()
353                .unwrap_or(Err(GuestError::PermissionDenied)),
354            None => Err(GuestError::NotFound),
355        };
356
357        write_poll_result(&mut caller, result_ptr, result_capacity, result)
358    }
359
360    fn drop_stub_future(
361        mut caller: Caller<'_, InstanceRegistry>,
362        state_id: GuestUint,
363        result_ptr: GuestInt,
364        result_capacity: GuestUint,
365        module: &'static str,
366        capability: Capability,
367    ) -> Result<GuestUint, KernelError> {
368        debug!(%module, ?capability, "dropping stub capability binding");
369
370        let state_id = usize::try_from(state_id).map_err(KernelError::IntConvert)?;
371        let result = if let Some(state) = caller.data_mut().remove_future(state_id) {
372            state.abandon();
373            Ok(Vec::new())
374        } else {
375            Err(GuestError::NotFound)
376        };
377
378        write_poll_result(&mut caller, result_ptr, result_capacity, result)
379    }
380}
381
382impl LinkableOperation for StubOperation {
383    fn link(&self, linker: &mut Linker<InstanceRegistry>) -> Result<(), KernelError> {
384        let module = self.module;
385        let capability = self.capability;
386        linker.func_wrap(
387            module,
388            "create",
389            move |caller: Caller<'_, InstanceRegistry>,
390                  _args_ptr: GuestInt,
391                  _args_len: GuestUint| {
392                StubOperation::create_stub_future(caller, module, capability).map_err(Into::into)
393            },
394        )?;
395
396        let module = self.module;
397        let capability = self.capability;
398        linker.func_wrap(
399            module,
400            "poll",
401            move |caller: Caller<'_, InstanceRegistry>,
402                  state_id: GuestUint,
403                  task_id: GuestUint,
404                  result_ptr: GuestInt,
405                  result_capacity: GuestUint| {
406                StubOperation::poll_stub_future(
407                    caller,
408                    state_id,
409                    task_id,
410                    result_ptr,
411                    result_capacity,
412                    module,
413                    capability,
414                )
415                .map_err(Into::into)
416            },
417        )?;
418
419        let module = self.module;
420        let capability = self.capability;
421        linker.func_wrap(
422            module,
423            "drop",
424            move |caller: Caller<'_, InstanceRegistry>,
425                  state_id: GuestUint,
426                  result_ptr: GuestInt,
427                  result_capacity: GuestUint| {
428                StubOperation::drop_stub_future(
429                    caller,
430                    state_id,
431                    result_ptr,
432                    result_capacity,
433                    module,
434                    capability,
435                )
436                .map_err(Into::into)
437            },
438        )?;
439
440        Ok(())
441    }
442}
443
444async fn invoke_entrypoint(
445    func: Func,
446    mut store: Store<InstanceRegistry>,
447    memory: Memory,
448    params: Vec<Val>,
449    mut results: Vec<Val>,
450    signature: AbiSignature,
451) -> Result<Vec<AbiValue>, wasmtime::Error> {
452    func.call_async(&mut store, &params, &mut results).await?;
453    decode_results(&memory, &store, &results, &signature)
454}
455
456fn decode_results(
457    memory: &Memory,
458    store: &Store<InstanceRegistry>,
459    raw: &[Val],
460    signature: &AbiSignature,
461) -> Result<Vec<AbiValue>, wasmtime::Error> {
462    let mut iter = raw.iter();
463    let mut values = Vec::new();
464
465    for param in signature.results() {
466        match param {
467            AbiParam::Scalar(kind) => {
468                let scalar = decode_scalar(&mut iter, *kind)?;
469                values.push(AbiValue::Scalar(scalar));
470            }
471            AbiParam::Buffer => {
472                let ptr_val = iter
473                    .next()
474                    .ok_or_else(|| wasmtime::Error::msg("missing buffer pointer"))?;
475                let len_val = iter
476                    .next()
477                    .ok_or_else(|| wasmtime::Error::msg("missing buffer length"))?;
478                let ptr = match ptr_val {
479                    Val::I32(v) if *v >= 0 => *v as usize,
480                    _ => return Err(wasmtime::Error::msg("buffer pointer must be i32")),
481                };
482                let len = match len_val {
483                    Val::I32(v) if *v >= 0 => *v as usize,
484                    _ => return Err(wasmtime::Error::msg("buffer length must be i32")),
485                };
486
487                if len == 0 {
488                    values.push(AbiValue::Buffer(Vec::new()));
489                    continue;
490                }
491
492                let data = memory
493                    .data(store)
494                    .get(ptr..ptr + len)
495                    .ok_or_else(|| wasmtime::Error::msg("buffer result out of bounds"))?;
496                values.push(AbiValue::Buffer(data.to_vec()));
497            }
498        }
499    }
500
501    if iter.next().is_some() {
502        return Err(wasmtime::Error::msg("extra values returned by entrypoint"));
503    }
504
505    Ok(values)
506}
507
508fn scalar_to_val(value: &AbiScalarValue, ty: &ValType) -> Result<Val, String> {
509    match (value, ty) {
510        (AbiScalarValue::I32(v), ValType::I32) => Ok(Val::I32(*v)),
511        (AbiScalarValue::U32(v), ValType::I32) => {
512            let bits = i32::from_ne_bytes(v.to_ne_bytes());
513            Ok(Val::I32(bits))
514        }
515        (AbiScalarValue::I16(v), ValType::I32) => Ok(Val::I32(i32::from(*v))),
516        (AbiScalarValue::U16(v), ValType::I32) => Ok(Val::I32(i32::from(*v))),
517        (AbiScalarValue::I8(v), ValType::I32) => Ok(Val::I32(i32::from(*v))),
518        (AbiScalarValue::U8(v), ValType::I32) => Ok(Val::I32(i32::from(*v))),
519        (AbiScalarValue::I64(v), ValType::I64) => Ok(Val::I64(*v)),
520        (AbiScalarValue::F32(v), ValType::F32) => Ok(Val::F32(v.to_bits())),
521        (AbiScalarValue::F64(v), ValType::F64) => Ok(Val::F64(v.to_bits())),
522        _ => Err(format!(
523            "type mismatch: value {:?} cannot be passed as {:?}",
524            value, ty
525        )),
526    }
527}
528
529fn decode_scalar(
530    iter: &mut std::slice::Iter<Val>,
531    expected: AbiScalarType,
532) -> Result<AbiScalarValue, wasmtime::Error> {
533    match expected {
534        AbiScalarType::I8 => {
535            let raw = take_i32(iter, "missing i8 result")?;
536            i8::try_from(raw)
537                .map(AbiScalarValue::I8)
538                .map_err(|_| wasmtime::Error::msg("i8 result out of range"))
539        }
540        AbiScalarType::U8 => {
541            let raw = take_u32(iter, "missing u8 result")?;
542            u8::try_from(raw)
543                .map(AbiScalarValue::U8)
544                .map_err(|_| wasmtime::Error::msg("u8 result out of range"))
545        }
546        AbiScalarType::I16 => {
547            let raw = take_i32(iter, "missing i16 result")?;
548            i16::try_from(raw)
549                .map(AbiScalarValue::I16)
550                .map_err(|_| wasmtime::Error::msg("i16 result out of range"))
551        }
552        AbiScalarType::U16 => {
553            let raw = take_u32(iter, "missing u16 result")?;
554            u16::try_from(raw)
555                .map(AbiScalarValue::U16)
556                .map_err(|_| wasmtime::Error::msg("u16 result out of range"))
557        }
558        AbiScalarType::I32 => {
559            let raw = take_i32(iter, "missing i32 result")?;
560            Ok(AbiScalarValue::I32(raw))
561        }
562        AbiScalarType::U32 => {
563            let raw = take_u32(iter, "missing u32 result")?;
564            Ok(AbiScalarValue::U32(raw))
565        }
566        AbiScalarType::I64 => {
567            let lo = take_u32(iter, "missing low i64 result")?;
568            let hi = take_u32(iter, "missing high i64 result")?;
569            let combined = (u64::from(hi) << 32) | u64::from(lo);
570            Ok(AbiScalarValue::I64(i64::from_le_bytes(
571                combined.to_le_bytes(),
572            )))
573        }
574        AbiScalarType::U64 => {
575            let lo = take_u32(iter, "missing low u64 result")?;
576            let hi = take_u32(iter, "missing high u64 result")?;
577            let combined = (u64::from(hi) << 32) | u64::from(lo);
578            Ok(AbiScalarValue::U64(combined))
579        }
580        AbiScalarType::F32 => {
581            let val = iter
582                .next()
583                .ok_or_else(|| wasmtime::Error::msg("missing f32 result"))?;
584            match val {
585                Val::F32(bits) => Ok(AbiScalarValue::F32(f32::from_bits(*bits))),
586                _ => Err(wasmtime::Error::msg("f32 result must be f32")),
587            }
588        }
589        AbiScalarType::F64 => {
590            let val = iter
591                .next()
592                .ok_or_else(|| wasmtime::Error::msg("missing f64 result"))?;
593            match val {
594                Val::F64(bits) => Ok(AbiScalarValue::F64(f64::from_bits(*bits))),
595                _ => Err(wasmtime::Error::msg("f64 result must be f64")),
596            }
597        }
598    }
599}
600
601fn default_val(ty: ValType) -> Val {
602    match ty {
603        ValType::I32 => Val::I32(0),
604        ValType::I64 => Val::I64(0),
605        ValType::F32 => Val::F32(0u32),
606        ValType::F64 => Val::F64(0u64),
607        other => panic!("unsupported Wasm value type in entrypoint: {other:?}"),
608    }
609}
610
611fn flatten_signature_types(spec: &[AbiParam]) -> Vec<ValType> {
612    let mut types = Vec::new();
613    for param in spec {
614        match param {
615            AbiParam::Scalar(kind) => push_scalar_types(*kind, &mut types),
616            AbiParam::Buffer => {
617                types.push(ValType::I32);
618                types.push(ValType::I32);
619            }
620        }
621    }
622    types
623}
624
625fn push_scalar_types(kind: AbiScalarType, types: &mut Vec<ValType>) {
626    match kind {
627        AbiScalarType::F32 => types.push(ValType::F32),
628        AbiScalarType::F64 => types.push(ValType::F64),
629        AbiScalarType::I64 | AbiScalarType::U64 => {
630            types.push(ValType::I32);
631            types.push(ValType::I32);
632        }
633        AbiScalarType::I8
634        | AbiScalarType::U8
635        | AbiScalarType::I16
636        | AbiScalarType::U16
637        | AbiScalarType::I32
638        | AbiScalarType::U32 => types.push(ValType::I32),
639    }
640}
641
642fn valtype_eq(a: &ValType, b: &ValType) -> bool {
643    matches!(
644        (a, b),
645        (ValType::I32, ValType::I32)
646            | (ValType::I64, ValType::I64)
647            | (ValType::F32, ValType::F32)
648            | (ValType::F64, ValType::F64)
649    )
650}
651
652fn take_i32(iter: &mut std::slice::Iter<Val>, msg: &str) -> Result<i32, wasmtime::Error> {
653    let Some(val) = iter.next() else {
654        return Err(wasmtime::Error::msg(msg.to_owned()));
655    };
656
657    match val {
658        Val::I32(v) => Ok(*v),
659        _ => Err(wasmtime::Error::msg(msg.to_owned())),
660    }
661}
662
663fn take_u32(iter: &mut std::slice::Iter<Val>, msg: &str) -> Result<u32, wasmtime::Error> {
664    let raw = take_i32(iter, msg)?;
665    Ok(u32::from_ne_bytes(raw.to_ne_bytes()))
666}