1use std::{collections::BTreeMap, net::IpAddr, path::PathBuf, time::Duration};
2
3use ipnet::IpNet;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5
6use crate::error::{RealIpError, RealIpResult};
7
8#[derive(Debug, Clone, Deserialize, Serialize, Default)]
9pub struct RealIpResolveConfig {
10 #[serde(default)]
11 pub providers: Vec<ProviderConfig>,
12 #[serde(default)]
13 pub sources: Vec<SourceConfig>,
14 #[serde(default)]
15 pub fallback: FallbackConfig,
16}
17
18impl RealIpResolveConfig {
19 pub fn validate(&self) -> RealIpResult<()> {
20 let mut provider_names = std::collections::BTreeSet::new();
21 for provider in &self.providers {
22 if !provider_names.insert(provider.name().to_string()) {
23 return Err(RealIpError::Config {
24 message: format!("duplicate provider name `{}`", provider.name()),
25 });
26 }
27 provider.validate()?;
28 }
29
30 let known_providers = provider_names;
31 let mut source_names = std::collections::BTreeSet::new();
32 for source in &self.sources {
33 if !source_names.insert(source.name.clone()) {
34 return Err(RealIpError::Config {
35 message: format!("duplicate source name `{}`", source.name),
36 });
37 }
38
39 for provider in &source.peers_from {
40 if !known_providers.contains(provider) {
41 return Err(RealIpError::UnknownSourceProvider {
42 source_name: source.name.clone(),
43 provider: provider.clone(),
44 });
45 }
46 }
47 }
48
49 Ok(())
50 }
51}
52
53#[derive(Debug, Clone)]
54pub enum ProviderConfig {
55 Core(CoreProviderConfig),
56 Custom(CustomProviderConfig),
57}
58
59impl ProviderConfig {
60 pub fn name(&self) -> &str {
61 match self {
62 Self::Core(config) => config.name(),
63 Self::Custom(config) => &config.name,
64 }
65 }
66
67 pub fn kind(&self) -> &str {
68 match self {
69 Self::Core(config) => config.kind(),
70 Self::Custom(config) => &config.kind,
71 }
72 }
73
74 pub fn refresh(&self) -> Option<Duration> {
75 match self {
76 Self::Core(config) => config.refresh(),
77 Self::Custom(config) => config.refresh,
78 }
79 }
80
81 pub fn timeout(&self) -> Option<Duration> {
82 match self {
83 Self::Core(config) => config.timeout(),
84 Self::Custom(config) => config.timeout,
85 }
86 }
87
88 pub fn on_refresh_failure(&self) -> RefreshFailurePolicy {
89 match self {
90 Self::Core(config) => config.on_refresh_failure(),
91 Self::Custom(config) => config.on_refresh_failure,
92 }
93 }
94
95 pub fn max_stale(&self) -> Option<Duration> {
96 match self {
97 Self::Core(config) => config.max_stale(),
98 Self::Custom(config) => config.max_stale,
99 }
100 }
101
102 pub fn watch_path(&self) -> Option<(&PathBuf, Duration)> {
103 match self {
104 Self::Core(config) => config.watch_path(),
105 Self::Custom(_) => None,
106 }
107 }
108
109 pub fn inline_cidrs(&self) -> Option<&[IpNet]> {
110 match self {
111 Self::Core(config) => config.inline_cidrs(),
112 Self::Custom(_) => None,
113 }
114 }
115
116 pub fn local_file_path(&self) -> Option<&PathBuf> {
117 match self {
118 Self::Core(config) => config.local_file_path(),
119 Self::Custom(_) => None,
120 }
121 }
122
123 pub fn remote_file_url(&self) -> Option<&str> {
124 match self {
125 Self::Core(config) => config.remote_file_url(),
126 Self::Custom(_) => None,
127 }
128 }
129
130 pub fn command_spec(&self) -> Option<(&str, &[String])> {
131 match self {
132 Self::Core(config) => config.command_spec(),
133 Self::Custom(_) => None,
134 }
135 }
136
137 pub fn custom(&self) -> Option<&CustomProviderConfig> {
138 match self {
139 Self::Custom(config) => Some(config),
140 Self::Core(_) => None,
141 }
142 }
143
144 pub fn validate(&self) -> RealIpResult<()> {
145 match self {
146 Self::Core(config) => config.validate(),
147 Self::Custom(config) => {
148 if config.kind.trim().is_empty() {
149 return Err(RealIpError::Config {
150 message: format!("custom provider `{}` has empty kind", config.name),
151 });
152 }
153 Ok(())
154 }
155 }
156 }
157}
158
159impl<'de> Deserialize<'de> for ProviderConfig {
160 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
161 where
162 D: Deserializer<'de>,
163 {
164 let value = serde_json::Value::deserialize(deserializer)?;
165 let kind = value
166 .get("kind")
167 .and_then(serde_json::Value::as_str)
168 .ok_or_else(|| serde::de::Error::custom("provider requires string field `kind`"))?;
169
170 match kind {
171 "inline" | "local-file" | "remote-file" | "command" => {
172 CoreProviderConfig::deserialize(value)
173 .map(ProviderConfig::Core)
174 .map_err(serde::de::Error::custom)
175 }
176 _ => CustomProviderConfig::deserialize(value)
177 .map(ProviderConfig::Custom)
178 .map_err(serde::de::Error::custom),
179 }
180 }
181}
182
183impl Serialize for ProviderConfig {
184 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
185 where
186 S: Serializer,
187 {
188 match self {
189 Self::Core(config) => config.serialize(serializer),
190 Self::Custom(config) => config.serialize(serializer),
191 }
192 }
193}
194
195#[derive(Debug, Clone, Deserialize, Serialize)]
196#[serde(tag = "kind", rename_all = "kebab-case")]
197pub enum CoreProviderConfig {
198 Inline(InlineProviderConfig),
199 LocalFile(LocalFileProviderConfig),
200 RemoteFile(RemoteFileProviderConfig),
201 Command(CommandProviderConfig),
202}
203
204impl CoreProviderConfig {
205 pub fn name(&self) -> &str {
206 match self {
207 Self::Inline(config) => &config.name,
208 Self::LocalFile(config) => &config.name,
209 Self::RemoteFile(config) => &config.name,
210 Self::Command(config) => &config.name,
211 }
212 }
213
214 pub fn kind(&self) -> &str {
215 match self {
216 Self::Inline(_) => "inline",
217 Self::LocalFile(_) => "local-file",
218 Self::RemoteFile(_) => "remote-file",
219 Self::Command(_) => "command",
220 }
221 }
222
223 pub fn refresh(&self) -> Option<Duration> {
224 match self {
225 Self::RemoteFile(config) => config.refresh,
226 Self::Command(config) => config.refresh,
227 Self::Inline(_) | Self::LocalFile(_) => None,
228 }
229 }
230
231 pub fn timeout(&self) -> Option<Duration> {
232 match self {
233 Self::RemoteFile(config) => config.timeout,
234 Self::Command(config) => config.timeout,
235 Self::Inline(_) | Self::LocalFile(_) => None,
236 }
237 }
238
239 pub fn on_refresh_failure(&self) -> RefreshFailurePolicy {
240 match self {
241 Self::RemoteFile(config) => config.on_refresh_failure,
242 Self::Command(config) => config.on_refresh_failure,
243 Self::Inline(_) | Self::LocalFile(_) => RefreshFailurePolicy::KeepLastGood,
244 }
245 }
246
247 pub fn max_stale(&self) -> Option<Duration> {
248 match self {
249 Self::Inline(_) => None,
250 Self::LocalFile(config) => config.max_stale,
251 Self::RemoteFile(config) => config.max_stale,
252 Self::Command(config) => config.max_stale,
253 }
254 }
255
256 pub fn watch_path(&self) -> Option<(&PathBuf, Duration)> {
257 match self {
258 Self::LocalFile(config) if config.watch => Some((
259 &config.path,
260 config.debounce.unwrap_or(Duration::from_secs(2)),
261 )),
262 _ => None,
263 }
264 }
265
266 pub fn inline_cidrs(&self) -> Option<&[IpNet]> {
267 match self {
268 Self::Inline(config) => Some(&config.cidrs),
269 _ => None,
270 }
271 }
272
273 pub fn local_file_path(&self) -> Option<&PathBuf> {
274 match self {
275 Self::LocalFile(config) => Some(&config.path),
276 _ => None,
277 }
278 }
279
280 pub fn remote_file_url(&self) -> Option<&str> {
281 match self {
282 Self::RemoteFile(config) => Some(&config.url),
283 _ => None,
284 }
285 }
286
287 pub fn command_spec(&self) -> Option<(&str, &[String])> {
288 match self {
289 Self::Command(config) => Some((&config.command, &config.args)),
290 _ => None,
291 }
292 }
293
294 pub fn validate(&self) -> RealIpResult<()> {
295 if let Self::Inline(config) = self
296 && config.cidrs.is_empty()
297 {
298 return Err(RealIpError::MissingProviderField {
299 provider: config.name.clone(),
300 field: "cidrs",
301 });
302 }
303 Ok(())
304 }
305}
306
307#[derive(Debug, Clone, Deserialize, Serialize)]
308pub struct InlineProviderConfig {
309 pub name: String,
310 pub cidrs: Vec<IpNet>,
311 #[serde(flatten, default)]
312 pub extra: BTreeMap<String, serde_json::Value>,
313}
314
315#[derive(Debug, Clone, Deserialize, Serialize)]
316pub struct LocalFileProviderConfig {
317 pub name: String,
318 pub path: PathBuf,
319 #[serde(default)]
320 pub watch: bool,
321 #[serde(default, with = "humantime_serde::option")]
322 pub debounce: Option<Duration>,
323 #[serde(default, with = "humantime_serde::option")]
324 pub max_stale: Option<Duration>,
325 #[serde(flatten, default)]
326 pub extra: BTreeMap<String, serde_json::Value>,
327}
328
329#[derive(Debug, Clone, Deserialize, Serialize)]
330pub struct RemoteFileProviderConfig {
331 pub name: String,
332 pub url: String,
333 #[serde(default, with = "humantime_serde::option")]
334 pub refresh: Option<Duration>,
335 #[serde(default, with = "humantime_serde::option")]
336 pub timeout: Option<Duration>,
337 #[serde(default)]
338 pub on_refresh_failure: RefreshFailurePolicy,
339 #[serde(default, with = "humantime_serde::option")]
340 pub max_stale: Option<Duration>,
341 #[serde(flatten, default)]
342 pub extra: BTreeMap<String, serde_json::Value>,
343}
344
345#[derive(Debug, Clone, Deserialize, Serialize)]
346pub struct CommandProviderConfig {
347 pub name: String,
348 pub command: String,
349 #[serde(default)]
350 pub args: Vec<String>,
351 #[serde(default, with = "humantime_serde::option")]
352 pub refresh: Option<Duration>,
353 #[serde(default, with = "humantime_serde::option")]
354 pub timeout: Option<Duration>,
355 #[serde(default)]
356 pub on_refresh_failure: RefreshFailurePolicy,
357 #[serde(default, with = "humantime_serde::option")]
358 pub max_stale: Option<Duration>,
359 #[serde(flatten, default)]
360 pub extra: BTreeMap<String, serde_json::Value>,
361}
362
363#[derive(Debug, Clone, Deserialize, Serialize)]
364pub struct CustomProviderConfig {
365 pub name: String,
366 pub kind: String,
367 #[serde(default, with = "humantime_serde::option")]
368 pub refresh: Option<Duration>,
369 #[serde(default, with = "humantime_serde::option")]
370 pub timeout: Option<Duration>,
371 #[serde(default)]
372 pub on_refresh_failure: RefreshFailurePolicy,
373 #[serde(default, with = "humantime_serde::option")]
374 pub max_stale: Option<Duration>,
375 #[serde(flatten, default)]
376 pub extra: BTreeMap<String, serde_json::Value>,
377}
378
379#[derive(Debug, Clone, Deserialize, Serialize)]
380pub struct SourceConfig {
381 pub name: String,
382 #[serde(default)]
383 pub priority: i32,
384 #[serde(default)]
385 pub peers_from: Vec<String>,
386 #[serde(default)]
387 pub accept_transport: Vec<TransportInputConfig>,
388 #[serde(default)]
389 pub accept_headers: Vec<HeaderInputConfig>,
390}
391
392#[derive(Debug, Clone, Deserialize, Serialize)]
393pub struct TransportInputConfig {
394 pub kind: String,
395}
396
397#[derive(Debug, Clone, Deserialize, Serialize)]
398pub struct HeaderInputConfig {
399 pub kind: String,
400 #[serde(default)]
401 pub mode: HeaderMode,
402 #[serde(default)]
403 pub direction: ChainDirection,
404 #[serde(default)]
405 pub param: Option<String>,
406 #[serde(default)]
407 pub use_only_if_not_in_trusted_peers: bool,
408}
409
410#[derive(Debug, Clone, Copy, Deserialize, Serialize, Default, PartialEq, Eq)]
411#[serde(rename_all = "kebab-case")]
412pub enum HeaderMode {
413 #[default]
414 Single,
415 Recursive,
416}
417
418#[derive(Debug, Clone, Copy, Deserialize, Serialize, Default, PartialEq, Eq)]
419#[serde(rename_all = "kebab-case")]
420pub enum ChainDirection {
421 LeftToRight,
422 #[default]
423 RightToLeft,
424}
425
426#[derive(Debug, Clone, Copy, Deserialize, Serialize, Default, PartialEq, Eq)]
427#[serde(rename_all = "kebab-case")]
428pub enum RefreshFailurePolicy {
429 #[default]
430 KeepLastGood,
431 Clear,
432}
433
434#[derive(Debug, Clone, Deserialize, Serialize)]
435pub struct FallbackConfig {
436 #[serde(default)]
437 pub strategy: FallbackStrategy,
438}
439
440impl Default for FallbackConfig {
441 fn default() -> Self {
442 Self {
443 strategy: FallbackStrategy::RemoteAddr,
444 }
445 }
446}
447
448#[derive(Debug, Clone, Copy, Deserialize, Serialize, Default, PartialEq, Eq)]
449#[serde(rename_all = "kebab-case")]
450pub enum FallbackStrategy {
451 #[default]
452 RemoteAddr,
453}
454
455pub(crate) fn parse_ip_or_cidr(entry: &str) -> Result<IpNet, ()> {
456 if let Ok(net) = entry.parse::<IpNet>() {
457 return Ok(net);
458 }
459 let addr = entry.parse::<IpAddr>().map_err(|_| ())?;
460 Ok(IpNet::from(addr))
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[test]
468 fn deserialize_docker_provider_as_custom_provider() {
469 let config: ProviderConfig = serde_json::from_value(serde_json::json!({
470 "name": "docker-ingress",
471 "kind": "docker-provider",
472 "host": "unix:///var/run/docker.sock",
473 "networks": ["edge-ingress", "internal-proxy"],
474 "refresh": "30s",
475 "timeout": "5s",
476 "on_refresh_failure": "keep-last-good",
477 "max_stale": "10m"
478 }))
479 .unwrap();
480
481 let ProviderConfig::Custom(custom) = config else {
482 panic!("expected custom provider");
483 };
484 assert_eq!(custom.kind, "docker-provider");
485 assert_eq!(custom.name, "docker-ingress");
486 assert_eq!(custom.refresh, Some(Duration::from_secs(30)));
487 assert_eq!(custom.timeout, Some(Duration::from_secs(5)));
488 assert_eq!(custom.max_stale, Some(Duration::from_secs(600)));
489 assert_eq!(
490 custom.extra.get("host").and_then(serde_json::Value::as_str),
491 Some("unix:///var/run/docker.sock")
492 );
493 assert_eq!(
494 custom
495 .extra
496 .get("networks")
497 .and_then(serde_json::Value::as_array)
498 .map(Vec::len),
499 Some(2)
500 );
501 }
502
503 #[test]
504 fn deserialize_kube_provider_as_custom_provider() {
505 let config: ProviderConfig = serde_json::from_value(serde_json::json!({
506 "name": "kube-ingress-pods",
507 "kind": "kube-provider",
508 "resource": "pods",
509 "namespace": "ingress-nginx",
510 "label_selector": "app.kubernetes.io/name=ingress-nginx",
511 "refresh": "30s",
512 "timeout": "5s"
513 }))
514 .unwrap();
515
516 let ProviderConfig::Custom(custom) = config else {
517 panic!("expected custom provider");
518 };
519 assert_eq!(custom.kind, "kube-provider");
520 assert_eq!(custom.name, "kube-ingress-pods");
521 assert_eq!(custom.refresh, Some(Duration::from_secs(30)));
522 assert_eq!(custom.timeout, Some(Duration::from_secs(5)));
523 assert_eq!(
524 custom
525 .extra
526 .get("resource")
527 .and_then(serde_json::Value::as_str),
528 Some("pods")
529 );
530 assert_eq!(
531 custom
532 .extra
533 .get("namespace")
534 .and_then(serde_json::Value::as_str),
535 Some("ingress-nginx")
536 );
537 assert_eq!(
538 custom
539 .extra
540 .get("label_selector")
541 .and_then(serde_json::Value::as_str),
542 Some("app.kubernetes.io/name=ingress-nginx")
543 );
544 }
545}