Skip to main content

securitydept_realip/
resolve.rs

1use std::{collections::HashSet, net::IpAddr};
2
3use http::HeaderMap;
4use ipnet::IpNet;
5use rfc7239::parse as parse_forwarded;
6
7use crate::{
8    config::{
9        ChainDirection, FallbackStrategy, HeaderInputConfig, HeaderMode, RealIpResolveConfig,
10    },
11    error::RealIpResult,
12    extension::ProviderFactoryRegistry,
13    providers::{ProviderRegistry, ProviderSnapshot},
14};
15
16#[derive(Debug, Clone, Default)]
17pub struct TransportContext {
18    pub proxy_protocol_addr: Option<IpAddr>,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ResolvedSourceKind {
23    Transport,
24    Header,
25    Fallback,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct ResolvedClientIp {
30    pub client_ip: IpAddr,
31    pub peer_ip: IpAddr,
32    pub source_name: Option<String>,
33    pub source_kind: ResolvedSourceKind,
34    pub header_name: Option<String>,
35}
36
37#[derive(Debug, Clone)]
38struct CompiledSource {
39    name: String,
40    priority: i32,
41    peer_cidrs: Vec<IpNet>,
42    accept_transport: Vec<String>,
43    accept_headers: Vec<HeaderInputConfig>,
44}
45
46pub struct RealIpResolver {
47    config: RealIpResolveConfig,
48    providers: ProviderRegistry,
49}
50
51impl RealIpResolver {
52    pub async fn from_config(config: RealIpResolveConfig) -> RealIpResult<Self> {
53        let factories = ProviderFactoryRegistry::with_builtin_providers()?;
54        Self::from_config_with_factories(config, &factories).await
55    }
56
57    pub async fn from_config_with_factories(
58        config: RealIpResolveConfig,
59        factories: &ProviderFactoryRegistry,
60    ) -> RealIpResult<Self> {
61        config.validate()?;
62        let providers =
63            ProviderRegistry::from_configs_with_factories(&config.providers, factories).await?;
64        Ok(Self { config, providers })
65    }
66
67    pub async fn resolve(
68        &self,
69        peer_ip: IpAddr,
70        headers: &HeaderMap,
71        transport: &TransportContext,
72    ) -> ResolvedClientIp {
73        let compiled_sources = self.compile_sources().await;
74        let trusted_peers = self.providers.all_cidrs().await;
75        let trusted_set = TrustedSet::new(trusted_peers);
76
77        for source in compiled_sources {
78            if !source.matches_peer(peer_ip) {
79                continue;
80            }
81
82            if let Some(result) = source.resolve_transport(peer_ip, transport) {
83                return result;
84            }
85
86            if let Some(result) = source.resolve_headers(peer_ip, headers, &trusted_set) {
87                return result;
88            }
89        }
90
91        match self.config.fallback.strategy {
92            FallbackStrategy::RemoteAddr => ResolvedClientIp {
93                client_ip: peer_ip,
94                peer_ip,
95                source_name: None,
96                source_kind: ResolvedSourceKind::Fallback,
97                header_name: None,
98            },
99        }
100    }
101
102    async fn compile_sources(&self) -> Vec<CompiledSource> {
103        let mut compiled = Vec::with_capacity(self.config.sources.len());
104        for source in &self.config.sources {
105            let mut peer_cidrs = Vec::new();
106            for provider_name in &source.peers_from {
107                if let Some(ProviderSnapshot { cidrs, .. }) =
108                    self.providers.snapshot(provider_name).await
109                {
110                    peer_cidrs.extend(cidrs.iter().copied());
111                }
112            }
113
114            compiled.push(CompiledSource {
115                name: source.name.clone(),
116                priority: source.priority,
117                peer_cidrs,
118                accept_transport: source
119                    .accept_transport
120                    .iter()
121                    .map(|item| item.kind.to_ascii_lowercase())
122                    .collect(),
123                accept_headers: source.accept_headers.clone(),
124            });
125        }
126
127        compiled.sort_by_key(|right| std::cmp::Reverse(right.priority));
128        compiled
129    }
130}
131
132impl CompiledSource {
133    fn matches_peer(&self, peer_ip: IpAddr) -> bool {
134        self.peer_cidrs.iter().any(|cidr| cidr.contains(&peer_ip))
135    }
136
137    fn resolve_transport(
138        &self,
139        peer_ip: IpAddr,
140        transport: &TransportContext,
141    ) -> Option<ResolvedClientIp> {
142        if self
143            .accept_transport
144            .iter()
145            .any(|kind| kind == "proxy-protocol")
146            && let Some(proxy_ip) = transport.proxy_protocol_addr
147        {
148            return Some(ResolvedClientIp {
149                client_ip: proxy_ip,
150                peer_ip,
151                source_name: Some(self.name.clone()),
152                source_kind: ResolvedSourceKind::Transport,
153                header_name: Some("proxy-protocol".to_string()),
154            });
155        }
156
157        None
158    }
159
160    fn resolve_headers(
161        &self,
162        peer_ip: IpAddr,
163        headers: &HeaderMap,
164        trusted_set: &TrustedSet,
165    ) -> Option<ResolvedClientIp> {
166        for header in &self.accept_headers {
167            let kind = header.kind.to_ascii_lowercase();
168            let candidate = match header.mode {
169                HeaderMode::Single => resolve_single_header(headers, &kind),
170                HeaderMode::Recursive => resolve_chain_header(headers, &kind, header, trusted_set),
171            };
172
173            let candidate = match candidate {
174                Some(ip) => ip,
175                None => continue,
176            };
177
178            if header.use_only_if_not_in_trusted_peers && trusted_set.contains(candidate) {
179                continue;
180            }
181
182            return Some(ResolvedClientIp {
183                client_ip: candidate,
184                peer_ip,
185                source_name: Some(self.name.clone()),
186                source_kind: ResolvedSourceKind::Header,
187                header_name: Some(kind),
188            });
189        }
190
191        None
192    }
193}
194
195#[derive(Debug, Clone)]
196struct TrustedSet {
197    cidrs: Vec<IpNet>,
198}
199
200impl TrustedSet {
201    fn new(cidrs: Vec<IpNet>) -> Self {
202        let mut unique = HashSet::new();
203        let mut deduped = Vec::new();
204        for cidr in cidrs {
205            if unique.insert(cidr) {
206                deduped.push(cidr);
207            }
208        }
209        Self { cidrs: deduped }
210    }
211
212    fn contains(&self, ip: IpAddr) -> bool {
213        self.cidrs.iter().any(|cidr| cidr.contains(&ip))
214    }
215}
216
217fn resolve_single_header(headers: &HeaderMap, kind: &str) -> Option<IpAddr> {
218    headers
219        .get(kind)
220        .and_then(|value| value.to_str().ok())
221        .map(str::trim)
222        .filter(|value| !value.is_empty())
223        .and_then(|value| value.parse::<IpAddr>().ok())
224}
225
226fn resolve_chain_header(
227    headers: &HeaderMap,
228    kind: &str,
229    config: &HeaderInputConfig,
230    trusted_set: &TrustedSet,
231) -> Option<IpAddr> {
232    let chain = match kind {
233        "x-forwarded-for" => parse_x_forwarded_for(headers),
234        "forwarded" => parse_forwarded_for(headers, config.param.as_deref().unwrap_or("for")),
235        _ => Vec::new(),
236    };
237
238    resolve_from_chain(&chain, trusted_set, config.direction)
239}
240
241fn parse_x_forwarded_for(headers: &HeaderMap) -> Vec<IpAddr> {
242    headers
243        .get("x-forwarded-for")
244        .and_then(|value| value.to_str().ok())
245        .into_iter()
246        .flat_map(|value| value.split(','))
247        .map(str::trim)
248        .filter(|value| !value.is_empty())
249        .filter_map(|value| value.parse::<IpAddr>().ok())
250        .collect()
251}
252
253fn parse_forwarded_for(headers: &HeaderMap, param: &str) -> Vec<IpAddr> {
254    let Some(raw) = headers
255        .get(http::header::FORWARDED)
256        .and_then(|value| value.to_str().ok())
257    else {
258        return Vec::new();
259    };
260
261    let mut result = Vec::new();
262    for node in parse_forwarded(raw).flatten() {
263        if param != "for" {
264            continue;
265        }
266        let Some(value) = node.forwarded_for.map(|value| value.to_string()) else {
267            continue;
268        };
269        if let Some(ip) = parse_forwarded_ip(&value) {
270            result.push(ip);
271        }
272    }
273    result
274}
275
276fn parse_forwarded_ip(value: &str) -> Option<IpAddr> {
277    let trimmed = value.trim().trim_matches('"');
278    let without_brackets = trimmed
279        .strip_prefix('[')
280        .and_then(|value| value.strip_suffix(']'))
281        .unwrap_or(trimmed);
282
283    if let Ok(ip) = without_brackets.parse::<IpAddr>() {
284        return Some(ip);
285    }
286
287    if let Some((host, _port)) = without_brackets.rsplit_once(':')
288        && let Ok(ip) = host.parse::<IpAddr>()
289    {
290        return Some(ip);
291    }
292
293    None
294}
295
296fn resolve_from_chain(
297    chain: &[IpAddr],
298    trusted_set: &TrustedSet,
299    direction: ChainDirection,
300) -> Option<IpAddr> {
301    let iter: Box<dyn Iterator<Item = &IpAddr>> = match direction {
302        ChainDirection::LeftToRight => Box::new(chain.iter()),
303        ChainDirection::RightToLeft => Box::new(chain.iter().rev()),
304    };
305
306    let mut last = None;
307    for ip in iter {
308        last = Some(*ip);
309        if !trusted_set.contains(*ip) {
310            return Some(*ip);
311        }
312    }
313
314    match direction {
315        ChainDirection::LeftToRight => chain.last().copied().or(last),
316        ChainDirection::RightToLeft => chain.first().copied().or(last),
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use std::{fs, net::IpAddr, path::PathBuf, sync::Arc};
323
324    use http::HeaderMap;
325
326    use super::*;
327    use crate::{
328        config::{
329            CommandProviderConfig, CoreProviderConfig, CustomProviderConfig, HeaderInputConfig,
330            HeaderMode, InlineProviderConfig, LocalFileProviderConfig, ProviderConfig,
331            RefreshFailurePolicy, SourceConfig,
332        },
333        extension::{
334            CustomProviderFactory, DynamicProvider, ProviderFactoryRegistry, ProviderLoadFuture,
335        },
336    };
337
338    fn temp_file(name: &str, content: &str) -> PathBuf {
339        let path =
340            std::env::temp_dir().join(format!("securitydept-realip-{name}-{}", std::process::id()));
341        fs::write(&path, content).unwrap();
342        path
343    }
344
345    #[tokio::test]
346    async fn resolves_recursive_xff_after_skipping_trusted_proxies() {
347        let config = RealIpResolveConfig {
348            providers: vec![
349                ProviderConfig::Core(CoreProviderConfig::Inline(InlineProviderConfig {
350                    name: "cloudflare".to_string(),
351                    cidrs: vec!["203.0.113.0/24".parse().unwrap()],
352                    extra: Default::default(),
353                })),
354                ProviderConfig::Core(CoreProviderConfig::Inline(InlineProviderConfig {
355                    name: "edgeone".to_string(),
356                    cidrs: vec!["198.51.100.0/24".parse().unwrap()],
357                    extra: Default::default(),
358                })),
359            ],
360            sources: vec![SourceConfig {
361                name: "cloudflare".to_string(),
362                priority: 100,
363                peers_from: vec!["cloudflare".to_string()],
364                accept_transport: vec![],
365                accept_headers: vec![
366                    HeaderInputConfig {
367                        kind: "cf-connecting-ip".to_string(),
368                        mode: HeaderMode::Single,
369                        direction: ChainDirection::RightToLeft,
370                        param: None,
371                        use_only_if_not_in_trusted_peers: true,
372                    },
373                    HeaderInputConfig {
374                        kind: "x-forwarded-for".to_string(),
375                        mode: HeaderMode::Recursive,
376                        direction: ChainDirection::RightToLeft,
377                        param: None,
378                        use_only_if_not_in_trusted_peers: false,
379                    },
380                ],
381            }],
382            fallback: Default::default(),
383        };
384        let resolver = RealIpResolver::from_config(config).await.unwrap();
385        let peer_ip: IpAddr = "203.0.113.10".parse().unwrap();
386
387        let mut headers = HeaderMap::new();
388        headers.insert("cf-connecting-ip", "198.51.100.2".parse().unwrap());
389        headers.insert(
390            "x-forwarded-for",
391            "198.18.0.10, 198.51.100.2".parse().unwrap(),
392        );
393
394        let resolved = resolver
395            .resolve(peer_ip, &headers, &TransportContext::default())
396            .await;
397
398        assert_eq!(resolved.client_ip, "198.18.0.10".parse::<IpAddr>().unwrap());
399        assert_eq!(resolved.header_name.as_deref(), Some("x-forwarded-for"));
400    }
401
402    #[tokio::test]
403    async fn loads_local_file_provider() {
404        let path = temp_file("local-provider", "127.0.0.1/32\n::1/128\n");
405        let config = RealIpResolveConfig {
406            providers: vec![ProviderConfig::Core(CoreProviderConfig::LocalFile(
407                LocalFileProviderConfig {
408                    name: "local".to_string(),
409                    path: path.clone(),
410                    watch: false,
411                    debounce: None,
412                    max_stale: None,
413                    extra: Default::default(),
414                },
415            ))],
416            sources: vec![],
417            fallback: Default::default(),
418        };
419
420        let resolver = RealIpResolver::from_config(config).await.unwrap();
421        let trusted = resolver.providers.all_cidrs().await;
422        assert_eq!(trusted.len(), 2);
423
424        let _ = fs::remove_file(path);
425    }
426
427    #[tokio::test]
428    async fn loads_command_provider() {
429        let config = RealIpResolveConfig {
430            providers: vec![ProviderConfig::Core(CoreProviderConfig::Command(
431                CommandProviderConfig {
432                    name: "command".to_string(),
433                    command: "sh".to_string(),
434                    args: vec![
435                        "-c".to_string(),
436                        "printf '10.0.0.1\\n10.0.0.0/24\\n'".to_string(),
437                    ],
438                    refresh: None,
439                    timeout: Some(std::time::Duration::from_secs(5)),
440                    on_refresh_failure: RefreshFailurePolicy::KeepLastGood,
441                    max_stale: None,
442                    extra: Default::default(),
443                },
444            ))],
445            sources: vec![],
446            fallback: Default::default(),
447        };
448
449        let resolver = RealIpResolver::from_config(config).await.unwrap();
450        let trusted = resolver.providers.all_cidrs().await;
451        assert_eq!(trusted.len(), 2);
452    }
453
454    struct StaticCustomProvider {
455        cidrs: Vec<IpNet>,
456    }
457
458    impl DynamicProvider for StaticCustomProvider {
459        fn load<'a>(&'a self) -> ProviderLoadFuture<'a> {
460            let cidrs = self.cidrs.clone();
461            Box::pin(async move { Ok(cidrs) })
462        }
463    }
464
465    struct StaticCustomProviderFactory;
466
467    impl CustomProviderFactory for StaticCustomProviderFactory {
468        fn kind(&self) -> &'static str {
469            "static-custom"
470        }
471
472        fn create(&self, config: &CustomProviderConfig) -> RealIpResult<Arc<dyn DynamicProvider>> {
473            let cidrs = config
474                .extra
475                .get("cidrs")
476                .and_then(|value| value.as_array())
477                .into_iter()
478                .flatten()
479                .filter_map(|value| value.as_str())
480                .map(|value| value.parse::<IpNet>().unwrap())
481                .collect();
482            Ok(Arc::new(StaticCustomProvider { cidrs }))
483        }
484    }
485
486    #[tokio::test]
487    async fn loads_custom_provider_via_factory_registry() {
488        let mut factories = ProviderFactoryRegistry::new();
489        factories.register(StaticCustomProviderFactory).unwrap();
490
491        let config = RealIpResolveConfig {
492            providers: vec![ProviderConfig::Custom(CustomProviderConfig {
493                name: "custom".to_string(),
494                kind: "static-custom".to_string(),
495                refresh: None,
496                timeout: None,
497                on_refresh_failure: RefreshFailurePolicy::KeepLastGood,
498                max_stale: None,
499                extra: [(
500                    "cidrs".to_string(),
501                    serde_json::json!(["10.10.0.0/16", "127.0.0.1/32"]),
502                )]
503                .into_iter()
504                .collect(),
505            })],
506            sources: vec![],
507            fallback: Default::default(),
508        };
509
510        let resolver = RealIpResolver::from_config_with_factories(config, &factories)
511            .await
512            .unwrap();
513        let trusted = resolver.providers.all_cidrs().await;
514        assert_eq!(trusted.len(), 2);
515    }
516}