Skip to main content

wrpc_runtime_wasmtime/rpc/host/
transport.rs

1use core::marker::PhantomData;
2
3use std::sync::Arc;
4
5use wasmtime::bail;
6use wasmtime::component::Resource;
7use wasmtime::error::Context as _;
8use wasmtime_wasi::p2::bindings::io::poll::Pollable;
9use wasmtime_wasi::p2::bindings::io::streams::{InputStream, OutputStream};
10use wasmtime_wasi::p2::pipe::{AsyncReadStream, AsyncWriteStream};
11use wasmtime_wasi::p2::subscribe;
12use wrpc_transport::{Index as _, Invoke};
13
14use crate::bindings::rpc::error::Error;
15use crate::bindings::rpc::transport::{
16    Host, HostIncomingChannel, HostInvocation, HostOutgoingChannel, IncomingChannel, Invocation,
17    OutgoingChannel,
18};
19use crate::rpc::{IncomingChannelStream, OutgoingChannelStream, WrpcRpcImpl};
20use crate::{WrpcView, WrpcViewExt as _};
21
22impl<T: WrpcView> Host for WrpcRpcImpl<T> {}
23
24impl<T: WrpcView> HostInvocation for WrpcRpcImpl<T> {
25    fn subscribe(
26        &mut self,
27        invocation: Resource<Invocation>,
28    ) -> wasmtime::Result<Resource<Pollable>> {
29        subscribe(self.0.wrpc().table, invocation)
30    }
31
32    async fn finish(
33        &mut self,
34        invocation: Resource<Invocation>,
35    ) -> wasmtime::Result<
36        Result<(Resource<OutgoingChannel>, Resource<IncomingChannel>), Resource<Error>>,
37    > {
38        let invocation = self.0.delete_invocation(invocation)?;
39        match invocation.await {
40            Ok((tx, rx)) => {
41                let rx = self.0.push_incoming_channel(rx)?;
42                let tx = self.0.push_outgoing_channel(tx)?;
43                Ok(Ok((tx, rx)))
44            }
45            Err(error) => {
46                let error = self.0.push_error(Error::Invoke(error))?;
47                Ok(Err(error))
48            }
49        }
50    }
51
52    fn drop(&mut self, invocation: Resource<Invocation>) -> wasmtime::Result<()> {
53        _ = self.0.delete_invocation(invocation)?;
54        Ok(())
55    }
56}
57
58impl<T: WrpcView> HostIncomingChannel for WrpcRpcImpl<T> {
59    fn data(
60        &mut self,
61        incoming: Resource<IncomingChannel>,
62    ) -> wasmtime::Result<Option<Resource<InputStream>>> {
63        let IncomingChannel(stream) = self
64            .0
65            .wrpc()
66            .table
67            .get_mut(&incoming)
68            .context("failed to get incoming channel from table")?;
69        if Arc::get_mut(stream).is_none() {
70            return Ok(None);
71        }
72        let stream = Arc::clone(stream);
73        let stream = self
74            .0
75            .wrpc()
76            .table
77            .push_child(
78                Box::new(AsyncReadStream::new(IncomingChannelStream {
79                    incoming: IncomingChannel(stream),
80                    _ty: PhantomData::<<T::Invoke as Invoke>::Incoming>,
81                })) as InputStream,
82                &incoming,
83            )
84            .context("failed to push input stream to table")?;
85        Ok(Some(stream))
86    }
87
88    fn index(
89        &mut self,
90        incoming: Resource<IncomingChannel>,
91        path: Vec<u32>,
92    ) -> wasmtime::Result<Result<Resource<IncomingChannel>, Resource<Error>>> {
93        let path = path
94            .into_iter()
95            .map(usize::try_from)
96            .collect::<Result<Box<[_]>, _>>()
97            .context("failed to construct subscription path")?;
98        let IncomingChannel(incoming) = self
99            .0
100            .wrpc()
101            .table
102            .get(&incoming)
103            .context("failed to get incoming channel from table")?;
104        let incoming = {
105            let Ok(incoming) = incoming.read() else {
106                bail!("lock poisoned");
107            };
108            let incoming = incoming
109                .downcast_ref::<<T::Invoke as Invoke>::Incoming>()
110                .context("invalid incoming channel type")?;
111            incoming.index(&path)
112        };
113        match incoming {
114            Ok(incoming) => {
115                let incoming = self.0.push_incoming_channel(incoming)?;
116                Ok(Ok(incoming))
117            }
118            Err(error) => {
119                let error = self.0.push_error(Error::IncomingIndex(error))?;
120                Ok(Err(error))
121            }
122        }
123    }
124
125    fn drop(&mut self, incoming: Resource<IncomingChannel>) -> wasmtime::Result<()> {
126        self.0.delete_incoming_channel(incoming)?;
127        Ok(())
128    }
129}
130
131impl<T: WrpcView> HostOutgoingChannel for WrpcRpcImpl<T> {
132    fn data(
133        &mut self,
134        outgoing: Resource<OutgoingChannel>,
135    ) -> wasmtime::Result<Option<Resource<OutputStream>>> {
136        let OutgoingChannel(stream) = self
137            .0
138            .wrpc()
139            .table
140            .get_mut(&outgoing)
141            .context("failed to get outgoing channel from table")?;
142        if Arc::get_mut(stream).is_none() {
143            return Ok(None);
144        }
145        let stream = Arc::clone(stream);
146        let stream = self
147            .0
148            .wrpc()
149            .table
150            .push_child(
151                Box::new(AsyncWriteStream::new(
152                    8192,
153                    OutgoingChannelStream {
154                        outgoing: OutgoingChannel(stream),
155                        _ty: PhantomData::<<T::Invoke as Invoke>::Outgoing>,
156                    },
157                )) as OutputStream,
158                &outgoing,
159            )
160            .context("failed to push output stream to table")?;
161        Ok(Some(stream))
162    }
163
164    fn index(
165        &mut self,
166        outgoing: Resource<OutgoingChannel>,
167        path: Vec<u32>,
168    ) -> wasmtime::Result<Result<Resource<OutgoingChannel>, Resource<Error>>> {
169        let path = path
170            .into_iter()
171            .map(usize::try_from)
172            .collect::<Result<Box<[_]>, _>>()
173            .context("failed to construct subscription path")?;
174        let OutgoingChannel(outgoing) = self
175            .0
176            .wrpc()
177            .table
178            .get(&outgoing)
179            .context("failed to get outgoing channel from table")?;
180        let incoming = {
181            let Ok(outgoing) = outgoing.read() else {
182                bail!("lock poisoned");
183            };
184            let outgoing = outgoing
185                .downcast_ref::<<T::Invoke as Invoke>::Outgoing>()
186                .context("invalid outgoing channel type")?;
187            outgoing.index(&path)
188        };
189        match incoming {
190            Ok(outgoing) => {
191                let outgoing = self.0.push_outgoing_channel(outgoing)?;
192                Ok(Ok(outgoing))
193            }
194            Err(error) => {
195                let error = self.0.push_error(Error::OutgoingIndex(error))?;
196                Ok(Err(error))
197            }
198        }
199    }
200
201    fn drop(&mut self, outgoing: Resource<OutgoingChannel>) -> wasmtime::Result<()> {
202        self.0.delete_outgoing_channel(outgoing)?;
203        Ok(())
204    }
205}