veilid_tools/
socket_tools.rs

1use super::*;
2use async_io::Async;
3use std::io;
4
5cfg_if! {
6    if #[cfg(feature="rt-async-std")] {
7        pub use async_std::net::{TcpStream, TcpListener, UdpSocket};
8    } else if #[cfg(feature="rt-tokio")] {
9        pub use tokio::net::{TcpStream, TcpListener, UdpSocket};
10        pub use tokio_util::compat::*;
11    } else {
12        compile_error!("needs executor implementation");
13    }
14}
15
16use socket2::{Domain, Protocol, SockAddr, Socket, Type};
17
18//////////////////////////////////////////////////////////////////////////////////////////
19
20pub fn bind_async_udp_socket(local_address: SocketAddr) -> io::Result<Option<UdpSocket>> {
21    let Some(socket) = new_bound_default_socket2_udp(local_address)? else {
22        return Ok(None);
23    };
24
25    // Make an async UdpSocket from the socket2 socket
26    let std_udp_socket: std::net::UdpSocket = socket.into();
27    cfg_if! {
28        if #[cfg(feature="rt-async-std")] {
29            let udp_socket = UdpSocket::from(std_udp_socket);
30        } else if #[cfg(feature="rt-tokio")] {
31            std_udp_socket.set_nonblocking(true)?;
32            let udp_socket = UdpSocket::from_std(std_udp_socket)?;
33        } else {
34            compile_error!("needs executor implementation");
35        }
36    }
37    Ok(Some(udp_socket))
38}
39
40pub fn bind_async_tcp_listener(local_address: SocketAddr) -> io::Result<Option<TcpListener>> {
41    // Create a default non-shared socket and bind it
42    let Some(socket) = new_bound_default_socket2_tcp(local_address)? else {
43        return Ok(None);
44    };
45
46    // Drop the socket so we can make another shared socket in its place
47    drop(socket);
48
49    // Create a shared socket and bind it now we have determined the port is free
50    let Some(socket) = new_bound_shared_socket2_tcp(local_address)? else {
51        return Ok(None);
52    };
53
54    // Listen on the socket
55    if socket.listen(128).is_err() {
56        return Ok(None);
57    }
58
59    // Make an async tcplistener from the socket2 socket
60    let std_listener: std::net::TcpListener = socket.into();
61    cfg_if! {
62        if #[cfg(feature="rt-async-std")] {
63            let listener = TcpListener::from(std_listener);
64        } else if #[cfg(feature="rt-tokio")] {
65            std_listener.set_nonblocking(true)?;
66            let listener = TcpListener::from_std(std_listener)?;
67        } else {
68            compile_error!("needs executor implementation");
69        }
70    }
71    Ok(Some(listener))
72}
73
74pub async fn connect_async_tcp_stream(
75    local_address: Option<SocketAddr>,
76    remote_address: SocketAddr,
77    timeout_ms: u32,
78) -> io::Result<TimeoutOr<TcpStream>> {
79    let socket = match local_address {
80        Some(a) => {
81            new_bound_shared_socket2_tcp(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
82        }
83        None => new_default_socket2_tcp(domain_for_address(remote_address))?,
84    };
85
86    // Non-blocking connect to remote address
87    nonblocking_connect(socket, remote_address, timeout_ms).await
88}
89
90pub fn set_tcp_stream_linger(
91    tcp_stream: &TcpStream,
92    linger: Option<core::time::Duration>,
93) -> io::Result<()> {
94    #[cfg(all(feature = "rt-async-std", unix))]
95    {
96        // async-std does not directly support linger on TcpStream yet
97        use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd};
98        unsafe {
99            let s = socket2::Socket::from_raw_fd(tcp_stream.as_raw_fd());
100            let res = s.set_linger(linger);
101            let _ = s.into_raw_fd();
102            res
103        }
104    }
105    #[cfg(all(feature = "rt-async-std", windows))]
106    {
107        // async-std does not directly support linger on TcpStream yet
108        use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};
109        unsafe {
110            let s = socket2::Socket::from_raw_socket(tcp_stream.as_raw_socket());
111            let res = s.set_linger(linger);
112            let _ = s.into_raw_socket();
113            res
114        }
115    }
116    #[cfg(not(feature = "rt-async-std"))]
117    tcp_stream.set_linger(linger)
118}
119
120cfg_if! {
121    if #[cfg(feature="rt-async-std")] {
122        pub type ReadHalf = futures_util::io::ReadHalf<TcpStream>;
123        pub type WriteHalf = futures_util::io::WriteHalf<TcpStream>;
124    } else if #[cfg(feature="rt-tokio")] {
125        pub type ReadHalf = tokio::net::tcp::OwnedReadHalf;
126        pub type WriteHalf = tokio::net::tcp::OwnedWriteHalf;
127    } else {
128        compile_error!("needs executor implementation");
129    }
130}
131
132pub fn async_tcp_listener_incoming(
133    tcp_listener: TcpListener,
134) -> Pin<Box<impl futures_util::stream::Stream<Item = std::io::Result<TcpStream>> + Send>> {
135    cfg_if! {
136        if #[cfg(feature="rt-async-std")] {
137            Box::pin(tcp_listener.into_incoming())
138        } else if #[cfg(feature="rt-tokio")] {
139            Box::pin(tokio_stream::wrappers::TcpListenerStream::new(tcp_listener))
140        } else {
141            compile_error!("needs executor implementation");
142        }
143    }
144}
145
146pub fn split_async_tcp_stream(tcp_stream: TcpStream) -> (ReadHalf, WriteHalf) {
147    cfg_if! {
148        if #[cfg(feature="rt-async-std")] {
149            use futures_util::AsyncReadExt;
150            tcp_stream.split()
151        } else if #[cfg(feature="rt-tokio")] {
152            tcp_stream.into_split()
153        } else {
154            compile_error!("needs executor implementation");
155        }
156    }
157}
158
159//////////////////////////////////////////////////////////////////////////////////////////
160
161fn new_default_udp_socket(domain: core::ffi::c_int) -> io::Result<Socket> {
162    let domain = Domain::from(domain);
163    let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
164    if domain == Domain::IPV6 {
165        socket.set_only_v6(true)?;
166    }
167
168    Ok(socket)
169}
170
171fn new_bound_default_socket2_udp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
172    let domain = domain_for_address(local_address);
173    let socket = new_default_udp_socket(domain)?;
174    let socket2_addr = SockAddr::from(local_address);
175
176    if socket.bind(&socket2_addr).is_err() {
177        return Ok(None);
178    }
179
180    Ok(Some(socket))
181}
182
183pub fn new_default_socket2_tcp(domain: core::ffi::c_int) -> io::Result<Socket> {
184    let domain = Domain::from(domain);
185    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
186    socket.set_linger(Some(core::time::Duration::from_secs(0)))?;
187    socket.set_nodelay(true)?;
188    if domain == Domain::IPV6 {
189        socket.set_only_v6(true)?;
190    }
191    Ok(socket)
192}
193
194fn new_shared_socket2_tcp(domain: core::ffi::c_int) -> io::Result<Socket> {
195    let domain = Domain::from(domain);
196    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
197    socket.set_linger(Some(core::time::Duration::from_secs(0)))?;
198    socket.set_nodelay(true)?;
199    if domain == Domain::IPV6 {
200        socket.set_only_v6(true)?;
201    }
202    socket.set_reuse_address(true)?;
203    cfg_if! {
204        if #[cfg(unix)] {
205            socket.set_reuse_port(true)?;
206        }
207    }
208
209    Ok(socket)
210}
211
212fn new_bound_default_socket2_tcp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
213    let domain = domain_for_address(local_address);
214    let socket = new_default_socket2_tcp(domain)?;
215    let socket2_addr = SockAddr::from(local_address);
216    if socket.bind(&socket2_addr).is_err() {
217        return Ok(None);
218    }
219
220    Ok(Some(socket))
221}
222
223fn new_bound_shared_socket2_tcp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
224    // Create the reuseaddr/reuseport socket now that we've asserted the port is free
225    let domain = domain_for_address(local_address);
226    let socket = new_shared_socket2_tcp(domain)?;
227    let socket2_addr = SockAddr::from(local_address);
228    if socket.bind(&socket2_addr).is_err() {
229        return Ok(None);
230    }
231
232    Ok(Some(socket))
233}
234
235// Non-blocking connect is tricky when you want to start with a prepared socket
236// Errors should not be logged as they are valid conditions for this function
237async fn nonblocking_connect(
238    socket: Socket,
239    addr: SocketAddr,
240    timeout_ms: u32,
241) -> io::Result<TimeoutOr<TcpStream>> {
242    // Set for non blocking connect
243    socket.set_nonblocking(true)?;
244
245    // Make socket2 SockAddr
246    let socket2_addr = socket2::SockAddr::from(addr);
247
248    // Connect to the remote address
249    match socket.connect(&socket2_addr) {
250        Ok(()) => Ok(()),
251        #[cfg(unix)]
252        Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
253        Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
254        Err(e) => Err(e),
255    }?;
256    let async_stream = Async::new(std::net::TcpStream::from(socket))?;
257
258    // The stream becomes writable when connected
259    timeout_or_try!(timeout(timeout_ms, async_stream.writable())
260        .await
261        .into_timeout_or()
262        .into_result()?);
263
264    // Check low level error
265    let async_stream = match async_stream.get_ref().take_error()? {
266        None => Ok(async_stream),
267        Some(err) => Err(err),
268    }?;
269
270    // Convert back to inner and then return async version
271    cfg_if! {
272        if #[cfg(feature="rt-async-std")] {
273            Ok(TimeoutOr::value(TcpStream::from(async_stream.into_inner()?)))
274        } else if #[cfg(feature="rt-tokio")] {
275            Ok(TimeoutOr::value(TcpStream::from_std(async_stream.into_inner()?)?))
276        } else {
277            compile_error!("needs executor implementation");
278        }
279    }
280}
281
282pub fn domain_for_address(address: SocketAddr) -> core::ffi::c_int {
283    socket2::Domain::for_address(address).into()
284}
285
286// Run operations on underlying socket
287cfg_if! {
288    if #[cfg(unix)] {
289        pub fn socket2_operation<S: std::os::fd::AsRawFd, F: FnOnce(&mut socket2::Socket) -> R, R>(
290            s: &S,
291            callback: F,
292        ) -> R {
293            use std::os::fd::{FromRawFd, IntoRawFd};
294            let mut s = unsafe { socket2::Socket::from_raw_fd(s.as_raw_fd()) };
295            let res = callback(&mut s);
296            let _ = s.into_raw_fd();
297            res
298        }
299    } else if #[cfg(windows)] {
300        pub fn socket2_operation<
301            S: std::os::windows::io::AsRawSocket,
302            F: FnOnce(&mut socket2::Socket) -> R,
303            R,
304        >(
305            s: &S,
306            callback: F,
307        ) -> R {
308            use std::os::windows::io::{FromRawSocket, IntoRawSocket};
309            let mut s = unsafe { socket2::Socket::from_raw_socket(s.as_raw_socket()) };
310            let res = callback(&mut s);
311            let _ = s.into_raw_socket();
312            res
313        }
314    } else {
315        #[compile_error("unimplemented")]
316    }
317}