tcp_stream_utils/
impl_tokio.rs

1use core::time::Duration;
2
3use socket2::{Socket, TcpKeepalive};
4use std::{io::Error as IoError, net::TcpStream as StdTcpStream};
5use tokio::net::TcpStream;
6
7#[cfg(unix)]
8use std::os::fd::{FromRawFd as _, IntoRawFd as _};
9#[cfg(windows)]
10use std::os::windows::{FromRawSocket as _, IntoRawSocket as _};
11
12//
13#[cfg(any(unix, windows))]
14pub fn tcp_stream_configure_keepalive(
15    tcp_stream: TcpStream,
16    time: Option<Duration>,
17    interval: Option<Duration>,
18    retries: Option<u32>,
19) -> Result<TcpStream, IoError> {
20    let tcp_keepalive = TcpKeepalive::new();
21
22    let tcp_keepalive = if let Some(time) = time {
23        tcp_keepalive.with_time(time)
24    } else {
25        tcp_keepalive
26    };
27
28    let tcp_keepalive = if let Some(interval) = interval {
29        tcp_keepalive.with_interval(interval)
30    } else {
31        tcp_keepalive
32    };
33
34    #[allow(unused_variables)]
35    let tcp_keepalive = if let Some(retries) = retries {
36        #[cfg(windows)]
37        {
38            tcp_keepalive.with_retries(retries)
39        }
40        #[cfg(unix)]
41        {
42            tcp_keepalive
43        }
44    } else {
45        tcp_keepalive
46    };
47
48    tcp_stream_configure(tcp_stream, move |socket| {
49        socket.set_keepalive(true)?;
50        socket.set_tcp_keepalive(&tcp_keepalive)?;
51        Ok(socket)
52    })
53}
54
55//
56#[cfg(any(unix, windows))]
57pub fn tcp_stream_configure<F>(tcp_stream: TcpStream, f: F) -> Result<TcpStream, IoError>
58where
59    F: Fn(Socket) -> Result<Socket, IoError>,
60{
61    let std_tcp_stream = tcp_stream.into_std()?;
62
63    #[cfg(unix)]
64    let socket = unsafe { Socket::from_raw_fd(std_tcp_stream.into_raw_fd()) };
65    #[cfg(windows)]
66    let socket = unsafe { Socket::from_raw_socket(std_tcp_stream.into_raw_socket()) };
67
68    let socket = f(socket)?;
69
70    #[cfg(unix)]
71    let std_tcp_stream = unsafe { StdTcpStream::from_raw_fd(socket.into_raw_fd()) };
72    #[cfg(windows)]
73    let std_tcp_stream = unsafe { StdTcpStream::from_raw_socket(socket.into_raw_socket()) };
74
75    TcpStream::from_std(std_tcp_stream)
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    #[cfg(any(unix, windows))]
83    #[tokio::test]
84    async fn test_tcp_stream_configure_keepalive() {
85        let tcp_stream = match TcpStream::connect("google.com:443").await {
86            Ok(x) => x,
87            Err(_) => return,
88        };
89
90        match tcp_stream_configure_keepalive(tcp_stream, Some(Duration::from_secs(15)), None, None)
91        {
92            Ok(_) => {}
93            Err(err) => panic!("{err}"),
94        }
95    }
96}