stuck/net/
tcp.rs

1use std::io;
2use std::io::{Read, Write};
3use std::mem::ManuallyDrop;
4use std::net::SocketAddr;
5use std::os::unix::io::{AsRawFd, RawFd};
6use std::rc::Rc;
7use std::time::Duration;
8
9use ignore_result::Ignore;
10use mio::{net, Token};
11use static_assertions::{assert_impl_all, assert_not_impl_any};
12
13use crate::channel::parallel;
14use crate::channel::prelude::*;
15use crate::runtime::Scheduler;
16
17/// Listener for incoming TCP connections.
18pub struct TcpListener {
19    listener: ManuallyDrop<net::TcpListener>,
20    readable: parallel::Receiver<()>,
21    token: Token,
22}
23
24assert_impl_all!(TcpListener: Send, Sync);
25
26impl Drop for TcpListener {
27    fn drop(&mut self) {
28        let registry = unsafe { Scheduler::registry() };
29        let listener = unsafe { ManuallyDrop::take(&mut self.listener) };
30        registry.deregister_event_source(self.token, listener);
31    }
32}
33
34impl TcpListener {
35    /// Binds and listens to given socket address.
36    pub fn bind(addr: SocketAddr) -> io::Result<TcpListener> {
37        let mut listener = net::TcpListener::bind(addr)?;
38        let registry = unsafe { Scheduler::registry() };
39        let (token, readable) = registry.register_tcp_listener(&mut listener)?;
40        Ok(TcpListener { listener: ManuallyDrop::new(listener), readable, token })
41    }
42
43    /// Accepts an incoming connection.
44    pub fn accept(&mut self) -> io::Result<(TcpStream, SocketAddr)> {
45        loop {
46            match self.listener.accept() {
47                Ok((stream, addr)) => {
48                    let stream = TcpStream::new(stream)?;
49                    return Ok((stream, addr));
50                },
51                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
52                    self.readable.recv().expect("runtime closing");
53                },
54                Err(err) => return Err(err),
55            }
56        }
57    }
58
59    /// Returns the local socket address this listener bound to.
60    pub fn local_addr(&self) -> io::Result<SocketAddr> {
61        self.listener.local_addr()
62    }
63
64    /// Sets the time-to-live (aka. TTL or hop limit) option for ip packets sent from incoming connections.
65    ///
66    /// Accepted connections will inherit this option.
67    pub fn set_ttl(&self, ttl: u8) -> io::Result<()> {
68        self.listener.set_ttl(ttl.into())
69    }
70
71    /// Gets the time-to-live option for this listening socket.
72    pub fn ttl(&self) -> io::Result<u8> {
73        self.listener.ttl().map(|ttl| ttl as u8)
74    }
75}
76
77/// A TCP stream between a local and a remote socket.
78pub struct TcpStream {
79    stream: ManuallyDrop<net::TcpStream>,
80    readable: parallel::Receiver<()>,
81    writable: parallel::Receiver<()>,
82    token: Token,
83}
84
85assert_impl_all!(TcpStream: Send, Sync);
86
87impl Drop for TcpStream {
88    fn drop(&mut self) {
89        let registry = unsafe { Scheduler::registry() };
90        let stream = unsafe { ManuallyDrop::take(&mut self.stream) };
91        registry.deregister_event_source(self.token, stream);
92    }
93}
94
95impl TcpStream {
96    fn new(mut stream: net::TcpStream) -> io::Result<Self> {
97        let registry = unsafe { Scheduler::registry() };
98        let (token, readable, mut writable) = registry.register_tcp_stream(&mut stream)?;
99        writable.recv().expect("runtime closing");
100        Ok(TcpStream { stream: ManuallyDrop::new(stream), readable, writable, token })
101    }
102
103    /// Sets the time-to-live (aka. TTL or hop limit) option for ip packets sent from this socket.
104    pub fn set_ttl(&self, ttl: u8) -> io::Result<()> {
105        self.stream.set_ttl(ttl.into())
106    }
107
108    /// Gets the time-to-live option for this socket.
109    pub fn ttl(&self) -> io::Result<u8> {
110        self.stream.ttl().map(|ttl| ttl as u8)
111    }
112
113    /// Connects to remote host.
114    pub fn connect(addr: SocketAddr) -> io::Result<Self> {
115        let stream = Self::new(net::TcpStream::connect(addr)?)?;
116        if let Some(err) = stream.stream.take_error()? {
117            return Err(err);
118        }
119        Ok(stream)
120    }
121
122    /// Sets the value of the `TCP_NODELAY` option on this socket.
123    ///
124    /// If set, this option disables the Nagle algorithm. This means that segments are always
125    /// sent as soon as possible, even if there is only a small amount of data. When not set,
126    /// data is buffered until there is a sufficient amount to send out, thereby avoiding the
127    /// frequent sending of small packets.
128    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
129        self.stream.set_nodelay(nodelay)
130    }
131
132    /// Gets the value of the `TCP_NODELAY` option on this socket.
133    ///
134    /// For more information about this option, see [TcpStream::set_nodelay].
135    pub fn nodelay(&self) -> io::Result<bool> {
136        self.stream.nodelay()
137    }
138
139    /// Sets the value of the `SO_LINGER` option on this socket.
140    ///
141    /// This value controls how the socket is closed when data remains to be sent. If `SO_LINGER`
142    /// is set, the socket will remain open for the specified duration as the system attempts to
143    /// send pending data. Otherwise, the system may close the socket immediately, or wait for a
144    /// default timeout.
145    pub fn set_linger(&self, linger: Option<Duration>) -> io::Result<()> {
146        let fd = self.stream.as_raw_fd();
147        let linger = libc::linger {
148            l_onoff: if linger.is_some() { 1 } else { 0 },
149            l_linger: linger.map(|d| d.as_secs() as libc::c_int).unwrap_or_default(),
150        };
151        let rc = unsafe {
152            libc::setsockopt(
153                fd,
154                libc::SOL_SOCKET,
155                libc::SO_LINGER,
156                &linger as *const _ as *const libc::c_void,
157                std::mem::size_of::<libc::linger>() as libc::socklen_t,
158            )
159        };
160        if rc != 0 {
161            return Err(io::Error::last_os_error());
162        }
163        Ok(())
164    }
165
166    /// Gets the value of the `SO_LINGER` option on this socket.
167    ///
168    /// For more information about this option, see [TcpStream::set_linger].
169    pub fn linger(&self) -> io::Result<Option<Duration>> {
170        let fd = self.stream.as_raw_fd();
171        let mut linger: libc::linger = unsafe { std::mem::zeroed() };
172        let mut optlen = std::mem::size_of::<libc::linger>() as libc::socklen_t;
173        let rc = unsafe {
174            libc::getsockopt(
175                fd,
176                libc::SOL_SOCKET,
177                libc::SO_LINGER,
178                &mut linger as *mut _ as *mut libc::c_void,
179                &mut optlen,
180            )
181        };
182        if rc != 0 {
183            return Err(io::Error::last_os_error());
184        }
185        Ok((linger.l_onoff != 0).then(|| Duration::from_secs(linger.l_linger as u64)))
186    }
187
188    /// Returns the local socket address of this connection.
189    pub fn local_addr(&self) -> io::Result<SocketAddr> {
190        self.stream.local_addr()
191    }
192
193    /// Returns the remote socket address of this connection.
194    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
195        self.stream.peer_addr()
196    }
197
198    /// Splits this connection to reader and writer.
199    pub fn into_split(mut self) -> (TcpReader, TcpWriter) {
200        let stream = Rc::new(unsafe { ManuallyDrop::take(&mut self.stream) });
201        let reader = TcpReader {
202            stream: ManuallyDrop::new(stream.clone()),
203            readable: unsafe { std::ptr::read(&self.readable) },
204            token: self.token,
205        };
206        let writer = TcpWriter {
207            stream: ManuallyDrop::new(stream),
208            writable: unsafe { std::ptr::read(&self.writable) },
209            token: self.token,
210        };
211        std::mem::forget(self);
212        (reader, writer)
213    }
214
215    /// Shuts down read half of this connection.
216    pub fn shutdown_read(&self) -> io::Result<()> {
217        self.stream.shutdown(std::net::Shutdown::Read)
218    }
219
220    /// Shuts down write half of this connection.
221    pub fn shutdown_write(&self) -> io::Result<()> {
222        self.stream.shutdown(std::net::Shutdown::Write)
223    }
224
225    fn read(stream: &mut net::TcpStream, readable: &mut parallel::Receiver<()>, buf: &mut [u8]) -> io::Result<usize> {
226        loop {
227            match stream.read(buf) {
228                Ok(n) => return Ok(n),
229                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
230                    readable.recv().expect("runtime closing");
231                },
232                Err(err) => return Err(err),
233            }
234        }
235    }
236
237    fn write(stream: &mut net::TcpStream, writable: &mut parallel::Receiver<()>, buf: &[u8]) -> io::Result<usize> {
238        loop {
239            match stream.write(buf) {
240                Ok(n) => return Ok(n),
241                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
242                    writable.recv().expect("runtime closing");
243                },
244                Err(err) => return Err(err),
245            }
246        }
247    }
248}
249
250impl AsRawFd for TcpStream {
251    fn as_raw_fd(&self) -> RawFd {
252        self.stream.as_raw_fd()
253    }
254}
255
256/// Read half of [TcpStream].
257///
258/// The read half of this connection will be shutdown when this value is dropped.
259pub struct TcpReader {
260    stream: ManuallyDrop<Rc<net::TcpStream>>,
261    readable: parallel::Receiver<()>,
262    token: Token,
263}
264
265assert_not_impl_any!(TcpReader: Send, Sync);
266
267impl Drop for TcpReader {
268    fn drop(&mut self) {
269        let stream = unsafe { ManuallyDrop::take(&mut self.stream) };
270        stream.shutdown(std::net::Shutdown::Read).ignore();
271        if let Some(stream) = Rc::into_inner(stream) {
272            let registry = unsafe { Scheduler::registry() };
273            registry.deregister_event_source(self.token, stream);
274        }
275    }
276}
277
278impl io::Read for TcpReader {
279    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
280        let stream = Rc::as_ptr(&self.stream) as *mut _;
281        TcpStream::read(unsafe { &mut *stream }, &mut self.readable, buf)
282    }
283}
284
285/// Write half of [TcpStream].
286///
287/// The write half of this connection will be shutdown when this value is dropped.
288pub struct TcpWriter {
289    stream: ManuallyDrop<Rc<net::TcpStream>>,
290    writable: parallel::Receiver<()>,
291    token: Token,
292}
293
294assert_not_impl_any!(TcpReader: Send, Sync);
295
296impl Drop for TcpWriter {
297    fn drop(&mut self) {
298        let stream = unsafe { ManuallyDrop::take(&mut self.stream) };
299        stream.shutdown(std::net::Shutdown::Write).ignore();
300        if let Some(stream) = Rc::into_inner(stream) {
301            let registry = unsafe { Scheduler::registry() };
302            registry.deregister_event_source(self.token, stream);
303        }
304    }
305}
306
307impl io::Write for TcpWriter {
308    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
309        let stream = Rc::as_ptr(&self.stream) as *mut _;
310        TcpStream::write(unsafe { &mut *stream }, &mut self.writable, buf)
311    }
312
313    fn flush(&mut self) -> io::Result<()> {
314        Ok(())
315    }
316}
317
318impl io::Read for TcpStream {
319    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
320        TcpStream::read(&mut self.stream, &mut self.readable, buf)
321    }
322}
323
324impl io::Write for TcpStream {
325    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
326        TcpStream::write(&mut self.stream, &mut self.writable, buf)
327    }
328
329    fn flush(&mut self) -> io::Result<()> {
330        Ok(())
331    }
332}