xitca_http/util/middleware/
socket_config.rs

1use core::{net::SocketAddr, time::Duration};
2
3use std::io;
4
5use socket2::{SockRef, TcpKeepalive};
6
7use tracing::warn;
8use xitca_io::net::{Stream as ServerStream, TcpStream};
9use xitca_service::{Service, ready::ReadyService};
10
11#[cfg(unix)]
12use xitca_io::net::UnixStream;
13
14/// A middleware for socket options config of `TcpStream` and `UnixStream`.
15#[derive(Clone, Debug)]
16pub struct SocketConfig {
17    ka: Option<TcpKeepalive>,
18    nodelay: bool,
19}
20
21impl Default for SocketConfig {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl SocketConfig {
28    pub const fn new() -> Self {
29        Self {
30            ka: None,
31            nodelay: false,
32        }
33    }
34
35    /// For more information about this option, see [`set_nodelay`].
36    ///
37    /// [`set_nodelay`]: socket2::Socket::set_nodelay
38    pub fn set_nodelay(mut self, value: bool) -> Self {
39        self.nodelay = value;
40        self
41    }
42
43    /// For more information about this option, see [`with_time`].
44    ///
45    /// [`with_time`]: TcpKeepalive::with_time
46    pub fn keep_alive_with_time(mut self, time: Duration) -> Self {
47        self.ka = Some(self.ka.unwrap_or_else(TcpKeepalive::new).with_time(time));
48        self
49    }
50
51    /// For more information about this option, see [`with_interval`].
52    ///
53    /// [`with_interval`]: TcpKeepalive::with_interval
54    pub fn keep_alive_with_interval(mut self, time: Duration) -> Self {
55        self.ka = Some(self.ka.unwrap_or_else(TcpKeepalive::new).with_interval(time));
56        self
57    }
58
59    #[cfg(not(windows))]
60    /// For more information about this option, see [`with_retries`].
61    ///
62    /// [`with_retries`]: TcpKeepalive::with_retries
63    pub fn keep_alive_with_retries(mut self, retries: u32) -> Self {
64        self.ka = Some(self.ka.unwrap_or_else(TcpKeepalive::new).with_retries(retries));
65        self
66    }
67}
68
69impl<S, E> Service<Result<S, E>> for SocketConfig {
70    type Response = SocketConfigService<S>;
71    type Error = E;
72
73    async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
74        res.map(|service| SocketConfigService {
75            config: self.clone(),
76            service,
77        })
78    }
79}
80
81impl<S> ReadyService for SocketConfigService<S>
82where
83    S: ReadyService,
84{
85    type Ready = S::Ready;
86
87    #[inline]
88    async fn ready(&self) -> Self::Ready {
89        self.service.ready().await
90    }
91}
92
93impl<S> Service<(TcpStream, SocketAddr)> for SocketConfigService<S>
94where
95    S: Service<(TcpStream, SocketAddr)>,
96{
97    type Response = S::Response;
98    type Error = S::Error;
99
100    async fn call(&self, (stream, addr): (TcpStream, SocketAddr)) -> Result<Self::Response, Self::Error> {
101        self.try_apply_config(&stream);
102        self.service.call((stream, addr)).await
103    }
104}
105
106#[cfg(unix)]
107impl<S> Service<(UnixStream, SocketAddr)> for SocketConfigService<S>
108where
109    S: Service<(UnixStream, SocketAddr)>,
110{
111    type Response = S::Response;
112    type Error = S::Error;
113
114    async fn call(&self, (stream, addr): (UnixStream, SocketAddr)) -> Result<Self::Response, Self::Error> {
115        self.try_apply_config(&stream);
116        self.service.call((stream, addr)).await
117    }
118}
119
120impl<S> Service<ServerStream> for SocketConfigService<S>
121where
122    S: Service<ServerStream>,
123{
124    type Response = S::Response;
125    type Error = S::Error;
126
127    #[inline]
128    async fn call(&self, stream: ServerStream) -> Result<Self::Response, Self::Error> {
129        #[cfg_attr(windows, allow(irrefutable_let_patterns))]
130        if let ServerStream::Tcp(ref tcp, _) = stream {
131            self.try_apply_config(tcp)
132        };
133
134        #[cfg(unix)]
135        if let ServerStream::Unix(ref unix, _) = stream {
136            self.try_apply_config(unix)
137        };
138
139        self.service.call(stream).await
140    }
141}
142
143pub struct SocketConfigService<S> {
144    config: SocketConfig,
145    service: S,
146}
147
148impl<S> SocketConfigService<S> {
149    fn apply_config<'s>(&self, stream: impl Into<SockRef<'s>>) -> io::Result<()> {
150        let stream_ref = stream.into();
151
152        stream_ref.set_nodelay(self.config.nodelay)?;
153
154        if let Some(ka) = self.config.ka.as_ref() {
155            stream_ref.set_tcp_keepalive(ka)?;
156        }
157
158        Ok(())
159    }
160
161    fn try_apply_config<'s>(&self, stream: impl Into<SockRef<'s>>) {
162        if let Err(e) = self.apply_config(stream) {
163            warn!(target: "SocketConfig", "Failed to apply configuration to SocketConfig. {:?}", e);
164        };
165    }
166}