Skip to main content

selium_kernel/drivers/
net.rs

1use futures_util::future::BoxFuture;
2use std::{future::Future, sync::Arc};
3
4use wasmtime::Caller;
5
6use crate::{
7    drivers::io::{self, IoCapability, IoReadDriver, IoWriteDriver},
8    guest_data::{GuestError, GuestResult},
9    operation::{Contract, Operation},
10    registry::{InstanceRegistry, ResourceHandle, ResourceType},
11};
12use selium_abi::{
13    GuestResourceId, NetAccept, NetAcceptReply, NetConnect, NetConnectReply, NetCreateListener,
14    NetCreateListenerReply, NetProtocol, hostcalls::Hostcall,
15};
16
17type NetFuture<'a, T, E> = BoxFuture<'a, Result<T, E>>;
18type NetIoFuture<'a, R, W, E> = BoxFuture<'a, Result<(R, W, String), E>>;
19
20pub trait NetCapability {
21    type Handle: Send + Unpin;
22    type Reader: Send + Unpin;
23    type Writer: Send + Unpin;
24    type Error: Into<GuestError>;
25
26    /// Creates a new network listener for the selected protocol, returning a handle
27    /// to the listener.
28    fn create(
29        &self,
30        protocol: NetProtocol,
31        domain: &str,
32        port: u16,
33        tls: Option<Arc<TlsServerConfig>>,
34    ) -> NetFuture<'_, Self::Handle, Self::Error>;
35
36    /// Connect to a remote listener for the selected protocol, returning
37    /// bidirectional comms.
38    fn connect(
39        &self,
40        protocol: NetProtocol,
41        domain: &str,
42        port: u16,
43        tls: Option<Arc<TlsClientConfig>>,
44    ) -> NetIoFuture<'_, Self::Reader, Self::Writer, Self::Error>;
45
46    /// Accept a new inbound connection for the listener represented by `handle`.
47    fn accept(
48        &self,
49        handle: &Self::Handle,
50    ) -> NetIoFuture<'_, Self::Reader, Self::Writer, Self::Error>;
51}
52
53/// TLS configuration supplied for server listeners.
54#[derive(Clone, Debug, PartialEq, Eq)]
55pub struct TlsServerConfig {
56    /// PEM-encoded certificate chain presented by the server.
57    pub cert_chain_pem: Vec<u8>,
58    /// PEM-encoded private key for the certificate chain.
59    pub private_key_pem: Vec<u8>,
60    /// PEM-encoded CA bundle used to verify client certificates.
61    pub client_ca_pem: Option<Vec<u8>>,
62    /// Optional ALPN protocol list override.
63    pub alpn: Option<Vec<String>>,
64    /// Require client authentication when true.
65    pub require_client_auth: bool,
66}
67
68/// TLS configuration supplied for client connections.
69#[derive(Clone, Debug, PartialEq, Eq)]
70pub struct TlsClientConfig {
71    /// PEM-encoded CA bundle used to verify servers.
72    pub ca_bundle_pem: Option<Vec<u8>>,
73    /// PEM-encoded client certificate chain.
74    pub client_cert_pem: Option<Vec<u8>>,
75    /// PEM-encoded private key for the client certificate.
76    pub client_key_pem: Option<Vec<u8>>,
77    /// Optional ALPN protocol list override.
78    pub alpn: Option<Vec<String>>,
79}
80
81/// Driver creating network listeners.
82pub struct BindDriver<Impl>(Impl);
83/// Driver opening outbound network connections.
84pub struct ConnectDriver<Impl>(Impl);
85/// Driver responsible for accepting inbound network connections.
86pub struct AcceptDriver<Impl>(Impl);
87
88impl<T> NetCapability for Arc<T>
89where
90    T: NetCapability,
91{
92    type Handle = T::Handle;
93    type Reader = T::Reader;
94    type Writer = T::Writer;
95    type Error = T::Error;
96
97    fn create(
98        &self,
99        protocol: NetProtocol,
100        domain: &str,
101        port: u16,
102        tls: Option<Arc<TlsServerConfig>>,
103    ) -> NetFuture<'_, Self::Handle, Self::Error> {
104        self.as_ref().create(protocol, domain, port, tls)
105    }
106
107    fn connect(
108        &self,
109        protocol: NetProtocol,
110        domain: &str,
111        port: u16,
112        tls: Option<Arc<TlsClientConfig>>,
113    ) -> NetIoFuture<'_, Self::Reader, Self::Writer, Self::Error> {
114        self.as_ref().connect(protocol, domain, port, tls)
115    }
116
117    fn accept(
118        &self,
119        handle: &Self::Handle,
120    ) -> NetIoFuture<'_, Self::Reader, Self::Writer, Self::Error> {
121        self.as_ref().accept(handle)
122    }
123}
124
125impl<Impl> Contract for BindDriver<Impl>
126where
127    Impl: NetCapability + Clone + Send + 'static,
128    Impl::Handle: Send + Unpin,
129{
130    type Input = NetCreateListener;
131    type Output = NetCreateListenerReply;
132
133    fn to_future(
134        &self,
135        caller: &mut Caller<'_, InstanceRegistry>,
136        input: Self::Input,
137    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
138        let inner = self.0.clone();
139        let registrar = caller.data().registrar();
140        let registry = caller.data().registry_arc();
141        let NetCreateListener {
142            protocol,
143            domain,
144            port,
145            tls,
146        } = input;
147        let tls = resolve_tls_server_config(caller.data(), &registry, protocol, tls);
148
149        async move {
150            let handle = inner
151                .create(protocol, &domain, port, tls?)
152                .await
153                .map_err(Into::into)?;
154
155            let slot = registrar
156                .insert(handle, None, ResourceType::Network)
157                .map_err(GuestError::from)?;
158            let handle =
159                GuestResourceId::try_from(slot).map_err(|_| GuestError::InvalidArgument)?;
160
161            Ok(NetCreateListenerReply { handle })
162        }
163    }
164}
165
166impl<Impl> Contract for AcceptDriver<Impl>
167where
168    Impl: NetCapability + Clone + Send + 'static,
169    Impl::Handle: Send + Unpin,
170    <Impl as NetCapability>::Reader: Send + Unpin,
171    <Impl as NetCapability>::Writer: Send + Unpin,
172{
173    type Input = NetAccept;
174    type Output = NetAcceptReply;
175
176    fn to_future(
177        &self,
178        caller: &mut Caller<'_, InstanceRegistry>,
179        input: Self::Input,
180    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
181        let inner = self.0.clone();
182        let registrar = caller.data().registrar();
183        let registry = caller.data().registry_arc();
184        let handle = (|| {
185            let slot = usize::try_from(input.handle).map_err(|_| GuestError::InvalidArgument)?;
186            caller.data().entry(slot).ok_or(GuestError::NotFound)
187        })();
188
189        async move {
190            let handle_resource = handle?;
191
192            let (reader, writer, remote_addr) = registry
193                .with_async(
194                    ResourceHandle::<Impl::Handle>::new(handle_resource),
195                    move |handle| Box::pin(async move { inner.accept(handle).await }),
196                )
197                .await
198                .expect("Invalid resource id from InstanceRegistry")
199                .map_err(Into::into)?;
200
201            let reader_slot = registrar
202                .insert(reader, None, ResourceType::Reader)
203                .map_err(GuestError::from)?;
204            let writer_slot = registrar
205                .insert(writer, None, ResourceType::Writer)
206                .map_err(GuestError::from)?;
207
208            let reader =
209                GuestResourceId::try_from(reader_slot).map_err(|_| GuestError::InvalidArgument)?;
210            let writer =
211                GuestResourceId::try_from(writer_slot).map_err(|_| GuestError::InvalidArgument)?;
212
213            Ok(NetAcceptReply {
214                reader,
215                writer,
216                remote_addr,
217            })
218        }
219    }
220}
221
222impl<Impl> Contract for ConnectDriver<Impl>
223where
224    Impl: NetCapability + Clone + Send + 'static,
225    <Impl as NetCapability>::Reader: Send + Unpin,
226    <Impl as NetCapability>::Writer: Send + Unpin,
227{
228    type Input = NetConnect;
229    type Output = NetConnectReply;
230
231    fn to_future(
232        &self,
233        caller: &mut Caller<'_, InstanceRegistry>,
234        input: Self::Input,
235    ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
236        let inner = self.0.clone();
237        let registrar = caller.data().registrar();
238        let registry = caller.data().registry_arc();
239        let NetConnect {
240            protocol,
241            domain,
242            port,
243            tls,
244        } = input;
245        let tls = resolve_tls_client_config(caller.data(), &registry, protocol, tls);
246
247        async move {
248            let (reader, writer, remote_addr) = inner
249                .connect(protocol, &domain, port, tls?)
250                .await
251                .map_err(Into::into)?;
252
253            let reader_slot = registrar
254                .insert(reader, None, ResourceType::Reader)
255                .map_err(GuestError::from)?;
256            let writer_slot = registrar
257                .insert(writer, None, ResourceType::Writer)
258                .map_err(GuestError::from)?;
259
260            let reader =
261                GuestResourceId::try_from(reader_slot).map_err(|_| GuestError::InvalidArgument)?;
262            let writer =
263                GuestResourceId::try_from(writer_slot).map_err(|_| GuestError::InvalidArgument)?;
264
265            Ok(NetConnectReply {
266                reader,
267                writer,
268                remote_addr,
269            })
270        }
271    }
272}
273
274pub fn listener_op<C>(cap: C, protocol: NetProtocol) -> Arc<Operation<BindDriver<C>>>
275where
276    C: NetCapability + Clone + Send + 'static,
277{
278    let hostcall = hostcall_for_protocol(
279        protocol,
280        selium_abi::hostcall_contract!(NET_QUIC_BIND),
281        selium_abi::hostcall_contract!(NET_HTTP_BIND),
282    );
283    Operation::from_hostcall(BindDriver(cap), hostcall)
284}
285
286pub fn connect_op<C>(cap: C, protocol: NetProtocol) -> Arc<Operation<ConnectDriver<C>>>
287where
288    C: NetCapability + Clone + Send + 'static,
289{
290    let hostcall = hostcall_for_protocol(
291        protocol,
292        selium_abi::hostcall_contract!(NET_QUIC_CONNECT),
293        selium_abi::hostcall_contract!(NET_HTTP_CONNECT),
294    );
295    Operation::from_hostcall(ConnectDriver(cap), hostcall)
296}
297
298/// Host operation for accepting inbound connections on a listener.
299pub fn accept_op<C>(cap: C, protocol: NetProtocol) -> Arc<Operation<AcceptDriver<C>>>
300where
301    C: NetCapability + Clone + Send + 'static,
302{
303    let hostcall = hostcall_for_protocol(
304        protocol,
305        selium_abi::hostcall_contract!(NET_QUIC_ACCEPT),
306        selium_abi::hostcall_contract!(NET_HTTP_ACCEPT),
307    );
308    Operation::from_hostcall(AcceptDriver(cap), hostcall)
309}
310
311pub fn read_op<C>(cap: C, protocol: NetProtocol) -> Arc<Operation<IoReadDriver<C>>>
312where
313    C: IoCapability + Clone + Send + 'static,
314{
315    let hostcall = hostcall_for_protocol(
316        protocol,
317        selium_abi::hostcall_contract!(NET_QUIC_READ),
318        selium_abi::hostcall_contract!(NET_HTTP_READ),
319    );
320    io::read_op(cap, hostcall)
321}
322
323pub fn write_op<C>(cap: C, protocol: NetProtocol) -> Arc<Operation<IoWriteDriver<C>>>
324where
325    C: IoCapability + Clone + Send + 'static,
326{
327    let hostcall = hostcall_for_protocol(
328        protocol,
329        selium_abi::hostcall_contract!(NET_QUIC_WRITE),
330        selium_abi::hostcall_contract!(NET_HTTP_WRITE),
331    );
332    io::write_op(cap, hostcall)
333}
334
335fn resolve_tls_server_config(
336    instance: &InstanceRegistry,
337    registry: &Arc<crate::registry::Registry>,
338    protocol: NetProtocol,
339    handle: Option<GuestResourceId>,
340) -> GuestResult<Option<Arc<TlsServerConfig>>> {
341    let Some(handle) = handle else {
342        return Ok(None);
343    };
344    if matches!(protocol, NetProtocol::Http) {
345        return Err(GuestError::InvalidArgument);
346    }
347    let slot = usize::try_from(handle).map_err(|_| GuestError::InvalidArgument)?;
348    let resource_id = instance.entry(slot).ok_or(GuestError::NotFound)?;
349    let config = registry
350        .with(
351            ResourceHandle::<Arc<TlsServerConfig>>::new(resource_id),
352            |config| Arc::clone(config),
353        )
354        .ok_or(GuestError::NotFound)?;
355    Ok(Some(config))
356}
357
358fn resolve_tls_client_config(
359    instance: &InstanceRegistry,
360    registry: &Arc<crate::registry::Registry>,
361    protocol: NetProtocol,
362    handle: Option<GuestResourceId>,
363) -> GuestResult<Option<Arc<TlsClientConfig>>> {
364    let Some(handle) = handle else {
365        return Ok(None);
366    };
367    if matches!(protocol, NetProtocol::Http) {
368        return Err(GuestError::InvalidArgument);
369    }
370    let slot = usize::try_from(handle).map_err(|_| GuestError::InvalidArgument)?;
371    let resource_id = instance.entry(slot).ok_or(GuestError::NotFound)?;
372    let config = registry
373        .with(
374            ResourceHandle::<Arc<TlsClientConfig>>::new(resource_id),
375            |config| Arc::clone(config),
376        )
377        .ok_or(GuestError::NotFound)?;
378    Ok(Some(config))
379}
380
381fn hostcall_for_protocol<I, O>(
382    protocol: NetProtocol,
383    quic: &'static Hostcall<I, O>,
384    http: &'static Hostcall<I, O>,
385) -> &'static Hostcall<I, O> {
386    match protocol {
387        NetProtocol::Quic => quic,
388        NetProtocol::Http | NetProtocol::Https => http,
389    }
390}