Skip to main content

selium_kernel/drivers/
channel.rs

1use std::{
2    future::{Future, ready},
3    io,
4    pin::Pin,
5    sync::Arc,
6};
7
8use wasmtime::Caller;
9
10use crate::{
11    drivers::io::{
12        IoCapability, IoCreateReaderDriver, IoCreateWriterDriver, IoReadDriver, IoWriteDriver,
13        create_reader_op, create_writer_op, read_op, write_op,
14    },
15    guest_data::{GuestError, GuestResult, GuestUint},
16    operation::{Contract, Operation},
17    registry::{InstanceRegistry, ResourceType},
18};
19use selium_abi::{ChannelBackpressure, ChannelCreate, GuestResourceId};
20
21type ChannelLifecycleOps<C> = (
22    Arc<Operation<ChannelCreateDriver<C>>>,
23    Arc<Operation<ChannelDeleteDriver<C>>>,
24    Arc<Operation<ChannelDrainDriver<C>>>,
25);
26
27type ChannelHandoffOps = (
28    Arc<Operation<ChannelExportDriver>>,
29    Arc<Operation<ChannelAttachDriver>>,
30    Arc<Operation<ChannelDetachDriver>>,
31);
32
33type FrameReadFuture<'a> = Pin<Box<dyn Future<Output = io::Result<(u16, Vec<u8>)>> + Send + 'a>>;
34
35type ChannelReadOps<S, W> = (
36    Arc<Operation<IoCreateReaderDriver<S>>>,
37    Arc<Operation<IoCreateReaderDriver<W>>>,
38    Arc<Operation<IoReadDriver<S>>>,
39    Arc<Operation<IoReadDriver<W>>>,
40);
41
42type ChannelWriteOps<S, W> = (
43    Arc<Operation<IoCreateWriterDriver<S>>>,
44    Arc<Operation<IoCreateWriterDriver<W>>>,
45    Arc<Operation<IoWriteDriver<S>>>,
46    Arc<Operation<IoWriteDriver<W>>>,
47);
48
49/// The capabilities that any subsystem implementation needs to provide
50pub trait ChannelCapability: Send + Sync {
51    type Channel: Send;
52    type StrongWriter: Send + Unpin;
53    type WeakWriter: Send + Unpin;
54    type StrongReader: Send + Unpin;
55    type WeakReader: Send + Unpin;
56    type Error: Into<GuestError>;
57
58    /// Create a new channel for transporting bytes.
59    fn create(
60        &self,
61        size: GuestUint,
62        backpressure: ChannelBackpressure,
63    ) -> Result<Self::Channel, Self::Error>;
64
65    /// Delete this channel
66    fn delete(&self, channel: Self::Channel) -> Result<(), Self::Error>;
67
68    /// Terminate this channel whilst allowing unfinished reads/writes to continue
69    fn drain(&self, channel: &Self::Channel) -> Result<(), Self::Error>;
70
71    /// Downgrade this strong writer to a weak variant
72    fn downgrade_writer(&self, writer: Self::StrongWriter)
73    -> Result<Self::WeakWriter, Self::Error>;
74
75    /// Downgrade this strong writer to a weak variant
76    fn downgrade_reader(&self, writer: Self::StrongReader)
77    -> Result<Self::WeakReader, Self::Error>;
78
79    #[doc(hidden)]
80    fn ptr(&self, channel: &Self::Channel) -> String;
81}
82
83/// Reader capable of yielding whole frames with attribution.
84pub trait FrameReadable {
85    fn read_frame(&mut self, max_len: usize) -> FrameReadFuture<'_>;
86}
87
88pub struct ChannelCreateDriver<Impl>(Impl);
89pub struct ChannelDeleteDriver<Impl>(Impl);
90pub struct ChannelDrainDriver<Impl>(Impl);
91pub struct ChannelDowngradeStrongWriterDriver<Impl>(Impl);
92pub struct ChannelExportDriver;
93pub struct ChannelAttachDriver;
94pub struct ChannelDetachDriver;
95
96impl<T> ChannelCapability for Arc<T>
97where
98    T: ChannelCapability + Send + Sync,
99{
100    type Channel = T::Channel;
101    type StrongWriter = T::StrongWriter;
102    type WeakWriter = T::WeakWriter;
103    type StrongReader = T::StrongReader;
104    type WeakReader = T::WeakReader;
105    type Error = T::Error;
106
107    fn create(
108        &self,
109        size: GuestUint,
110        backpressure: ChannelBackpressure,
111    ) -> Result<Self::Channel, Self::Error> {
112        self.as_ref().create(size, backpressure)
113    }
114
115    fn delete(&self, channel: Self::Channel) -> Result<(), Self::Error> {
116        self.as_ref().delete(channel)
117    }
118
119    fn drain(&self, channel: &Self::Channel) -> Result<(), Self::Error> {
120        self.as_ref().drain(channel)
121    }
122
123    fn downgrade_writer(
124        &self,
125        writer: Self::StrongWriter,
126    ) -> Result<Self::WeakWriter, Self::Error> {
127        self.as_ref().downgrade_writer(writer)
128    }
129
130    fn downgrade_reader(
131        &self,
132        reader: Self::StrongReader,
133    ) -> Result<Self::WeakReader, Self::Error> {
134        self.as_ref().downgrade_reader(reader)
135    }
136
137    fn ptr(&self, channel: &Self::Channel) -> String {
138        self.as_ref().ptr(channel)
139    }
140}
141
142impl<Impl> Contract for ChannelCreateDriver<Impl>
143where
144    Impl: ChannelCapability + Clone + Send + 'static,
145{
146    type Input = ChannelCreate;
147    type Output = GuestUint;
148
149    fn to_future(
150        &self,
151        caller: &mut Caller<'_, InstanceRegistry>,
152        args: Self::Input,
153    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
154        let inner = self.0.clone();
155        let registry = caller.data().registry_arc();
156
157        let result = (|| -> GuestResult<GuestUint> {
158            let channel = inner
159                .create(args.capacity, args.backpressure)
160                .map_err(Into::into)?;
161            let ptr = inner.ptr(&channel);
162            let slot = caller
163                .data_mut()
164                .insert(channel, None, ResourceType::Channel)
165                .map_err(GuestError::from)?;
166            if let Some(resource_id) = caller.data().entry(slot) {
167                registry.record_host_ptr(resource_id, &ptr);
168            }
169            let handle = GuestUint::try_from(slot).map_err(|_| GuestError::InvalidArgument)?;
170            Ok(handle)
171        })();
172
173        ready(result)
174    }
175}
176
177impl<Impl> Contract for ChannelDeleteDriver<Impl>
178where
179    Impl: ChannelCapability + Clone + Send + 'static,
180{
181    type Input = GuestUint;
182    type Output = ();
183
184    fn to_future(
185        &self,
186        caller: &mut Caller<'_, InstanceRegistry>,
187        channel_id: Self::Input,
188    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
189        let this = self.0.clone();
190        let result = (|| -> GuestResult<()> {
191            let slot = channel_id as usize;
192            let channel = caller
193                .data_mut()
194                .remove::<Impl::Channel>(slot)
195                .ok_or(GuestError::NotFound)?;
196
197            this.delete(channel).map_err(Into::into)?;
198            Ok(())
199        })();
200
201        ready(result)
202    }
203}
204
205impl<Impl> Contract for ChannelDrainDriver<Impl>
206where
207    Impl: ChannelCapability + Clone + Send + 'static,
208{
209    type Input = u32;
210    type Output = ();
211
212    fn to_future(
213        &self,
214        caller: &mut Caller<'_, InstanceRegistry>,
215        channel_id: Self::Input,
216    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
217        let this = self.0.clone();
218        let result = (|| -> GuestResult<()> {
219            let slot = channel_id as usize;
220            caller
221                .data()
222                .with(slot, |chan| this.drain(chan))
223                .ok_or(GuestError::NotFound)?
224                .map_err(Into::into)?;
225            Ok(())
226        })();
227
228        ready(result)
229    }
230}
231
232impl<Impl> Contract for ChannelDowngradeStrongWriterDriver<Impl>
233where
234    Impl: ChannelCapability + Send + 'static,
235{
236    type Input = GuestUint;
237    type Output = GuestUint;
238
239    fn to_future(
240        &self,
241        caller: &mut Caller<'_, InstanceRegistry>,
242        input: Self::Input,
243    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
244        match caller
245            .data_mut()
246            .remove::<Impl::StrongWriter>(input as usize)
247            .ok_or(GuestError::NotFound)
248            .and_then(|writer| self.0.downgrade_writer(writer).map_err(Into::into))
249        {
250            Ok(writer) => {
251                let result = caller
252                    .data_mut()
253                    .insert(writer, None, ResourceType::Writer)
254                    .map_err(Into::into)
255                    .and_then(|idx| {
256                        GuestUint::try_from(idx).map_err(|_| GuestError::InvalidArgument)
257                    });
258                ready(result)
259            }
260            Err(e) => ready(Err(e)),
261        }
262    }
263}
264
265impl Contract for ChannelExportDriver {
266    type Input = GuestUint;
267    type Output = GuestResourceId;
268
269    fn to_future(
270        &self,
271        caller: &mut Caller<'_, InstanceRegistry>,
272        handle: Self::Input,
273    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
274        let registry = caller.data().registry_arc();
275        let result = caller
276            .data()
277            .entry(handle as usize)
278            .ok_or(GuestError::NotFound)
279            .and_then(|rid| registry.share_handle(rid).map_err(GuestError::from));
280
281        ready(result)
282    }
283}
284
285impl Contract for ChannelAttachDriver {
286    type Input = GuestResourceId;
287    type Output = GuestUint;
288
289    fn to_future(
290        &self,
291        caller: &mut Caller<'_, InstanceRegistry>,
292        resource_id: Self::Input,
293    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
294        let registry = caller.data().registry_arc();
295        let result = registry
296            .resolve_shared(resource_id)
297            .ok_or(GuestError::NotFound)
298            .and_then(|rid| {
299                caller
300                    .data_mut()
301                    .insert_id(rid)
302                    .map_err(GuestError::from)
303                    .and_then(|slot| {
304                        GuestUint::try_from(slot).map_err(|_| GuestError::InvalidArgument)
305                    })
306            });
307
308        ready(result)
309    }
310}
311
312impl Contract for ChannelDetachDriver {
313    type Input = GuestUint;
314    type Output = ();
315
316    fn to_future(
317        &self,
318        caller: &mut Caller<'_, InstanceRegistry>,
319        handle: Self::Input,
320    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
321        let result = caller
322            .data_mut()
323            .detach_slot(handle as usize)
324            .ok_or(GuestError::NotFound)
325            .map(|_| ());
326
327        ready(result)
328    }
329}
330
331impl From<io::Error> for GuestError {
332    fn from(err: io::Error) -> Self {
333        GuestError::Subsystem(err.to_string())
334    }
335}
336
337pub fn read_ops<S, W>(strong_cap: S, weak_cap: W) -> ChannelReadOps<S, W>
338where
339    S: IoCapability + Clone + Send + 'static,
340    W: IoCapability + Clone + Send + 'static,
341{
342    (
343        create_reader_op(
344            strong_cap.clone(),
345            selium_abi::hostcall_contract!(CHANNEL_STRONG_READER_CREATE),
346        ),
347        create_reader_op(
348            weak_cap.clone(),
349            selium_abi::hostcall_contract!(CHANNEL_WEAK_READER_CREATE),
350        ),
351        read_op(
352            strong_cap,
353            selium_abi::hostcall_contract!(CHANNEL_STRONG_READ),
354        ),
355        read_op(weak_cap, selium_abi::hostcall_contract!(CHANNEL_WEAK_READ)),
356    )
357}
358
359pub fn write_ops<S, W>(strong_cap: S, weak_cap: W) -> ChannelWriteOps<S, W>
360where
361    S: IoCapability + Clone + Send + 'static,
362    W: IoCapability + Clone + Send + 'static,
363{
364    (
365        create_writer_op(
366            strong_cap.clone(),
367            selium_abi::hostcall_contract!(CHANNEL_STRONG_WRITER_CREATE),
368        ),
369        create_writer_op(
370            weak_cap.clone(),
371            selium_abi::hostcall_contract!(CHANNEL_WEAK_WRITER_CREATE),
372        ),
373        write_op(
374            strong_cap,
375            selium_abi::hostcall_contract!(CHANNEL_STRONG_WRITE),
376        ),
377        write_op(weak_cap, selium_abi::hostcall_contract!(CHANNEL_WEAK_WRITE)),
378    )
379}
380
381pub fn writer_downgrade_op<C>(ch_cap: C) -> Arc<Operation<ChannelDowngradeStrongWriterDriver<C>>>
382where
383    C: ChannelCapability + 'static,
384{
385    Operation::from_hostcall(
386        ChannelDowngradeStrongWriterDriver(ch_cap),
387        selium_abi::hostcall_contract!(CHANNEL_WRITER_DOWNGRADE),
388    )
389}
390
391pub fn lifecycle_ops<C>(cap: C) -> ChannelLifecycleOps<C>
392where
393    C: ChannelCapability + Clone + 'static,
394{
395    (
396        Operation::from_hostcall(
397            ChannelCreateDriver(cap.clone()),
398            selium_abi::hostcall_contract!(CHANNEL_CREATE),
399        ),
400        Operation::from_hostcall(
401            ChannelDeleteDriver(cap.clone()),
402            selium_abi::hostcall_contract!(CHANNEL_DELETE),
403        ),
404        Operation::from_hostcall(
405            ChannelDrainDriver(cap),
406            selium_abi::hostcall_contract!(CHANNEL_DRAIN),
407        ),
408    )
409}
410
411pub fn handoff_ops() -> ChannelHandoffOps {
412    (
413        Operation::from_hostcall(
414            ChannelExportDriver,
415            selium_abi::hostcall_contract!(CHANNEL_SHARE),
416        ),
417        Operation::from_hostcall(
418            ChannelAttachDriver,
419            selium_abi::hostcall_contract!(CHANNEL_ATTACH),
420        ),
421        Operation::from_hostcall(
422            ChannelDetachDriver,
423            selium_abi::hostcall_contract!(CHANNEL_DETACH),
424        ),
425    )
426}