1use std::{
2 future::{Future, ready},
3 io,
4 pin::Pin,
5 sync::Arc,
6};
7
8use wasmtime::Caller;
9
10use crate::{
11 drivers::io::{
12 IoCapability, IoCreateReaderDriver, IoCreateWriterDriver, IoReadDriver, IoWriteDriver,
13 create_reader_op, create_writer_op, read_op, write_op,
14 },
15 guest_data::{GuestError, GuestResult, GuestUint},
16 operation::{Contract, Operation},
17 registry::{InstanceRegistry, ResourceType},
18};
19use selium_abi::{ChannelBackpressure, ChannelCreate, GuestResourceId};
20
21type ChannelLifecycleOps<C> = (
22 Arc<Operation<ChannelCreateDriver<C>>>,
23 Arc<Operation<ChannelDeleteDriver<C>>>,
24 Arc<Operation<ChannelDrainDriver<C>>>,
25);
26
27type ChannelHandoffOps = (
28 Arc<Operation<ChannelExportDriver>>,
29 Arc<Operation<ChannelAttachDriver>>,
30 Arc<Operation<ChannelDetachDriver>>,
31);
32
33type FrameReadFuture<'a> = Pin<Box<dyn Future<Output = io::Result<(u16, Vec<u8>)>> + Send + 'a>>;
34
35type ChannelReadOps<S, W> = (
36 Arc<Operation<IoCreateReaderDriver<S>>>,
37 Arc<Operation<IoCreateReaderDriver<W>>>,
38 Arc<Operation<IoReadDriver<S>>>,
39 Arc<Operation<IoReadDriver<W>>>,
40);
41
42type ChannelWriteOps<S, W> = (
43 Arc<Operation<IoCreateWriterDriver<S>>>,
44 Arc<Operation<IoCreateWriterDriver<W>>>,
45 Arc<Operation<IoWriteDriver<S>>>,
46 Arc<Operation<IoWriteDriver<W>>>,
47);
48
49pub trait ChannelCapability: Send + Sync {
51 type Channel: Send;
52 type StrongWriter: Send + Unpin;
53 type WeakWriter: Send + Unpin;
54 type StrongReader: Send + Unpin;
55 type WeakReader: Send + Unpin;
56 type Error: Into<GuestError>;
57
58 fn create(
60 &self,
61 size: GuestUint,
62 backpressure: ChannelBackpressure,
63 ) -> Result<Self::Channel, Self::Error>;
64
65 fn delete(&self, channel: Self::Channel) -> Result<(), Self::Error>;
67
68 fn drain(&self, channel: &Self::Channel) -> Result<(), Self::Error>;
70
71 fn downgrade_writer(&self, writer: Self::StrongWriter)
73 -> Result<Self::WeakWriter, Self::Error>;
74
75 fn downgrade_reader(&self, writer: Self::StrongReader)
77 -> Result<Self::WeakReader, Self::Error>;
78
79 #[doc(hidden)]
80 fn ptr(&self, channel: &Self::Channel) -> String;
81}
82
83pub trait FrameReadable {
85 fn read_frame(&mut self, max_len: usize) -> FrameReadFuture<'_>;
86}
87
88pub struct ChannelCreateDriver<Impl>(Impl);
89pub struct ChannelDeleteDriver<Impl>(Impl);
90pub struct ChannelDrainDriver<Impl>(Impl);
91pub struct ChannelDowngradeStrongWriterDriver<Impl>(Impl);
92pub struct ChannelExportDriver;
93pub struct ChannelAttachDriver;
94pub struct ChannelDetachDriver;
95
96impl<T> ChannelCapability for Arc<T>
97where
98 T: ChannelCapability + Send + Sync,
99{
100 type Channel = T::Channel;
101 type StrongWriter = T::StrongWriter;
102 type WeakWriter = T::WeakWriter;
103 type StrongReader = T::StrongReader;
104 type WeakReader = T::WeakReader;
105 type Error = T::Error;
106
107 fn create(
108 &self,
109 size: GuestUint,
110 backpressure: ChannelBackpressure,
111 ) -> Result<Self::Channel, Self::Error> {
112 self.as_ref().create(size, backpressure)
113 }
114
115 fn delete(&self, channel: Self::Channel) -> Result<(), Self::Error> {
116 self.as_ref().delete(channel)
117 }
118
119 fn drain(&self, channel: &Self::Channel) -> Result<(), Self::Error> {
120 self.as_ref().drain(channel)
121 }
122
123 fn downgrade_writer(
124 &self,
125 writer: Self::StrongWriter,
126 ) -> Result<Self::WeakWriter, Self::Error> {
127 self.as_ref().downgrade_writer(writer)
128 }
129
130 fn downgrade_reader(
131 &self,
132 reader: Self::StrongReader,
133 ) -> Result<Self::WeakReader, Self::Error> {
134 self.as_ref().downgrade_reader(reader)
135 }
136
137 fn ptr(&self, channel: &Self::Channel) -> String {
138 self.as_ref().ptr(channel)
139 }
140}
141
142impl<Impl> Contract for ChannelCreateDriver<Impl>
143where
144 Impl: ChannelCapability + Clone + Send + 'static,
145{
146 type Input = ChannelCreate;
147 type Output = GuestUint;
148
149 fn to_future(
150 &self,
151 caller: &mut Caller<'_, InstanceRegistry>,
152 args: Self::Input,
153 ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
154 let inner = self.0.clone();
155 let registry = caller.data().registry_arc();
156
157 let result = (|| -> GuestResult<GuestUint> {
158 let channel = inner
159 .create(args.capacity, args.backpressure)
160 .map_err(Into::into)?;
161 let ptr = inner.ptr(&channel);
162 let slot = caller
163 .data_mut()
164 .insert(channel, None, ResourceType::Channel)
165 .map_err(GuestError::from)?;
166 if let Some(resource_id) = caller.data().entry(slot) {
167 registry.record_host_ptr(resource_id, &ptr);
168 }
169 let handle = GuestUint::try_from(slot).map_err(|_| GuestError::InvalidArgument)?;
170 Ok(handle)
171 })();
172
173 ready(result)
174 }
175}
176
177impl<Impl> Contract for ChannelDeleteDriver<Impl>
178where
179 Impl: ChannelCapability + Clone + Send + 'static,
180{
181 type Input = GuestUint;
182 type Output = ();
183
184 fn to_future(
185 &self,
186 caller: &mut Caller<'_, InstanceRegistry>,
187 channel_id: Self::Input,
188 ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
189 let this = self.0.clone();
190 let result = (|| -> GuestResult<()> {
191 let slot = channel_id as usize;
192 let channel = caller
193 .data_mut()
194 .remove::<Impl::Channel>(slot)
195 .ok_or(GuestError::NotFound)?;
196
197 this.delete(channel).map_err(Into::into)?;
198 Ok(())
199 })();
200
201 ready(result)
202 }
203}
204
205impl<Impl> Contract for ChannelDrainDriver<Impl>
206where
207 Impl: ChannelCapability + Clone + Send + 'static,
208{
209 type Input = u32;
210 type Output = ();
211
212 fn to_future(
213 &self,
214 caller: &mut Caller<'_, InstanceRegistry>,
215 channel_id: Self::Input,
216 ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
217 let this = self.0.clone();
218 let result = (|| -> GuestResult<()> {
219 let slot = channel_id as usize;
220 caller
221 .data()
222 .with(slot, |chan| this.drain(chan))
223 .ok_or(GuestError::NotFound)?
224 .map_err(Into::into)?;
225 Ok(())
226 })();
227
228 ready(result)
229 }
230}
231
232impl<Impl> Contract for ChannelDowngradeStrongWriterDriver<Impl>
233where
234 Impl: ChannelCapability + Send + 'static,
235{
236 type Input = GuestUint;
237 type Output = GuestUint;
238
239 fn to_future(
240 &self,
241 caller: &mut Caller<'_, InstanceRegistry>,
242 input: Self::Input,
243 ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
244 match caller
245 .data_mut()
246 .remove::<Impl::StrongWriter>(input as usize)
247 .ok_or(GuestError::NotFound)
248 .and_then(|writer| self.0.downgrade_writer(writer).map_err(Into::into))
249 {
250 Ok(writer) => {
251 let result = caller
252 .data_mut()
253 .insert(writer, None, ResourceType::Writer)
254 .map_err(Into::into)
255 .and_then(|idx| {
256 GuestUint::try_from(idx).map_err(|_| GuestError::InvalidArgument)
257 });
258 ready(result)
259 }
260 Err(e) => ready(Err(e)),
261 }
262 }
263}
264
265impl Contract for ChannelExportDriver {
266 type Input = GuestUint;
267 type Output = GuestResourceId;
268
269 fn to_future(
270 &self,
271 caller: &mut Caller<'_, InstanceRegistry>,
272 handle: Self::Input,
273 ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
274 let registry = caller.data().registry_arc();
275 let result = caller
276 .data()
277 .entry(handle as usize)
278 .ok_or(GuestError::NotFound)
279 .and_then(|rid| registry.share_handle(rid).map_err(GuestError::from));
280
281 ready(result)
282 }
283}
284
285impl Contract for ChannelAttachDriver {
286 type Input = GuestResourceId;
287 type Output = GuestUint;
288
289 fn to_future(
290 &self,
291 caller: &mut Caller<'_, InstanceRegistry>,
292 resource_id: Self::Input,
293 ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
294 let registry = caller.data().registry_arc();
295 let result = registry
296 .resolve_shared(resource_id)
297 .ok_or(GuestError::NotFound)
298 .and_then(|rid| {
299 caller
300 .data_mut()
301 .insert_id(rid)
302 .map_err(GuestError::from)
303 .and_then(|slot| {
304 GuestUint::try_from(slot).map_err(|_| GuestError::InvalidArgument)
305 })
306 });
307
308 ready(result)
309 }
310}
311
312impl Contract for ChannelDetachDriver {
313 type Input = GuestUint;
314 type Output = ();
315
316 fn to_future(
317 &self,
318 caller: &mut Caller<'_, InstanceRegistry>,
319 handle: Self::Input,
320 ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
321 let result = caller
322 .data_mut()
323 .detach_slot(handle as usize)
324 .ok_or(GuestError::NotFound)
325 .map(|_| ());
326
327 ready(result)
328 }
329}
330
331impl From<io::Error> for GuestError {
332 fn from(err: io::Error) -> Self {
333 GuestError::Subsystem(err.to_string())
334 }
335}
336
337pub fn read_ops<S, W>(strong_cap: S, weak_cap: W) -> ChannelReadOps<S, W>
338where
339 S: IoCapability + Clone + Send + 'static,
340 W: IoCapability + Clone + Send + 'static,
341{
342 (
343 create_reader_op(
344 strong_cap.clone(),
345 selium_abi::hostcall_contract!(CHANNEL_STRONG_READER_CREATE),
346 ),
347 create_reader_op(
348 weak_cap.clone(),
349 selium_abi::hostcall_contract!(CHANNEL_WEAK_READER_CREATE),
350 ),
351 read_op(
352 strong_cap,
353 selium_abi::hostcall_contract!(CHANNEL_STRONG_READ),
354 ),
355 read_op(weak_cap, selium_abi::hostcall_contract!(CHANNEL_WEAK_READ)),
356 )
357}
358
359pub fn write_ops<S, W>(strong_cap: S, weak_cap: W) -> ChannelWriteOps<S, W>
360where
361 S: IoCapability + Clone + Send + 'static,
362 W: IoCapability + Clone + Send + 'static,
363{
364 (
365 create_writer_op(
366 strong_cap.clone(),
367 selium_abi::hostcall_contract!(CHANNEL_STRONG_WRITER_CREATE),
368 ),
369 create_writer_op(
370 weak_cap.clone(),
371 selium_abi::hostcall_contract!(CHANNEL_WEAK_WRITER_CREATE),
372 ),
373 write_op(
374 strong_cap,
375 selium_abi::hostcall_contract!(CHANNEL_STRONG_WRITE),
376 ),
377 write_op(weak_cap, selium_abi::hostcall_contract!(CHANNEL_WEAK_WRITE)),
378 )
379}
380
381pub fn writer_downgrade_op<C>(ch_cap: C) -> Arc<Operation<ChannelDowngradeStrongWriterDriver<C>>>
382where
383 C: ChannelCapability + 'static,
384{
385 Operation::from_hostcall(
386 ChannelDowngradeStrongWriterDriver(ch_cap),
387 selium_abi::hostcall_contract!(CHANNEL_WRITER_DOWNGRADE),
388 )
389}
390
391pub fn lifecycle_ops<C>(cap: C) -> ChannelLifecycleOps<C>
392where
393 C: ChannelCapability + Clone + 'static,
394{
395 (
396 Operation::from_hostcall(
397 ChannelCreateDriver(cap.clone()),
398 selium_abi::hostcall_contract!(CHANNEL_CREATE),
399 ),
400 Operation::from_hostcall(
401 ChannelDeleteDriver(cap.clone()),
402 selium_abi::hostcall_contract!(CHANNEL_DELETE),
403 ),
404 Operation::from_hostcall(
405 ChannelDrainDriver(cap),
406 selium_abi::hostcall_contract!(CHANNEL_DRAIN),
407 ),
408 )
409}
410
411pub fn handoff_ops() -> ChannelHandoffOps {
412 (
413 Operation::from_hostcall(
414 ChannelExportDriver,
415 selium_abi::hostcall_contract!(CHANNEL_SHARE),
416 ),
417 Operation::from_hostcall(
418 ChannelAttachDriver,
419 selium_abi::hostcall_contract!(CHANNEL_ATTACH),
420 ),
421 Operation::from_hostcall(
422 ChannelDetachDriver,
423 selium_abi::hostcall_contract!(CHANNEL_DETACH),
424 ),
425 )
426}