tower_ipfilter/
connection_info_service.rs1use std::{
2 net::IpAddr,
3 task::{Context, Poll},
4};
5use http::Request;
6use tower::{Layer, Service};
7
8#[derive(Clone, Debug)]
9pub struct AddConnectionInfo<S> {
10 inner: S,
11}
12
13impl<S> AddConnectionInfo<S> {
14 pub fn new(inner: S) -> Self {
15 Self { inner }
16 }
17}
18
19impl<S, B> Service<Request<B>> for AddConnectionInfo<S>
20where
21 S: Service<Request<B>>,
22{
23 type Response = S::Response;
24 type Error = S::Error;
25 type Future = S::Future;
26
27 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
28 self.inner.poll_ready(cx)
29 }
30
31 fn call(&mut self, mut req: Request<B>) -> Self::Future {
32 if let Some(ip_addr) = extract_ip(&req) {
33 req.extensions_mut().insert(ConnectionInfo { ip_addr });
34 }
35 self.inner.call(req)
36 }
37}
38
39
40fn extract_ip<B>(req: &Request<B>) -> Option<IpAddr> {
41 cfg_if::cfg_if! {
42 if #[cfg(feature = "axum")] {
43 use axum_impl::extract_ip_axum;
44 return extract_ip_axum(&req)
45 } else if #[cfg(feature = "hyper")] {
46 use hyper_impl::extract_ip_hyper;
47 return extract_ip_hyper(&req)
48 } else {
49 panic!("Either axum or hyper feature must be enabled")
50 }
51 };
52}
53
54#[derive(Clone, Copy, Debug)]
55pub struct AddConnectionInfoLayer;
56
57impl<S: Clone> Layer<S> for AddConnectionInfoLayer {
58 type Service = AddConnectionInfo<S>;
59
60 fn layer(&self, service: S) -> Self::Service {
61 AddConnectionInfo::new(service)
62 }
63}
64
65#[derive(Clone, Debug)]
66pub struct ConnectionInfo {
67 pub ip_addr: IpAddr,
68}
69
70#[cfg(feature = "axum")]
71mod axum_impl {
72 use super::*;
73 use axum::extract::connect_info::ConnectInfo;
74 use std::net::SocketAddr;
75
76 pub fn extract_ip_axum<B>(req: &Request<B>) -> Option<IpAddr> {
77 let headers_to_check = [
78 "CF-Connecting-IP",
79 "True-Client-IP",
80 "X-Real-IP",
81 "X-Forwarded-For",
82 ];
83
84 for header in headers_to_check.iter() {
85 if let Some(ip) = req
86 .headers()
87 .get(*header)
88 .and_then(|hv| hv.to_str().ok())
89 .and_then(|s| s.split(',').next())
90 .and_then(|s| s.trim().parse().ok())
91 {
92 return Some(ip);
93 }
94 }
95
96 req.extensions()
97 .get::<ConnectInfo<SocketAddr>>()
98 .map(|socket_addr| socket_addr.ip())
99
100 }
101}
102
103#[cfg(feature = "hyper")]
104mod hyper_impl {
105 use super::*;
106
107 pub fn extract_ip_hyper<B>(req: &Request<B>) -> Option<IpAddr> {
108 let headers_to_check = [
109 "CF-Connecting-IP",
110 "True-Client-IP",
111 "X-Real-IP",
112 "X-Forwarded-For",
113 ];
114
115 for header in headers_to_check.iter() {
116 if let Some(ip) = req
117 .headers()
118 .get(*header)
119 .and_then(|hv| hv.to_str().ok())
120 .and_then(|s| s.split(',').next())
121 .and_then(|s| s.trim().parse().ok())
122 {
123 return Some(ip);
124 }
125 }
126
127 req.uri().host().and_then(|host| host.parse().ok())
128 }
129}
130
131#[cfg(feature = "axum")]
132pub use axum_impl::extract_ip_axum;
133
134#[cfg(feature = "hyper")]
135pub use hyper_impl::extract_ip_hyper;