1use std::{collections::HashSet, net::IpAddr};
2
3use http::HeaderMap;
4use ipnet::IpNet;
5use rfc7239::parse as parse_forwarded;
6
7use crate::{
8 config::{
9 ChainDirection, FallbackStrategy, HeaderInputConfig, HeaderMode, RealIpResolveConfig,
10 },
11 error::RealIpResult,
12 extension::ProviderFactoryRegistry,
13 providers::{ProviderRegistry, ProviderSnapshot},
14};
15
16#[derive(Debug, Clone, Default)]
17pub struct TransportContext {
18 pub proxy_protocol_addr: Option<IpAddr>,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ResolvedSourceKind {
23 Transport,
24 Header,
25 Fallback,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct ResolvedClientIp {
30 pub client_ip: IpAddr,
31 pub peer_ip: IpAddr,
32 pub source_name: Option<String>,
33 pub source_kind: ResolvedSourceKind,
34 pub header_name: Option<String>,
35}
36
37#[derive(Debug, Clone)]
38struct CompiledSource {
39 name: String,
40 priority: i32,
41 peer_cidrs: Vec<IpNet>,
42 accept_transport: Vec<String>,
43 accept_headers: Vec<HeaderInputConfig>,
44}
45
46pub struct RealIpResolver {
47 config: RealIpResolveConfig,
48 providers: ProviderRegistry,
49}
50
51impl RealIpResolver {
52 pub async fn from_config(config: RealIpResolveConfig) -> RealIpResult<Self> {
53 let factories = ProviderFactoryRegistry::with_builtin_providers()?;
54 Self::from_config_with_factories(config, &factories).await
55 }
56
57 pub async fn from_config_with_factories(
58 config: RealIpResolveConfig,
59 factories: &ProviderFactoryRegistry,
60 ) -> RealIpResult<Self> {
61 config.validate()?;
62 let providers =
63 ProviderRegistry::from_configs_with_factories(&config.providers, factories).await?;
64 Ok(Self { config, providers })
65 }
66
67 pub async fn resolve(
68 &self,
69 peer_ip: IpAddr,
70 headers: &HeaderMap,
71 transport: &TransportContext,
72 ) -> ResolvedClientIp {
73 let compiled_sources = self.compile_sources().await;
74 let trusted_peers = self.providers.all_cidrs().await;
75 let trusted_set = TrustedSet::new(trusted_peers);
76
77 for source in compiled_sources {
78 if !source.matches_peer(peer_ip) {
79 continue;
80 }
81
82 if let Some(result) = source.resolve_transport(peer_ip, transport) {
83 return result;
84 }
85
86 if let Some(result) = source.resolve_headers(peer_ip, headers, &trusted_set) {
87 return result;
88 }
89 }
90
91 match self.config.fallback.strategy {
92 FallbackStrategy::RemoteAddr => ResolvedClientIp {
93 client_ip: peer_ip,
94 peer_ip,
95 source_name: None,
96 source_kind: ResolvedSourceKind::Fallback,
97 header_name: None,
98 },
99 }
100 }
101
102 async fn compile_sources(&self) -> Vec<CompiledSource> {
103 let mut compiled = Vec::with_capacity(self.config.sources.len());
104 for source in &self.config.sources {
105 let mut peer_cidrs = Vec::new();
106 for provider_name in &source.peers_from {
107 if let Some(ProviderSnapshot { cidrs, .. }) =
108 self.providers.snapshot(provider_name).await
109 {
110 peer_cidrs.extend(cidrs.iter().copied());
111 }
112 }
113
114 compiled.push(CompiledSource {
115 name: source.name.clone(),
116 priority: source.priority,
117 peer_cidrs,
118 accept_transport: source
119 .accept_transport
120 .iter()
121 .map(|item| item.kind.to_ascii_lowercase())
122 .collect(),
123 accept_headers: source.accept_headers.clone(),
124 });
125 }
126
127 compiled.sort_by_key(|right| std::cmp::Reverse(right.priority));
128 compiled
129 }
130}
131
132impl CompiledSource {
133 fn matches_peer(&self, peer_ip: IpAddr) -> bool {
134 self.peer_cidrs.iter().any(|cidr| cidr.contains(&peer_ip))
135 }
136
137 fn resolve_transport(
138 &self,
139 peer_ip: IpAddr,
140 transport: &TransportContext,
141 ) -> Option<ResolvedClientIp> {
142 if self
143 .accept_transport
144 .iter()
145 .any(|kind| kind == "proxy-protocol")
146 && let Some(proxy_ip) = transport.proxy_protocol_addr
147 {
148 return Some(ResolvedClientIp {
149 client_ip: proxy_ip,
150 peer_ip,
151 source_name: Some(self.name.clone()),
152 source_kind: ResolvedSourceKind::Transport,
153 header_name: Some("proxy-protocol".to_string()),
154 });
155 }
156
157 None
158 }
159
160 fn resolve_headers(
161 &self,
162 peer_ip: IpAddr,
163 headers: &HeaderMap,
164 trusted_set: &TrustedSet,
165 ) -> Option<ResolvedClientIp> {
166 for header in &self.accept_headers {
167 let kind = header.kind.to_ascii_lowercase();
168 let candidate = match header.mode {
169 HeaderMode::Single => resolve_single_header(headers, &kind),
170 HeaderMode::Recursive => resolve_chain_header(headers, &kind, header, trusted_set),
171 };
172
173 let candidate = match candidate {
174 Some(ip) => ip,
175 None => continue,
176 };
177
178 if header.use_only_if_not_in_trusted_peers && trusted_set.contains(candidate) {
179 continue;
180 }
181
182 return Some(ResolvedClientIp {
183 client_ip: candidate,
184 peer_ip,
185 source_name: Some(self.name.clone()),
186 source_kind: ResolvedSourceKind::Header,
187 header_name: Some(kind),
188 });
189 }
190
191 None
192 }
193}
194
195#[derive(Debug, Clone)]
196struct TrustedSet {
197 cidrs: Vec<IpNet>,
198}
199
200impl TrustedSet {
201 fn new(cidrs: Vec<IpNet>) -> Self {
202 let mut unique = HashSet::new();
203 let mut deduped = Vec::new();
204 for cidr in cidrs {
205 if unique.insert(cidr) {
206 deduped.push(cidr);
207 }
208 }
209 Self { cidrs: deduped }
210 }
211
212 fn contains(&self, ip: IpAddr) -> bool {
213 self.cidrs.iter().any(|cidr| cidr.contains(&ip))
214 }
215}
216
217fn resolve_single_header(headers: &HeaderMap, kind: &str) -> Option<IpAddr> {
218 headers
219 .get(kind)
220 .and_then(|value| value.to_str().ok())
221 .map(str::trim)
222 .filter(|value| !value.is_empty())
223 .and_then(|value| value.parse::<IpAddr>().ok())
224}
225
226fn resolve_chain_header(
227 headers: &HeaderMap,
228 kind: &str,
229 config: &HeaderInputConfig,
230 trusted_set: &TrustedSet,
231) -> Option<IpAddr> {
232 let chain = match kind {
233 "x-forwarded-for" => parse_x_forwarded_for(headers),
234 "forwarded" => parse_forwarded_for(headers, config.param.as_deref().unwrap_or("for")),
235 _ => Vec::new(),
236 };
237
238 resolve_from_chain(&chain, trusted_set, config.direction)
239}
240
241fn parse_x_forwarded_for(headers: &HeaderMap) -> Vec<IpAddr> {
242 headers
243 .get("x-forwarded-for")
244 .and_then(|value| value.to_str().ok())
245 .into_iter()
246 .flat_map(|value| value.split(','))
247 .map(str::trim)
248 .filter(|value| !value.is_empty())
249 .filter_map(|value| value.parse::<IpAddr>().ok())
250 .collect()
251}
252
253fn parse_forwarded_for(headers: &HeaderMap, param: &str) -> Vec<IpAddr> {
254 let Some(raw) = headers
255 .get(http::header::FORWARDED)
256 .and_then(|value| value.to_str().ok())
257 else {
258 return Vec::new();
259 };
260
261 let mut result = Vec::new();
262 for node in parse_forwarded(raw).flatten() {
263 if param != "for" {
264 continue;
265 }
266 let Some(value) = node.forwarded_for.map(|value| value.to_string()) else {
267 continue;
268 };
269 if let Some(ip) = parse_forwarded_ip(&value) {
270 result.push(ip);
271 }
272 }
273 result
274}
275
276fn parse_forwarded_ip(value: &str) -> Option<IpAddr> {
277 let trimmed = value.trim().trim_matches('"');
278 let without_brackets = trimmed
279 .strip_prefix('[')
280 .and_then(|value| value.strip_suffix(']'))
281 .unwrap_or(trimmed);
282
283 if let Ok(ip) = without_brackets.parse::<IpAddr>() {
284 return Some(ip);
285 }
286
287 if let Some((host, _port)) = without_brackets.rsplit_once(':')
288 && let Ok(ip) = host.parse::<IpAddr>()
289 {
290 return Some(ip);
291 }
292
293 None
294}
295
296fn resolve_from_chain(
297 chain: &[IpAddr],
298 trusted_set: &TrustedSet,
299 direction: ChainDirection,
300) -> Option<IpAddr> {
301 let iter: Box<dyn Iterator<Item = &IpAddr>> = match direction {
302 ChainDirection::LeftToRight => Box::new(chain.iter()),
303 ChainDirection::RightToLeft => Box::new(chain.iter().rev()),
304 };
305
306 let mut last = None;
307 for ip in iter {
308 last = Some(*ip);
309 if !trusted_set.contains(*ip) {
310 return Some(*ip);
311 }
312 }
313
314 match direction {
315 ChainDirection::LeftToRight => chain.last().copied().or(last),
316 ChainDirection::RightToLeft => chain.first().copied().or(last),
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use std::{fs, net::IpAddr, path::PathBuf, sync::Arc};
323
324 use http::HeaderMap;
325
326 use super::*;
327 use crate::{
328 config::{
329 CommandProviderConfig, CoreProviderConfig, CustomProviderConfig, HeaderInputConfig,
330 HeaderMode, InlineProviderConfig, LocalFileProviderConfig, ProviderConfig,
331 RefreshFailurePolicy, SourceConfig,
332 },
333 extension::{
334 CustomProviderFactory, DynamicProvider, ProviderFactoryRegistry, ProviderLoadFuture,
335 },
336 };
337
338 fn temp_file(name: &str, content: &str) -> PathBuf {
339 let path =
340 std::env::temp_dir().join(format!("securitydept-realip-{name}-{}", std::process::id()));
341 fs::write(&path, content).unwrap();
342 path
343 }
344
345 #[tokio::test]
346 async fn resolves_recursive_xff_after_skipping_trusted_proxies() {
347 let config = RealIpResolveConfig {
348 providers: vec![
349 ProviderConfig::Core(CoreProviderConfig::Inline(InlineProviderConfig {
350 name: "cloudflare".to_string(),
351 cidrs: vec!["203.0.113.0/24".parse().unwrap()],
352 extra: Default::default(),
353 })),
354 ProviderConfig::Core(CoreProviderConfig::Inline(InlineProviderConfig {
355 name: "edgeone".to_string(),
356 cidrs: vec!["198.51.100.0/24".parse().unwrap()],
357 extra: Default::default(),
358 })),
359 ],
360 sources: vec![SourceConfig {
361 name: "cloudflare".to_string(),
362 priority: 100,
363 peers_from: vec!["cloudflare".to_string()],
364 accept_transport: vec![],
365 accept_headers: vec![
366 HeaderInputConfig {
367 kind: "cf-connecting-ip".to_string(),
368 mode: HeaderMode::Single,
369 direction: ChainDirection::RightToLeft,
370 param: None,
371 use_only_if_not_in_trusted_peers: true,
372 },
373 HeaderInputConfig {
374 kind: "x-forwarded-for".to_string(),
375 mode: HeaderMode::Recursive,
376 direction: ChainDirection::RightToLeft,
377 param: None,
378 use_only_if_not_in_trusted_peers: false,
379 },
380 ],
381 }],
382 fallback: Default::default(),
383 };
384 let resolver = RealIpResolver::from_config(config).await.unwrap();
385 let peer_ip: IpAddr = "203.0.113.10".parse().unwrap();
386
387 let mut headers = HeaderMap::new();
388 headers.insert("cf-connecting-ip", "198.51.100.2".parse().unwrap());
389 headers.insert(
390 "x-forwarded-for",
391 "198.18.0.10, 198.51.100.2".parse().unwrap(),
392 );
393
394 let resolved = resolver
395 .resolve(peer_ip, &headers, &TransportContext::default())
396 .await;
397
398 assert_eq!(resolved.client_ip, "198.18.0.10".parse::<IpAddr>().unwrap());
399 assert_eq!(resolved.header_name.as_deref(), Some("x-forwarded-for"));
400 }
401
402 #[tokio::test]
403 async fn loads_local_file_provider() {
404 let path = temp_file("local-provider", "127.0.0.1/32\n::1/128\n");
405 let config = RealIpResolveConfig {
406 providers: vec![ProviderConfig::Core(CoreProviderConfig::LocalFile(
407 LocalFileProviderConfig {
408 name: "local".to_string(),
409 path: path.clone(),
410 watch: false,
411 debounce: None,
412 max_stale: None,
413 extra: Default::default(),
414 },
415 ))],
416 sources: vec![],
417 fallback: Default::default(),
418 };
419
420 let resolver = RealIpResolver::from_config(config).await.unwrap();
421 let trusted = resolver.providers.all_cidrs().await;
422 assert_eq!(trusted.len(), 2);
423
424 let _ = fs::remove_file(path);
425 }
426
427 #[tokio::test]
428 async fn loads_command_provider() {
429 let config = RealIpResolveConfig {
430 providers: vec![ProviderConfig::Core(CoreProviderConfig::Command(
431 CommandProviderConfig {
432 name: "command".to_string(),
433 command: "sh".to_string(),
434 args: vec![
435 "-c".to_string(),
436 "printf '10.0.0.1\\n10.0.0.0/24\\n'".to_string(),
437 ],
438 refresh: None,
439 timeout: Some(std::time::Duration::from_secs(5)),
440 on_refresh_failure: RefreshFailurePolicy::KeepLastGood,
441 max_stale: None,
442 extra: Default::default(),
443 },
444 ))],
445 sources: vec![],
446 fallback: Default::default(),
447 };
448
449 let resolver = RealIpResolver::from_config(config).await.unwrap();
450 let trusted = resolver.providers.all_cidrs().await;
451 assert_eq!(trusted.len(), 2);
452 }
453
454 struct StaticCustomProvider {
455 cidrs: Vec<IpNet>,
456 }
457
458 impl DynamicProvider for StaticCustomProvider {
459 fn load<'a>(&'a self) -> ProviderLoadFuture<'a> {
460 let cidrs = self.cidrs.clone();
461 Box::pin(async move { Ok(cidrs) })
462 }
463 }
464
465 struct StaticCustomProviderFactory;
466
467 impl CustomProviderFactory for StaticCustomProviderFactory {
468 fn kind(&self) -> &'static str {
469 "static-custom"
470 }
471
472 fn create(&self, config: &CustomProviderConfig) -> RealIpResult<Arc<dyn DynamicProvider>> {
473 let cidrs = config
474 .extra
475 .get("cidrs")
476 .and_then(|value| value.as_array())
477 .into_iter()
478 .flatten()
479 .filter_map(|value| value.as_str())
480 .map(|value| value.parse::<IpNet>().unwrap())
481 .collect();
482 Ok(Arc::new(StaticCustomProvider { cidrs }))
483 }
484 }
485
486 #[tokio::test]
487 async fn loads_custom_provider_via_factory_registry() {
488 let mut factories = ProviderFactoryRegistry::new();
489 factories.register(StaticCustomProviderFactory).unwrap();
490
491 let config = RealIpResolveConfig {
492 providers: vec![ProviderConfig::Custom(CustomProviderConfig {
493 name: "custom".to_string(),
494 kind: "static-custom".to_string(),
495 refresh: None,
496 timeout: None,
497 on_refresh_failure: RefreshFailurePolicy::KeepLastGood,
498 max_stale: None,
499 extra: [(
500 "cidrs".to_string(),
501 serde_json::json!(["10.10.0.0/16", "127.0.0.1/32"]),
502 )]
503 .into_iter()
504 .collect(),
505 })],
506 sources: vec![],
507 fallback: Default::default(),
508 };
509
510 let resolver = RealIpResolver::from_config_with_factories(config, &factories)
511 .await
512 .unwrap();
513 let trusted = resolver.providers.all_cidrs().await;
514 assert_eq!(trusted.len(), 2);
515 }
516}