wapo/
net.rs

1//! Networking support.
2
3use std::future::Future;
4use std::io::Error;
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use env::tls::TlsServerConfig;
10
11use crate::env::{self, tasks, Result};
12use crate::{ocall, ResourceId};
13
14/// A TCP socket server, listening for connections.
15pub struct TcpListener {
16    res_id: ResourceId,
17}
18
19/// A resource pointing to a connecting TCP socket.
20#[derive(Debug)]
21pub struct TcpConnector {
22    res: Result<ResourceId, env::OcallError>,
23}
24
25/// A connected TCP socket.
26#[derive(Debug)]
27pub struct TcpStream {
28    res_id: ResourceId,
29}
30
31/// Future returned by `TcpListener::accept`.
32pub struct Acceptor<'a> {
33    listener: &'a TcpListener,
34}
35
36impl Future for Acceptor<'_> {
37    type Output = Result<(TcpStream, SocketAddr)>;
38
39    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40        use env::OcallError;
41        let waker_id = tasks::intern_waker(cx.waker().clone());
42        match ocall::tcp_accept(waker_id, self.listener.res_id.0) {
43            Ok((res_id, remote_addr)) => Poll::Ready(Ok((
44                TcpStream::new(ResourceId(res_id)),
45                remote_addr
46                    .parse()
47                    .expect("ocall::tcp_accept returned an invalid remote address"),
48            ))),
49            Err(OcallError::Pending) => Poll::Pending,
50            Err(err) => Poll::Ready(Err(err)),
51        }
52    }
53}
54
55impl TcpListener {
56    /// Bind and listen on the specified address for incoming TCP connections.
57    pub async fn bind(addr: &str) -> Result<Self> {
58        // Side notes: could be used to probe enabled interfaces and occupied ports. We may
59        // consider to introduce some manifest file to further limit the capability in the future
60        // TODO.kevin: prevent local interface probing and port occupation
61        let res_id = ResourceId(ocall::tcp_listen(addr.into(), None)?);
62        Ok(Self { res_id })
63    }
64
65    /// Bind and listen on the specified address for incoming TLS-enabled TCP connections.
66    pub async fn bind_tls(addr: &str, config: TlsServerConfig) -> Result<Self> {
67        let res_id = ResourceId(ocall::tcp_listen(addr.into(), Some(config))?);
68        Ok(Self { res_id })
69    }
70
71    /// Accept a new incoming connection.
72    pub fn accept(&self) -> Acceptor {
73        Acceptor { listener: self }
74    }
75}
76
77impl Future for TcpConnector {
78    type Output = Result<TcpStream>;
79
80    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
81        use env::OcallError;
82
83        let res_id = match &self.get_mut().res {
84            Ok(res_id) => res_id,
85            Err(err) => return Poll::Ready(Err(*err)),
86        };
87
88        match ocall::poll_res(env::tasks::intern_waker(ctx.waker().clone()), res_id.0) {
89            Ok(res_id) => Poll::Ready(Ok(TcpStream::new(ResourceId(res_id)))),
90            Err(OcallError::Pending) => Poll::Pending,
91            Err(err) => Poll::Ready(Err(err)),
92        }
93    }
94}
95
96impl TcpStream {
97    /// Create a new TcpStream from a resource ID.
98    pub fn new(res_id: ResourceId) -> Self {
99        Self { res_id }
100    }
101
102    /// Initiate a TCP connection to a remote host.
103    pub fn connect(host: &str, port: u16, enable_tls: bool) -> TcpConnector {
104        let res = if enable_tls {
105            ocall::tcp_connect_tls(host.into(), port, env::tls::TlsClientConfig::V0)
106        } else {
107            ocall::tcp_connect(host, port)
108        };
109        let res = res.map(ResourceId);
110        TcpConnector { res }
111    }
112}
113
114#[cfg(feature = "hyper")]
115pub use impl_hyper::{AddrIncoming, AddrStream, HttpConnector};
116#[cfg(feature = "hyper")]
117mod impl_hyper {
118    use super::*;
119    use env::OcallError;
120    use hyper::client::connect::{Connected, Connection};
121    use hyper::server::accept::Accept;
122    use hyper::{service::Service, Uri};
123    use std::{io, task};
124
125    macro_rules! ready_ok {
126        ($poll: expr) => {
127            match $poll {
128                Poll::Ready(Ok(val)) => val,
129                Poll::Ready(Err(OcallError::EndOfFile)) => return Poll::Ready(None),
130                Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
131                Poll::Pending => return Poll::Pending,
132            }
133        };
134    }
135
136    impl Accept for TcpListener {
137        type Conn = TcpStream;
138
139        type Error = env::OcallError;
140
141        fn poll_accept(
142            self: Pin<&mut Self>,
143            cx: &mut Context<'_>,
144        ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
145            let (conn, _addr) = ready_ok!(Pin::new(&mut self.accept()).poll(cx));
146            Poll::Ready(Some(Ok(conn)))
147        }
148    }
149
150    impl Connection for TcpStream {
151        fn connected(&self) -> Connected {
152            Connected::new()
153        }
154    }
155
156    impl TcpListener {
157        /// Convert the listener into another one that outputs AddrStreams.
158        pub fn into_addr_incoming(self) -> AddrIncoming {
159            AddrIncoming { listener: self }
160        }
161    }
162
163    /// An HTTP/HTTPS Connector for hyper working under wapo.
164    #[derive(Clone, Default, Debug)]
165    pub struct HttpConnector;
166
167    impl HttpConnector {
168        /// Create a new HttpConnector.
169        pub fn new() -> Self {
170            Self
171        }
172    }
173
174    impl Service<Uri> for HttpConnector {
175        type Response = TcpStream;
176        type Error = OcallError;
177        type Future = TcpConnector;
178
179        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
180            Poll::Ready(Ok(()))
181        }
182
183        fn call(&mut self, dst: Uri) -> Self::Future {
184            let is_https = dst.scheme_str() == Some("https");
185            let host = dst
186                .host()
187                .unwrap_or("")
188                .trim_matches(|c| c == '[' || c == ']');
189            let port = dst.port_u16().unwrap_or(if is_https { 443 } else { 80 });
190            TcpStream::connect(host, port, is_https)
191        }
192    }
193
194    /// A wrapper of  TcpListener that outputs AddrStreams.
195    pub struct AddrIncoming {
196        listener: TcpListener,
197    }
198
199    impl Accept for AddrIncoming {
200        type Conn = AddrStream;
201        type Error = env::OcallError;
202
203        fn poll_accept(
204            self: Pin<&mut Self>,
205            cx: &mut task::Context<'_>,
206        ) -> task::Poll<Option<Result<Self::Conn, Self::Error>>> {
207            let (stream, remote_addr) = ready_ok!(Pin::new(&mut self.listener.accept()).poll(cx));
208            Poll::Ready(Some(Ok(AddrStream {
209                stream,
210                remote_addr,
211            })))
212        }
213    }
214
215    /// A wrapper of TcpStream that keep remote address.
216    #[pin_project::pin_project]
217    pub struct AddrStream {
218        #[pin]
219        stream: TcpStream,
220        remote_addr: SocketAddr,
221    }
222
223    impl AddrStream {
224        /// Get the remote address of the connection.
225        pub fn remote_addr(&self) -> SocketAddr {
226            self.remote_addr
227        }
228    }
229
230    #[cfg(feature = "tokio")]
231    const _: () = {
232        use tokio::io::{AsyncRead, AsyncWrite};
233        impl AsyncRead for AddrStream {
234            fn poll_read(
235                self: Pin<&mut Self>,
236                cx: &mut task::Context<'_>,
237                buf: &mut tokio::io::ReadBuf<'_>,
238            ) -> task::Poll<io::Result<()>> {
239                self.project().stream.poll_read(cx, buf)
240            }
241        }
242
243        impl AsyncWrite for AddrStream {
244            fn poll_write(
245                self: Pin<&mut Self>,
246                cx: &mut task::Context<'_>,
247                buf: &[u8],
248            ) -> task::Poll<io::Result<usize>> {
249                self.project().stream.poll_write(cx, buf)
250            }
251
252            fn poll_flush(
253                self: Pin<&mut Self>,
254                cx: &mut task::Context<'_>,
255            ) -> task::Poll<io::Result<()>> {
256                self.project().stream.poll_flush(cx)
257            }
258
259            fn poll_shutdown(
260                self: Pin<&mut Self>,
261                cx: &mut task::Context<'_>,
262            ) -> task::Poll<io::Result<()>> {
263                self.project().stream.poll_shutdown(cx)
264            }
265        }
266    };
267
268    const _: () = {
269        use futures::{AsyncRead, AsyncWrite};
270
271        impl AsyncRead for AddrStream {
272            fn poll_read(
273                self: Pin<&mut Self>,
274                cx: &mut Context<'_>,
275                buf: &mut [u8],
276            ) -> Poll<io::Result<usize>> {
277                self.project().stream.poll_read(cx, buf)
278            }
279        }
280
281        impl AsyncWrite for AddrStream {
282            fn poll_write(
283                self: Pin<&mut Self>,
284                cx: &mut Context<'_>,
285                buf: &[u8],
286            ) -> Poll<io::Result<usize>> {
287                self.project().stream.poll_write(cx, buf)
288            }
289
290            fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
291                self.project().stream.poll_flush(cx)
292            }
293
294            fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
295                self.project().stream.poll_close(cx)
296            }
297        }
298    };
299}
300
301#[cfg(feature = "tokio")]
302mod impl_tokio {
303    use super::*;
304    use tokio::io::{AsyncRead, AsyncWrite};
305
306    impl AsyncRead for TcpStream {
307        fn poll_read(
308            self: Pin<&mut Self>,
309            cx: &mut Context<'_>,
310            buf: &mut tokio::io::ReadBuf<'_>,
311        ) -> Poll<std::io::Result<()>> {
312            let result = {
313                let size = buf.remaining();
314                let buf = buf.initialize_unfilled_to(size);
315                let waker_id = tasks::intern_waker(cx.waker().clone());
316                ocall::poll_read(waker_id, self.res_id.0, buf)
317            };
318            use env::OcallError;
319            match result {
320                Ok(len) => {
321                    let len = len as usize;
322                    if len > buf.remaining() {
323                        Poll::Ready(Err(Error::from_raw_os_error(
324                            env::OcallError::InvalidEncoding as i32,
325                        )))
326                    } else {
327                        buf.advance(len);
328                        Poll::Ready(Ok(()))
329                    }
330                }
331                Err(OcallError::Pending) => Poll::Pending,
332                Err(err) => Poll::Ready(Err(Error::from_raw_os_error(err as i32))),
333            }
334        }
335    }
336
337    impl AsyncWrite for TcpStream {
338        fn poll_write(
339            self: Pin<&mut Self>,
340            cx: &mut Context<'_>,
341            buf: &[u8],
342        ) -> Poll<Result<usize, Error>> {
343            let waker_id = tasks::intern_waker(cx.waker().clone());
344            into_poll(ocall::poll_write(waker_id, self.res_id.0, buf).map(|len| len as usize))
345        }
346
347        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
348            Poll::Ready(Ok(()))
349        }
350
351        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
352            let waker_id = tasks::intern_waker(cx.waker().clone());
353            into_poll(ocall::poll_shutdown(waker_id, self.res_id.0))
354        }
355    }
356}
357
358mod impl_futures_io {
359    use super::*;
360    use futures::io::{AsyncRead, AsyncWrite};
361
362    impl AsyncRead for TcpStream {
363        fn poll_read(
364            self: Pin<&mut Self>,
365            cx: &mut Context<'_>,
366            buf: &mut [u8],
367        ) -> Poll<std::io::Result<usize>> {
368            let waker_id = tasks::intern_waker(cx.waker().clone());
369            into_poll(ocall::poll_read(waker_id, self.res_id.0, buf).map(|len| len as usize))
370        }
371    }
372
373    impl AsyncWrite for TcpStream {
374        fn poll_write(
375            self: Pin<&mut Self>,
376            cx: &mut Context<'_>,
377            buf: &[u8],
378        ) -> Poll<std::io::Result<usize>> {
379            let waker_id = tasks::intern_waker(cx.waker().clone());
380            into_poll(ocall::poll_write(waker_id, self.res_id.0, buf).map(|len| len as usize))
381        }
382
383        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
384            Poll::Ready(Ok(()))
385        }
386
387        fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
388            let waker_id = tasks::intern_waker(cx.waker().clone());
389            into_poll(ocall::poll_shutdown(waker_id, self.res_id.0))
390        }
391    }
392}
393
394fn into_poll<T>(res: Result<T, env::OcallError>) -> Poll<std::io::Result<T>> {
395    match res {
396        Ok(v) => Poll::Ready(Ok(v)),
397        Err(env::OcallError::Pending) => Poll::Pending,
398        Err(err) => Poll::Ready(Err(std::io::Error::from_raw_os_error(err as i32))),
399    }
400}