tower_ipfilter/
connection_info_service.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
use core::panic;
use std::{
    net::IpAddr,
    task::{Context, Poll},
};
use http::Request;
use tower::{Layer, Service};

#[derive(Clone, Debug)]
pub struct AddConnectionInfo<S> {
    inner: S,
}

impl<S> AddConnectionInfo<S> {
    pub fn new(inner: S) -> Self {
        Self { inner }
    }
}

impl<S, B> Service<Request<B>> for AddConnectionInfo<S>
where
    S: Service<Request<B>>,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = S::Future;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, mut req: Request<B>) -> Self::Future {
        if let Some(ip_addr) = extract_ip(&req) {
            req.extensions_mut().insert(ConnectionInfo { ip_addr });
        }
        self.inner.call(req)
    }
}


fn extract_ip<B>(req: &Request<B>) -> Option<IpAddr> {
    cfg_if::cfg_if! {
            if #[cfg(feature = "axum")] {
                use axum_impl::extract_ip_axum;
                return extract_ip_axum(&req)
            } else if #[cfg(feature = "hyper")] {
                use hyper_impl::extract_ip_hyper;
                return extract_ip_hyper(&req)
            } else {
                panic!("Either axum or hyper feature must be enabled")
            }
        };
}

#[derive(Clone, Copy, Debug)]
pub struct AddConnectionInfoLayer;

impl<S: Clone> Layer<S> for AddConnectionInfoLayer {
    type Service = AddConnectionInfo<S>;

    fn layer(&self, service: S) -> Self::Service {
        AddConnectionInfo::new(service)
    }
}

#[derive(Clone, Debug)]
pub struct ConnectionInfo {
    pub ip_addr: IpAddr,
}

#[cfg(feature = "axum")]
mod axum_impl {
    use super::*;
    use axum::extract::connect_info::ConnectInfo;
    use tracing::debug;
    use std::net::SocketAddr;

    pub fn extract_ip_axum<B>(req: &Request<B>) -> Option<IpAddr> {
        let headers_to_check = [
            "CF-Connecting-IP",
            "True-Client-IP",
            "X-Real-IP",
            "X-Forwarded-For",
        ];

        for header in headers_to_check.iter() {
            if let Some(ip) = req
                .headers()
                .get(*header)
                .and_then(|hv| hv.to_str().ok())
                .and_then(|s| s.split(',').next())
                .and_then(|s| s.trim().parse().ok())
            {
                return Some(ip);
            }
        }

        req.extensions()
            .get::<ConnectInfo<SocketAddr>>()
            .map(|socket_addr| socket_addr.ip())

    }
}

#[cfg(feature = "hyper")]
mod hyper_impl {
    use super::*;

    pub fn extract_ip_hyper<B>(req: &Request<B>) -> Option<IpAddr> {
        let headers_to_check = [
            "CF-Connecting-IP",
            "True-Client-IP",
            "X-Real-IP",
            "X-Forwarded-For",
        ];

        for header in headers_to_check.iter() {
            if let Some(ip) = req
                .headers()
                .get(*header)
                .and_then(|hv| hv.to_str().ok())
                .and_then(|s| s.split(',').next())
                .and_then(|s| s.trim().parse().ok())
            {
                return Some(ip);
            }
        }

        req.uri().host().and_then(|host| host.parse().ok())
    }
}

#[cfg(feature = "axum")]
pub use axum_impl::extract_ip_axum;

#[cfg(feature = "hyper")]
pub use hyper_impl::extract_ip_hyper;
use tracing::debug;