tower_ipfilter/
connection_info_service.rs

1use std::{
2    net::IpAddr,
3    task::{Context, Poll},
4};
5use http::Request;
6use tower::{Layer, Service};
7
8#[derive(Clone, Debug)]
9pub struct AddConnectionInfo<S> {
10    inner: S,
11}
12
13impl<S> AddConnectionInfo<S> {
14    pub fn new(inner: S) -> Self {
15        Self { inner }
16    }
17}
18
19impl<S, B> Service<Request<B>> for AddConnectionInfo<S>
20where
21    S: Service<Request<B>>,
22{
23    type Response = S::Response;
24    type Error = S::Error;
25    type Future = S::Future;
26
27    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
28        self.inner.poll_ready(cx)
29    }
30
31    fn call(&mut self, mut req: Request<B>) -> Self::Future {
32        if let Some(ip_addr) = extract_ip(&req) {
33            req.extensions_mut().insert(ConnectionInfo { ip_addr });
34        }
35        self.inner.call(req)
36    }
37}
38
39
40fn extract_ip<B>(req: &Request<B>) -> Option<IpAddr> {
41    cfg_if::cfg_if! {
42            if #[cfg(feature = "axum")] {
43                use axum_impl::extract_ip_axum;
44                return extract_ip_axum(&req)
45            } else if #[cfg(feature = "hyper")] {
46                use hyper_impl::extract_ip_hyper;
47                return extract_ip_hyper(&req)
48            } else {
49                panic!("Either axum or hyper feature must be enabled")
50            }
51        };
52}
53
54#[derive(Clone, Copy, Debug)]
55pub struct AddConnectionInfoLayer;
56
57impl<S: Clone> Layer<S> for AddConnectionInfoLayer {
58    type Service = AddConnectionInfo<S>;
59
60    fn layer(&self, service: S) -> Self::Service {
61        AddConnectionInfo::new(service)
62    }
63}
64
65#[derive(Clone, Debug)]
66pub struct ConnectionInfo {
67    pub ip_addr: IpAddr,
68}
69
70#[cfg(feature = "axum")]
71mod axum_impl {
72    use super::*;
73    use axum::extract::connect_info::ConnectInfo;
74    use std::net::SocketAddr;
75
76    pub fn extract_ip_axum<B>(req: &Request<B>) -> Option<IpAddr> {
77        let headers_to_check = [
78            "CF-Connecting-IP",
79            "True-Client-IP",
80            "X-Real-IP",
81            "X-Forwarded-For",
82        ];
83
84        for header in headers_to_check.iter() {
85            if let Some(ip) = req
86                .headers()
87                .get(*header)
88                .and_then(|hv| hv.to_str().ok())
89                .and_then(|s| s.split(',').next())
90                .and_then(|s| s.trim().parse().ok())
91            {
92                return Some(ip);
93            }
94        }
95
96        req.extensions()
97            .get::<ConnectInfo<SocketAddr>>()
98            .map(|socket_addr| socket_addr.ip())
99
100    }
101}
102
103#[cfg(feature = "hyper")]
104mod hyper_impl {
105    use super::*;
106
107    pub fn extract_ip_hyper<B>(req: &Request<B>) -> Option<IpAddr> {
108        let headers_to_check = [
109            "CF-Connecting-IP",
110            "True-Client-IP",
111            "X-Real-IP",
112            "X-Forwarded-For",
113        ];
114
115        for header in headers_to_check.iter() {
116            if let Some(ip) = req
117                .headers()
118                .get(*header)
119                .and_then(|hv| hv.to_str().ok())
120                .and_then(|s| s.split(',').next())
121                .and_then(|s| s.trim().parse().ok())
122            {
123                return Some(ip);
124            }
125        }
126
127        req.uri().host().and_then(|host| host.parse().ok())
128    }
129}
130
131#[cfg(feature = "axum")]
132pub use axum_impl::extract_ip_axum;
133
134#[cfg(feature = "hyper")]
135pub use hyper_impl::extract_ip_hyper;