wrpc_runtime_wasmtime/
lib.rs

1#![allow(clippy::type_complexity)] // TODO: https://github.com/bytecodealliance/wrpc/issues/2
2
3use core::any::Any;
4use core::borrow::Borrow;
5use core::fmt;
6use core::future::Future;
7use core::iter::zip;
8use core::pin::pin;
9use core::time::Duration;
10
11use std::collections::{BTreeMap, HashMap};
12use std::sync::Arc;
13
14use anyhow::{anyhow, bail, Context as _};
15use bytes::{Bytes, BytesMut};
16use futures::future::try_join_all;
17use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _};
18use tokio_util::codec::Encoder;
19use tracing::{debug, instrument, trace, warn};
20use uuid::Uuid;
21use wasmtime::component::{types, Func, Resource, ResourceAny, ResourceType, Type, Val};
22use wasmtime::{AsContextMut, Engine};
23use wasmtime_wasi::{IoView, WasiView};
24use wrpc_transport::Invoke;
25
26use crate::bindings::rpc::context::Context;
27use crate::bindings::rpc::error::Error;
28use crate::bindings::rpc::transport::{IncomingChannel, Invocation, OutgoingChannel};
29
30pub mod bindings;
31mod codec;
32mod polyfill;
33pub mod rpc;
34mod serve;
35
36pub use codec::*;
37pub use polyfill::*;
38pub use serve::*;
39
40// this returns the RPC name for a wasmtime function name.
41// Unfortunately, the [`types::ComponentFunc`] does not include the kind information and we want to
42// avoid (re-)parsing the WIT here.
43fn rpc_func_name(name: &str) -> &str {
44    if let Some(name) = name.strip_prefix("[constructor]") {
45        name
46    } else if let Some(name) = name.strip_prefix("[static]") {
47        name
48    } else if let Some(name) = name.strip_prefix("[method]") {
49        name
50    } else {
51        name
52    }
53}
54
55fn rpc_result_type<T: Borrow<Type>>(
56    host_resources: &HashMap<Box<str>, HashMap<Box<str>, (ResourceType, ResourceType)>>,
57    results_ty: impl IntoIterator<Item = T>,
58) -> Option<Option<Type>> {
59    let rpc_err_ty = host_resources
60        .get("wrpc:rpc/error@0.1.0")
61        .and_then(|instance| instance.get("error"));
62    let mut results_ty = results_ty.into_iter();
63    match (
64        rpc_err_ty,
65        results_ty.next().as_ref().map(Borrow::borrow),
66        results_ty.next(),
67    ) {
68        (Some((guest_rpc_err_ty, host_rpc_err_ty)), Some(Type::Result(result_ty)), None)
69            if *host_rpc_err_ty == ResourceType::host::<Error>()
70                && result_ty.err() == Some(Type::Own(*guest_rpc_err_ty)) =>
71        {
72            Some(result_ty.ok())
73        }
74        _ => None,
75    }
76}
77
78pub struct RemoteResource(pub Bytes);
79
80/// A table of shared resources exported by the component
81#[derive(Debug, Default)]
82pub struct SharedResourceTable(HashMap<Uuid, ResourceAny>);
83
84pub trait WrpcView: IoView + Send {
85    type Invoke: Invoke;
86
87    /// Returns context to use for invocation
88    fn context(&self) -> <Self::Invoke as Invoke>::Context;
89
90    /// Returns an [Invoke] implementation used to satisfy polyfilled imports
91    fn client(&self) -> &Self::Invoke;
92
93    /// Returns a table of shared exported resources
94    fn shared_resources(&mut self) -> &mut SharedResourceTable;
95
96    /// Optional invocation timeout, component will trap if invocation is not finished within the
97    /// returned [Duration]. If this method returns [None], then no timeout will be used.
98    fn timeout(&self) -> Option<Duration> {
99        None
100    }
101}
102
103impl<T: WrpcView> WrpcView for &mut T {
104    type Invoke = T::Invoke;
105
106    fn context(&self) -> <Self::Invoke as Invoke>::Context {
107        (**self).context()
108    }
109
110    fn client(&self) -> &Self::Invoke {
111        (**self).client()
112    }
113
114    fn shared_resources(&mut self) -> &mut SharedResourceTable {
115        (**self).shared_resources()
116    }
117
118    fn timeout(&self) -> Option<Duration> {
119        (**self).timeout()
120    }
121}
122
123pub trait WrpcViewExt: WrpcView {
124    fn push_invocation(
125        &mut self,
126        invocation: impl Future<
127                Output = anyhow::Result<(
128                    <Self::Invoke as Invoke>::Outgoing,
129                    <Self::Invoke as Invoke>::Incoming,
130                )>,
131            > + Send
132            + 'static,
133    ) -> anyhow::Result<Resource<Invocation>> {
134        self.table()
135            .push(Invocation::Future(Box::pin(async move {
136                let res = invocation.await;
137                Box::new(res) as Box<dyn Any + Send>
138            })))
139            .context("failed to push invocation to table")
140    }
141
142    fn get_invocation_result(
143        &mut self,
144        invocation: &Resource<Invocation>,
145    ) -> anyhow::Result<
146        Option<
147            &Box<
148                anyhow::Result<(
149                    <Self::Invoke as Invoke>::Outgoing,
150                    <Self::Invoke as Invoke>::Incoming,
151                )>,
152            >,
153        >,
154    > {
155        let invocation = self
156            .table()
157            .get(invocation)
158            .context("failed to get invocation from table")?;
159        match invocation {
160            Invocation::Future(..) => Ok(None),
161            Invocation::Ready(res) => {
162                let res = res.downcast_ref().context("invalid invocation type")?;
163                Ok(Some(res))
164            }
165        }
166    }
167
168    fn delete_invocation(
169        &mut self,
170        invocation: Resource<Invocation>,
171    ) -> anyhow::Result<
172        impl Future<
173            Output = anyhow::Result<(
174                <Self::Invoke as Invoke>::Outgoing,
175                <Self::Invoke as Invoke>::Incoming,
176            )>,
177        >,
178    > {
179        let invocation = self
180            .table()
181            .delete(invocation)
182            .context("failed to delete invocation from table")?;
183        Ok(async move {
184            let res = match invocation {
185                Invocation::Future(fut) => fut.await,
186                Invocation::Ready(res) => res,
187            };
188            let res = res
189                .downcast()
190                .map_err(|_| anyhow!("invalid invocation type"))?;
191            *res
192        })
193    }
194
195    fn push_outgoing_channel(
196        &mut self,
197        outgoing: <Self::Invoke as Invoke>::Outgoing,
198    ) -> anyhow::Result<Resource<OutgoingChannel>> {
199        self.table()
200            .push(OutgoingChannel(Arc::new(std::sync::RwLock::new(Box::new(
201                outgoing,
202            )))))
203            .context("failed to push outgoing channel to table")
204    }
205
206    fn delete_outgoing_channel(
207        &mut self,
208        outgoing: Resource<OutgoingChannel>,
209    ) -> anyhow::Result<<Self::Invoke as Invoke>::Outgoing> {
210        let OutgoingChannel(outgoing) = self
211            .table()
212            .delete(outgoing)
213            .context("failed to delete outgoing channel from table")?;
214        let outgoing =
215            Arc::into_inner(outgoing).context("outgoing channel has an active stream")?;
216        let Ok(outgoing) = outgoing.into_inner() else {
217            bail!("lock poisoned");
218        };
219        let outgoing = outgoing
220            .downcast()
221            .map_err(|_| anyhow!("invalid outgoing channel type"))?;
222        Ok(*outgoing)
223    }
224
225    fn push_incoming_channel(
226        &mut self,
227        incoming: <Self::Invoke as Invoke>::Incoming,
228    ) -> anyhow::Result<Resource<IncomingChannel>> {
229        self.table()
230            .push(IncomingChannel(Arc::new(std::sync::RwLock::new(Box::new(
231                incoming,
232            )))))
233            .context("failed to push incoming channel to table")
234    }
235
236    fn delete_incoming_channel(
237        &mut self,
238        incoming: Resource<IncomingChannel>,
239    ) -> anyhow::Result<<Self::Invoke as Invoke>::Incoming> {
240        let IncomingChannel(incoming) = self
241            .table()
242            .delete(incoming)
243            .context("failed to delete incoming channel from table")?;
244        let incoming =
245            Arc::into_inner(incoming).context("incoming channel has an active stream")?;
246        let Ok(incoming) = incoming.into_inner() else {
247            bail!("lock poisoned");
248        };
249        let incoming = incoming
250            .downcast()
251            .map_err(|_| anyhow!("invalid incoming channel type"))?;
252        Ok(*incoming)
253    }
254
255    fn push_error(&mut self, error: Error) -> anyhow::Result<Resource<Error>> {
256        self.table()
257            .push(error)
258            .context("failed to push error to table")
259    }
260
261    fn get_error(&mut self, error: &Resource<Error>) -> anyhow::Result<&Error> {
262        let error = self
263            .table()
264            .get(error)
265            .context("failed to get error from table")?;
266        Ok(error)
267    }
268
269    fn get_error_mut(&mut self, error: &Resource<Error>) -> anyhow::Result<&mut Error> {
270        let error = self
271            .table()
272            .get_mut(error)
273            .context("failed to get error from table")?;
274        Ok(error)
275    }
276
277    fn delete_error(&mut self, error: Resource<Error>) -> anyhow::Result<Error> {
278        let error = self
279            .table()
280            .delete(error)
281            .context("failed to delete error from table")?;
282        Ok(error)
283    }
284
285    fn push_context(
286        &mut self,
287        cx: <Self::Invoke as Invoke>::Context,
288    ) -> anyhow::Result<Resource<Context>>
289    where
290        <Self::Invoke as Invoke>::Context: 'static,
291    {
292        self.table()
293            .push(Context(Box::new(cx)))
294            .context("failed to push context to table")
295    }
296
297    fn delete_context(
298        &mut self,
299        cx: Resource<Context>,
300    ) -> anyhow::Result<<Self::Invoke as Invoke>::Context>
301    where
302        <Self::Invoke as Invoke>::Context: 'static,
303    {
304        let Context(cx) = self
305            .table()
306            .delete(cx)
307            .context("failed to delete context from table")?;
308        let cx = cx.downcast().map_err(|_| anyhow!("invalid context type"))?;
309        Ok(*cx)
310    }
311}
312
313impl<T: WrpcView> WrpcViewExt for T {}
314
315/// Error type returned by [call]
316pub enum CallError {
317    Decode(anyhow::Error),
318    Encode(anyhow::Error),
319    Table(anyhow::Error),
320    Call(anyhow::Error),
321    TypeMismatch(anyhow::Error),
322    Write(anyhow::Error),
323    Flush(anyhow::Error),
324    Deferred(anyhow::Error),
325    PostReturn(anyhow::Error),
326    Guest(Error),
327}
328
329impl core::error::Error for CallError {}
330
331impl fmt::Debug for CallError {
332    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333        match self {
334            CallError::Decode(error)
335            | CallError::Encode(error)
336            | CallError::Table(error)
337            | CallError::Call(error)
338            | CallError::TypeMismatch(error)
339            | CallError::Write(error)
340            | CallError::Flush(error)
341            | CallError::Deferred(error)
342            | CallError::PostReturn(error) => error.fmt(f),
343            CallError::Guest(error) => error.fmt(f),
344        }
345    }
346}
347
348impl fmt::Display for CallError {
349    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350        match self {
351            CallError::Decode(error)
352            | CallError::Encode(error)
353            | CallError::Table(error)
354            | CallError::Call(error)
355            | CallError::TypeMismatch(error)
356            | CallError::Write(error)
357            | CallError::Flush(error)
358            | CallError::Deferred(error)
359            | CallError::PostReturn(error) => error.fmt(f),
360            CallError::Guest(error) => error.fmt(f),
361        }
362    }
363}
364
365#[allow(clippy::too_many_arguments)]
366pub async fn call<C, I, O>(
367    mut store: C,
368    rx: I,
369    mut tx: O,
370    guest_resources: &[ResourceType],
371    host_resources: &HashMap<Box<str>, HashMap<Box<str>, (ResourceType, ResourceType)>>,
372    params_ty: impl ExactSizeIterator<Item = &Type>,
373    results_ty: &[Type],
374    func: Func,
375) -> Result<(), CallError>
376where
377    I: AsyncRead + wrpc_transport::Index<I> + Send + Sync + Unpin + 'static,
378    O: AsyncWrite + wrpc_transport::Index<O> + Send + Sync + Unpin + 'static,
379    C: AsContextMut,
380    C::Data: WasiView + WrpcView,
381{
382    let mut params = vec![Val::Bool(false); params_ty.len()];
383    let mut rx = pin!(rx);
384    for (i, (v, ty)) in zip(&mut params, params_ty).enumerate() {
385        read_value(&mut store, &mut rx, guest_resources, v, ty, &[i])
386            .await
387            .with_context(|| format!("failed to decode parameter value {i}"))
388            .map_err(CallError::Decode)?;
389    }
390    let mut results = vec![Val::Bool(false); results_ty.len()];
391    func.call_async(&mut store, &params, &mut results)
392        .await
393        .context("failed to call function")
394        .map_err(CallError::Call)?;
395
396    let mut buf = BytesMut::default();
397    let mut deferred = vec![];
398    match (
399        &rpc_result_type(host_resources, results_ty),
400        results.as_slice(),
401    ) {
402        (None, results) => {
403            for (i, (v, ty)) in zip(results, results_ty).enumerate() {
404                let mut enc = ValEncoder::new(store.as_context_mut(), ty, guest_resources);
405                enc.encode(v, &mut buf)
406                    .with_context(|| format!("failed to encode result value {i}"))
407                    .map_err(CallError::Encode)?;
408                deferred.push(enc.deferred);
409            }
410        }
411        // `result<_, rpc-eror>`
412        (Some(None), [Val::Result(Ok(None))]) => {}
413        // `result<T, rpc-eror>`
414        (Some(Some(ty)), [Val::Result(Ok(Some(v)))]) => {
415            let mut enc = ValEncoder::new(store.as_context_mut(), ty, guest_resources);
416            enc.encode(v, &mut buf)
417                .context("failed to encode result value 0")
418                .map_err(CallError::Encode)?;
419            deferred.push(enc.deferred);
420        }
421        (Some(..), [Val::Result(Err(Some(err)))]) => {
422            let Val::Resource(err) = &**err else {
423                return Err(CallError::TypeMismatch(anyhow!(
424                    "RPC result error value is not a resource"
425                )));
426            };
427            let mut store = store.as_context_mut();
428            let err = err
429                .try_into_resource(&mut store)
430                .context("RPC result error resource type mismatch")
431                .map_err(CallError::TypeMismatch)?;
432            let err = store
433                .data_mut()
434                .delete_error(err)
435                .map_err(CallError::Table)?;
436            return Err(CallError::Guest(err));
437        }
438        _ => return Err(CallError::TypeMismatch(anyhow!("RPC result type mismatch"))),
439    }
440
441    debug!("transmitting results");
442    tx.write_all(&buf)
443        .await
444        .context("failed to transmit results")
445        .map_err(CallError::Write)?;
446    tx.flush()
447        .await
448        .context("failed to flush outgoing stream")
449        .map_err(CallError::Flush)?;
450    if let Err(err) = tx.shutdown().await {
451        trace!(?err, "failed to shutdown outgoing stream");
452    }
453    try_join_all(
454        zip(0.., deferred)
455            .filter_map(|(i, f)| f.map(|f| (tx.index(&[i]), f)))
456            .map(|(w, f)| async move {
457                let w = w?;
458                f(w).await
459            }),
460    )
461    .await
462    .map_err(CallError::Deferred)?;
463    func.post_return_async(&mut store)
464        .await
465        .context("failed to perform post-return cleanup")
466        .map_err(CallError::PostReturn)?;
467    Ok(())
468}
469
470/// Recursively iterates the component item type and collects all exported resource types
471#[instrument(level = "debug", skip_all)]
472pub fn collect_item_resource_exports(
473    engine: &Engine,
474    ty: types::ComponentItem,
475    resources: &mut impl Extend<types::ResourceType>,
476) {
477    match ty {
478        types::ComponentItem::ComponentFunc(_)
479        | types::ComponentItem::CoreFunc(_)
480        | types::ComponentItem::Module(_)
481        | types::ComponentItem::Type(_) => {}
482        types::ComponentItem::Component(ty) => {
483            collect_component_resource_exports(engine, &ty, resources)
484        }
485
486        types::ComponentItem::ComponentInstance(ty) => {
487            collect_instance_resource_exports(engine, &ty, resources)
488        }
489        types::ComponentItem::Resource(ty) => {
490            debug!(?ty, "collect resource export");
491            resources.extend([ty])
492        }
493    }
494}
495
496/// Recursively iterates the instance type and collects all exported resource types
497#[instrument(level = "debug", skip_all)]
498pub fn collect_instance_resource_exports(
499    engine: &Engine,
500    ty: &types::ComponentInstance,
501    resources: &mut impl Extend<types::ResourceType>,
502) {
503    for (name, ty) in ty.exports(engine) {
504        trace!(name, ?ty, "collect instance item resource exports");
505        collect_item_resource_exports(engine, ty, resources);
506    }
507}
508
509/// Recursively iterates the component type and collects all exported resource types
510#[instrument(level = "debug", skip_all)]
511pub fn collect_component_resource_exports(
512    engine: &Engine,
513    ty: &types::Component,
514    resources: &mut impl Extend<types::ResourceType>,
515) {
516    for (name, ty) in ty.exports(engine) {
517        trace!(name, ?ty, "collect component item resource exports");
518        collect_item_resource_exports(engine, ty, resources);
519    }
520}
521
522/// Iterates the component type and collects all imported resource types
523#[instrument(level = "debug", skip_all)]
524pub fn collect_component_resource_imports(
525    engine: &Engine,
526    ty: &types::Component,
527    resources: &mut BTreeMap<Box<str>, HashMap<Box<str>, types::ResourceType>>,
528) {
529    for (name, ty) in ty.imports(engine) {
530        match ty {
531            types::ComponentItem::ComponentFunc(..)
532            | types::ComponentItem::CoreFunc(..)
533            | types::ComponentItem::Module(..)
534            | types::ComponentItem::Type(..)
535            | types::ComponentItem::Component(..) => {}
536            types::ComponentItem::ComponentInstance(ty) => {
537                let instance = name;
538                for (name, ty) in ty.exports(engine) {
539                    if let types::ComponentItem::Resource(ty) = ty {
540                        debug!(instance, name, ?ty, "collect instance resource import");
541                        if let Some(resources) = resources.get_mut(instance) {
542                            resources.insert(name.into(), ty);
543                        } else {
544                            resources.insert(instance.into(), HashMap::from([(name.into(), ty)]));
545                        }
546                    }
547                }
548            }
549            types::ComponentItem::Resource(ty) => {
550                debug!(name, "collect component resource import");
551                if let Some(resources) = resources.get_mut("") {
552                    resources.insert(name.into(), ty);
553                } else {
554                    resources.insert("".into(), HashMap::from([(name.into(), ty)]));
555                }
556            }
557        }
558    }
559}