1use std::{
2 io::{self, Read, Write},
3 mem::replace,
4 net::SocketAddr,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use futures::{ready, AsyncRead, AsyncWrite, Stream, StreamExt};
10use mio::Interest;
11
12use crate::io::Registration;
13
14pub struct TcpListener {
16 registration: Registration<mio::net::TcpListener>,
17}
18
19impl TcpListener {
20 pub fn bind(addr: SocketAddr) -> std::io::Result<Self> {
22 let registration =
23 Registration::new(mio::net::TcpListener::bind(addr)?, Interest::READABLE)?;
24 Ok(Self { registration })
25 }
26
27 pub fn accept(self) -> Accept {
29 let Self { registration } = self;
30 Accept { registration }
31 }
32}
33
34pub struct Accept {
36 registration: Registration<mio::net::TcpListener>,
37}
38impl Unpin for Accept {}
39
40impl Stream for Accept {
41 type Item = io::Result<(TcpStream, SocketAddr)>;
42
43 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
44 if ready!(self.registration.events.poll_next_unpin(cx)).is_none() {
45 return Poll::Ready(None);
46 }
47 match self.registration.accept() {
48 Ok((stream, socket)) => Poll::Ready(Some(Ok((TcpStream::from_mio(stream)?, socket)))),
49 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Poll::Pending,
50 Err(e) => Poll::Ready(Some(Err(e))),
51 }
52 }
53}
54
55pub struct TcpStream {
57 registration: Registration<mio::net::TcpStream>,
58 read: bool,
59 write: bool,
60}
61
62impl Unpin for TcpStream {}
63
64impl TcpStream {
65 pub(crate) fn from_mio(stream: mio::net::TcpStream) -> std::io::Result<Self> {
66 let registration = Registration::new(stream, Interest::READABLE | Interest::WRITABLE)?;
68 Ok(Self {
69 registration,
70 read: true,
71 write: true,
72 })
73 }
74
75 fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
76 let Some(event) = ready!(self.registration.events.poll_next_unpin(cx)) else {
77 return Poll::Ready(Err(io::Error::new(
78 io::ErrorKind::BrokenPipe,
79 "channel disconnected",
80 )));
81 };
82 self.read |= event.is_readable();
83 self.write |= event.is_writable();
84 Poll::Ready(Ok(()))
85 }
86}
87
88macro_rules! poll_io {
90 ($self:ident, $cx:ident @ $mode:ident: let $pat:pat = $io:expr; $expr:expr) => {
91 loop {
92 if replace(&mut $self.$mode, false) {
93 match $io {
94 Ok($pat) => {
95 $self.$mode = true;
99 return Poll::Ready($expr);
100 }
101 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
102 Err(e) => return Poll::Ready(Err(e)),
103 }
104 }
105 ready!($self.as_mut().poll_event($cx)?)
106 }
107 };
108}
109
110impl AsyncRead for TcpStream {
111 fn poll_read(
112 mut self: Pin<&mut Self>,
113 cx: &mut Context<'_>,
114 buf: &mut [u8],
115 ) -> Poll<io::Result<usize>> {
116 poll_io!(self, cx @ read: let n = self.registration.read(buf); Ok(n))
117 }
118}
119
120impl AsyncWrite for TcpStream {
121 fn poll_write(
122 mut self: Pin<&mut Self>,
123 cx: &mut Context<'_>,
124 buf: &[u8],
125 ) -> Poll<io::Result<usize>> {
126 poll_io!(self, cx @ write: let n = self.registration.write(buf); Ok(n))
127 }
128
129 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
130 poll_io!(self, cx @ write: let () = self.registration.flush(); Ok(()))
131 }
132
133 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
134 Poll::Ready(self.registration.shutdown(std::net::Shutdown::Write))
135 }
136}