wrpc_runtime_wasmtime/rpc/
mod.rs

1//! `wrpc:transport` implementation
2
3use core::any::Any;
4use core::fmt;
5use core::future::Future;
6use core::marker::PhantomData;
7use core::pin::Pin;
8use core::task::{Context, Poll};
9
10use std::sync::Arc;
11
12use anyhow::Context as _;
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14use wasmtime::component::Linker;
15use wasmtime_wasi::Pollable;
16use wrpc_transport::Invoke;
17
18use crate::{bindings, WrpcView};
19
20mod host;
21
22/// Wrapper struct, for which [crate::bindings::wrpc::transport::transport::Host] is implemented
23#[repr(transparent)]
24pub struct WrpcRpcImpl<T>(pub T);
25
26fn type_annotate<T, F>(val: F) -> F
27where
28    F: Fn(&mut T) -> WrpcRpcImpl<&mut T>,
29{
30    val
31}
32
33pub fn add_to_linker<T>(linker: &mut Linker<T>) -> anyhow::Result<()>
34where
35    T: WrpcView,
36    T::Invoke: Clone + 'static,
37    <T::Invoke as Invoke>::Context: 'static,
38{
39    let closure = type_annotate::<T, _>(|t| WrpcRpcImpl(t));
40    bindings::rpc::context::add_to_linker_get_host(linker, closure)
41        .context("failed to link `wrpc:rpc/context`")?;
42    bindings::rpc::error::add_to_linker_get_host(linker, closure)
43        .context("failed to link `wrpc:rpc/error`")?;
44    bindings::rpc::invoker::add_to_linker_get_host(linker, closure)
45        .context("failed to link `wrpc:rpc/invoker`")?;
46    bindings::rpc::transport::add_to_linker_get_host(linker, closure)
47        .context("failed to link `wrpc:rpc/transport`")?;
48    Ok(())
49}
50
51/// RPC error
52pub enum Error {
53    /// Error originating from [Invoke::invoke] call
54    Invoke(anyhow::Error),
55    /// Error originating from [Index::index](wrpc_transport::Index::index) call on [Invoke::Incoming].
56    IncomingIndex(anyhow::Error),
57    /// Error originating from [Index::index](wrpc_transport::Index::index) call on
58    /// [Invoke::Outgoing].
59    OutgoingIndex(anyhow::Error),
60    /// Error originating from a `wasi:io` stream provided by this crate.
61    Stream(StreamError),
62}
63
64impl fmt::Debug for Error {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        match self {
67            Error::Invoke(error) | Error::IncomingIndex(error) | Error::OutgoingIndex(error) => {
68                error.fmt(f)
69            }
70            Error::Stream(error) => error.fmt(f),
71        }
72    }
73}
74
75impl fmt::Display for Error {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        match self {
78            Error::Invoke(error) | Error::IncomingIndex(error) | Error::OutgoingIndex(error) => {
79                error.fmt(f)
80            }
81            Error::Stream(error) => error.fmt(f),
82        }
83    }
84}
85
86/// Error type originating from `wasi:io` streams provided by this crate.
87pub enum StreamError {
88    LockPoisoned,
89    TypeMismatch(&'static str),
90    Read(std::io::Error),
91    Write(std::io::Error),
92    Flush(std::io::Error),
93    Shutdown(std::io::Error),
94}
95
96impl core::error::Error for StreamError {}
97
98impl fmt::Debug for StreamError {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        match self {
101            StreamError::LockPoisoned => "lock poisoned".fmt(f),
102            StreamError::TypeMismatch(error) => error.fmt(f),
103            StreamError::Read(error)
104            | StreamError::Write(error)
105            | StreamError::Flush(error)
106            | StreamError::Shutdown(error) => error.fmt(f),
107        }
108    }
109}
110
111impl fmt::Display for StreamError {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        match self {
114            StreamError::LockPoisoned => "lock poisoned".fmt(f),
115            StreamError::TypeMismatch(error) => error.fmt(f),
116            StreamError::Read(error)
117            | StreamError::Write(error)
118            | StreamError::Flush(error)
119            | StreamError::Shutdown(error) => error.fmt(f),
120        }
121    }
122}
123
124pub enum Invocation {
125    Future(Pin<Box<dyn Future<Output = Box<dyn Any + Send>> + Send>>),
126    Ready(Box<dyn Any + Send>),
127}
128
129#[wasmtime_wasi::async_trait]
130impl Pollable for Invocation {
131    async fn ready(&mut self) {
132        match self {
133            Self::Future(fut) => {
134                let res = fut.await;
135                *self = Self::Ready(res);
136            }
137            Self::Ready(..) => {}
138        }
139    }
140}
141
142pub struct OutgoingChannel(pub Arc<std::sync::RwLock<Box<dyn Any + Send + Sync>>>);
143
144pub struct IncomingChannel(pub Arc<std::sync::RwLock<Box<dyn Any + Send + Sync>>>);
145
146pub struct IncomingChannelStream<T> {
147    incoming: IncomingChannel,
148    _ty: PhantomData<T>,
149}
150
151impl<T: AsyncRead + Unpin + 'static> AsyncRead for IncomingChannelStream<T> {
152    fn poll_read(
153        self: Pin<&mut Self>,
154        cx: &mut Context<'_>,
155        buf: &mut ReadBuf<'_>,
156    ) -> Poll<std::io::Result<()>> {
157        let Ok(mut incoming) = self.incoming.0.write() else {
158            return Poll::Ready(Err(std::io::Error::new(
159                std::io::ErrorKind::Deadlock,
160                StreamError::LockPoisoned,
161            )));
162        };
163        let Some(incoming) = incoming.downcast_mut::<T>() else {
164            return Poll::Ready(Err(std::io::Error::new(
165                std::io::ErrorKind::InvalidData,
166                StreamError::TypeMismatch("invalid incoming channel type"),
167            )));
168        };
169        Pin::new(incoming)
170            .poll_read(cx, buf)
171            .map_err(|err| std::io::Error::new(err.kind(), StreamError::Read(err)))
172    }
173}
174
175pub struct OutgoingChannelStream<T> {
176    outgoing: OutgoingChannel,
177    _ty: PhantomData<T>,
178}
179
180impl<T: AsyncWrite + Unpin + 'static> AsyncWrite for OutgoingChannelStream<T> {
181    fn poll_write(
182        self: Pin<&mut Self>,
183        cx: &mut Context<'_>,
184        buf: &[u8],
185    ) -> Poll<Result<usize, std::io::Error>> {
186        let Ok(mut outgoing) = self.outgoing.0.write() else {
187            return Poll::Ready(Err(std::io::Error::new(
188                std::io::ErrorKind::Deadlock,
189                StreamError::LockPoisoned,
190            )));
191        };
192        let Some(outgoing) = outgoing.downcast_mut::<T>() else {
193            return Poll::Ready(Err(std::io::Error::new(
194                std::io::ErrorKind::InvalidData,
195                StreamError::TypeMismatch("invalid outgoing channel type"),
196            )));
197        };
198        Pin::new(outgoing)
199            .poll_write(cx, buf)
200            .map_err(|err| std::io::Error::new(err.kind(), StreamError::Write(err)))
201    }
202
203    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
204        let Ok(mut outgoing) = self.outgoing.0.write() else {
205            return Poll::Ready(Err(std::io::Error::new(
206                std::io::ErrorKind::Deadlock,
207                StreamError::LockPoisoned,
208            )));
209        };
210        let Some(outgoing) = outgoing.downcast_mut::<T>() else {
211            return Poll::Ready(Err(std::io::Error::new(
212                std::io::ErrorKind::InvalidData,
213                StreamError::TypeMismatch("invalid outgoing channel type"),
214            )));
215        };
216        Pin::new(outgoing)
217            .poll_flush(cx)
218            .map_err(|err| std::io::Error::new(err.kind(), StreamError::Flush(err)))
219    }
220
221    fn poll_shutdown(
222        self: Pin<&mut Self>,
223        cx: &mut Context<'_>,
224    ) -> Poll<Result<(), std::io::Error>> {
225        let Ok(mut outgoing) = self.outgoing.0.write() else {
226            return Poll::Ready(Err(std::io::Error::new(
227                std::io::ErrorKind::Deadlock,
228                StreamError::LockPoisoned,
229            )));
230        };
231        let Some(outgoing) = outgoing.downcast_mut::<T>() else {
232            return Poll::Ready(Err(std::io::Error::new(
233                std::io::ErrorKind::InvalidData,
234                StreamError::TypeMismatch("invalid outgoing channel type"),
235            )));
236        };
237        Pin::new(outgoing)
238            .poll_shutdown(cx)
239            .map_err(|err| std::io::Error::new(err.kind(), StreamError::Shutdown(err)))
240    }
241}