1use std::{
2 collections::HashMap,
3 process::Stdio,
4 sync::Arc,
5 time::{Duration, Instant},
6};
7
8use arc_swap::ArcSwap;
9use ipnet::IpNet;
10use notify::{RecommendedWatcher, RecursiveMode, Watcher};
11use tokio::{
12 process::Command,
13 sync::watch,
14 task::JoinHandle,
15 time::{sleep, timeout},
16};
17use tracing::{debug, warn};
18
19use crate::{
20 config::{CoreProviderConfig, ProviderConfig, RefreshFailurePolicy, parse_ip_or_cidr},
21 error::{RealIpError, RealIpResult},
22 extension::{DynamicProvider, ProviderFactoryRegistry},
23};
24
25#[derive(Debug, Clone)]
26pub struct ProviderSnapshot {
27 pub cidrs: Arc<Vec<IpNet>>,
28 pub updated_at: Instant,
29 pub stale_after: Option<Duration>,
30}
31
32impl ProviderSnapshot {
33 fn new(cidrs: Vec<IpNet>, stale_after: Option<Duration>) -> Self {
34 Self {
35 cidrs: Arc::new(cidrs),
36 updated_at: Instant::now(),
37 stale_after,
38 }
39 }
40}
41
42#[derive(Debug)]
43pub struct ProviderRegistry {
44 state: Arc<ArcSwap<ProviderState>>,
45 tasks: Vec<JoinHandle<()>>,
46 _watchers: Vec<RecommendedWatcher>,
47}
48
49#[derive(Debug, Default)]
50struct ProviderState {
51 by_name: HashMap<String, ProviderSnapshot>,
52 all_cidrs: Vec<IpNet>,
53}
54
55impl ProviderRegistry {
56 pub async fn from_configs(configs: &[ProviderConfig]) -> RealIpResult<Self> {
57 let factories = ProviderFactoryRegistry::with_builtin_providers()?;
58 Self::from_configs_with_factories(configs, &factories).await
59 }
60
61 pub async fn from_configs_with_factories(
62 configs: &[ProviderConfig],
63 factories: &ProviderFactoryRegistry,
64 ) -> RealIpResult<Self> {
65 let mut by_name = HashMap::new();
66 let mut runtime_configs = Vec::with_capacity(configs.len());
67 let mut tasks = Vec::new();
68 let mut watchers = Vec::new();
69
70 for config in configs {
71 let custom_provider = build_custom_provider(config, factories)?;
72 let snapshot = load_provider(config, custom_provider.as_deref()).await?;
73 by_name.insert(config.name().to_string(), snapshot);
74 runtime_configs.push((config.clone(), custom_provider));
75 }
76
77 let state = Arc::new(ArcSwap::from_pointee(ProviderState {
78 all_cidrs: collect_all_cidrs(&by_name),
79 by_name,
80 }));
81
82 for (config, custom_provider) in runtime_configs {
83 if let Some(handle) =
84 spawn_refresh_task(config.clone(), custom_provider.clone(), state.clone())
85 {
86 tasks.push(handle);
87 }
88
89 if let Some(watcher) = spawn_file_watcher(config, state.clone())? {
90 watchers.push(watcher);
91 }
92 }
93
94 Ok(Self {
95 state,
96 tasks,
97 _watchers: watchers,
98 })
99 }
100
101 pub async fn snapshot(&self, name: &str) -> Option<ProviderSnapshot> {
102 self.state.load().by_name.get(name).cloned()
103 }
104
105 pub async fn all_cidrs(&self) -> Vec<IpNet> {
106 self.state.load().all_cidrs.clone()
107 }
108}
109
110impl Drop for ProviderRegistry {
111 fn drop(&mut self) {
112 for task in &self.tasks {
113 task.abort();
114 }
115 }
116}
117
118fn spawn_refresh_task(
119 config: ProviderConfig,
120 custom_provider: Option<Arc<dyn DynamicProvider>>,
121 state: Arc<ArcSwap<ProviderState>>,
122) -> Option<JoinHandle<()>> {
123 let refresh = config.refresh()?;
124
125 Some(tokio::spawn(async move {
126 loop {
127 sleep(refresh).await;
128 if let Err(error) = refresh_provider(&config, custom_provider.as_deref(), &state).await
129 {
130 warn!(provider = %config.name(), error = %error, "Failed to refresh real-ip provider");
131 }
132 }
133 }))
134}
135
136fn spawn_file_watcher(
137 config: ProviderConfig,
138 state: Arc<ArcSwap<ProviderState>>,
139) -> RealIpResult<Option<RecommendedWatcher>> {
140 let (path, debounce) = match config.watch_path() {
141 Some((path, debounce)) => (path.clone(), debounce),
142 None => return Ok(None),
143 };
144 let handle = tokio::runtime::Handle::current();
145 let (tx, mut rx) = watch::channel(());
146
147 let mut watcher = notify::recommended_watcher(move |event: notify::Result<notify::Event>| {
148 if event.is_ok() {
149 let _ = tx.send(());
150 }
151 })
152 .map_err(|error| RealIpError::WatchProvider {
153 path: path.clone(),
154 details: error.to_string(),
155 })?;
156 watcher
157 .watch(&path, RecursiveMode::NonRecursive)
158 .map_err(|error| RealIpError::WatchProvider {
159 path: path.clone(),
160 details: error.to_string(),
161 })?;
162
163 handle.spawn(async move {
164 while rx.changed().await.is_ok() {
165 sleep(debounce).await;
166 if let Err(error) = refresh_provider(&config, None, &state).await {
167 warn!(provider = %config.name(), error = %error, "Failed to refresh watched local-file provider");
168 }
169 }
170 });
171
172 Ok(Some(watcher))
173}
174
175async fn refresh_provider(
176 config: &ProviderConfig,
177 custom_provider: Option<&dyn DynamicProvider>,
178 state: &Arc<ArcSwap<ProviderState>>,
179) -> RealIpResult<()> {
180 match load_provider(config, custom_provider).await {
181 Ok(snapshot) => {
182 replace_provider_snapshot(state, config.name(), Some(snapshot));
183 debug!(provider = %config.name(), "Refreshed real-ip provider");
184 Ok(())
185 }
186 Err(error) => {
187 if matches!(config.on_refresh_failure(), RefreshFailurePolicy::Clear) {
188 replace_provider_snapshot(state, config.name(), None);
189 }
190 Err(error)
191 }
192 }
193}
194
195async fn load_provider(
196 config: &ProviderConfig,
197 custom_provider: Option<&dyn DynamicProvider>,
198) -> RealIpResult<ProviderSnapshot> {
199 let cidrs = match config {
200 ProviderConfig::Core(CoreProviderConfig::Inline(config)) => config.cidrs.clone(),
201 ProviderConfig::Core(CoreProviderConfig::LocalFile(_)) => {
202 parse_entries(config.name(), &read_local_file(config).await?)?
203 }
204 ProviderConfig::Core(CoreProviderConfig::RemoteFile(_)) => {
205 parse_entries(config.name(), &read_remote_file(config).await?)?
206 }
207 ProviderConfig::Core(CoreProviderConfig::Command(_)) => {
208 parse_entries(config.name(), &run_command_provider(config).await?)?
209 }
210 ProviderConfig::Custom(config) => {
211 custom_provider
212 .ok_or_else(|| RealIpError::MissingProviderFactory {
213 kind: config.kind.clone(),
214 })?
215 .load()
216 .await?
217 }
218 };
219
220 if cidrs.is_empty() {
221 return Err(RealIpError::EmptyProviderOutput {
222 provider: config.name().to_string(),
223 });
224 }
225
226 Ok(ProviderSnapshot::new(cidrs, config.max_stale()))
227}
228
229fn build_custom_provider(
230 config: &ProviderConfig,
231 factories: &ProviderFactoryRegistry,
232) -> RealIpResult<Option<Arc<dyn DynamicProvider>>> {
233 let Some(custom) = config.custom() else {
234 return Ok(None);
235 };
236 let Some(factory) = factories.get(&custom.kind) else {
237 return Err(RealIpError::MissingProviderFactory {
238 kind: custom.kind.clone(),
239 });
240 };
241 factory.create(custom).map(Some)
242}
243
244async fn read_local_file(config: &ProviderConfig) -> RealIpResult<String> {
245 let path = config.local_file_path().expect("validated path").clone();
246 tokio::fs::read_to_string(&path)
247 .await
248 .map_err(|source| RealIpError::ReadProviderFile { path, source })
249}
250
251async fn read_remote_file(config: &ProviderConfig) -> RealIpResult<String> {
252 let url = config.remote_file_url().expect("validated url").to_string();
253 let mut builder = reqwest::Client::builder();
254 if let Some(timeout) = config.timeout() {
255 builder = builder.timeout(timeout);
256 }
257 let client = builder
258 .build()
259 .map_err(|source| RealIpError::ProviderHttp {
260 url: url.clone(),
261 source,
262 })?;
263 let response = client
264 .get(&url)
265 .send()
266 .await
267 .and_then(reqwest::Response::error_for_status)
268 .map_err(|source| RealIpError::ProviderHttp {
269 url: url.clone(),
270 source,
271 })?;
272 response
273 .text()
274 .await
275 .map_err(|source| RealIpError::ProviderHttp { url, source })
276}
277
278async fn run_command_provider(config: &ProviderConfig) -> RealIpResult<String> {
279 let (command, args) = config.command_spec().expect("validated command");
280 let command = command.to_string();
281 let mut child = Command::new(&command);
282 child.args(args);
283 child.stdout(Stdio::piped());
284 child.stderr(Stdio::piped());
285
286 let output = if let Some(limit) = config.timeout() {
287 timeout(limit, child.output())
288 .await
289 .map_err(|_| RealIpError::ProviderCommand {
290 command: command.clone(),
291 details: format!("timed out after {:?}", limit),
292 })?
293 .map_err(|error| RealIpError::ProviderCommand {
294 command: command.clone(),
295 details: error.to_string(),
296 })?
297 } else {
298 child
299 .output()
300 .await
301 .map_err(|error| RealIpError::ProviderCommand {
302 command: command.clone(),
303 details: error.to_string(),
304 })?
305 };
306
307 if !output.status.success() {
308 let stderr = String::from_utf8_lossy(&output.stderr);
309 return Err(RealIpError::ProviderCommand {
310 command,
311 details: stderr.trim().to_string(),
312 });
313 }
314
315 Ok(String::from_utf8_lossy(&output.stdout).into_owned())
316}
317
318fn parse_entries(provider: &str, content: &str) -> RealIpResult<Vec<IpNet>> {
319 let mut cidrs = Vec::new();
320 for raw_line in content.lines() {
321 let line = raw_line.split('#').next().unwrap_or("").trim();
322 if line.is_empty() {
323 continue;
324 }
325
326 for entry in line
327 .split(|ch: char| ch == ',' || ch.is_ascii_whitespace())
328 .filter(|entry| !entry.is_empty())
329 {
330 let cidr = parse_ip_or_cidr(entry).map_err(|_| RealIpError::InvalidProviderEntry {
331 provider: provider.to_string(),
332 entry: entry.to_string(),
333 })?;
334 cidrs.push(cidr);
335 }
336 }
337
338 Ok(cidrs)
339}
340
341fn replace_provider_snapshot(
342 state: &Arc<ArcSwap<ProviderState>>,
343 name: &str,
344 snapshot: Option<ProviderSnapshot>,
345) {
346 let current = state.load();
347 let mut by_name = current.by_name.clone();
348 match snapshot {
349 Some(snapshot) => {
350 by_name.insert(name.to_string(), snapshot);
351 }
352 None => {
353 by_name.remove(name);
354 }
355 }
356 state.store(Arc::new(ProviderState {
357 all_cidrs: collect_all_cidrs(&by_name),
358 by_name,
359 }));
360}
361
362fn collect_all_cidrs(by_name: &HashMap<String, ProviderSnapshot>) -> Vec<IpNet> {
363 by_name
364 .values()
365 .flat_map(|snapshot| snapshot.cidrs.iter().copied())
366 .collect()
367}