1use std::future::Future;
4use std::io::Error;
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use env::tls::TlsServerConfig;
10
11use crate::env::{self, tasks, Result};
12use crate::{ocall, ResourceId};
13
14pub struct TcpListener {
16 res_id: ResourceId,
17}
18
19#[derive(Debug)]
21pub struct TcpConnector {
22 res: Result<ResourceId, env::OcallError>,
23}
24
25#[derive(Debug)]
27pub struct TcpStream {
28 res_id: ResourceId,
29}
30
31pub struct Acceptor<'a> {
33 listener: &'a TcpListener,
34}
35
36impl Future for Acceptor<'_> {
37 type Output = Result<(TcpStream, SocketAddr)>;
38
39 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40 use env::OcallError;
41 let waker_id = tasks::intern_waker(cx.waker().clone());
42 match ocall::tcp_accept(waker_id, self.listener.res_id.0) {
43 Ok((res_id, remote_addr)) => Poll::Ready(Ok((
44 TcpStream::new(ResourceId(res_id)),
45 remote_addr
46 .parse()
47 .expect("ocall::tcp_accept returned an invalid remote address"),
48 ))),
49 Err(OcallError::Pending) => Poll::Pending,
50 Err(err) => Poll::Ready(Err(err)),
51 }
52 }
53}
54
55impl TcpListener {
56 pub async fn bind(addr: &str) -> Result<Self> {
58 let res_id = ResourceId(ocall::tcp_listen(addr.into(), None)?);
62 Ok(Self { res_id })
63 }
64
65 pub async fn bind_tls(addr: &str, config: TlsServerConfig) -> Result<Self> {
67 let res_id = ResourceId(ocall::tcp_listen(addr.into(), Some(config))?);
68 Ok(Self { res_id })
69 }
70
71 pub fn accept(&self) -> Acceptor {
73 Acceptor { listener: self }
74 }
75}
76
77impl Future for TcpConnector {
78 type Output = Result<TcpStream>;
79
80 fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
81 use env::OcallError;
82
83 let res_id = match &self.get_mut().res {
84 Ok(res_id) => res_id,
85 Err(err) => return Poll::Ready(Err(*err)),
86 };
87
88 match ocall::poll_res(env::tasks::intern_waker(ctx.waker().clone()), res_id.0) {
89 Ok(res_id) => Poll::Ready(Ok(TcpStream::new(ResourceId(res_id)))),
90 Err(OcallError::Pending) => Poll::Pending,
91 Err(err) => Poll::Ready(Err(err)),
92 }
93 }
94}
95
96impl TcpStream {
97 pub fn new(res_id: ResourceId) -> Self {
99 Self { res_id }
100 }
101
102 pub fn connect(host: &str, port: u16, enable_tls: bool) -> TcpConnector {
104 let res = if enable_tls {
105 ocall::tcp_connect_tls(host.into(), port, env::tls::TlsClientConfig::V0)
106 } else {
107 ocall::tcp_connect(host, port)
108 };
109 let res = res.map(ResourceId);
110 TcpConnector { res }
111 }
112}
113
114#[cfg(feature = "hyper")]
115pub use impl_hyper::{AddrIncoming, AddrStream, HttpConnector};
116#[cfg(feature = "hyper")]
117mod impl_hyper {
118 use super::*;
119 use env::OcallError;
120 use hyper::client::connect::{Connected, Connection};
121 use hyper::server::accept::Accept;
122 use hyper::{service::Service, Uri};
123 use std::{io, task};
124
125 macro_rules! ready_ok {
126 ($poll: expr) => {
127 match $poll {
128 Poll::Ready(Ok(val)) => val,
129 Poll::Ready(Err(OcallError::EndOfFile)) => return Poll::Ready(None),
130 Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
131 Poll::Pending => return Poll::Pending,
132 }
133 };
134 }
135
136 impl Accept for TcpListener {
137 type Conn = TcpStream;
138
139 type Error = env::OcallError;
140
141 fn poll_accept(
142 self: Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
145 let (conn, _addr) = ready_ok!(Pin::new(&mut self.accept()).poll(cx));
146 Poll::Ready(Some(Ok(conn)))
147 }
148 }
149
150 impl Connection for TcpStream {
151 fn connected(&self) -> Connected {
152 Connected::new()
153 }
154 }
155
156 impl TcpListener {
157 pub fn into_addr_incoming(self) -> AddrIncoming {
159 AddrIncoming { listener: self }
160 }
161 }
162
163 #[derive(Clone, Default, Debug)]
165 pub struct HttpConnector;
166
167 impl HttpConnector {
168 pub fn new() -> Self {
170 Self
171 }
172 }
173
174 impl Service<Uri> for HttpConnector {
175 type Response = TcpStream;
176 type Error = OcallError;
177 type Future = TcpConnector;
178
179 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
180 Poll::Ready(Ok(()))
181 }
182
183 fn call(&mut self, dst: Uri) -> Self::Future {
184 let is_https = dst.scheme_str() == Some("https");
185 let host = dst
186 .host()
187 .unwrap_or("")
188 .trim_matches(|c| c == '[' || c == ']');
189 let port = dst.port_u16().unwrap_or(if is_https { 443 } else { 80 });
190 TcpStream::connect(host, port, is_https)
191 }
192 }
193
194 pub struct AddrIncoming {
196 listener: TcpListener,
197 }
198
199 impl Accept for AddrIncoming {
200 type Conn = AddrStream;
201 type Error = env::OcallError;
202
203 fn poll_accept(
204 self: Pin<&mut Self>,
205 cx: &mut task::Context<'_>,
206 ) -> task::Poll<Option<Result<Self::Conn, Self::Error>>> {
207 let (stream, remote_addr) = ready_ok!(Pin::new(&mut self.listener.accept()).poll(cx));
208 Poll::Ready(Some(Ok(AddrStream {
209 stream,
210 remote_addr,
211 })))
212 }
213 }
214
215 #[pin_project::pin_project]
217 pub struct AddrStream {
218 #[pin]
219 stream: TcpStream,
220 remote_addr: SocketAddr,
221 }
222
223 impl AddrStream {
224 pub fn remote_addr(&self) -> SocketAddr {
226 self.remote_addr
227 }
228 }
229
230 #[cfg(feature = "tokio")]
231 const _: () = {
232 use tokio::io::{AsyncRead, AsyncWrite};
233 impl AsyncRead for AddrStream {
234 fn poll_read(
235 self: Pin<&mut Self>,
236 cx: &mut task::Context<'_>,
237 buf: &mut tokio::io::ReadBuf<'_>,
238 ) -> task::Poll<io::Result<()>> {
239 self.project().stream.poll_read(cx, buf)
240 }
241 }
242
243 impl AsyncWrite for AddrStream {
244 fn poll_write(
245 self: Pin<&mut Self>,
246 cx: &mut task::Context<'_>,
247 buf: &[u8],
248 ) -> task::Poll<io::Result<usize>> {
249 self.project().stream.poll_write(cx, buf)
250 }
251
252 fn poll_flush(
253 self: Pin<&mut Self>,
254 cx: &mut task::Context<'_>,
255 ) -> task::Poll<io::Result<()>> {
256 self.project().stream.poll_flush(cx)
257 }
258
259 fn poll_shutdown(
260 self: Pin<&mut Self>,
261 cx: &mut task::Context<'_>,
262 ) -> task::Poll<io::Result<()>> {
263 self.project().stream.poll_shutdown(cx)
264 }
265 }
266 };
267
268 const _: () = {
269 use futures::{AsyncRead, AsyncWrite};
270
271 impl AsyncRead for AddrStream {
272 fn poll_read(
273 self: Pin<&mut Self>,
274 cx: &mut Context<'_>,
275 buf: &mut [u8],
276 ) -> Poll<io::Result<usize>> {
277 self.project().stream.poll_read(cx, buf)
278 }
279 }
280
281 impl AsyncWrite for AddrStream {
282 fn poll_write(
283 self: Pin<&mut Self>,
284 cx: &mut Context<'_>,
285 buf: &[u8],
286 ) -> Poll<io::Result<usize>> {
287 self.project().stream.poll_write(cx, buf)
288 }
289
290 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
291 self.project().stream.poll_flush(cx)
292 }
293
294 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
295 self.project().stream.poll_close(cx)
296 }
297 }
298 };
299}
300
301#[cfg(feature = "tokio")]
302mod impl_tokio {
303 use super::*;
304 use tokio::io::{AsyncRead, AsyncWrite};
305
306 impl AsyncRead for TcpStream {
307 fn poll_read(
308 self: Pin<&mut Self>,
309 cx: &mut Context<'_>,
310 buf: &mut tokio::io::ReadBuf<'_>,
311 ) -> Poll<std::io::Result<()>> {
312 let result = {
313 let size = buf.remaining();
314 let buf = buf.initialize_unfilled_to(size);
315 let waker_id = tasks::intern_waker(cx.waker().clone());
316 ocall::poll_read(waker_id, self.res_id.0, buf)
317 };
318 use env::OcallError;
319 match result {
320 Ok(len) => {
321 let len = len as usize;
322 if len > buf.remaining() {
323 Poll::Ready(Err(Error::from_raw_os_error(
324 env::OcallError::InvalidEncoding as i32,
325 )))
326 } else {
327 buf.advance(len);
328 Poll::Ready(Ok(()))
329 }
330 }
331 Err(OcallError::Pending) => Poll::Pending,
332 Err(err) => Poll::Ready(Err(Error::from_raw_os_error(err as i32))),
333 }
334 }
335 }
336
337 impl AsyncWrite for TcpStream {
338 fn poll_write(
339 self: Pin<&mut Self>,
340 cx: &mut Context<'_>,
341 buf: &[u8],
342 ) -> Poll<Result<usize, Error>> {
343 let waker_id = tasks::intern_waker(cx.waker().clone());
344 into_poll(ocall::poll_write(waker_id, self.res_id.0, buf).map(|len| len as usize))
345 }
346
347 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
348 Poll::Ready(Ok(()))
349 }
350
351 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
352 let waker_id = tasks::intern_waker(cx.waker().clone());
353 into_poll(ocall::poll_shutdown(waker_id, self.res_id.0))
354 }
355 }
356}
357
358mod impl_futures_io {
359 use super::*;
360 use futures::io::{AsyncRead, AsyncWrite};
361
362 impl AsyncRead for TcpStream {
363 fn poll_read(
364 self: Pin<&mut Self>,
365 cx: &mut Context<'_>,
366 buf: &mut [u8],
367 ) -> Poll<std::io::Result<usize>> {
368 let waker_id = tasks::intern_waker(cx.waker().clone());
369 into_poll(ocall::poll_read(waker_id, self.res_id.0, buf).map(|len| len as usize))
370 }
371 }
372
373 impl AsyncWrite for TcpStream {
374 fn poll_write(
375 self: Pin<&mut Self>,
376 cx: &mut Context<'_>,
377 buf: &[u8],
378 ) -> Poll<std::io::Result<usize>> {
379 let waker_id = tasks::intern_waker(cx.waker().clone());
380 into_poll(ocall::poll_write(waker_id, self.res_id.0, buf).map(|len| len as usize))
381 }
382
383 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
384 Poll::Ready(Ok(()))
385 }
386
387 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
388 let waker_id = tasks::intern_waker(cx.waker().clone());
389 into_poll(ocall::poll_shutdown(waker_id, self.res_id.0))
390 }
391 }
392}
393
394fn into_poll<T>(res: Result<T, env::OcallError>) -> Poll<std::io::Result<T>> {
395 match res {
396 Ok(v) => Poll::Ready(Ok(v)),
397 Err(env::OcallError::Pending) => Poll::Pending,
398 Err(err) => Poll::Ready(Err(std::io::Error::from_raw_os_error(err as i32))),
399 }
400}