1use futures_util::future::BoxFuture;
2use std::{future::Future, sync::Arc};
3
4use wasmtime::Caller;
5
6use crate::{
7 drivers::io::{self, IoCapability, IoReadDriver, IoWriteDriver},
8 guest_data::{GuestError, GuestResult},
9 operation::{Contract, Operation},
10 registry::{InstanceRegistry, ResourceHandle, ResourceType},
11};
12use selium_abi::{
13 GuestResourceId, NetAccept, NetAcceptReply, NetConnect, NetConnectReply, NetCreateListener,
14 NetCreateListenerReply, NetProtocol, hostcalls::Hostcall,
15};
16
17type NetFuture<'a, T, E> = BoxFuture<'a, Result<T, E>>;
18type NetIoFuture<'a, R, W, E> = BoxFuture<'a, Result<(R, W, String), E>>;
19
20pub trait NetCapability {
21 type Handle: Send + Unpin;
22 type Reader: Send + Unpin;
23 type Writer: Send + Unpin;
24 type Error: Into<GuestError>;
25
26 fn create(
29 &self,
30 protocol: NetProtocol,
31 domain: &str,
32 port: u16,
33 tls: Option<Arc<TlsServerConfig>>,
34 ) -> NetFuture<'_, Self::Handle, Self::Error>;
35
36 fn connect(
39 &self,
40 protocol: NetProtocol,
41 domain: &str,
42 port: u16,
43 tls: Option<Arc<TlsClientConfig>>,
44 ) -> NetIoFuture<'_, Self::Reader, Self::Writer, Self::Error>;
45
46 fn accept(
48 &self,
49 handle: &Self::Handle,
50 ) -> NetIoFuture<'_, Self::Reader, Self::Writer, Self::Error>;
51}
52
53#[derive(Clone, Debug, PartialEq, Eq)]
55pub struct TlsServerConfig {
56 pub cert_chain_pem: Vec<u8>,
58 pub private_key_pem: Vec<u8>,
60 pub client_ca_pem: Option<Vec<u8>>,
62 pub alpn: Option<Vec<String>>,
64 pub require_client_auth: bool,
66}
67
68#[derive(Clone, Debug, PartialEq, Eq)]
70pub struct TlsClientConfig {
71 pub ca_bundle_pem: Option<Vec<u8>>,
73 pub client_cert_pem: Option<Vec<u8>>,
75 pub client_key_pem: Option<Vec<u8>>,
77 pub alpn: Option<Vec<String>>,
79}
80
81pub struct BindDriver<Impl>(Impl);
83pub struct ConnectDriver<Impl>(Impl);
85pub struct AcceptDriver<Impl>(Impl);
87
88impl<T> NetCapability for Arc<T>
89where
90 T: NetCapability,
91{
92 type Handle = T::Handle;
93 type Reader = T::Reader;
94 type Writer = T::Writer;
95 type Error = T::Error;
96
97 fn create(
98 &self,
99 protocol: NetProtocol,
100 domain: &str,
101 port: u16,
102 tls: Option<Arc<TlsServerConfig>>,
103 ) -> NetFuture<'_, Self::Handle, Self::Error> {
104 self.as_ref().create(protocol, domain, port, tls)
105 }
106
107 fn connect(
108 &self,
109 protocol: NetProtocol,
110 domain: &str,
111 port: u16,
112 tls: Option<Arc<TlsClientConfig>>,
113 ) -> NetIoFuture<'_, Self::Reader, Self::Writer, Self::Error> {
114 self.as_ref().connect(protocol, domain, port, tls)
115 }
116
117 fn accept(
118 &self,
119 handle: &Self::Handle,
120 ) -> NetIoFuture<'_, Self::Reader, Self::Writer, Self::Error> {
121 self.as_ref().accept(handle)
122 }
123}
124
125impl<Impl> Contract for BindDriver<Impl>
126where
127 Impl: NetCapability + Clone + Send + 'static,
128 Impl::Handle: Send + Unpin,
129{
130 type Input = NetCreateListener;
131 type Output = NetCreateListenerReply;
132
133 fn to_future(
134 &self,
135 caller: &mut Caller<'_, InstanceRegistry>,
136 input: Self::Input,
137 ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
138 let inner = self.0.clone();
139 let registrar = caller.data().registrar();
140 let registry = caller.data().registry_arc();
141 let NetCreateListener {
142 protocol,
143 domain,
144 port,
145 tls,
146 } = input;
147 let tls = resolve_tls_server_config(caller.data(), ®istry, protocol, tls);
148
149 async move {
150 let handle = inner
151 .create(protocol, &domain, port, tls?)
152 .await
153 .map_err(Into::into)?;
154
155 let slot = registrar
156 .insert(handle, None, ResourceType::Network)
157 .map_err(GuestError::from)?;
158 let handle =
159 GuestResourceId::try_from(slot).map_err(|_| GuestError::InvalidArgument)?;
160
161 Ok(NetCreateListenerReply { handle })
162 }
163 }
164}
165
166impl<Impl> Contract for AcceptDriver<Impl>
167where
168 Impl: NetCapability + Clone + Send + 'static,
169 Impl::Handle: Send + Unpin,
170 <Impl as NetCapability>::Reader: Send + Unpin,
171 <Impl as NetCapability>::Writer: Send + Unpin,
172{
173 type Input = NetAccept;
174 type Output = NetAcceptReply;
175
176 fn to_future(
177 &self,
178 caller: &mut Caller<'_, InstanceRegistry>,
179 input: Self::Input,
180 ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
181 let inner = self.0.clone();
182 let registrar = caller.data().registrar();
183 let registry = caller.data().registry_arc();
184 let handle = (|| {
185 let slot = usize::try_from(input.handle).map_err(|_| GuestError::InvalidArgument)?;
186 caller.data().entry(slot).ok_or(GuestError::NotFound)
187 })();
188
189 async move {
190 let handle_resource = handle?;
191
192 let (reader, writer, remote_addr) = registry
193 .with_async(
194 ResourceHandle::<Impl::Handle>::new(handle_resource),
195 move |handle| Box::pin(async move { inner.accept(handle).await }),
196 )
197 .await
198 .expect("Invalid resource id from InstanceRegistry")
199 .map_err(Into::into)?;
200
201 let reader_slot = registrar
202 .insert(reader, None, ResourceType::Reader)
203 .map_err(GuestError::from)?;
204 let writer_slot = registrar
205 .insert(writer, None, ResourceType::Writer)
206 .map_err(GuestError::from)?;
207
208 let reader =
209 GuestResourceId::try_from(reader_slot).map_err(|_| GuestError::InvalidArgument)?;
210 let writer =
211 GuestResourceId::try_from(writer_slot).map_err(|_| GuestError::InvalidArgument)?;
212
213 Ok(NetAcceptReply {
214 reader,
215 writer,
216 remote_addr,
217 })
218 }
219 }
220}
221
222impl<Impl> Contract for ConnectDriver<Impl>
223where
224 Impl: NetCapability + Clone + Send + 'static,
225 <Impl as NetCapability>::Reader: Send + Unpin,
226 <Impl as NetCapability>::Writer: Send + Unpin,
227{
228 type Input = NetConnect;
229 type Output = NetConnectReply;
230
231 fn to_future(
232 &self,
233 caller: &mut Caller<'_, InstanceRegistry>,
234 input: Self::Input,
235 ) -> impl Future<Output = GuestResult<Self::Output>> + 'static {
236 let inner = self.0.clone();
237 let registrar = caller.data().registrar();
238 let registry = caller.data().registry_arc();
239 let NetConnect {
240 protocol,
241 domain,
242 port,
243 tls,
244 } = input;
245 let tls = resolve_tls_client_config(caller.data(), ®istry, protocol, tls);
246
247 async move {
248 let (reader, writer, remote_addr) = inner
249 .connect(protocol, &domain, port, tls?)
250 .await
251 .map_err(Into::into)?;
252
253 let reader_slot = registrar
254 .insert(reader, None, ResourceType::Reader)
255 .map_err(GuestError::from)?;
256 let writer_slot = registrar
257 .insert(writer, None, ResourceType::Writer)
258 .map_err(GuestError::from)?;
259
260 let reader =
261 GuestResourceId::try_from(reader_slot).map_err(|_| GuestError::InvalidArgument)?;
262 let writer =
263 GuestResourceId::try_from(writer_slot).map_err(|_| GuestError::InvalidArgument)?;
264
265 Ok(NetConnectReply {
266 reader,
267 writer,
268 remote_addr,
269 })
270 }
271 }
272}
273
274pub fn listener_op<C>(cap: C, protocol: NetProtocol) -> Arc<Operation<BindDriver<C>>>
275where
276 C: NetCapability + Clone + Send + 'static,
277{
278 let hostcall = hostcall_for_protocol(
279 protocol,
280 selium_abi::hostcall_contract!(NET_QUIC_BIND),
281 selium_abi::hostcall_contract!(NET_HTTP_BIND),
282 );
283 Operation::from_hostcall(BindDriver(cap), hostcall)
284}
285
286pub fn connect_op<C>(cap: C, protocol: NetProtocol) -> Arc<Operation<ConnectDriver<C>>>
287where
288 C: NetCapability + Clone + Send + 'static,
289{
290 let hostcall = hostcall_for_protocol(
291 protocol,
292 selium_abi::hostcall_contract!(NET_QUIC_CONNECT),
293 selium_abi::hostcall_contract!(NET_HTTP_CONNECT),
294 );
295 Operation::from_hostcall(ConnectDriver(cap), hostcall)
296}
297
298pub fn accept_op<C>(cap: C, protocol: NetProtocol) -> Arc<Operation<AcceptDriver<C>>>
300where
301 C: NetCapability + Clone + Send + 'static,
302{
303 let hostcall = hostcall_for_protocol(
304 protocol,
305 selium_abi::hostcall_contract!(NET_QUIC_ACCEPT),
306 selium_abi::hostcall_contract!(NET_HTTP_ACCEPT),
307 );
308 Operation::from_hostcall(AcceptDriver(cap), hostcall)
309}
310
311pub fn read_op<C>(cap: C, protocol: NetProtocol) -> Arc<Operation<IoReadDriver<C>>>
312where
313 C: IoCapability + Clone + Send + 'static,
314{
315 let hostcall = hostcall_for_protocol(
316 protocol,
317 selium_abi::hostcall_contract!(NET_QUIC_READ),
318 selium_abi::hostcall_contract!(NET_HTTP_READ),
319 );
320 io::read_op(cap, hostcall)
321}
322
323pub fn write_op<C>(cap: C, protocol: NetProtocol) -> Arc<Operation<IoWriteDriver<C>>>
324where
325 C: IoCapability + Clone + Send + 'static,
326{
327 let hostcall = hostcall_for_protocol(
328 protocol,
329 selium_abi::hostcall_contract!(NET_QUIC_WRITE),
330 selium_abi::hostcall_contract!(NET_HTTP_WRITE),
331 );
332 io::write_op(cap, hostcall)
333}
334
335fn resolve_tls_server_config(
336 instance: &InstanceRegistry,
337 registry: &Arc<crate::registry::Registry>,
338 protocol: NetProtocol,
339 handle: Option<GuestResourceId>,
340) -> GuestResult<Option<Arc<TlsServerConfig>>> {
341 let Some(handle) = handle else {
342 return Ok(None);
343 };
344 if matches!(protocol, NetProtocol::Http) {
345 return Err(GuestError::InvalidArgument);
346 }
347 let slot = usize::try_from(handle).map_err(|_| GuestError::InvalidArgument)?;
348 let resource_id = instance.entry(slot).ok_or(GuestError::NotFound)?;
349 let config = registry
350 .with(
351 ResourceHandle::<Arc<TlsServerConfig>>::new(resource_id),
352 |config| Arc::clone(config),
353 )
354 .ok_or(GuestError::NotFound)?;
355 Ok(Some(config))
356}
357
358fn resolve_tls_client_config(
359 instance: &InstanceRegistry,
360 registry: &Arc<crate::registry::Registry>,
361 protocol: NetProtocol,
362 handle: Option<GuestResourceId>,
363) -> GuestResult<Option<Arc<TlsClientConfig>>> {
364 let Some(handle) = handle else {
365 return Ok(None);
366 };
367 if matches!(protocol, NetProtocol::Http) {
368 return Err(GuestError::InvalidArgument);
369 }
370 let slot = usize::try_from(handle).map_err(|_| GuestError::InvalidArgument)?;
371 let resource_id = instance.entry(slot).ok_or(GuestError::NotFound)?;
372 let config = registry
373 .with(
374 ResourceHandle::<Arc<TlsClientConfig>>::new(resource_id),
375 |config| Arc::clone(config),
376 )
377 .ok_or(GuestError::NotFound)?;
378 Ok(Some(config))
379}
380
381fn hostcall_for_protocol<I, O>(
382 protocol: NetProtocol,
383 quic: &'static Hostcall<I, O>,
384 http: &'static Hostcall<I, O>,
385) -> &'static Hostcall<I, O> {
386 match protocol {
387 NetProtocol::Quic => quic,
388 NetProtocol::Http | NetProtocol::Https => http,
389 }
390}