tcp_stream_utils/
impl_tokio.rs1use core::time::Duration;
2
3use socket2::{Socket, TcpKeepalive};
4use std::{io::Error as IoError, net::TcpStream as StdTcpStream};
5use tokio::net::TcpStream;
6
7#[cfg(unix)]
8use std::os::fd::{FromRawFd as _, IntoRawFd as _};
9#[cfg(windows)]
10use std::os::windows::{FromRawSocket as _, IntoRawSocket as _};
11
12#[cfg(any(unix, windows))]
14pub fn tcp_stream_configure_keepalive(
15 tcp_stream: TcpStream,
16 time: Option<Duration>,
17 interval: Option<Duration>,
18 retries: Option<u32>,
19) -> Result<TcpStream, IoError> {
20 let tcp_keepalive = TcpKeepalive::new();
21
22 let tcp_keepalive = if let Some(time) = time {
23 tcp_keepalive.with_time(time)
24 } else {
25 tcp_keepalive
26 };
27
28 let tcp_keepalive = if let Some(interval) = interval {
29 tcp_keepalive.with_interval(interval)
30 } else {
31 tcp_keepalive
32 };
33
34 #[allow(unused_variables)]
35 let tcp_keepalive = if let Some(retries) = retries {
36 #[cfg(windows)]
37 {
38 tcp_keepalive.with_retries(retries)
39 }
40 #[cfg(unix)]
41 {
42 tcp_keepalive
43 }
44 } else {
45 tcp_keepalive
46 };
47
48 tcp_stream_configure(tcp_stream, move |socket| {
49 socket.set_keepalive(true)?;
50 socket.set_tcp_keepalive(&tcp_keepalive)?;
51 Ok(socket)
52 })
53}
54
55#[cfg(any(unix, windows))]
57pub fn tcp_stream_configure<F>(tcp_stream: TcpStream, f: F) -> Result<TcpStream, IoError>
58where
59 F: Fn(Socket) -> Result<Socket, IoError>,
60{
61 let std_tcp_stream = tcp_stream.into_std()?;
62
63 #[cfg(unix)]
64 let socket = unsafe { Socket::from_raw_fd(std_tcp_stream.into_raw_fd()) };
65 #[cfg(windows)]
66 let socket = unsafe { Socket::from_raw_socket(std_tcp_stream.into_raw_socket()) };
67
68 let socket = f(socket)?;
69
70 #[cfg(unix)]
71 let std_tcp_stream = unsafe { StdTcpStream::from_raw_fd(socket.into_raw_fd()) };
72 #[cfg(windows)]
73 let std_tcp_stream = unsafe { StdTcpStream::from_raw_socket(socket.into_raw_socket()) };
74
75 TcpStream::from_std(std_tcp_stream)
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 #[cfg(any(unix, windows))]
83 #[tokio::test]
84 async fn test_tcp_stream_configure_keepalive() {
85 let tcp_stream = match TcpStream::connect("google.com:443").await {
86 Ok(x) => x,
87 Err(_) => return,
88 };
89
90 match tcp_stream_configure_keepalive(tcp_stream, Some(Duration::from_secs(15)), None, None)
91 {
92 Ok(_) => {}
93 Err(err) => panic!("{err}"),
94 }
95 }
96}