Skip to main content

securitydept_realip/
config.rs

1use std::{collections::BTreeMap, net::IpAddr, path::PathBuf, time::Duration};
2
3use ipnet::IpNet;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5
6use crate::error::{RealIpError, RealIpResult};
7
8#[derive(Debug, Clone, Deserialize, Serialize, Default)]
9pub struct RealIpResolveConfig {
10    #[serde(default)]
11    pub providers: Vec<ProviderConfig>,
12    #[serde(default)]
13    pub sources: Vec<SourceConfig>,
14    #[serde(default)]
15    pub fallback: FallbackConfig,
16}
17
18impl RealIpResolveConfig {
19    pub fn validate(&self) -> RealIpResult<()> {
20        let mut provider_names = std::collections::BTreeSet::new();
21        for provider in &self.providers {
22            if !provider_names.insert(provider.name().to_string()) {
23                return Err(RealIpError::Config {
24                    message: format!("duplicate provider name `{}`", provider.name()),
25                });
26            }
27            provider.validate()?;
28        }
29
30        let known_providers = provider_names;
31        let mut source_names = std::collections::BTreeSet::new();
32        for source in &self.sources {
33            if !source_names.insert(source.name.clone()) {
34                return Err(RealIpError::Config {
35                    message: format!("duplicate source name `{}`", source.name),
36                });
37            }
38
39            for provider in &source.peers_from {
40                if !known_providers.contains(provider) {
41                    return Err(RealIpError::UnknownSourceProvider {
42                        source_name: source.name.clone(),
43                        provider: provider.clone(),
44                    });
45                }
46            }
47        }
48
49        Ok(())
50    }
51}
52
53#[derive(Debug, Clone)]
54pub enum ProviderConfig {
55    Core(CoreProviderConfig),
56    Custom(CustomProviderConfig),
57}
58
59impl ProviderConfig {
60    pub fn name(&self) -> &str {
61        match self {
62            Self::Core(config) => config.name(),
63            Self::Custom(config) => &config.name,
64        }
65    }
66
67    pub fn kind(&self) -> &str {
68        match self {
69            Self::Core(config) => config.kind(),
70            Self::Custom(config) => &config.kind,
71        }
72    }
73
74    pub fn refresh(&self) -> Option<Duration> {
75        match self {
76            Self::Core(config) => config.refresh(),
77            Self::Custom(config) => config.refresh,
78        }
79    }
80
81    pub fn timeout(&self) -> Option<Duration> {
82        match self {
83            Self::Core(config) => config.timeout(),
84            Self::Custom(config) => config.timeout,
85        }
86    }
87
88    pub fn on_refresh_failure(&self) -> RefreshFailurePolicy {
89        match self {
90            Self::Core(config) => config.on_refresh_failure(),
91            Self::Custom(config) => config.on_refresh_failure,
92        }
93    }
94
95    pub fn max_stale(&self) -> Option<Duration> {
96        match self {
97            Self::Core(config) => config.max_stale(),
98            Self::Custom(config) => config.max_stale,
99        }
100    }
101
102    pub fn watch_path(&self) -> Option<(&PathBuf, Duration)> {
103        match self {
104            Self::Core(config) => config.watch_path(),
105            Self::Custom(_) => None,
106        }
107    }
108
109    pub fn inline_cidrs(&self) -> Option<&[IpNet]> {
110        match self {
111            Self::Core(config) => config.inline_cidrs(),
112            Self::Custom(_) => None,
113        }
114    }
115
116    pub fn local_file_path(&self) -> Option<&PathBuf> {
117        match self {
118            Self::Core(config) => config.local_file_path(),
119            Self::Custom(_) => None,
120        }
121    }
122
123    pub fn remote_file_url(&self) -> Option<&str> {
124        match self {
125            Self::Core(config) => config.remote_file_url(),
126            Self::Custom(_) => None,
127        }
128    }
129
130    pub fn command_spec(&self) -> Option<(&str, &[String])> {
131        match self {
132            Self::Core(config) => config.command_spec(),
133            Self::Custom(_) => None,
134        }
135    }
136
137    pub fn custom(&self) -> Option<&CustomProviderConfig> {
138        match self {
139            Self::Custom(config) => Some(config),
140            Self::Core(_) => None,
141        }
142    }
143
144    pub fn validate(&self) -> RealIpResult<()> {
145        match self {
146            Self::Core(config) => config.validate(),
147            Self::Custom(config) => {
148                if config.kind.trim().is_empty() {
149                    return Err(RealIpError::Config {
150                        message: format!("custom provider `{}` has empty kind", config.name),
151                    });
152                }
153                Ok(())
154            }
155        }
156    }
157}
158
159impl<'de> Deserialize<'de> for ProviderConfig {
160    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
161    where
162        D: Deserializer<'de>,
163    {
164        let value = serde_json::Value::deserialize(deserializer)?;
165        let kind = value
166            .get("kind")
167            .and_then(serde_json::Value::as_str)
168            .ok_or_else(|| serde::de::Error::custom("provider requires string field `kind`"))?;
169
170        match kind {
171            "inline" | "local-file" | "remote-file" | "command" => {
172                CoreProviderConfig::deserialize(value)
173                    .map(ProviderConfig::Core)
174                    .map_err(serde::de::Error::custom)
175            }
176            _ => CustomProviderConfig::deserialize(value)
177                .map(ProviderConfig::Custom)
178                .map_err(serde::de::Error::custom),
179        }
180    }
181}
182
183impl Serialize for ProviderConfig {
184    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
185    where
186        S: Serializer,
187    {
188        match self {
189            Self::Core(config) => config.serialize(serializer),
190            Self::Custom(config) => config.serialize(serializer),
191        }
192    }
193}
194
195#[derive(Debug, Clone, Deserialize, Serialize)]
196#[serde(tag = "kind", rename_all = "kebab-case")]
197pub enum CoreProviderConfig {
198    Inline(InlineProviderConfig),
199    LocalFile(LocalFileProviderConfig),
200    RemoteFile(RemoteFileProviderConfig),
201    Command(CommandProviderConfig),
202}
203
204impl CoreProviderConfig {
205    pub fn name(&self) -> &str {
206        match self {
207            Self::Inline(config) => &config.name,
208            Self::LocalFile(config) => &config.name,
209            Self::RemoteFile(config) => &config.name,
210            Self::Command(config) => &config.name,
211        }
212    }
213
214    pub fn kind(&self) -> &str {
215        match self {
216            Self::Inline(_) => "inline",
217            Self::LocalFile(_) => "local-file",
218            Self::RemoteFile(_) => "remote-file",
219            Self::Command(_) => "command",
220        }
221    }
222
223    pub fn refresh(&self) -> Option<Duration> {
224        match self {
225            Self::RemoteFile(config) => config.refresh,
226            Self::Command(config) => config.refresh,
227            Self::Inline(_) | Self::LocalFile(_) => None,
228        }
229    }
230
231    pub fn timeout(&self) -> Option<Duration> {
232        match self {
233            Self::RemoteFile(config) => config.timeout,
234            Self::Command(config) => config.timeout,
235            Self::Inline(_) | Self::LocalFile(_) => None,
236        }
237    }
238
239    pub fn on_refresh_failure(&self) -> RefreshFailurePolicy {
240        match self {
241            Self::RemoteFile(config) => config.on_refresh_failure,
242            Self::Command(config) => config.on_refresh_failure,
243            Self::Inline(_) | Self::LocalFile(_) => RefreshFailurePolicy::KeepLastGood,
244        }
245    }
246
247    pub fn max_stale(&self) -> Option<Duration> {
248        match self {
249            Self::Inline(_) => None,
250            Self::LocalFile(config) => config.max_stale,
251            Self::RemoteFile(config) => config.max_stale,
252            Self::Command(config) => config.max_stale,
253        }
254    }
255
256    pub fn watch_path(&self) -> Option<(&PathBuf, Duration)> {
257        match self {
258            Self::LocalFile(config) if config.watch => Some((
259                &config.path,
260                config.debounce.unwrap_or(Duration::from_secs(2)),
261            )),
262            _ => None,
263        }
264    }
265
266    pub fn inline_cidrs(&self) -> Option<&[IpNet]> {
267        match self {
268            Self::Inline(config) => Some(&config.cidrs),
269            _ => None,
270        }
271    }
272
273    pub fn local_file_path(&self) -> Option<&PathBuf> {
274        match self {
275            Self::LocalFile(config) => Some(&config.path),
276            _ => None,
277        }
278    }
279
280    pub fn remote_file_url(&self) -> Option<&str> {
281        match self {
282            Self::RemoteFile(config) => Some(&config.url),
283            _ => None,
284        }
285    }
286
287    pub fn command_spec(&self) -> Option<(&str, &[String])> {
288        match self {
289            Self::Command(config) => Some((&config.command, &config.args)),
290            _ => None,
291        }
292    }
293
294    pub fn validate(&self) -> RealIpResult<()> {
295        if let Self::Inline(config) = self
296            && config.cidrs.is_empty()
297        {
298            return Err(RealIpError::MissingProviderField {
299                provider: config.name.clone(),
300                field: "cidrs",
301            });
302        }
303        Ok(())
304    }
305}
306
307#[derive(Debug, Clone, Deserialize, Serialize)]
308pub struct InlineProviderConfig {
309    pub name: String,
310    pub cidrs: Vec<IpNet>,
311    #[serde(flatten, default)]
312    pub extra: BTreeMap<String, serde_json::Value>,
313}
314
315#[derive(Debug, Clone, Deserialize, Serialize)]
316pub struct LocalFileProviderConfig {
317    pub name: String,
318    pub path: PathBuf,
319    #[serde(default)]
320    pub watch: bool,
321    #[serde(default, with = "humantime_serde::option")]
322    pub debounce: Option<Duration>,
323    #[serde(default, with = "humantime_serde::option")]
324    pub max_stale: Option<Duration>,
325    #[serde(flatten, default)]
326    pub extra: BTreeMap<String, serde_json::Value>,
327}
328
329#[derive(Debug, Clone, Deserialize, Serialize)]
330pub struct RemoteFileProviderConfig {
331    pub name: String,
332    pub url: String,
333    #[serde(default, with = "humantime_serde::option")]
334    pub refresh: Option<Duration>,
335    #[serde(default, with = "humantime_serde::option")]
336    pub timeout: Option<Duration>,
337    #[serde(default)]
338    pub on_refresh_failure: RefreshFailurePolicy,
339    #[serde(default, with = "humantime_serde::option")]
340    pub max_stale: Option<Duration>,
341    #[serde(flatten, default)]
342    pub extra: BTreeMap<String, serde_json::Value>,
343}
344
345#[derive(Debug, Clone, Deserialize, Serialize)]
346pub struct CommandProviderConfig {
347    pub name: String,
348    pub command: String,
349    #[serde(default)]
350    pub args: Vec<String>,
351    #[serde(default, with = "humantime_serde::option")]
352    pub refresh: Option<Duration>,
353    #[serde(default, with = "humantime_serde::option")]
354    pub timeout: Option<Duration>,
355    #[serde(default)]
356    pub on_refresh_failure: RefreshFailurePolicy,
357    #[serde(default, with = "humantime_serde::option")]
358    pub max_stale: Option<Duration>,
359    #[serde(flatten, default)]
360    pub extra: BTreeMap<String, serde_json::Value>,
361}
362
363#[derive(Debug, Clone, Deserialize, Serialize)]
364pub struct CustomProviderConfig {
365    pub name: String,
366    pub kind: String,
367    #[serde(default, with = "humantime_serde::option")]
368    pub refresh: Option<Duration>,
369    #[serde(default, with = "humantime_serde::option")]
370    pub timeout: Option<Duration>,
371    #[serde(default)]
372    pub on_refresh_failure: RefreshFailurePolicy,
373    #[serde(default, with = "humantime_serde::option")]
374    pub max_stale: Option<Duration>,
375    #[serde(flatten, default)]
376    pub extra: BTreeMap<String, serde_json::Value>,
377}
378
379#[derive(Debug, Clone, Deserialize, Serialize)]
380pub struct SourceConfig {
381    pub name: String,
382    #[serde(default)]
383    pub priority: i32,
384    #[serde(default)]
385    pub peers_from: Vec<String>,
386    #[serde(default)]
387    pub accept_transport: Vec<TransportInputConfig>,
388    #[serde(default)]
389    pub accept_headers: Vec<HeaderInputConfig>,
390}
391
392#[derive(Debug, Clone, Deserialize, Serialize)]
393pub struct TransportInputConfig {
394    pub kind: String,
395}
396
397#[derive(Debug, Clone, Deserialize, Serialize)]
398pub struct HeaderInputConfig {
399    pub kind: String,
400    #[serde(default)]
401    pub mode: HeaderMode,
402    #[serde(default)]
403    pub direction: ChainDirection,
404    #[serde(default)]
405    pub param: Option<String>,
406    #[serde(default)]
407    pub use_only_if_not_in_trusted_peers: bool,
408}
409
410#[derive(Debug, Clone, Copy, Deserialize, Serialize, Default, PartialEq, Eq)]
411#[serde(rename_all = "kebab-case")]
412pub enum HeaderMode {
413    #[default]
414    Single,
415    Recursive,
416}
417
418#[derive(Debug, Clone, Copy, Deserialize, Serialize, Default, PartialEq, Eq)]
419#[serde(rename_all = "kebab-case")]
420pub enum ChainDirection {
421    LeftToRight,
422    #[default]
423    RightToLeft,
424}
425
426#[derive(Debug, Clone, Copy, Deserialize, Serialize, Default, PartialEq, Eq)]
427#[serde(rename_all = "kebab-case")]
428pub enum RefreshFailurePolicy {
429    #[default]
430    KeepLastGood,
431    Clear,
432}
433
434#[derive(Debug, Clone, Deserialize, Serialize)]
435pub struct FallbackConfig {
436    #[serde(default)]
437    pub strategy: FallbackStrategy,
438}
439
440impl Default for FallbackConfig {
441    fn default() -> Self {
442        Self {
443            strategy: FallbackStrategy::RemoteAddr,
444        }
445    }
446}
447
448#[derive(Debug, Clone, Copy, Deserialize, Serialize, Default, PartialEq, Eq)]
449#[serde(rename_all = "kebab-case")]
450pub enum FallbackStrategy {
451    #[default]
452    RemoteAddr,
453}
454
455pub(crate) fn parse_ip_or_cidr(entry: &str) -> Result<IpNet, ()> {
456    if let Ok(net) = entry.parse::<IpNet>() {
457        return Ok(net);
458    }
459    let addr = entry.parse::<IpAddr>().map_err(|_| ())?;
460    Ok(IpNet::from(addr))
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    #[test]
468    fn deserialize_docker_provider_as_custom_provider() {
469        let config: ProviderConfig = serde_json::from_value(serde_json::json!({
470            "name": "docker-ingress",
471            "kind": "docker-provider",
472            "host": "unix:///var/run/docker.sock",
473            "networks": ["edge-ingress", "internal-proxy"],
474            "refresh": "30s",
475            "timeout": "5s",
476            "on_refresh_failure": "keep-last-good",
477            "max_stale": "10m"
478        }))
479        .unwrap();
480
481        let ProviderConfig::Custom(custom) = config else {
482            panic!("expected custom provider");
483        };
484        assert_eq!(custom.kind, "docker-provider");
485        assert_eq!(custom.name, "docker-ingress");
486        assert_eq!(custom.refresh, Some(Duration::from_secs(30)));
487        assert_eq!(custom.timeout, Some(Duration::from_secs(5)));
488        assert_eq!(custom.max_stale, Some(Duration::from_secs(600)));
489        assert_eq!(
490            custom.extra.get("host").and_then(serde_json::Value::as_str),
491            Some("unix:///var/run/docker.sock")
492        );
493        assert_eq!(
494            custom
495                .extra
496                .get("networks")
497                .and_then(serde_json::Value::as_array)
498                .map(Vec::len),
499            Some(2)
500        );
501    }
502
503    #[test]
504    fn deserialize_kube_provider_as_custom_provider() {
505        let config: ProviderConfig = serde_json::from_value(serde_json::json!({
506            "name": "kube-ingress-pods",
507            "kind": "kube-provider",
508            "resource": "pods",
509            "namespace": "ingress-nginx",
510            "label_selector": "app.kubernetes.io/name=ingress-nginx",
511            "refresh": "30s",
512            "timeout": "5s"
513        }))
514        .unwrap();
515
516        let ProviderConfig::Custom(custom) = config else {
517            panic!("expected custom provider");
518        };
519        assert_eq!(custom.kind, "kube-provider");
520        assert_eq!(custom.name, "kube-ingress-pods");
521        assert_eq!(custom.refresh, Some(Duration::from_secs(30)));
522        assert_eq!(custom.timeout, Some(Duration::from_secs(5)));
523        assert_eq!(
524            custom
525                .extra
526                .get("resource")
527                .and_then(serde_json::Value::as_str),
528            Some("pods")
529        );
530        assert_eq!(
531            custom
532                .extra
533                .get("namespace")
534                .and_then(serde_json::Value::as_str),
535            Some("ingress-nginx")
536        );
537        assert_eq!(
538            custom
539                .extra
540                .get("label_selector")
541                .and_then(serde_json::Value::as_str),
542            Some("app.kubernetes.io/name=ingress-nginx")
543        );
544    }
545}