selium_kernel/drivers/
channel.rs

1use std::{
2    future::{Future, ready},
3    io,
4    pin::Pin,
5    sync::Arc,
6};
7
8use wasmtime::Caller;
9
10use crate::{
11    drivers::io::{
12        IoCapability, IoCreateReaderDriver, IoCreateWriterDriver, IoReadDriver, IoWriteDriver,
13        create_reader_op, create_writer_op, read_op, write_op,
14    },
15    guest_data::{GuestError, GuestResult, GuestUint},
16    operation::{Contract, Operation},
17    registry::{InstanceRegistry, ResourceType},
18};
19use selium_abi::GuestResourceId;
20
21type ChannelLifecycleOps<C> = (
22    Arc<Operation<ChannelCreateDriver<C>>>,
23    Arc<Operation<ChannelDeleteDriver<C>>>,
24    Arc<Operation<ChannelDrainDriver<C>>>,
25);
26
27type ChannelHandoffOps = (
28    Arc<Operation<ChannelExportDriver>>,
29    Arc<Operation<ChannelAttachDriver>>,
30    Arc<Operation<ChannelDetachDriver>>,
31);
32
33type FrameReadFuture<'a> = Pin<Box<dyn Future<Output = io::Result<(u16, Vec<u8>)>> + Send + 'a>>;
34
35type ChannelReadOps<S, W> = (
36    Arc<Operation<IoCreateReaderDriver<S>>>,
37    Arc<Operation<IoCreateReaderDriver<W>>>,
38    Arc<Operation<IoReadDriver<S>>>,
39    Arc<Operation<IoReadDriver<W>>>,
40);
41
42type ChannelWriteOps<S, W> = (
43    Arc<Operation<IoCreateWriterDriver<S>>>,
44    Arc<Operation<IoCreateWriterDriver<W>>>,
45    Arc<Operation<IoWriteDriver<S>>>,
46    Arc<Operation<IoWriteDriver<W>>>,
47);
48
49/// The capabilities that any subsystem implementation needs to provide
50pub trait ChannelCapability: Send + Sync {
51    type Channel: Send;
52    type StrongWriter: Send + Unpin;
53    type WeakWriter: Send + Unpin;
54    type StrongReader: Send + Unpin;
55    type WeakReader: Send + Unpin;
56    type Error: Into<GuestError>;
57
58    /// Create a new channel for transporting bytes
59    fn create(&self, size: GuestUint) -> Result<Self::Channel, Self::Error>;
60
61    /// Delete this channel
62    fn delete(&self, channel: Self::Channel) -> Result<(), Self::Error>;
63
64    /// Terminate this channel whilst allowing unfinished reads/writes to continue
65    fn drain(&self, channel: &Self::Channel) -> Result<(), Self::Error>;
66
67    /// Downgrade this strong writer to a weak variant
68    fn downgrade_writer(&self, writer: Self::StrongWriter)
69    -> Result<Self::WeakWriter, Self::Error>;
70
71    /// Downgrade this strong writer to a weak variant
72    fn downgrade_reader(&self, writer: Self::StrongReader)
73    -> Result<Self::WeakReader, Self::Error>;
74
75    #[doc(hidden)]
76    fn ptr(&self, channel: &Self::Channel) -> String;
77}
78
79/// Reader capable of yielding whole frames with attribution.
80pub trait FrameReadable {
81    fn read_frame(&mut self, max_len: usize) -> FrameReadFuture<'_>;
82}
83
84pub struct ChannelCreateDriver<Impl>(Impl);
85pub struct ChannelDeleteDriver<Impl>(Impl);
86pub struct ChannelDrainDriver<Impl>(Impl);
87pub struct ChannelDowngradeStrongWriterDriver<Impl>(Impl);
88pub struct ChannelExportDriver;
89pub struct ChannelAttachDriver;
90pub struct ChannelDetachDriver;
91
92impl<T> ChannelCapability for Arc<T>
93where
94    T: ChannelCapability + Send + Sync,
95{
96    type Channel = T::Channel;
97    type StrongWriter = T::StrongWriter;
98    type WeakWriter = T::WeakWriter;
99    type StrongReader = T::StrongReader;
100    type WeakReader = T::WeakReader;
101    type Error = T::Error;
102
103    fn create(&self, size: GuestUint) -> Result<Self::Channel, Self::Error> {
104        self.as_ref().create(size)
105    }
106
107    fn delete(&self, channel: Self::Channel) -> Result<(), Self::Error> {
108        self.as_ref().delete(channel)
109    }
110
111    fn drain(&self, channel: &Self::Channel) -> Result<(), Self::Error> {
112        self.as_ref().drain(channel)
113    }
114
115    fn downgrade_writer(
116        &self,
117        writer: Self::StrongWriter,
118    ) -> Result<Self::WeakWriter, Self::Error> {
119        self.as_ref().downgrade_writer(writer)
120    }
121
122    fn downgrade_reader(
123        &self,
124        reader: Self::StrongReader,
125    ) -> Result<Self::WeakReader, Self::Error> {
126        self.as_ref().downgrade_reader(reader)
127    }
128
129    fn ptr(&self, channel: &Self::Channel) -> String {
130        self.as_ref().ptr(channel)
131    }
132}
133
134impl<Impl> Contract for ChannelCreateDriver<Impl>
135where
136    Impl: ChannelCapability + Clone + Send + 'static,
137{
138    type Input = GuestUint;
139    type Output = GuestUint;
140
141    fn to_future(
142        &self,
143        caller: &mut Caller<'_, InstanceRegistry>,
144        size: Self::Input,
145    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
146        let inner = self.0.clone();
147        let registry = caller.data().registry_arc();
148
149        let result = (|| -> GuestResult<GuestUint> {
150            let channel = inner.create(size).map_err(Into::into)?;
151            let ptr = inner.ptr(&channel);
152            let slot = caller
153                .data_mut()
154                .insert(channel, None, ResourceType::Channel)
155                .map_err(GuestError::from)?;
156            if let Some(resource_id) = caller.data().entry(slot) {
157                registry.record_host_ptr(resource_id, &ptr);
158            }
159            let handle = GuestUint::try_from(slot).map_err(|_| GuestError::InvalidArgument)?;
160            Ok(handle)
161        })();
162
163        ready(result)
164    }
165}
166
167impl<Impl> Contract for ChannelDeleteDriver<Impl>
168where
169    Impl: ChannelCapability + Clone + Send + 'static,
170{
171    type Input = GuestUint;
172    type Output = ();
173
174    fn to_future(
175        &self,
176        caller: &mut Caller<'_, InstanceRegistry>,
177        channel_id: Self::Input,
178    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
179        let this = self.0.clone();
180        let result = (|| -> GuestResult<()> {
181            let slot = channel_id as usize;
182            let channel = caller
183                .data_mut()
184                .remove::<Impl::Channel>(slot)
185                .ok_or(GuestError::NotFound)?;
186
187            this.delete(channel).map_err(Into::into)?;
188            Ok(())
189        })();
190
191        ready(result)
192    }
193}
194
195impl<Impl> Contract for ChannelDrainDriver<Impl>
196where
197    Impl: ChannelCapability + Clone + Send + 'static,
198{
199    type Input = u32;
200    type Output = ();
201
202    fn to_future(
203        &self,
204        caller: &mut Caller<'_, InstanceRegistry>,
205        channel_id: Self::Input,
206    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
207        let this = self.0.clone();
208        let result = (|| -> GuestResult<()> {
209            let slot = channel_id as usize;
210            caller
211                .data()
212                .with(slot, |chan| this.drain(chan))
213                .ok_or(GuestError::NotFound)?
214                .map_err(Into::into)?;
215            Ok(())
216        })();
217
218        ready(result)
219    }
220}
221
222impl<Impl> Contract for ChannelDowngradeStrongWriterDriver<Impl>
223where
224    Impl: ChannelCapability + Send + 'static,
225{
226    type Input = GuestUint;
227    type Output = GuestUint;
228
229    fn to_future(
230        &self,
231        caller: &mut Caller<'_, InstanceRegistry>,
232        input: Self::Input,
233    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
234        match caller
235            .data_mut()
236            .remove::<Impl::StrongWriter>(input as usize)
237            .ok_or(GuestError::NotFound)
238            .and_then(|writer| self.0.downgrade_writer(writer).map_err(Into::into))
239        {
240            Ok(writer) => {
241                let result = caller
242                    .data_mut()
243                    .insert(writer, None, ResourceType::Writer)
244                    .map_err(Into::into)
245                    .and_then(|idx| {
246                        GuestUint::try_from(idx).map_err(|_| GuestError::InvalidArgument)
247                    });
248                ready(result)
249            }
250            Err(e) => ready(Err(e)),
251        }
252    }
253}
254
255impl Contract for ChannelExportDriver {
256    type Input = GuestUint;
257    type Output = GuestResourceId;
258
259    fn to_future(
260        &self,
261        caller: &mut Caller<'_, InstanceRegistry>,
262        handle: Self::Input,
263    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
264        let registry = caller.data().registry_arc();
265        let result = caller
266            .data()
267            .entry(handle as usize)
268            .ok_or(GuestError::NotFound)
269            .and_then(|rid| registry.share_handle(rid).map_err(GuestError::from));
270
271        ready(result)
272    }
273}
274
275impl Contract for ChannelAttachDriver {
276    type Input = GuestResourceId;
277    type Output = GuestUint;
278
279    fn to_future(
280        &self,
281        caller: &mut Caller<'_, InstanceRegistry>,
282        resource_id: Self::Input,
283    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
284        let registry = caller.data().registry_arc();
285        let result = registry
286            .resolve_shared(resource_id)
287            .ok_or(GuestError::NotFound)
288            .and_then(|rid| {
289                caller
290                    .data_mut()
291                    .insert_id(rid)
292                    .map_err(GuestError::from)
293                    .and_then(|slot| {
294                        GuestUint::try_from(slot).map_err(|_| GuestError::InvalidArgument)
295                    })
296            });
297
298        ready(result)
299    }
300}
301
302impl Contract for ChannelDetachDriver {
303    type Input = GuestUint;
304    type Output = ();
305
306    fn to_future(
307        &self,
308        caller: &mut Caller<'_, InstanceRegistry>,
309        handle: Self::Input,
310    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
311        let result = caller
312            .data_mut()
313            .detach_slot(handle as usize)
314            .ok_or(GuestError::NotFound)
315            .map(|_| ());
316
317        ready(result)
318    }
319}
320
321impl From<io::Error> for GuestError {
322    fn from(err: io::Error) -> Self {
323        GuestError::Subsystem(err.to_string())
324    }
325}
326
327pub fn read_ops<S, W>(strong_cap: S, weak_cap: W) -> ChannelReadOps<S, W>
328where
329    S: IoCapability + Clone + Send + 'static,
330    W: IoCapability + Clone + Send + 'static,
331{
332    (
333        create_reader_op(
334            strong_cap.clone(),
335            selium_abi::hostcall_contract!(CHANNEL_STRONG_READER_CREATE),
336        ),
337        create_reader_op(
338            weak_cap.clone(),
339            selium_abi::hostcall_contract!(CHANNEL_WEAK_READER_CREATE),
340        ),
341        read_op(
342            strong_cap,
343            selium_abi::hostcall_contract!(CHANNEL_STRONG_READ),
344        ),
345        read_op(weak_cap, selium_abi::hostcall_contract!(CHANNEL_WEAK_READ)),
346    )
347}
348
349pub fn write_ops<S, W>(strong_cap: S, weak_cap: W) -> ChannelWriteOps<S, W>
350where
351    S: IoCapability + Clone + Send + 'static,
352    W: IoCapability + Clone + Send + 'static,
353{
354    (
355        create_writer_op(
356            strong_cap.clone(),
357            selium_abi::hostcall_contract!(CHANNEL_STRONG_WRITER_CREATE),
358        ),
359        create_writer_op(
360            weak_cap.clone(),
361            selium_abi::hostcall_contract!(CHANNEL_WEAK_WRITER_CREATE),
362        ),
363        write_op(
364            strong_cap,
365            selium_abi::hostcall_contract!(CHANNEL_STRONG_WRITE),
366        ),
367        write_op(weak_cap, selium_abi::hostcall_contract!(CHANNEL_WEAK_WRITE)),
368    )
369}
370
371pub fn writer_downgrade_op<C>(ch_cap: C) -> Arc<Operation<ChannelDowngradeStrongWriterDriver<C>>>
372where
373    C: ChannelCapability + 'static,
374{
375    Operation::from_hostcall(
376        ChannelDowngradeStrongWriterDriver(ch_cap),
377        selium_abi::hostcall_contract!(CHANNEL_WRITER_DOWNGRADE),
378    )
379}
380
381pub fn lifecycle_ops<C>(cap: C) -> ChannelLifecycleOps<C>
382where
383    C: ChannelCapability + Clone + 'static,
384{
385    (
386        Operation::from_hostcall(
387            ChannelCreateDriver(cap.clone()),
388            selium_abi::hostcall_contract!(CHANNEL_CREATE),
389        ),
390        Operation::from_hostcall(
391            ChannelDeleteDriver(cap.clone()),
392            selium_abi::hostcall_contract!(CHANNEL_DELETE),
393        ),
394        Operation::from_hostcall(
395            ChannelDrainDriver(cap),
396            selium_abi::hostcall_contract!(CHANNEL_DRAIN),
397        ),
398    )
399}
400
401pub fn handoff_ops() -> ChannelHandoffOps {
402    (
403        Operation::from_hostcall(
404            ChannelExportDriver,
405            selium_abi::hostcall_contract!(CHANNEL_SHARE),
406        ),
407        Operation::from_hostcall(
408            ChannelAttachDriver,
409            selium_abi::hostcall_contract!(CHANNEL_ATTACH),
410        ),
411        Operation::from_hostcall(
412            ChannelDetachDriver,
413            selium_abi::hostcall_contract!(CHANNEL_DETACH),
414        ),
415    )
416}