tower_real_ip/lib.rs
1//! # Tower Real IP
2//!
3//! A robust middleware for extracting the real client IP address from HTTP requests,
4//! designed for environments behind trusted proxies (Load Balancers, CDNs, Nginx).
5//!
6//! ## Features
7//! - Supports `X-Forwarded-For` parsing (Right-to-Left security traversal).
8//! - Supports CIDR ranges (IPv4 & IPv6).
9//! - Auto-configuration from Environment Variables (split by `;`).
10//! - Axum 0.8 Extractor support.
11
12use axum::extract::{ConnectInfo, FromRequestParts};
13use http::{request::Parts, Request, Response};
14use ipnetwork::IpNetwork;
15use std::{
16 env,
17 future::Future,
18 net::{IpAddr, SocketAddr},
19 pin::Pin,
20 str::FromStr,
21 sync::Arc,
22 task::{Context, Poll},
23};
24use tower::{Layer, Service};
25use tracing::{debug, warn};
26
27// ============================================================================
28// 1. Configuration Logic
29// ============================================================================
30
31/// Configuration holding the list of trusted networks.
32#[derive(Clone, Debug)]
33pub struct TrustedProxyConfig {
34 trusted_networks: Arc<Vec<IpNetwork>>,
35}
36
37impl TrustedProxyConfig {
38 /// Creates a new config from a list of IP networks.
39 pub fn new(networks: Vec<IpNetwork>) -> Self {
40 Self {
41 trusted_networks: Arc::new(networks),
42 }
43 }
44
45 /// Loads configuration from an environment variable.
46 ///
47 /// Expected format: "127.0.0.1;10.0.0.0/8;::1"
48 pub fn from_env(env_key: &str) -> Result<Self, String> {
49 let val =
50 env::var(env_key).map_err(|_| format!("Environment variable {} not found", env_key))?;
51 Self::parse_str(&val)
52 }
53
54 /// Parses a string separated by `;` into trusted networks.
55 pub fn parse_str(input: &str) -> Result<Self, String> {
56 let mut networks = Vec::new();
57 for part in input.split(';') {
58 let part = part.trim();
59 if part.is_empty() {
60 continue;
61 }
62
63 // Try parsing as CIDR first, then as single IP
64 match part.parse::<IpNetwork>() {
65 Ok(net) => networks.push(net),
66 Err(_) => match part.parse::<IpAddr>() {
67 Ok(ip) => networks.push(IpNetwork::from(ip)),
68 Err(_) => return Err(format!("Invalid IP or CIDR: {}", part)),
69 },
70 }
71 }
72
73 debug!("Loaded {} trusted proxy networks", networks.len());
74 Ok(Self::new(networks))
75 }
76
77 /// Checks if an IP is trusted.
78 pub fn is_trusted(&self, ip: &IpAddr) -> bool {
79 self.trusted_networks.iter().any(|net| net.contains(*ip))
80 }
81}
82
83// ============================================================================
84// 2. The Result Struct (What the user gets)
85// ============================================================================
86
87/// The resolved real IP address of the client.
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub struct RealIp(pub IpAddr);
90
91// ============================================================================
92// 3. Tower Middleware Implementation
93// ============================================================================
94
95#[derive(Clone)]
96pub struct RealIpLayer {
97 config: TrustedProxyConfig,
98}
99
100impl RealIpLayer {
101 pub fn new(config: TrustedProxyConfig) -> Self {
102 Self { config }
103 }
104}
105
106impl<S> Layer<S> for RealIpLayer {
107 type Service = RealIpService<S>;
108
109 fn layer(&self, inner: S) -> Self::Service {
110 RealIpService {
111 inner,
112 config: self.config.clone(),
113 }
114 }
115}
116
117#[derive(Clone)]
118pub struct RealIpService<S> {
119 inner: S,
120 config: TrustedProxyConfig,
121}
122
123impl<S, B> Service<Request<B>> for RealIpService<S>
124where
125 S: Service<Request<B>, Response = Response<B>> + Send + Clone + 'static,
126 S::Future: Send + 'static,
127 B: Send + 'static,
128{
129 type Response = S::Response;
130 type Error = S::Error;
131 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
132
133 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
134 self.inner.poll_ready(cx)
135 }
136
137 fn call(&mut self, mut req: Request<B>) -> Self::Future {
138 // 1. Extract the direct connection IP (Peer Address)
139 // Axum/Tower usually provides this via ConnectInfo extension
140 let remote_addr = req
141 .extensions()
142 .get::<ConnectInfo<SocketAddr>>()
143 .map(|ci| ci.0.ip());
144
145 let config = self.config.clone();
146 let headers = req.headers().clone(); // Clone headers to use in async block
147
148 let mut inner = self.inner.clone();
149
150 Box::pin(async move {
151 let mut resolved_ip = remote_addr.unwrap_or_else(|| {
152 // Fallback if no underlying TCP info is present (shouldn't happen in normal HTTP serving)
153 IpAddr::from([0, 0, 0, 0])
154 });
155
156 // 2. The Core Algorithm: Trusted Proxy Traversal
157 if let Some(peer_ip) = remote_addr {
158 // Only attempt to parse headers if the direct peer is trusted
159 if config.is_trusted(&peer_ip)
160 && let Some(xff_val) = headers.get("x-forwarded-for")
161 && let Ok(xff_str) = xff_val.to_str()
162 {
163 // Parse the comma-separated list
164 // List: Client, Proxy1, Proxy2
165 // We reverse iterate: Proxy2 -> Proxy1 -> Client
166 let ips: Vec<&str> = xff_str.split(',').map(|s| s.trim()).collect();
167
168 for ip_str in ips.iter().rev() {
169 if let Ok(ip) = IpAddr::from_str(ip_str) {
170 if !config.is_trusted(&ip) {
171 // Found the first untrusted IP (looking backwards)
172 // This is the Client.
173 resolved_ip = ip;
174 break;
175 }
176 // If trusted, continue strictly to the left
177 } else {
178 warn!("Skipping invalid IP in X-Forwarded-For: {}", ip_str);
179 }
180 }
181 // Edge case: If all IPs in header are trusted, the loop finishes.
182 // The `resolved_ip` remains the last trusted one (or peer),
183 // but technically if strictly all are trusted, the request originates
184 // from your internal network. We keep the peer or last logic.
185 }
186 }
187
188 // 3. Inject the result into extensions
189 req.extensions_mut().insert(RealIp(resolved_ip));
190
191 // 4. Forward request
192 inner.call(req).await
193 })
194 }
195}
196
197// ============================================================================
198// 4. Axum Extractor Support
199// ============================================================================
200
201/// Allows using `RealIp` directly in Axum handlers arguments.
202///
203/// Example:
204/// `async fn handler(RealIp(ip): RealIp) -> ...`
205impl<S> FromRequestParts<S> for RealIp
206where
207 S: Send + Sync,
208{
209 type Rejection = (http::StatusCode, &'static str);
210
211 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
212 parts.extensions.get::<RealIp>().cloned().ok_or((
213 http::StatusCode::INTERNAL_SERVER_ERROR,
214 "RealIp middleware is not configured correctly. Missing RealIp extension.",
215 ))
216 }
217}