veilid_tools/
socket_tools.rs1use super::*;
2use async_io::Async;
3use std::io;
4
5cfg_if! {
6 if #[cfg(feature="rt-async-std")] {
7 pub use async_std::net::{TcpStream, TcpListener, UdpSocket};
8 } else if #[cfg(feature="rt-tokio")] {
9 pub use tokio::net::{TcpStream, TcpListener, UdpSocket};
10 pub use tokio_util::compat::*;
11 } else {
12 compile_error!("needs executor implementation");
13 }
14}
15
16use socket2::{Domain, Protocol, SockAddr, Socket, Type};
17
18pub fn bind_async_udp_socket(local_address: SocketAddr) -> io::Result<Option<UdpSocket>> {
21 let Some(socket) = new_bound_default_socket2_udp(local_address)? else {
22 return Ok(None);
23 };
24
25 let std_udp_socket: std::net::UdpSocket = socket.into();
27 cfg_if! {
28 if #[cfg(feature="rt-async-std")] {
29 let udp_socket = UdpSocket::from(std_udp_socket);
30 } else if #[cfg(feature="rt-tokio")] {
31 std_udp_socket.set_nonblocking(true)?;
32 let udp_socket = UdpSocket::from_std(std_udp_socket)?;
33 } else {
34 compile_error!("needs executor implementation");
35 }
36 }
37 Ok(Some(udp_socket))
38}
39
40pub fn bind_async_tcp_listener(local_address: SocketAddr) -> io::Result<Option<TcpListener>> {
41 let Some(socket) = new_bound_default_socket2_tcp(local_address)? else {
43 return Ok(None);
44 };
45
46 drop(socket);
48
49 let Some(socket) = new_bound_shared_socket2_tcp(local_address)? else {
51 return Ok(None);
52 };
53
54 if socket.listen(128).is_err() {
56 return Ok(None);
57 }
58
59 let std_listener: std::net::TcpListener = socket.into();
61 cfg_if! {
62 if #[cfg(feature="rt-async-std")] {
63 let listener = TcpListener::from(std_listener);
64 } else if #[cfg(feature="rt-tokio")] {
65 std_listener.set_nonblocking(true)?;
66 let listener = TcpListener::from_std(std_listener)?;
67 } else {
68 compile_error!("needs executor implementation");
69 }
70 }
71 Ok(Some(listener))
72}
73
74pub async fn connect_async_tcp_stream(
75 local_address: Option<SocketAddr>,
76 remote_address: SocketAddr,
77 timeout_ms: u32,
78) -> io::Result<TimeoutOr<TcpStream>> {
79 let socket = match local_address {
80 Some(a) => {
81 new_bound_shared_socket2_tcp(a)?.ok_or(io::Error::from(io::ErrorKind::AddrInUse))?
82 }
83 None => new_default_socket2_tcp(domain_for_address(remote_address))?,
84 };
85
86 nonblocking_connect(socket, remote_address, timeout_ms).await
88}
89
90pub fn set_tcp_stream_linger(
91 tcp_stream: &TcpStream,
92 linger: Option<core::time::Duration>,
93) -> io::Result<()> {
94 #[cfg(all(feature = "rt-async-std", unix))]
95 {
96 use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd};
98 unsafe {
99 let s = socket2::Socket::from_raw_fd(tcp_stream.as_raw_fd());
100 let res = s.set_linger(linger);
101 let _ = s.into_raw_fd();
102 res
103 }
104 }
105 #[cfg(all(feature = "rt-async-std", windows))]
106 {
107 use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};
109 unsafe {
110 let s = socket2::Socket::from_raw_socket(tcp_stream.as_raw_socket());
111 let res = s.set_linger(linger);
112 let _ = s.into_raw_socket();
113 res
114 }
115 }
116 #[cfg(not(feature = "rt-async-std"))]
117 tcp_stream.set_linger(linger)
118}
119
120cfg_if! {
121 if #[cfg(feature="rt-async-std")] {
122 pub type ReadHalf = futures_util::io::ReadHalf<TcpStream>;
123 pub type WriteHalf = futures_util::io::WriteHalf<TcpStream>;
124 } else if #[cfg(feature="rt-tokio")] {
125 pub type ReadHalf = tokio::net::tcp::OwnedReadHalf;
126 pub type WriteHalf = tokio::net::tcp::OwnedWriteHalf;
127 } else {
128 compile_error!("needs executor implementation");
129 }
130}
131
132#[must_use]
133pub fn async_tcp_listener_incoming(
134 tcp_listener: TcpListener,
135) -> Pin<Box<impl futures_util::stream::Stream<Item = std::io::Result<TcpStream>> + Send>> {
136 cfg_if! {
137 if #[cfg(feature="rt-async-std")] {
138 Box::pin(tcp_listener.into_incoming())
139 } else if #[cfg(feature="rt-tokio")] {
140 Box::pin(tokio_stream::wrappers::TcpListenerStream::new(tcp_listener))
141 } else {
142 compile_error!("needs executor implementation");
143 }
144 }
145}
146
147#[must_use]
148pub fn split_async_tcp_stream(tcp_stream: TcpStream) -> (ReadHalf, WriteHalf) {
149 cfg_if! {
150 if #[cfg(feature="rt-async-std")] {
151 use futures_util::AsyncReadExt;
152 tcp_stream.split()
153 } else if #[cfg(feature="rt-tokio")] {
154 tcp_stream.into_split()
155 } else {
156 compile_error!("needs executor implementation");
157 }
158 }
159}
160
161fn new_default_udp_socket(domain: core::ffi::c_int) -> io::Result<Socket> {
164 let domain = Domain::from(domain);
165 let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
166 if domain == Domain::IPV6 {
167 socket.set_only_v6(true)?;
168 }
169
170 Ok(socket)
171}
172
173fn new_bound_default_socket2_udp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
174 let domain = domain_for_address(local_address);
175 let socket = new_default_udp_socket(domain)?;
176 let socket2_addr = SockAddr::from(local_address);
177
178 if socket.bind(&socket2_addr).is_err() {
179 return Ok(None);
180 }
181
182 Ok(Some(socket))
183}
184
185pub fn new_default_socket2_tcp(domain: core::ffi::c_int) -> io::Result<Socket> {
186 let domain = Domain::from(domain);
187 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
188 socket.set_linger(Some(core::time::Duration::from_secs(0)))?;
189 socket.set_nodelay(true)?;
190 if domain == Domain::IPV6 {
191 socket.set_only_v6(true)?;
192 }
193 Ok(socket)
194}
195
196fn new_shared_socket2_tcp(domain: core::ffi::c_int) -> io::Result<Socket> {
197 let domain = Domain::from(domain);
198 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
199 socket.set_linger(Some(core::time::Duration::from_secs(0)))?;
200 socket.set_nodelay(true)?;
201 if domain == Domain::IPV6 {
202 socket.set_only_v6(true)?;
203 }
204 socket.set_reuse_address(true)?;
205 cfg_if! {
206 if #[cfg(unix)] {
207 socket.set_reuse_port(true)?;
208 }
209 }
210
211 Ok(socket)
212}
213
214fn new_bound_default_socket2_tcp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
215 let domain = domain_for_address(local_address);
216 let socket = new_default_socket2_tcp(domain)?;
217 let socket2_addr = SockAddr::from(local_address);
218 if socket.bind(&socket2_addr).is_err() {
219 return Ok(None);
220 }
221
222 Ok(Some(socket))
223}
224
225fn new_bound_shared_socket2_tcp(local_address: SocketAddr) -> io::Result<Option<Socket>> {
226 let domain = domain_for_address(local_address);
228 let socket = new_shared_socket2_tcp(domain)?;
229 let socket2_addr = SockAddr::from(local_address);
230 if socket.bind(&socket2_addr).is_err() {
231 return Ok(None);
232 }
233
234 Ok(Some(socket))
235}
236
237async fn nonblocking_connect(
240 socket: Socket,
241 addr: SocketAddr,
242 timeout_ms: u32,
243) -> io::Result<TimeoutOr<TcpStream>> {
244 socket.set_nonblocking(true)?;
246
247 let socket2_addr = socket2::SockAddr::from(addr);
249
250 match socket.connect(&socket2_addr) {
252 Ok(()) => Ok(()),
253 #[cfg(unix)]
254 Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
255 Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(()),
256 Err(e) => Err(e),
257 }?;
258 let async_stream = Async::new(std::net::TcpStream::from(socket))?;
259
260 timeout_or_try!(timeout(timeout_ms, async_stream.writable())
262 .await
263 .into_timeout_or()
264 .into_result()?);
265
266 let async_stream = match async_stream.get_ref().take_error()? {
268 None => Ok(async_stream),
269 Some(err) => Err(err),
270 }?;
271
272 cfg_if! {
274 if #[cfg(feature="rt-async-std")] {
275 Ok(TimeoutOr::value(TcpStream::from(async_stream.into_inner()?)))
276 } else if #[cfg(feature="rt-tokio")] {
277 Ok(TimeoutOr::value(TcpStream::from_std(async_stream.into_inner()?)?))
278 } else {
279 compile_error!("needs executor implementation");
280 }
281 }
282}
283
284#[must_use]
285pub fn domain_for_address(address: SocketAddr) -> core::ffi::c_int {
286 socket2::Domain::for_address(address).into()
287}
288
289cfg_if! {
291 if #[cfg(unix)] {
292 pub fn socket2_operation<S: std::os::fd::AsRawFd, F: FnOnce(&mut socket2::Socket) -> R, R>(
293 s: &S,
294 callback: F,
295 ) -> R {
296 use std::os::fd::{FromRawFd, IntoRawFd};
297 let mut s = unsafe { socket2::Socket::from_raw_fd(s.as_raw_fd()) };
298 let res = callback(&mut s);
299 let _ = s.into_raw_fd();
300 res
301 }
302 } else if #[cfg(windows)] {
303 pub fn socket2_operation<
304 S: std::os::windows::io::AsRawSocket,
305 F: FnOnce(&mut socket2::Socket) -> R,
306 R,
307 >(
308 s: &S,
309 callback: F,
310 ) -> R {
311 use std::os::windows::io::{FromRawSocket, IntoRawSocket};
312 let mut s = unsafe { socket2::Socket::from_raw_socket(s.as_raw_socket()) };
313 let res = callback(&mut s);
314 let _ = s.into_raw_socket();
315 res
316 }
317 } else {
318 #[compile_error("unimplemented")]
319 }
320}