tcp_stream_utils/
impl_async_io.rs

1use core::time::Duration;
2
3use async_io::Async;
4use socket2::{Socket, TcpKeepalive};
5use std::{
6    io::Error as IoError,
7    net::{TcpStream as StdTcpStream, TcpStream},
8};
9
10#[cfg(unix)]
11use std::os::fd::{FromRawFd as _, IntoRawFd as _};
12#[cfg(windows)]
13use std::os::windows::{FromRawSocket as _, IntoRawSocket as _};
14
15//
16#[cfg(any(unix, windows))]
17pub fn tcp_stream_configure_keepalive(
18    tcp_stream: Async<TcpStream>,
19    time: Option<Duration>,
20    interval: Option<Duration>,
21    retries: Option<u32>,
22) -> Result<Async<TcpStream>, IoError> {
23    let tcp_keepalive = TcpKeepalive::new();
24
25    let tcp_keepalive = if let Some(time) = time {
26        tcp_keepalive.with_time(time)
27    } else {
28        tcp_keepalive
29    };
30
31    let tcp_keepalive = if let Some(interval) = interval {
32        tcp_keepalive.with_interval(interval)
33    } else {
34        tcp_keepalive
35    };
36
37    #[allow(unused_variables)]
38    let tcp_keepalive = if let Some(retries) = retries {
39        #[cfg(windows)]
40        {
41            tcp_keepalive.with_retries(retries)
42        }
43        #[cfg(unix)]
44        {
45            tcp_keepalive
46        }
47    } else {
48        tcp_keepalive
49    };
50
51    tcp_stream_configure(tcp_stream, move |socket| {
52        socket.set_keepalive(true)?;
53        socket.set_tcp_keepalive(&tcp_keepalive)?;
54        Ok(socket)
55    })
56}
57
58//
59#[cfg(any(unix, windows))]
60pub fn tcp_stream_configure<F>(
61    tcp_stream: Async<TcpStream>,
62    f: F,
63) -> Result<Async<TcpStream>, IoError>
64where
65    F: Fn(Socket) -> Result<Socket, IoError>,
66{
67    let std_tcp_stream = tcp_stream.into_inner()?;
68
69    #[cfg(unix)]
70    let socket = unsafe { Socket::from_raw_fd(std_tcp_stream.into_raw_fd()) };
71    #[cfg(windows)]
72    let socket = unsafe { Socket::from_raw_socket(std_tcp_stream.into_raw_socket()) };
73
74    let socket = f(socket)?;
75
76    #[cfg(unix)]
77    let std_tcp_stream = unsafe { StdTcpStream::from_raw_fd(socket.into_raw_fd()) };
78    #[cfg(windows)]
79    let std_tcp_stream = unsafe { StdTcpStream::from_raw_socket(socket.into_raw_socket()) };
80
81    Async::new(std_tcp_stream)
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    use std::net::ToSocketAddrs as _;
89
90    #[cfg(any(unix, windows))]
91    #[tokio::test]
92    async fn test_tcp_stream_configure_keepalive() {
93        let tcp_stream = match Async::<TcpStream>::connect(
94            "google.com:443".to_socket_addrs().unwrap().next().unwrap(),
95        )
96        .await
97        {
98            Ok(x) => x,
99            Err(_) => return,
100        };
101
102        match tcp_stream_configure_keepalive(tcp_stream, Some(Duration::from_secs(15)), None, None)
103        {
104            Ok(_) => {}
105            Err(err) => panic!("{err}"),
106        }
107    }
108}