tl_async_runtime/io/
net.rs

1use std::{
2    io::{self, Read, Write},
3    net::SocketAddr,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use futures::{channel::mpsc::UnboundedReceiver, Stream, StreamExt};
9use pin_project::pin_project;
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11
12use super::{Event, Registration};
13
14/// Listener for TCP events
15pub struct TcpListener {
16    registration: Registration<mio::net::TcpListener>,
17}
18
19impl TcpListener {
20    /// Create a new TcpListener bound to the socket
21    pub fn bind(addr: SocketAddr) -> std::io::Result<Self> {
22        let listener = mio::net::TcpListener::bind(addr)?;
23        let registration = super::Registration::new(listener, mio::Interest::READABLE)?;
24        Ok(Self { registration })
25    }
26
27    /// Accept a new TcpStream to communicate with
28    pub async fn accept(&self) -> std::io::Result<(TcpStream, SocketAddr)> {
29        loop {
30            self.registration.events().next().await;
31            match self.registration.accept() {
32                Ok((stream, socket)) => break Ok((TcpStream::from_mio(stream)?, socket)),
33                Err(e) if e.kind() != std::io::ErrorKind::WouldBlock => break Err(e),
34                _ => {}
35            }
36        }
37    }
38}
39
40/// Handles communication over a TCP connection
41#[pin_project]
42pub struct TcpStream {
43    registration: Registration<mio::net::TcpStream>,
44
45    readable: Option<()>,
46    writeable: Option<()>,
47
48    #[pin]
49    events: UnboundedReceiver<Event>,
50}
51
52impl TcpStream {
53    pub(crate) fn from_mio(stream: mio::net::TcpStream) -> std::io::Result<Self> {
54        // register the stream to the OS
55        let registration =
56            super::Registration::new(stream, mio::Interest::READABLE | mio::Interest::WRITABLE)?;
57        let events = registration.events();
58        Ok(Self {
59            registration,
60            readable: Some(()),
61            writeable: Some(()),
62            events,
63        })
64    }
65
66    fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
67        let mut this = self.as_mut().project();
68        let event = match this.events.as_mut().poll_next(cx) {
69            Poll::Ready(Some(event)) => event,
70            Poll::Ready(None) => {
71                return Poll::Ready(Err(io::Error::new(
72                    io::ErrorKind::BrokenPipe,
73                    "channel disconnected",
74                )))
75            }
76            Poll::Pending => return Poll::Pending,
77        };
78
79        if event.is_readable() {
80            *this.readable = Some(());
81        }
82        if event.is_writable() {
83            *this.writeable = Some(());
84        }
85        Poll::Ready(Ok(()))
86    }
87}
88
89impl AsyncRead for TcpStream {
90    fn poll_read(
91        mut self: Pin<&mut Self>,
92        cx: &mut Context<'_>,
93        buf: &mut ReadBuf<'_>,
94    ) -> Poll<io::Result<()>> {
95        loop {
96            // if the stream is readable
97            if let Some(()) = self.readable.take() {
98                // try read some bytes
99                let b = buf.initialize_unfilled();
100                match self.registration.read(b) {
101                    Ok(n) => {
102                        // if bytes were read, mark them
103                        buf.advance(n);
104                        // ensure that we attempt another read next time
105                        // since no new readable events will come through
106                        self.readable = Some(());
107                        return Poll::Ready(Ok(()));
108                    }
109                    // if reading would block the thread, continue to event polling
110                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
111                    // if there was some other io error, bail
112                    Err(e) => return Poll::Ready(Err(e)),
113                }
114            }
115
116            match self.as_mut().poll_event(cx)? {
117                Poll::Ready(()) => {}
118                Poll::Pending => return Poll::Pending,
119            }
120        }
121    }
122}
123
124impl AsyncWrite for TcpStream {
125    fn poll_write(
126        mut self: Pin<&mut Self>,
127        cx: &mut Context<'_>,
128        buf: &[u8],
129    ) -> Poll<std::io::Result<usize>> {
130        loop {
131            // if the stream is writeable
132            if let Some(()) = self.writeable.take() {
133                // try write some bytes
134                match self.registration.write(buf) {
135                    Ok(n) => {
136                        // ensure that we attempt another write next time
137                        // since no new writeable events will come through
138                        self.writeable = Some(());
139                        return Poll::Ready(Ok(n));
140                    }
141                    // if writing would block the thread, continue to event polling
142                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
143                    // if there was some other io error, bail
144                    Err(e) => return Poll::Ready(Err(e)),
145                }
146            }
147
148            match self.as_mut().poll_event(cx)? {
149                Poll::Ready(()) => {}
150                Poll::Pending => return Poll::Pending,
151            }
152        }
153    }
154
155    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
156        loop {
157            // if the stream is writeable
158            if let Some(()) = self.writeable.take() {
159                // try flush the bytes
160                match self.registration.flush() {
161                    Ok(()) => {
162                        // ensure that we attempt another write next time
163                        // since no new writeable events will come through
164                        self.writeable = Some(());
165                        return Poll::Ready(Ok(()));
166                    }
167                    // if flushing would block the thread, continue to event polling
168                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
169                    // if there was some other io error, bail
170                    Err(e) => return Poll::Ready(Err(e)),
171                }
172            }
173
174            match self.as_mut().poll_event(cx)? {
175                Poll::Ready(()) => {}
176                Poll::Pending => return Poll::Pending,
177            }
178        }
179    }
180
181    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
182        // shutdowns are immediate
183        Poll::Ready(self.registration.shutdown(std::net::Shutdown::Write))
184    }
185}