1#![allow(clippy::type_complexity)] use core::any::Any;
4use core::borrow::Borrow;
5use core::fmt;
6use core::future::Future;
7use core::iter::zip;
8use core::pin::pin;
9use core::time::Duration;
10
11use std::collections::{BTreeMap, HashMap};
12use std::sync::Arc;
13
14use anyhow::{anyhow, bail, Context as _};
15use bytes::{Bytes, BytesMut};
16use futures::future::try_join_all;
17use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _};
18use tokio_util::codec::Encoder;
19use tracing::{debug, instrument, trace, warn};
20use uuid::Uuid;
21use wasmtime::component::{types, Func, Resource, ResourceAny, ResourceType, Type, Val};
22use wasmtime::{AsContextMut, Engine};
23use wasmtime_wasi::{IoView, WasiView};
24use wrpc_transport::Invoke;
25
26use crate::bindings::rpc::context::Context;
27use crate::bindings::rpc::error::Error;
28use crate::bindings::rpc::transport::{IncomingChannel, Invocation, OutgoingChannel};
29
30pub mod bindings;
31mod codec;
32mod polyfill;
33pub mod rpc;
34mod serve;
35
36pub use codec::*;
37pub use polyfill::*;
38pub use serve::*;
39
40fn rpc_func_name(name: &str) -> &str {
44 if let Some(name) = name.strip_prefix("[constructor]") {
45 name
46 } else if let Some(name) = name.strip_prefix("[static]") {
47 name
48 } else if let Some(name) = name.strip_prefix("[method]") {
49 name
50 } else {
51 name
52 }
53}
54
55fn rpc_result_type<T: Borrow<Type>>(
56 host_resources: &HashMap<Box<str>, HashMap<Box<str>, (ResourceType, ResourceType)>>,
57 results_ty: impl IntoIterator<Item = T>,
58) -> Option<Option<Type>> {
59 let rpc_err_ty = host_resources
60 .get("wrpc:rpc/error@0.1.0")
61 .and_then(|instance| instance.get("error"));
62 let mut results_ty = results_ty.into_iter();
63 match (
64 rpc_err_ty,
65 results_ty.next().as_ref().map(Borrow::borrow),
66 results_ty.next(),
67 ) {
68 (Some((guest_rpc_err_ty, host_rpc_err_ty)), Some(Type::Result(result_ty)), None)
69 if *host_rpc_err_ty == ResourceType::host::<Error>()
70 && result_ty.err() == Some(Type::Own(*guest_rpc_err_ty)) =>
71 {
72 Some(result_ty.ok())
73 }
74 _ => None,
75 }
76}
77
78pub struct RemoteResource(pub Bytes);
79
80#[derive(Debug, Default)]
82pub struct SharedResourceTable(HashMap<Uuid, ResourceAny>);
83
84pub trait WrpcView: IoView + Send {
85 type Invoke: Invoke;
86
87 fn context(&self) -> <Self::Invoke as Invoke>::Context;
89
90 fn client(&self) -> &Self::Invoke;
92
93 fn shared_resources(&mut self) -> &mut SharedResourceTable;
95
96 fn timeout(&self) -> Option<Duration> {
99 None
100 }
101}
102
103impl<T: WrpcView> WrpcView for &mut T {
104 type Invoke = T::Invoke;
105
106 fn context(&self) -> <Self::Invoke as Invoke>::Context {
107 (**self).context()
108 }
109
110 fn client(&self) -> &Self::Invoke {
111 (**self).client()
112 }
113
114 fn shared_resources(&mut self) -> &mut SharedResourceTable {
115 (**self).shared_resources()
116 }
117
118 fn timeout(&self) -> Option<Duration> {
119 (**self).timeout()
120 }
121}
122
123pub trait WrpcViewExt: WrpcView {
124 fn push_invocation(
125 &mut self,
126 invocation: impl Future<
127 Output = anyhow::Result<(
128 <Self::Invoke as Invoke>::Outgoing,
129 <Self::Invoke as Invoke>::Incoming,
130 )>,
131 > + Send
132 + 'static,
133 ) -> anyhow::Result<Resource<Invocation>> {
134 self.table()
135 .push(Invocation::Future(Box::pin(async move {
136 let res = invocation.await;
137 Box::new(res) as Box<dyn Any + Send>
138 })))
139 .context("failed to push invocation to table")
140 }
141
142 fn get_invocation_result(
143 &mut self,
144 invocation: &Resource<Invocation>,
145 ) -> anyhow::Result<
146 Option<
147 &Box<
148 anyhow::Result<(
149 <Self::Invoke as Invoke>::Outgoing,
150 <Self::Invoke as Invoke>::Incoming,
151 )>,
152 >,
153 >,
154 > {
155 let invocation = self
156 .table()
157 .get(invocation)
158 .context("failed to get invocation from table")?;
159 match invocation {
160 Invocation::Future(..) => Ok(None),
161 Invocation::Ready(res) => {
162 let res = res.downcast_ref().context("invalid invocation type")?;
163 Ok(Some(res))
164 }
165 }
166 }
167
168 fn delete_invocation(
169 &mut self,
170 invocation: Resource<Invocation>,
171 ) -> anyhow::Result<
172 impl Future<
173 Output = anyhow::Result<(
174 <Self::Invoke as Invoke>::Outgoing,
175 <Self::Invoke as Invoke>::Incoming,
176 )>,
177 >,
178 > {
179 let invocation = self
180 .table()
181 .delete(invocation)
182 .context("failed to delete invocation from table")?;
183 Ok(async move {
184 let res = match invocation {
185 Invocation::Future(fut) => fut.await,
186 Invocation::Ready(res) => res,
187 };
188 let res = res
189 .downcast()
190 .map_err(|_| anyhow!("invalid invocation type"))?;
191 *res
192 })
193 }
194
195 fn push_outgoing_channel(
196 &mut self,
197 outgoing: <Self::Invoke as Invoke>::Outgoing,
198 ) -> anyhow::Result<Resource<OutgoingChannel>> {
199 self.table()
200 .push(OutgoingChannel(Arc::new(std::sync::RwLock::new(Box::new(
201 outgoing,
202 )))))
203 .context("failed to push outgoing channel to table")
204 }
205
206 fn delete_outgoing_channel(
207 &mut self,
208 outgoing: Resource<OutgoingChannel>,
209 ) -> anyhow::Result<<Self::Invoke as Invoke>::Outgoing> {
210 let OutgoingChannel(outgoing) = self
211 .table()
212 .delete(outgoing)
213 .context("failed to delete outgoing channel from table")?;
214 let outgoing =
215 Arc::into_inner(outgoing).context("outgoing channel has an active stream")?;
216 let Ok(outgoing) = outgoing.into_inner() else {
217 bail!("lock poisoned");
218 };
219 let outgoing = outgoing
220 .downcast()
221 .map_err(|_| anyhow!("invalid outgoing channel type"))?;
222 Ok(*outgoing)
223 }
224
225 fn push_incoming_channel(
226 &mut self,
227 incoming: <Self::Invoke as Invoke>::Incoming,
228 ) -> anyhow::Result<Resource<IncomingChannel>> {
229 self.table()
230 .push(IncomingChannel(Arc::new(std::sync::RwLock::new(Box::new(
231 incoming,
232 )))))
233 .context("failed to push incoming channel to table")
234 }
235
236 fn delete_incoming_channel(
237 &mut self,
238 incoming: Resource<IncomingChannel>,
239 ) -> anyhow::Result<<Self::Invoke as Invoke>::Incoming> {
240 let IncomingChannel(incoming) = self
241 .table()
242 .delete(incoming)
243 .context("failed to delete incoming channel from table")?;
244 let incoming =
245 Arc::into_inner(incoming).context("incoming channel has an active stream")?;
246 let Ok(incoming) = incoming.into_inner() else {
247 bail!("lock poisoned");
248 };
249 let incoming = incoming
250 .downcast()
251 .map_err(|_| anyhow!("invalid incoming channel type"))?;
252 Ok(*incoming)
253 }
254
255 fn push_error(&mut self, error: Error) -> anyhow::Result<Resource<Error>> {
256 self.table()
257 .push(error)
258 .context("failed to push error to table")
259 }
260
261 fn get_error(&mut self, error: &Resource<Error>) -> anyhow::Result<&Error> {
262 let error = self
263 .table()
264 .get(error)
265 .context("failed to get error from table")?;
266 Ok(error)
267 }
268
269 fn get_error_mut(&mut self, error: &Resource<Error>) -> anyhow::Result<&mut Error> {
270 let error = self
271 .table()
272 .get_mut(error)
273 .context("failed to get error from table")?;
274 Ok(error)
275 }
276
277 fn delete_error(&mut self, error: Resource<Error>) -> anyhow::Result<Error> {
278 let error = self
279 .table()
280 .delete(error)
281 .context("failed to delete error from table")?;
282 Ok(error)
283 }
284
285 fn push_context(
286 &mut self,
287 cx: <Self::Invoke as Invoke>::Context,
288 ) -> anyhow::Result<Resource<Context>>
289 where
290 <Self::Invoke as Invoke>::Context: 'static,
291 {
292 self.table()
293 .push(Context(Box::new(cx)))
294 .context("failed to push context to table")
295 }
296
297 fn delete_context(
298 &mut self,
299 cx: Resource<Context>,
300 ) -> anyhow::Result<<Self::Invoke as Invoke>::Context>
301 where
302 <Self::Invoke as Invoke>::Context: 'static,
303 {
304 let Context(cx) = self
305 .table()
306 .delete(cx)
307 .context("failed to delete context from table")?;
308 let cx = cx.downcast().map_err(|_| anyhow!("invalid context type"))?;
309 Ok(*cx)
310 }
311}
312
313impl<T: WrpcView> WrpcViewExt for T {}
314
315pub enum CallError {
317 Decode(anyhow::Error),
318 Encode(anyhow::Error),
319 Table(anyhow::Error),
320 Call(anyhow::Error),
321 TypeMismatch(anyhow::Error),
322 Write(anyhow::Error),
323 Flush(anyhow::Error),
324 Deferred(anyhow::Error),
325 PostReturn(anyhow::Error),
326 Guest(Error),
327}
328
329impl core::error::Error for CallError {}
330
331impl fmt::Debug for CallError {
332 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333 match self {
334 CallError::Decode(error)
335 | CallError::Encode(error)
336 | CallError::Table(error)
337 | CallError::Call(error)
338 | CallError::TypeMismatch(error)
339 | CallError::Write(error)
340 | CallError::Flush(error)
341 | CallError::Deferred(error)
342 | CallError::PostReturn(error) => error.fmt(f),
343 CallError::Guest(error) => error.fmt(f),
344 }
345 }
346}
347
348impl fmt::Display for CallError {
349 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350 match self {
351 CallError::Decode(error)
352 | CallError::Encode(error)
353 | CallError::Table(error)
354 | CallError::Call(error)
355 | CallError::TypeMismatch(error)
356 | CallError::Write(error)
357 | CallError::Flush(error)
358 | CallError::Deferred(error)
359 | CallError::PostReturn(error) => error.fmt(f),
360 CallError::Guest(error) => error.fmt(f),
361 }
362 }
363}
364
365#[allow(clippy::too_many_arguments)]
366pub async fn call<C, I, O>(
367 mut store: C,
368 rx: I,
369 mut tx: O,
370 guest_resources: &[ResourceType],
371 host_resources: &HashMap<Box<str>, HashMap<Box<str>, (ResourceType, ResourceType)>>,
372 params_ty: impl ExactSizeIterator<Item = &Type>,
373 results_ty: &[Type],
374 func: Func,
375) -> Result<(), CallError>
376where
377 I: AsyncRead + wrpc_transport::Index<I> + Send + Sync + Unpin + 'static,
378 O: AsyncWrite + wrpc_transport::Index<O> + Send + Sync + Unpin + 'static,
379 C: AsContextMut,
380 C::Data: WasiView + WrpcView,
381{
382 let mut params = vec![Val::Bool(false); params_ty.len()];
383 let mut rx = pin!(rx);
384 for (i, (v, ty)) in zip(&mut params, params_ty).enumerate() {
385 read_value(&mut store, &mut rx, guest_resources, v, ty, &[i])
386 .await
387 .with_context(|| format!("failed to decode parameter value {i}"))
388 .map_err(CallError::Decode)?;
389 }
390 let mut results = vec![Val::Bool(false); results_ty.len()];
391 func.call_async(&mut store, ¶ms, &mut results)
392 .await
393 .context("failed to call function")
394 .map_err(CallError::Call)?;
395
396 let mut buf = BytesMut::default();
397 let mut deferred = vec![];
398 match (
399 &rpc_result_type(host_resources, results_ty),
400 results.as_slice(),
401 ) {
402 (None, results) => {
403 for (i, (v, ty)) in zip(results, results_ty).enumerate() {
404 let mut enc = ValEncoder::new(store.as_context_mut(), ty, guest_resources);
405 enc.encode(v, &mut buf)
406 .with_context(|| format!("failed to encode result value {i}"))
407 .map_err(CallError::Encode)?;
408 deferred.push(enc.deferred);
409 }
410 }
411 (Some(None), [Val::Result(Ok(None))]) => {}
413 (Some(Some(ty)), [Val::Result(Ok(Some(v)))]) => {
415 let mut enc = ValEncoder::new(store.as_context_mut(), ty, guest_resources);
416 enc.encode(v, &mut buf)
417 .context("failed to encode result value 0")
418 .map_err(CallError::Encode)?;
419 deferred.push(enc.deferred);
420 }
421 (Some(..), [Val::Result(Err(Some(err)))]) => {
422 let Val::Resource(err) = &**err else {
423 return Err(CallError::TypeMismatch(anyhow!(
424 "RPC result error value is not a resource"
425 )));
426 };
427 let mut store = store.as_context_mut();
428 let err = err
429 .try_into_resource(&mut store)
430 .context("RPC result error resource type mismatch")
431 .map_err(CallError::TypeMismatch)?;
432 let err = store
433 .data_mut()
434 .delete_error(err)
435 .map_err(CallError::Table)?;
436 return Err(CallError::Guest(err));
437 }
438 _ => return Err(CallError::TypeMismatch(anyhow!("RPC result type mismatch"))),
439 }
440
441 debug!("transmitting results");
442 tx.write_all(&buf)
443 .await
444 .context("failed to transmit results")
445 .map_err(CallError::Write)?;
446 tx.flush()
447 .await
448 .context("failed to flush outgoing stream")
449 .map_err(CallError::Flush)?;
450 if let Err(err) = tx.shutdown().await {
451 trace!(?err, "failed to shutdown outgoing stream");
452 }
453 try_join_all(
454 zip(0.., deferred)
455 .filter_map(|(i, f)| f.map(|f| (tx.index(&[i]), f)))
456 .map(|(w, f)| async move {
457 let w = w?;
458 f(w).await
459 }),
460 )
461 .await
462 .map_err(CallError::Deferred)?;
463 func.post_return_async(&mut store)
464 .await
465 .context("failed to perform post-return cleanup")
466 .map_err(CallError::PostReturn)?;
467 Ok(())
468}
469
470#[instrument(level = "debug", skip_all)]
472pub fn collect_item_resource_exports(
473 engine: &Engine,
474 ty: types::ComponentItem,
475 resources: &mut impl Extend<types::ResourceType>,
476) {
477 match ty {
478 types::ComponentItem::ComponentFunc(_)
479 | types::ComponentItem::CoreFunc(_)
480 | types::ComponentItem::Module(_)
481 | types::ComponentItem::Type(_) => {}
482 types::ComponentItem::Component(ty) => {
483 collect_component_resource_exports(engine, &ty, resources)
484 }
485
486 types::ComponentItem::ComponentInstance(ty) => {
487 collect_instance_resource_exports(engine, &ty, resources)
488 }
489 types::ComponentItem::Resource(ty) => {
490 debug!(?ty, "collect resource export");
491 resources.extend([ty])
492 }
493 }
494}
495
496#[instrument(level = "debug", skip_all)]
498pub fn collect_instance_resource_exports(
499 engine: &Engine,
500 ty: &types::ComponentInstance,
501 resources: &mut impl Extend<types::ResourceType>,
502) {
503 for (name, ty) in ty.exports(engine) {
504 trace!(name, ?ty, "collect instance item resource exports");
505 collect_item_resource_exports(engine, ty, resources);
506 }
507}
508
509#[instrument(level = "debug", skip_all)]
511pub fn collect_component_resource_exports(
512 engine: &Engine,
513 ty: &types::Component,
514 resources: &mut impl Extend<types::ResourceType>,
515) {
516 for (name, ty) in ty.exports(engine) {
517 trace!(name, ?ty, "collect component item resource exports");
518 collect_item_resource_exports(engine, ty, resources);
519 }
520}
521
522#[instrument(level = "debug", skip_all)]
524pub fn collect_component_resource_imports(
525 engine: &Engine,
526 ty: &types::Component,
527 resources: &mut BTreeMap<Box<str>, HashMap<Box<str>, types::ResourceType>>,
528) {
529 for (name, ty) in ty.imports(engine) {
530 match ty {
531 types::ComponentItem::ComponentFunc(..)
532 | types::ComponentItem::CoreFunc(..)
533 | types::ComponentItem::Module(..)
534 | types::ComponentItem::Type(..)
535 | types::ComponentItem::Component(..) => {}
536 types::ComponentItem::ComponentInstance(ty) => {
537 let instance = name;
538 for (name, ty) in ty.exports(engine) {
539 if let types::ComponentItem::Resource(ty) = ty {
540 debug!(instance, name, ?ty, "collect instance resource import");
541 if let Some(resources) = resources.get_mut(instance) {
542 resources.insert(name.into(), ty);
543 } else {
544 resources.insert(instance.into(), HashMap::from([(name.into(), ty)]));
545 }
546 }
547 }
548 }
549 types::ComponentItem::Resource(ty) => {
550 debug!(name, "collect component resource import");
551 if let Some(resources) = resources.get_mut("") {
552 resources.insert(name.into(), ty);
553 } else {
554 resources.insert("".into(), HashMap::from([(name.into(), ty)]));
555 }
556 }
557 }
558 }
559}