1use std::{
5 net::{IpAddr, Ipv4Addr, Ipv6Addr},
6 str::FromStr,
7};
8
9use http::{HeaderMap, HeaderName};
10use ipnet::{IpNet, Ipv4Net, Ipv6Net};
11use motore::{Service, layer::Layer};
12use volo::{context::Context, net::Address};
13
14use crate::{context::ServerContext, request::Request};
15
16#[derive(Clone, Debug, Default)]
20pub struct ClientIpLayer {
21 config: ClientIpConfig,
22}
23
24impl ClientIpLayer {
25 pub fn new() -> Self {
27 Default::default()
28 }
29
30 pub fn with_config(self, config: ClientIpConfig) -> Self {
32 Self { config }
33 }
34}
35
36impl<S> Layer<S> for ClientIpLayer
37where
38 S: Send + Sync + 'static,
39{
40 type Service = ClientIpService<S>;
41
42 fn layer(self, inner: S) -> Self::Service {
43 ClientIpService {
44 service: inner,
45 config: self.config,
46 }
47 }
48}
49
50#[derive(Clone, Debug)]
52pub struct ClientIpConfig {
53 remote_ip_headers: Vec<HeaderName>,
54 trusted_cidrs: Vec<IpNet>,
55}
56
57impl Default for ClientIpConfig {
58 fn default() -> Self {
59 Self {
60 remote_ip_headers: vec![
61 HeaderName::from_static("x-real-ip"),
62 HeaderName::from_static("x-forwarded-for"),
63 ],
64 trusted_cidrs: vec![
65 IpNet::V4(Ipv4Net::new_assert(Ipv4Addr::new(0, 0, 0, 0), 0)),
66 IpNet::V6(Ipv6Net::new_assert(
67 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
68 0,
69 )),
70 ],
71 }
72 }
73}
74
75impl ClientIpConfig {
76 pub fn new() -> Self {
82 Default::default()
83 }
84
85 pub fn with_remote_ip_headers<I>(
98 self,
99 headers: I,
100 ) -> Result<Self, http::header::InvalidHeaderName>
101 where
102 I: IntoIterator,
103 I::Item: AsRef<str>,
104 {
105 let headers = headers.into_iter().collect::<Vec<_>>();
106 let mut remote_ip_headers = Vec::with_capacity(headers.len());
107 for header_str in headers {
108 let header_value = HeaderName::from_str(header_str.as_ref())?;
109 remote_ip_headers.push(header_value);
110 }
111
112 Ok(Self {
113 remote_ip_headers,
114 trusted_cidrs: self.trusted_cidrs,
115 })
116 }
117
118 pub fn with_trusted_cidrs<H>(self, cidrs: H) -> Self
131 where
132 H: IntoIterator<Item = IpNet>,
133 {
134 Self {
135 remote_ip_headers: self.remote_ip_headers,
136 trusted_cidrs: cidrs.into_iter().collect(),
137 }
138 }
139}
140
141#[derive(Clone, Debug, PartialEq, Eq)]
203pub struct ClientIp(pub Option<IpAddr>);
204
205#[derive(Clone, Debug)]
209pub struct ClientIpService<S> {
210 service: S,
211 config: ClientIpConfig,
212}
213
214impl<S> ClientIpService<S> {
215 fn get_client_ip(&self, cx: &ServerContext, headers: &HeaderMap) -> ClientIp {
216 let remote_ip = match &cx.rpc_info().caller().address {
217 Some(Address::Ip(socket_addr)) => Some(socket_addr.ip()),
218 #[cfg(target_family = "unix")]
219 Some(Address::Unix(_)) => None,
220 #[allow(unreachable_patterns)]
221 Some(_) => unimplemented!("unsupported type of address"),
222 None => return ClientIp(None),
223 };
224
225 if let Some(remote_ip) = &remote_ip {
226 if !self
227 .config
228 .trusted_cidrs
229 .iter()
230 .any(|cidr| cidr.contains(remote_ip))
231 {
232 return ClientIp(None);
233 }
234 }
235
236 for remote_ip_header in self.config.remote_ip_headers.iter() {
237 let Some(remote_ips) = headers.get(remote_ip_header).and_then(|v| v.to_str().ok())
238 else {
239 continue;
240 };
241 for remote_ip in remote_ips.split(',').map(str::trim) {
242 if let Ok(remote_ip_addr) = IpAddr::from_str(remote_ip) {
243 if self
244 .config
245 .trusted_cidrs
246 .iter()
247 .any(|cidr| cidr.contains(&remote_ip_addr))
248 {
249 return ClientIp(Some(remote_ip_addr));
250 }
251 }
252 }
253 }
254
255 ClientIp(remote_ip)
256 }
257}
258
259impl<S, B> Service<ServerContext, Request<B>> for ClientIpService<S>
260where
261 S: Service<ServerContext, Request<B>> + Send + Sync + 'static,
262 B: Send,
263{
264 type Response = S::Response;
265 type Error = S::Error;
266
267 async fn call(
268 &self,
269 cx: &mut ServerContext,
270 req: Request<B>,
271 ) -> Result<Self::Response, Self::Error> {
272 let client_ip = self.get_client_ip(cx, req.headers());
273 cx.extensions_mut().insert(client_ip);
274
275 self.service.call(cx, req).await
276 }
277}
278
279#[cfg(test)]
280mod client_ip_tests {
281 use std::{net::SocketAddr, str::FromStr};
282
283 use http::{HeaderValue, Method};
284 use motore::{Service, layer::Layer};
285 use volo::net::Address;
286
287 use crate::{
288 body::BodyConversion,
289 context::ServerContext,
290 server::{
291 route::{Route, get},
292 utils::client_ip::{ClientIp, ClientIpConfig, ClientIpLayer},
293 },
294 utils::test_helpers::simple_req,
295 };
296
297 #[tokio::test]
298 async fn test_client_ip() {
299 async fn handler(ClientIp(client_ip): ClientIp) -> String {
300 client_ip.unwrap().to_string()
301 }
302
303 let route: Route<&str> = Route::new(get(handler));
304 let service = ClientIpLayer::new()
305 .with_config(
306 ClientIpConfig::default().with_trusted_cidrs(vec!["10.0.0.0/8".parse().unwrap()]),
307 )
308 .layer(route);
309
310 let mut cx = ServerContext::new(Address::from(
311 SocketAddr::from_str("10.0.0.1:8080").unwrap(),
312 ));
313
314 let req = simple_req(Method::GET, "/", "");
316 let resp = service.call(&mut cx, req).await.unwrap();
317 assert_eq!("10.0.0.1", resp.into_string().await.unwrap());
318
319 let mut req = simple_req(Method::GET, "/", "");
321 req.headers_mut()
322 .insert("X-Real-IP", HeaderValue::from_static("10.0.0.2"));
323 let resp = service.call(&mut cx, req).await.unwrap();
324 assert_eq!("10.0.0.2", resp.into_string().await.unwrap());
325
326 let mut req = simple_req(Method::GET, "/", "");
327 req.headers_mut()
328 .insert("X-Forwarded-For", HeaderValue::from_static("10.0.1.0"));
329 let resp = service.call(&mut cx, req).await.unwrap();
330 assert_eq!("10.0.1.0", resp.into_string().await.unwrap());
331
332 let mut req = simple_req(Method::GET, "/", "");
334 req.headers_mut()
335 .insert("X-Real-IP", HeaderValue::from_static("11.0.0.1"));
336 let resp = service.call(&mut cx, req).await.unwrap();
337 assert_eq!("10.0.0.1", resp.into_string().await.unwrap());
338 }
339}