volo_http/server/utils/
client_ip.rs

1//! Utilities for extracting original client ip
2//!
3//! See [`ClientIp`] for more details.
4use std::{
5    net::{IpAddr, Ipv4Addr, Ipv6Addr},
6    str::FromStr,
7};
8
9use http::{HeaderMap, HeaderName};
10use ipnet::{IpNet, Ipv4Net, Ipv6Net};
11use motore::{Service, layer::Layer};
12use volo::{context::Context, net::Address};
13
14use crate::{context::ServerContext, request::Request};
15
16/// [`Layer`] for extracting client ip
17///
18/// See [`ClientIp`] for more details.
19#[derive(Clone, Debug, Default)]
20pub struct ClientIpLayer {
21    config: ClientIpConfig,
22}
23
24impl ClientIpLayer {
25    /// Create a new [`ClientIpLayer`] with default config
26    pub fn new() -> Self {
27        Default::default()
28    }
29
30    /// Create a new [`ClientIpLayer`] with the given [`ClientIpConfig`]
31    pub fn with_config(self, config: ClientIpConfig) -> Self {
32        Self { config }
33    }
34}
35
36impl<S> Layer<S> for ClientIpLayer
37where
38    S: Send + Sync + 'static,
39{
40    type Service = ClientIpService<S>;
41
42    fn layer(self, inner: S) -> Self::Service {
43        ClientIpService {
44            service: inner,
45            config: self.config,
46        }
47    }
48}
49
50/// Config for extract client ip
51#[derive(Clone, Debug)]
52pub struct ClientIpConfig {
53    remote_ip_headers: Vec<HeaderName>,
54    trusted_cidrs: Vec<IpNet>,
55}
56
57impl Default for ClientIpConfig {
58    fn default() -> Self {
59        Self {
60            remote_ip_headers: vec![
61                HeaderName::from_static("x-real-ip"),
62                HeaderName::from_static("x-forwarded-for"),
63            ],
64            trusted_cidrs: vec![
65                IpNet::V4(Ipv4Net::new_assert(Ipv4Addr::new(0, 0, 0, 0), 0)),
66                IpNet::V6(Ipv6Net::new_assert(
67                    Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
68                    0,
69                )),
70            ],
71        }
72    }
73}
74
75impl ClientIpConfig {
76    /// Create a new [`ClientIpConfig`] with default values
77    ///
78    /// default remote ip headers: `["X-Real-IP", "X-Forwarded-For"]`
79    ///
80    /// default trusted cidrs: `["0.0.0.0/0", "::/0"]`
81    pub fn new() -> Self {
82        Default::default()
83    }
84
85    /// Get Real Client IP by parsing the given headers.
86    ///
87    /// See [`ClientIp`] for more details.
88    ///
89    /// # Example
90    ///
91    /// ```rust
92    /// use volo_http::server::utils::client_ip::ClientIpConfig;
93    ///
94    /// let client_ip_config =
95    ///     ClientIpConfig::new().with_remote_ip_headers(vec!["X-Real-IP", "X-Forwarded-For"]);
96    /// ```
97    pub fn with_remote_ip_headers<I>(
98        self,
99        headers: I,
100    ) -> Result<Self, http::header::InvalidHeaderName>
101    where
102        I: IntoIterator,
103        I::Item: AsRef<str>,
104    {
105        let headers = headers.into_iter().collect::<Vec<_>>();
106        let mut remote_ip_headers = Vec::with_capacity(headers.len());
107        for header_str in headers {
108            let header_value = HeaderName::from_str(header_str.as_ref())?;
109            remote_ip_headers.push(header_value);
110        }
111
112        Ok(Self {
113            remote_ip_headers,
114            trusted_cidrs: self.trusted_cidrs,
115        })
116    }
117
118    /// Get Real Client IP if it is trusted, otherwise it will just return caller ip.
119    ///
120    /// See [`ClientIp`] for more details.
121    ///
122    /// # Example
123    ///
124    /// ```rust
125    /// use volo_http::server::utils::client_ip::ClientIpConfig;
126    ///
127    /// let client_ip_config = ClientIpConfig::new()
128    ///     .with_trusted_cidrs(vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]);
129    /// ```
130    pub fn with_trusted_cidrs<H>(self, cidrs: H) -> Self
131    where
132        H: IntoIterator<Item = IpNet>,
133    {
134        Self {
135            remote_ip_headers: self.remote_ip_headers,
136            trusted_cidrs: cidrs.into_iter().collect(),
137        }
138    }
139}
140
141/// Return original client IP Address
142///
143/// If you want to get client IP by retrieving specific headers, you can use
144/// [`with_remote_ip_headers`](ClientIpConfig::with_remote_ip_headers) to set the
145/// headers.
146///
147/// If you want to get client IP that is trusted with specific cidrs, you can use
148/// [`with_trusted_cidrs`](ClientIpConfig::with_trusted_cidrs) to set the cidrs.
149///
150/// # Example
151///
152/// ## Default config
153///
154/// default remote ip headers: `["X-Real-IP", "X-Forwarded-For"]`
155///
156/// default trusted cidrs: `["0.0.0.0/0", "::/0"]`
157///
158/// ```rust
159/// ///
160/// use volo_http::server::utils::client_ip::ClientIp;
161/// use volo_http::server::{
162///     Server,
163///     route::{Router, get},
164///     utils::client_ip::{ClientIpConfig, ClientIpLayer},
165/// };
166///
167/// async fn handler(ClientIp(client_ip): ClientIp) -> String {
168///     client_ip.unwrap().to_string()
169/// }
170///
171/// let router: Router = Router::new()
172///     .route("/", get(handler))
173///     .layer(ClientIpLayer::new());
174/// ```
175///
176/// ## With custom config
177///
178/// ```rust
179/// use http::HeaderMap;
180/// use volo_http::{
181///     context::ServerContext,
182///     server::{
183///         Server,
184///         route::{Router, get},
185///         utils::client_ip::{ClientIp, ClientIpConfig, ClientIpLayer},
186///     },
187/// };
188///
189/// async fn handler(ClientIp(client_ip): ClientIp) -> String {
190///     client_ip.unwrap().to_string()
191/// }
192///
193/// let router: Router = Router::new().route("/", get(handler)).layer(
194///     ClientIpLayer::new().with_config(
195///         ClientIpConfig::new()
196///             .with_remote_ip_headers(vec!["x-real-ip", "x-forwarded-for"])
197///             .unwrap()
198///             .with_trusted_cidrs(vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]),
199///     ),
200/// );
201/// ```
202#[derive(Clone, Debug, PartialEq, Eq)]
203pub struct ClientIp(pub Option<IpAddr>);
204
205/// [`ClientIpLayer`] generated [`Service`]
206///
207/// See [`ClientIp`] for more details.
208#[derive(Clone, Debug)]
209pub struct ClientIpService<S> {
210    service: S,
211    config: ClientIpConfig,
212}
213
214impl<S> ClientIpService<S> {
215    fn get_client_ip(&self, cx: &ServerContext, headers: &HeaderMap) -> ClientIp {
216        let remote_ip = match &cx.rpc_info().caller().address {
217            Some(Address::Ip(socket_addr)) => Some(socket_addr.ip()),
218            #[cfg(target_family = "unix")]
219            Some(Address::Unix(_)) => None,
220            None => return ClientIp(None),
221        };
222
223        if let Some(remote_ip) = &remote_ip {
224            if !self
225                .config
226                .trusted_cidrs
227                .iter()
228                .any(|cidr| cidr.contains(remote_ip))
229            {
230                return ClientIp(None);
231            }
232        }
233
234        for remote_ip_header in self.config.remote_ip_headers.iter() {
235            let Some(remote_ips) = headers.get(remote_ip_header).and_then(|v| v.to_str().ok())
236            else {
237                continue;
238            };
239            for remote_ip in remote_ips.split(',').map(str::trim) {
240                if let Ok(remote_ip_addr) = IpAddr::from_str(remote_ip) {
241                    if self
242                        .config
243                        .trusted_cidrs
244                        .iter()
245                        .any(|cidr| cidr.contains(&remote_ip_addr))
246                    {
247                        return ClientIp(Some(remote_ip_addr));
248                    }
249                }
250            }
251        }
252
253        ClientIp(remote_ip)
254    }
255}
256
257impl<S, B> Service<ServerContext, Request<B>> for ClientIpService<S>
258where
259    S: Service<ServerContext, Request<B>> + Send + Sync + 'static,
260    B: Send,
261{
262    type Response = S::Response;
263    type Error = S::Error;
264
265    async fn call(
266        &self,
267        cx: &mut ServerContext,
268        req: Request<B>,
269    ) -> Result<Self::Response, Self::Error> {
270        let client_ip = self.get_client_ip(cx, req.headers());
271        cx.extensions_mut().insert(client_ip);
272
273        self.service.call(cx, req).await
274    }
275}
276
277#[cfg(test)]
278mod client_ip_tests {
279    use std::{net::SocketAddr, str::FromStr};
280
281    use http::{HeaderValue, Method};
282    use motore::{Service, layer::Layer};
283    use volo::net::Address;
284
285    use crate::{
286        body::BodyConversion,
287        context::ServerContext,
288        server::{
289            route::{Route, get},
290            utils::client_ip::{ClientIp, ClientIpConfig, ClientIpLayer},
291        },
292        utils::test_helpers::simple_req,
293    };
294
295    #[tokio::test]
296    async fn test_client_ip() {
297        async fn handler(ClientIp(client_ip): ClientIp) -> String {
298            client_ip.unwrap().to_string()
299        }
300
301        let route: Route<&str> = Route::new(get(handler));
302        let service = ClientIpLayer::new()
303            .with_config(
304                ClientIpConfig::default().with_trusted_cidrs(vec!["10.0.0.0/8".parse().unwrap()]),
305            )
306            .layer(route);
307
308        let mut cx = ServerContext::new(Address::from(
309            SocketAddr::from_str("10.0.0.1:8080").unwrap(),
310        ));
311
312        // Test case 1: no remote ip header
313        let req = simple_req(Method::GET, "/", "");
314        let resp = service.call(&mut cx, req).await.unwrap();
315        assert_eq!("10.0.0.1", resp.into_string().await.unwrap());
316
317        // Test case 2: with remote ip header
318        let mut req = simple_req(Method::GET, "/", "");
319        req.headers_mut()
320            .insert("X-Real-IP", HeaderValue::from_static("10.0.0.2"));
321        let resp = service.call(&mut cx, req).await.unwrap();
322        assert_eq!("10.0.0.2", resp.into_string().await.unwrap());
323
324        let mut req = simple_req(Method::GET, "/", "");
325        req.headers_mut()
326            .insert("X-Forwarded-For", HeaderValue::from_static("10.0.1.0"));
327        let resp = service.call(&mut cx, req).await.unwrap();
328        assert_eq!("10.0.1.0", resp.into_string().await.unwrap());
329
330        // Test case 3: with untrusted remote ip
331        let mut req = simple_req(Method::GET, "/", "");
332        req.headers_mut()
333            .insert("X-Real-IP", HeaderValue::from_static("11.0.0.1"));
334        let resp = service.call(&mut cx, req).await.unwrap();
335        assert_eq!("10.0.0.1", resp.into_string().await.unwrap());
336    }
337}