wrpc_runtime_wasmtime/rpc/host/
transport.rs1use core::marker::PhantomData;
2
3use std::sync::Arc;
4
5use wasmtime::bail;
6use wasmtime::component::Resource;
7use wasmtime::error::Context as _;
8use wasmtime_wasi::p2::bindings::io::poll::Pollable;
9use wasmtime_wasi::p2::bindings::io::streams::{InputStream, OutputStream};
10use wasmtime_wasi::p2::pipe::{AsyncReadStream, AsyncWriteStream};
11use wasmtime_wasi::p2::subscribe;
12use wrpc_transport::{Index as _, Invoke};
13
14use crate::bindings::rpc::error::Error;
15use crate::bindings::rpc::transport::{
16 Host, HostIncomingChannel, HostInvocation, HostOutgoingChannel, IncomingChannel, Invocation,
17 OutgoingChannel,
18};
19use crate::rpc::{IncomingChannelStream, OutgoingChannelStream, WrpcRpcImpl};
20use crate::{WrpcView, WrpcViewExt as _};
21
22impl<T: WrpcView> Host for WrpcRpcImpl<T> {}
23
24impl<T: WrpcView> HostInvocation for WrpcRpcImpl<T> {
25 fn subscribe(
26 &mut self,
27 invocation: Resource<Invocation>,
28 ) -> wasmtime::Result<Resource<Pollable>> {
29 subscribe(self.0.wrpc().table, invocation)
30 }
31
32 async fn finish(
33 &mut self,
34 invocation: Resource<Invocation>,
35 ) -> wasmtime::Result<
36 Result<(Resource<OutgoingChannel>, Resource<IncomingChannel>), Resource<Error>>,
37 > {
38 let invocation = self.0.delete_invocation(invocation)?;
39 match invocation.await {
40 Ok((tx, rx)) => {
41 let rx = self.0.push_incoming_channel(rx)?;
42 let tx = self.0.push_outgoing_channel(tx)?;
43 Ok(Ok((tx, rx)))
44 }
45 Err(error) => {
46 let error = self.0.push_error(Error::Invoke(error))?;
47 Ok(Err(error))
48 }
49 }
50 }
51
52 fn drop(&mut self, invocation: Resource<Invocation>) -> wasmtime::Result<()> {
53 _ = self.0.delete_invocation(invocation)?;
54 Ok(())
55 }
56}
57
58impl<T: WrpcView> HostIncomingChannel for WrpcRpcImpl<T> {
59 fn data(
60 &mut self,
61 incoming: Resource<IncomingChannel>,
62 ) -> wasmtime::Result<Option<Resource<InputStream>>> {
63 let IncomingChannel(stream) = self
64 .0
65 .wrpc()
66 .table
67 .get_mut(&incoming)
68 .context("failed to get incoming channel from table")?;
69 if Arc::get_mut(stream).is_none() {
70 return Ok(None);
71 }
72 let stream = Arc::clone(stream);
73 let stream = self
74 .0
75 .wrpc()
76 .table
77 .push_child(
78 Box::new(AsyncReadStream::new(IncomingChannelStream {
79 incoming: IncomingChannel(stream),
80 _ty: PhantomData::<<T::Invoke as Invoke>::Incoming>,
81 })) as InputStream,
82 &incoming,
83 )
84 .context("failed to push input stream to table")?;
85 Ok(Some(stream))
86 }
87
88 fn index(
89 &mut self,
90 incoming: Resource<IncomingChannel>,
91 path: Vec<u32>,
92 ) -> wasmtime::Result<Result<Resource<IncomingChannel>, Resource<Error>>> {
93 let path = path
94 .into_iter()
95 .map(usize::try_from)
96 .collect::<Result<Box<[_]>, _>>()
97 .context("failed to construct subscription path")?;
98 let IncomingChannel(incoming) = self
99 .0
100 .wrpc()
101 .table
102 .get(&incoming)
103 .context("failed to get incoming channel from table")?;
104 let incoming = {
105 let Ok(incoming) = incoming.read() else {
106 bail!("lock poisoned");
107 };
108 let incoming = incoming
109 .downcast_ref::<<T::Invoke as Invoke>::Incoming>()
110 .context("invalid incoming channel type")?;
111 incoming.index(&path)
112 };
113 match incoming {
114 Ok(incoming) => {
115 let incoming = self.0.push_incoming_channel(incoming)?;
116 Ok(Ok(incoming))
117 }
118 Err(error) => {
119 let error = self.0.push_error(Error::IncomingIndex(error))?;
120 Ok(Err(error))
121 }
122 }
123 }
124
125 fn drop(&mut self, incoming: Resource<IncomingChannel>) -> wasmtime::Result<()> {
126 self.0.delete_incoming_channel(incoming)?;
127 Ok(())
128 }
129}
130
131impl<T: WrpcView> HostOutgoingChannel for WrpcRpcImpl<T> {
132 fn data(
133 &mut self,
134 outgoing: Resource<OutgoingChannel>,
135 ) -> wasmtime::Result<Option<Resource<OutputStream>>> {
136 let OutgoingChannel(stream) = self
137 .0
138 .wrpc()
139 .table
140 .get_mut(&outgoing)
141 .context("failed to get outgoing channel from table")?;
142 if Arc::get_mut(stream).is_none() {
143 return Ok(None);
144 }
145 let stream = Arc::clone(stream);
146 let stream = self
147 .0
148 .wrpc()
149 .table
150 .push_child(
151 Box::new(AsyncWriteStream::new(
152 8192,
153 OutgoingChannelStream {
154 outgoing: OutgoingChannel(stream),
155 _ty: PhantomData::<<T::Invoke as Invoke>::Outgoing>,
156 },
157 )) as OutputStream,
158 &outgoing,
159 )
160 .context("failed to push output stream to table")?;
161 Ok(Some(stream))
162 }
163
164 fn index(
165 &mut self,
166 outgoing: Resource<OutgoingChannel>,
167 path: Vec<u32>,
168 ) -> wasmtime::Result<Result<Resource<OutgoingChannel>, Resource<Error>>> {
169 let path = path
170 .into_iter()
171 .map(usize::try_from)
172 .collect::<Result<Box<[_]>, _>>()
173 .context("failed to construct subscription path")?;
174 let OutgoingChannel(outgoing) = self
175 .0
176 .wrpc()
177 .table
178 .get(&outgoing)
179 .context("failed to get outgoing channel from table")?;
180 let incoming = {
181 let Ok(outgoing) = outgoing.read() else {
182 bail!("lock poisoned");
183 };
184 let outgoing = outgoing
185 .downcast_ref::<<T::Invoke as Invoke>::Outgoing>()
186 .context("invalid outgoing channel type")?;
187 outgoing.index(&path)
188 };
189 match incoming {
190 Ok(outgoing) => {
191 let outgoing = self.0.push_outgoing_channel(outgoing)?;
192 Ok(Ok(outgoing))
193 }
194 Err(error) => {
195 let error = self.0.push_error(Error::OutgoingIndex(error))?;
196 Ok(Err(error))
197 }
198 }
199 }
200
201 fn drop(&mut self, outgoing: Resource<OutgoingChannel>) -> wasmtime::Result<()> {
202 self.0.delete_outgoing_channel(outgoing)?;
203 Ok(())
204 }
205}