Skip to main content

securitydept_realip/
providers.rs

1use std::{
2    collections::HashMap,
3    process::Stdio,
4    sync::Arc,
5    time::{Duration, Instant},
6};
7
8use arc_swap::ArcSwap;
9use ipnet::IpNet;
10use notify::{RecommendedWatcher, RecursiveMode, Watcher};
11use tokio::{
12    process::Command,
13    sync::watch,
14    task::JoinHandle,
15    time::{sleep, timeout},
16};
17use tracing::{debug, warn};
18
19use crate::{
20    config::{CoreProviderConfig, ProviderConfig, RefreshFailurePolicy, parse_ip_or_cidr},
21    error::{RealIpError, RealIpResult},
22    extension::{DynamicProvider, ProviderFactoryRegistry},
23};
24
25#[derive(Debug, Clone)]
26pub struct ProviderSnapshot {
27    pub cidrs: Arc<Vec<IpNet>>,
28    pub updated_at: Instant,
29    pub stale_after: Option<Duration>,
30}
31
32impl ProviderSnapshot {
33    fn new(cidrs: Vec<IpNet>, stale_after: Option<Duration>) -> Self {
34        Self {
35            cidrs: Arc::new(cidrs),
36            updated_at: Instant::now(),
37            stale_after,
38        }
39    }
40}
41
42#[derive(Debug)]
43pub struct ProviderRegistry {
44    state: Arc<ArcSwap<ProviderState>>,
45    tasks: Vec<JoinHandle<()>>,
46    _watchers: Vec<RecommendedWatcher>,
47}
48
49#[derive(Debug, Default)]
50struct ProviderState {
51    by_name: HashMap<String, ProviderSnapshot>,
52    all_cidrs: Vec<IpNet>,
53}
54
55impl ProviderRegistry {
56    pub async fn from_configs(configs: &[ProviderConfig]) -> RealIpResult<Self> {
57        let factories = ProviderFactoryRegistry::with_builtin_providers()?;
58        Self::from_configs_with_factories(configs, &factories).await
59    }
60
61    pub async fn from_configs_with_factories(
62        configs: &[ProviderConfig],
63        factories: &ProviderFactoryRegistry,
64    ) -> RealIpResult<Self> {
65        let mut by_name = HashMap::new();
66        let mut runtime_configs = Vec::with_capacity(configs.len());
67        let mut tasks = Vec::new();
68        let mut watchers = Vec::new();
69
70        for config in configs {
71            let custom_provider = build_custom_provider(config, factories)?;
72            let snapshot = load_provider(config, custom_provider.as_deref()).await?;
73            by_name.insert(config.name().to_string(), snapshot);
74            runtime_configs.push((config.clone(), custom_provider));
75        }
76
77        let state = Arc::new(ArcSwap::from_pointee(ProviderState {
78            all_cidrs: collect_all_cidrs(&by_name),
79            by_name,
80        }));
81
82        for (config, custom_provider) in runtime_configs {
83            if let Some(handle) =
84                spawn_refresh_task(config.clone(), custom_provider.clone(), state.clone())
85            {
86                tasks.push(handle);
87            }
88
89            if let Some(watcher) = spawn_file_watcher(config, state.clone())? {
90                watchers.push(watcher);
91            }
92        }
93
94        Ok(Self {
95            state,
96            tasks,
97            _watchers: watchers,
98        })
99    }
100
101    pub async fn snapshot(&self, name: &str) -> Option<ProviderSnapshot> {
102        self.state.load().by_name.get(name).cloned()
103    }
104
105    pub async fn all_cidrs(&self) -> Vec<IpNet> {
106        self.state.load().all_cidrs.clone()
107    }
108}
109
110impl Drop for ProviderRegistry {
111    fn drop(&mut self) {
112        for task in &self.tasks {
113            task.abort();
114        }
115    }
116}
117
118fn spawn_refresh_task(
119    config: ProviderConfig,
120    custom_provider: Option<Arc<dyn DynamicProvider>>,
121    state: Arc<ArcSwap<ProviderState>>,
122) -> Option<JoinHandle<()>> {
123    let refresh = config.refresh()?;
124
125    Some(tokio::spawn(async move {
126        loop {
127            sleep(refresh).await;
128            if let Err(error) = refresh_provider(&config, custom_provider.as_deref(), &state).await
129            {
130                warn!(provider = %config.name(), error = %error, "Failed to refresh real-ip provider");
131            }
132        }
133    }))
134}
135
136fn spawn_file_watcher(
137    config: ProviderConfig,
138    state: Arc<ArcSwap<ProviderState>>,
139) -> RealIpResult<Option<RecommendedWatcher>> {
140    let (path, debounce) = match config.watch_path() {
141        Some((path, debounce)) => (path.clone(), debounce),
142        None => return Ok(None),
143    };
144    let handle = tokio::runtime::Handle::current();
145    let (tx, mut rx) = watch::channel(());
146
147    let mut watcher = notify::recommended_watcher(move |event: notify::Result<notify::Event>| {
148        if event.is_ok() {
149            let _ = tx.send(());
150        }
151    })
152    .map_err(|error| RealIpError::WatchProvider {
153        path: path.clone(),
154        details: error.to_string(),
155    })?;
156    watcher
157        .watch(&path, RecursiveMode::NonRecursive)
158        .map_err(|error| RealIpError::WatchProvider {
159            path: path.clone(),
160            details: error.to_string(),
161        })?;
162
163    handle.spawn(async move {
164        while rx.changed().await.is_ok() {
165            sleep(debounce).await;
166            if let Err(error) = refresh_provider(&config, None, &state).await {
167                warn!(provider = %config.name(), error = %error, "Failed to refresh watched local-file provider");
168            }
169        }
170    });
171
172    Ok(Some(watcher))
173}
174
175async fn refresh_provider(
176    config: &ProviderConfig,
177    custom_provider: Option<&dyn DynamicProvider>,
178    state: &Arc<ArcSwap<ProviderState>>,
179) -> RealIpResult<()> {
180    match load_provider(config, custom_provider).await {
181        Ok(snapshot) => {
182            replace_provider_snapshot(state, config.name(), Some(snapshot));
183            debug!(provider = %config.name(), "Refreshed real-ip provider");
184            Ok(())
185        }
186        Err(error) => {
187            if matches!(config.on_refresh_failure(), RefreshFailurePolicy::Clear) {
188                replace_provider_snapshot(state, config.name(), None);
189            }
190            Err(error)
191        }
192    }
193}
194
195async fn load_provider(
196    config: &ProviderConfig,
197    custom_provider: Option<&dyn DynamicProvider>,
198) -> RealIpResult<ProviderSnapshot> {
199    let cidrs = match config {
200        ProviderConfig::Core(CoreProviderConfig::Inline(config)) => config.cidrs.clone(),
201        ProviderConfig::Core(CoreProviderConfig::LocalFile(_)) => {
202            parse_entries(config.name(), &read_local_file(config).await?)?
203        }
204        ProviderConfig::Core(CoreProviderConfig::RemoteFile(_)) => {
205            parse_entries(config.name(), &read_remote_file(config).await?)?
206        }
207        ProviderConfig::Core(CoreProviderConfig::Command(_)) => {
208            parse_entries(config.name(), &run_command_provider(config).await?)?
209        }
210        ProviderConfig::Custom(config) => {
211            custom_provider
212                .ok_or_else(|| RealIpError::MissingProviderFactory {
213                    kind: config.kind.clone(),
214                })?
215                .load()
216                .await?
217        }
218    };
219
220    if cidrs.is_empty() {
221        return Err(RealIpError::EmptyProviderOutput {
222            provider: config.name().to_string(),
223        });
224    }
225
226    Ok(ProviderSnapshot::new(cidrs, config.max_stale()))
227}
228
229fn build_custom_provider(
230    config: &ProviderConfig,
231    factories: &ProviderFactoryRegistry,
232) -> RealIpResult<Option<Arc<dyn DynamicProvider>>> {
233    let Some(custom) = config.custom() else {
234        return Ok(None);
235    };
236    let Some(factory) = factories.get(&custom.kind) else {
237        return Err(RealIpError::MissingProviderFactory {
238            kind: custom.kind.clone(),
239        });
240    };
241    factory.create(custom).map(Some)
242}
243
244async fn read_local_file(config: &ProviderConfig) -> RealIpResult<String> {
245    let path = config.local_file_path().expect("validated path").clone();
246    tokio::fs::read_to_string(&path)
247        .await
248        .map_err(|source| RealIpError::ReadProviderFile { path, source })
249}
250
251async fn read_remote_file(config: &ProviderConfig) -> RealIpResult<String> {
252    let url = config.remote_file_url().expect("validated url").to_string();
253    let mut builder = reqwest::Client::builder();
254    if let Some(timeout) = config.timeout() {
255        builder = builder.timeout(timeout);
256    }
257    let client = builder
258        .build()
259        .map_err(|source| RealIpError::ProviderHttp {
260            url: url.clone(),
261            source,
262        })?;
263    let response = client
264        .get(&url)
265        .send()
266        .await
267        .and_then(reqwest::Response::error_for_status)
268        .map_err(|source| RealIpError::ProviderHttp {
269            url: url.clone(),
270            source,
271        })?;
272    response
273        .text()
274        .await
275        .map_err(|source| RealIpError::ProviderHttp { url, source })
276}
277
278async fn run_command_provider(config: &ProviderConfig) -> RealIpResult<String> {
279    let (command, args) = config.command_spec().expect("validated command");
280    let command = command.to_string();
281    let mut child = Command::new(&command);
282    child.args(args);
283    child.stdout(Stdio::piped());
284    child.stderr(Stdio::piped());
285
286    let output = if let Some(limit) = config.timeout() {
287        timeout(limit, child.output())
288            .await
289            .map_err(|_| RealIpError::ProviderCommand {
290                command: command.clone(),
291                details: format!("timed out after {:?}", limit),
292            })?
293            .map_err(|error| RealIpError::ProviderCommand {
294                command: command.clone(),
295                details: error.to_string(),
296            })?
297    } else {
298        child
299            .output()
300            .await
301            .map_err(|error| RealIpError::ProviderCommand {
302                command: command.clone(),
303                details: error.to_string(),
304            })?
305    };
306
307    if !output.status.success() {
308        let stderr = String::from_utf8_lossy(&output.stderr);
309        return Err(RealIpError::ProviderCommand {
310            command,
311            details: stderr.trim().to_string(),
312        });
313    }
314
315    Ok(String::from_utf8_lossy(&output.stdout).into_owned())
316}
317
318fn parse_entries(provider: &str, content: &str) -> RealIpResult<Vec<IpNet>> {
319    let mut cidrs = Vec::new();
320    for raw_line in content.lines() {
321        let line = raw_line.split('#').next().unwrap_or("").trim();
322        if line.is_empty() {
323            continue;
324        }
325
326        for entry in line
327            .split(|ch: char| ch == ',' || ch.is_ascii_whitespace())
328            .filter(|entry| !entry.is_empty())
329        {
330            let cidr = parse_ip_or_cidr(entry).map_err(|_| RealIpError::InvalidProviderEntry {
331                provider: provider.to_string(),
332                entry: entry.to_string(),
333            })?;
334            cidrs.push(cidr);
335        }
336    }
337
338    Ok(cidrs)
339}
340
341fn replace_provider_snapshot(
342    state: &Arc<ArcSwap<ProviderState>>,
343    name: &str,
344    snapshot: Option<ProviderSnapshot>,
345) {
346    let current = state.load();
347    let mut by_name = current.by_name.clone();
348    match snapshot {
349        Some(snapshot) => {
350            by_name.insert(name.to_string(), snapshot);
351        }
352        None => {
353            by_name.remove(name);
354        }
355    }
356    state.store(Arc::new(ProviderState {
357        all_cidrs: collect_all_cidrs(&by_name),
358        by_name,
359    }));
360}
361
362fn collect_all_cidrs(by_name: &HashMap<String, ProviderSnapshot>) -> Vec<IpNet> {
363    by_name
364        .values()
365        .flat_map(|snapshot| snapshot.cidrs.iter().copied())
366        .collect()
367}