wrpc_runtime_wasmtime/rpc/
mod.rs1use core::any::Any;
4use core::fmt;
5use core::future::Future;
6use core::marker::PhantomData;
7use core::pin::Pin;
8use core::task::{Context, Poll};
9
10use std::sync::Arc;
11
12use anyhow::Context as _;
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14use wasmtime::component::Linker;
15use wasmtime_wasi::Pollable;
16use wrpc_transport::Invoke;
17
18use crate::{bindings, WrpcView};
19
20mod host;
21
22#[repr(transparent)]
24pub struct WrpcRpcImpl<T>(pub T);
25
26fn type_annotate<T, F>(val: F) -> F
27where
28 F: Fn(&mut T) -> WrpcRpcImpl<&mut T>,
29{
30 val
31}
32
33pub fn add_to_linker<T>(linker: &mut Linker<T>) -> anyhow::Result<()>
34where
35 T: WrpcView,
36 T::Invoke: Clone + 'static,
37 <T::Invoke as Invoke>::Context: 'static,
38{
39 let closure = type_annotate::<T, _>(|t| WrpcRpcImpl(t));
40 bindings::rpc::context::add_to_linker_get_host(linker, closure)
41 .context("failed to link `wrpc:rpc/context`")?;
42 bindings::rpc::error::add_to_linker_get_host(linker, closure)
43 .context("failed to link `wrpc:rpc/error`")?;
44 bindings::rpc::invoker::add_to_linker_get_host(linker, closure)
45 .context("failed to link `wrpc:rpc/invoker`")?;
46 bindings::rpc::transport::add_to_linker_get_host(linker, closure)
47 .context("failed to link `wrpc:rpc/transport`")?;
48 Ok(())
49}
50
51pub enum Error {
53 Invoke(anyhow::Error),
55 IncomingIndex(anyhow::Error),
57 OutgoingIndex(anyhow::Error),
60 Stream(StreamError),
62}
63
64impl fmt::Debug for Error {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 match self {
67 Error::Invoke(error) | Error::IncomingIndex(error) | Error::OutgoingIndex(error) => {
68 error.fmt(f)
69 }
70 Error::Stream(error) => error.fmt(f),
71 }
72 }
73}
74
75impl fmt::Display for Error {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 match self {
78 Error::Invoke(error) | Error::IncomingIndex(error) | Error::OutgoingIndex(error) => {
79 error.fmt(f)
80 }
81 Error::Stream(error) => error.fmt(f),
82 }
83 }
84}
85
86pub enum StreamError {
88 LockPoisoned,
89 TypeMismatch(&'static str),
90 Read(std::io::Error),
91 Write(std::io::Error),
92 Flush(std::io::Error),
93 Shutdown(std::io::Error),
94}
95
96impl core::error::Error for StreamError {}
97
98impl fmt::Debug for StreamError {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 match self {
101 StreamError::LockPoisoned => "lock poisoned".fmt(f),
102 StreamError::TypeMismatch(error) => error.fmt(f),
103 StreamError::Read(error)
104 | StreamError::Write(error)
105 | StreamError::Flush(error)
106 | StreamError::Shutdown(error) => error.fmt(f),
107 }
108 }
109}
110
111impl fmt::Display for StreamError {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 match self {
114 StreamError::LockPoisoned => "lock poisoned".fmt(f),
115 StreamError::TypeMismatch(error) => error.fmt(f),
116 StreamError::Read(error)
117 | StreamError::Write(error)
118 | StreamError::Flush(error)
119 | StreamError::Shutdown(error) => error.fmt(f),
120 }
121 }
122}
123
124pub enum Invocation {
125 Future(Pin<Box<dyn Future<Output = Box<dyn Any + Send>> + Send>>),
126 Ready(Box<dyn Any + Send>),
127}
128
129#[wasmtime_wasi::async_trait]
130impl Pollable for Invocation {
131 async fn ready(&mut self) {
132 match self {
133 Self::Future(fut) => {
134 let res = fut.await;
135 *self = Self::Ready(res);
136 }
137 Self::Ready(..) => {}
138 }
139 }
140}
141
142pub struct OutgoingChannel(pub Arc<std::sync::RwLock<Box<dyn Any + Send + Sync>>>);
143
144pub struct IncomingChannel(pub Arc<std::sync::RwLock<Box<dyn Any + Send + Sync>>>);
145
146pub struct IncomingChannelStream<T> {
147 incoming: IncomingChannel,
148 _ty: PhantomData<T>,
149}
150
151impl<T: AsyncRead + Unpin + 'static> AsyncRead for IncomingChannelStream<T> {
152 fn poll_read(
153 self: Pin<&mut Self>,
154 cx: &mut Context<'_>,
155 buf: &mut ReadBuf<'_>,
156 ) -> Poll<std::io::Result<()>> {
157 let Ok(mut incoming) = self.incoming.0.write() else {
158 return Poll::Ready(Err(std::io::Error::new(
159 std::io::ErrorKind::Deadlock,
160 StreamError::LockPoisoned,
161 )));
162 };
163 let Some(incoming) = incoming.downcast_mut::<T>() else {
164 return Poll::Ready(Err(std::io::Error::new(
165 std::io::ErrorKind::InvalidData,
166 StreamError::TypeMismatch("invalid incoming channel type"),
167 )));
168 };
169 Pin::new(incoming)
170 .poll_read(cx, buf)
171 .map_err(|err| std::io::Error::new(err.kind(), StreamError::Read(err)))
172 }
173}
174
175pub struct OutgoingChannelStream<T> {
176 outgoing: OutgoingChannel,
177 _ty: PhantomData<T>,
178}
179
180impl<T: AsyncWrite + Unpin + 'static> AsyncWrite for OutgoingChannelStream<T> {
181 fn poll_write(
182 self: Pin<&mut Self>,
183 cx: &mut Context<'_>,
184 buf: &[u8],
185 ) -> Poll<Result<usize, std::io::Error>> {
186 let Ok(mut outgoing) = self.outgoing.0.write() else {
187 return Poll::Ready(Err(std::io::Error::new(
188 std::io::ErrorKind::Deadlock,
189 StreamError::LockPoisoned,
190 )));
191 };
192 let Some(outgoing) = outgoing.downcast_mut::<T>() else {
193 return Poll::Ready(Err(std::io::Error::new(
194 std::io::ErrorKind::InvalidData,
195 StreamError::TypeMismatch("invalid outgoing channel type"),
196 )));
197 };
198 Pin::new(outgoing)
199 .poll_write(cx, buf)
200 .map_err(|err| std::io::Error::new(err.kind(), StreamError::Write(err)))
201 }
202
203 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
204 let Ok(mut outgoing) = self.outgoing.0.write() else {
205 return Poll::Ready(Err(std::io::Error::new(
206 std::io::ErrorKind::Deadlock,
207 StreamError::LockPoisoned,
208 )));
209 };
210 let Some(outgoing) = outgoing.downcast_mut::<T>() else {
211 return Poll::Ready(Err(std::io::Error::new(
212 std::io::ErrorKind::InvalidData,
213 StreamError::TypeMismatch("invalid outgoing channel type"),
214 )));
215 };
216 Pin::new(outgoing)
217 .poll_flush(cx)
218 .map_err(|err| std::io::Error::new(err.kind(), StreamError::Flush(err)))
219 }
220
221 fn poll_shutdown(
222 self: Pin<&mut Self>,
223 cx: &mut Context<'_>,
224 ) -> Poll<Result<(), std::io::Error>> {
225 let Ok(mut outgoing) = self.outgoing.0.write() else {
226 return Poll::Ready(Err(std::io::Error::new(
227 std::io::ErrorKind::Deadlock,
228 StreamError::LockPoisoned,
229 )));
230 };
231 let Some(outgoing) = outgoing.downcast_mut::<T>() else {
232 return Poll::Ready(Err(std::io::Error::new(
233 std::io::ErrorKind::InvalidData,
234 StreamError::TypeMismatch("invalid outgoing channel type"),
235 )));
236 };
237 Pin::new(outgoing)
238 .poll_shutdown(cx)
239 .map_err(|err| std::io::Error::new(err.kind(), StreamError::Shutdown(err)))
240 }
241}