wta_reactor/
net.rs

1use std::{
2    io::{self, Read, Write},
3    mem::replace,
4    net::SocketAddr,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use futures::{ready, AsyncRead, AsyncWrite, Stream, StreamExt};
10use mio::Interest;
11
12use crate::io::Registration;
13
14/// Listener for TCP connectors
15pub struct TcpListener {
16    registration: Registration<mio::net::TcpListener>,
17}
18
19impl TcpListener {
20    /// Create a new `TcpListener` bound to the socket
21    pub fn bind(addr: SocketAddr) -> std::io::Result<Self> {
22        let registration =
23            Registration::new(mio::net::TcpListener::bind(addr)?, Interest::READABLE)?;
24        Ok(Self { registration })
25    }
26
27    /// Accept [`TcpStream`]s to communicate with
28    pub fn accept(self) -> Accept {
29        let Self { registration } = self;
30        Accept { registration }
31    }
32}
33
34/// A [`Stream`] of [`TcpStream`] that are connecting to this tcp server
35pub struct Accept {
36    registration: Registration<mio::net::TcpListener>,
37}
38impl Unpin for Accept {}
39
40impl Stream for Accept {
41    type Item = io::Result<(TcpStream, SocketAddr)>;
42
43    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
44        if ready!(self.registration.events.poll_next_unpin(cx)).is_none() {
45            return Poll::Ready(None);
46        }
47        match self.registration.accept() {
48            Ok((stream, socket)) => Poll::Ready(Some(Ok((TcpStream::from_mio(stream)?, socket)))),
49            Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Poll::Pending,
50            Err(e) => Poll::Ready(Some(Err(e))),
51        }
52    }
53}
54
55/// Handles communication over a TCP connection
56pub struct TcpStream {
57    registration: Registration<mio::net::TcpStream>,
58    read: bool,
59    write: bool,
60}
61
62impl Unpin for TcpStream {}
63
64impl TcpStream {
65    pub(crate) fn from_mio(stream: mio::net::TcpStream) -> std::io::Result<Self> {
66        // register the stream to the OS
67        let registration = Registration::new(stream, Interest::READABLE | Interest::WRITABLE)?;
68        Ok(Self {
69            registration,
70            read: true,
71            write: true,
72        })
73    }
74
75    fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
76        let Some(event) = ready!(self.registration.events.poll_next_unpin(cx)) else {
77            return Poll::Ready(Err(io::Error::new(
78                io::ErrorKind::BrokenPipe,
79                "channel disconnected",
80            )));
81        };
82        self.read |= event.is_readable();
83        self.write |= event.is_writable();
84        Poll::Ready(Ok(()))
85    }
86}
87
88// polls the OS events until either there are no more events, or the IO does not block
89macro_rules! poll_io {
90    ($self:ident, $cx:ident @ $mode:ident: let $pat:pat = $io:expr; $expr:expr) => {
91        loop {
92            if replace(&mut $self.$mode, false) {
93                match $io {
94                    Ok($pat) => {
95                        // ensure that we attempt another read/write next time
96                        // since no new events will come through
97                        // https://docs.rs/mio/0.8.0/mio/struct.Poll.html#draining-readiness
98                        $self.$mode = true;
99                        return Poll::Ready($expr);
100                    }
101                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
102                    Err(e) => return Poll::Ready(Err(e)),
103                }
104            }
105            ready!($self.as_mut().poll_event($cx)?)
106        }
107    };
108}
109
110impl AsyncRead for TcpStream {
111    fn poll_read(
112        mut self: Pin<&mut Self>,
113        cx: &mut Context<'_>,
114        buf: &mut [u8],
115    ) -> Poll<io::Result<usize>> {
116        poll_io!(self, cx @ read: let n = self.registration.read(buf); Ok(n))
117    }
118}
119
120impl AsyncWrite for TcpStream {
121    fn poll_write(
122        mut self: Pin<&mut Self>,
123        cx: &mut Context<'_>,
124        buf: &[u8],
125    ) -> Poll<io::Result<usize>> {
126        poll_io!(self, cx @ write: let n = self.registration.write(buf); Ok(n))
127    }
128
129    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
130        poll_io!(self, cx @ write: let () = self.registration.flush(); Ok(()))
131    }
132
133    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
134        Poll::Ready(self.registration.shutdown(std::net::Shutdown::Write))
135    }
136}