xitca_http/util/middleware/
socket_config.rs1use core::{net::SocketAddr, time::Duration};
2
3use std::io;
4
5use socket2::{SockRef, TcpKeepalive};
6
7use tracing::warn;
8use xitca_io::net::{Stream as ServerStream, TcpStream};
9use xitca_service::{Service, ready::ReadyService};
10
11#[cfg(unix)]
12use xitca_io::net::UnixStream;
13
14#[derive(Clone, Debug)]
16pub struct SocketConfig {
17 ka: Option<TcpKeepalive>,
18 nodelay: bool,
19}
20
21impl Default for SocketConfig {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl SocketConfig {
28 pub const fn new() -> Self {
29 Self {
30 ka: None,
31 nodelay: false,
32 }
33 }
34
35 pub fn set_nodelay(mut self, value: bool) -> Self {
39 self.nodelay = value;
40 self
41 }
42
43 pub fn keep_alive_with_time(mut self, time: Duration) -> Self {
47 self.ka = Some(self.ka.unwrap_or_else(TcpKeepalive::new).with_time(time));
48 self
49 }
50
51 pub fn keep_alive_with_interval(mut self, time: Duration) -> Self {
55 self.ka = Some(self.ka.unwrap_or_else(TcpKeepalive::new).with_interval(time));
56 self
57 }
58
59 #[cfg(not(windows))]
60 pub fn keep_alive_with_retries(mut self, retries: u32) -> Self {
64 self.ka = Some(self.ka.unwrap_or_else(TcpKeepalive::new).with_retries(retries));
65 self
66 }
67}
68
69impl<S, E> Service<Result<S, E>> for SocketConfig {
70 type Response = SocketConfigService<S>;
71 type Error = E;
72
73 async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
74 res.map(|service| SocketConfigService {
75 config: self.clone(),
76 service,
77 })
78 }
79}
80
81impl<S> ReadyService for SocketConfigService<S>
82where
83 S: ReadyService,
84{
85 type Ready = S::Ready;
86
87 #[inline]
88 async fn ready(&self) -> Self::Ready {
89 self.service.ready().await
90 }
91}
92
93impl<S> Service<(TcpStream, SocketAddr)> for SocketConfigService<S>
94where
95 S: Service<(TcpStream, SocketAddr)>,
96{
97 type Response = S::Response;
98 type Error = S::Error;
99
100 async fn call(&self, (stream, addr): (TcpStream, SocketAddr)) -> Result<Self::Response, Self::Error> {
101 self.try_apply_config(&stream);
102 self.service.call((stream, addr)).await
103 }
104}
105
106#[cfg(unix)]
107impl<S> Service<(UnixStream, SocketAddr)> for SocketConfigService<S>
108where
109 S: Service<(UnixStream, SocketAddr)>,
110{
111 type Response = S::Response;
112 type Error = S::Error;
113
114 async fn call(&self, (stream, addr): (UnixStream, SocketAddr)) -> Result<Self::Response, Self::Error> {
115 self.try_apply_config(&stream);
116 self.service.call((stream, addr)).await
117 }
118}
119
120impl<S> Service<ServerStream> for SocketConfigService<S>
121where
122 S: Service<ServerStream>,
123{
124 type Response = S::Response;
125 type Error = S::Error;
126
127 #[inline]
128 async fn call(&self, stream: ServerStream) -> Result<Self::Response, Self::Error> {
129 #[cfg_attr(windows, allow(irrefutable_let_patterns))]
130 if let ServerStream::Tcp(ref tcp, _) = stream {
131 self.try_apply_config(tcp)
132 };
133
134 #[cfg(unix)]
135 if let ServerStream::Unix(ref unix, _) = stream {
136 self.try_apply_config(unix)
137 };
138
139 self.service.call(stream).await
140 }
141}
142
143pub struct SocketConfigService<S> {
144 config: SocketConfig,
145 service: S,
146}
147
148impl<S> SocketConfigService<S> {
149 fn apply_config<'s>(&self, stream: impl Into<SockRef<'s>>) -> io::Result<()> {
150 let stream_ref = stream.into();
151
152 stream_ref.set_nodelay(self.config.nodelay)?;
153
154 if let Some(ka) = self.config.ka.as_ref() {
155 stream_ref.set_tcp_keepalive(ka)?;
156 }
157
158 Ok(())
159 }
160
161 fn try_apply_config<'s>(&self, stream: impl Into<SockRef<'s>>) {
162 if let Err(e) = self.apply_config(stream) {
163 warn!(target: "SocketConfig", "Failed to apply configuration to SocketConfig. {:?}", e);
164 };
165 }
166}