tcp_stream_utils/
impl_async_io.rs1use core::time::Duration;
2
3use async_io::Async;
4use socket2::{Socket, TcpKeepalive};
5use std::{
6 io::Error as IoError,
7 net::{TcpStream as StdTcpStream, TcpStream},
8};
9
10#[cfg(unix)]
11use std::os::fd::{FromRawFd as _, IntoRawFd as _};
12#[cfg(windows)]
13use std::os::windows::{FromRawSocket as _, IntoRawSocket as _};
14
15#[cfg(any(unix, windows))]
17pub fn tcp_stream_configure_keepalive(
18 tcp_stream: Async<TcpStream>,
19 time: Option<Duration>,
20 interval: Option<Duration>,
21 retries: Option<u32>,
22) -> Result<Async<TcpStream>, IoError> {
23 let tcp_keepalive = TcpKeepalive::new();
24
25 let tcp_keepalive = if let Some(time) = time {
26 tcp_keepalive.with_time(time)
27 } else {
28 tcp_keepalive
29 };
30
31 let tcp_keepalive = if let Some(interval) = interval {
32 tcp_keepalive.with_interval(interval)
33 } else {
34 tcp_keepalive
35 };
36
37 #[allow(unused_variables)]
38 let tcp_keepalive = if let Some(retries) = retries {
39 #[cfg(windows)]
40 {
41 tcp_keepalive.with_retries(retries)
42 }
43 #[cfg(unix)]
44 {
45 tcp_keepalive
46 }
47 } else {
48 tcp_keepalive
49 };
50
51 tcp_stream_configure(tcp_stream, move |socket| {
52 socket.set_keepalive(true)?;
53 socket.set_tcp_keepalive(&tcp_keepalive)?;
54 Ok(socket)
55 })
56}
57
58#[cfg(any(unix, windows))]
60pub fn tcp_stream_configure<F>(
61 tcp_stream: Async<TcpStream>,
62 f: F,
63) -> Result<Async<TcpStream>, IoError>
64where
65 F: Fn(Socket) -> Result<Socket, IoError>,
66{
67 let std_tcp_stream = tcp_stream.into_inner()?;
68
69 #[cfg(unix)]
70 let socket = unsafe { Socket::from_raw_fd(std_tcp_stream.into_raw_fd()) };
71 #[cfg(windows)]
72 let socket = unsafe { Socket::from_raw_socket(std_tcp_stream.into_raw_socket()) };
73
74 let socket = f(socket)?;
75
76 #[cfg(unix)]
77 let std_tcp_stream = unsafe { StdTcpStream::from_raw_fd(socket.into_raw_fd()) };
78 #[cfg(windows)]
79 let std_tcp_stream = unsafe { StdTcpStream::from_raw_socket(socket.into_raw_socket()) };
80
81 Async::new(std_tcp_stream)
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87
88 use std::net::ToSocketAddrs as _;
89
90 #[cfg(any(unix, windows))]
91 #[tokio::test]
92 async fn test_tcp_stream_configure_keepalive() {
93 let tcp_stream = match Async::<TcpStream>::connect(
94 "google.com:443".to_socket_addrs().unwrap().next().unwrap(),
95 )
96 .await
97 {
98 Ok(x) => x,
99 Err(_) => return,
100 };
101
102 match tcp_stream_configure_keepalive(tcp_stream, Some(Duration::from_secs(15)), None, None)
103 {
104 Ok(_) => {}
105 Err(err) => panic!("{err}"),
106 }
107 }
108}