Skip to main content

soli_proxy/config/
mod.rs

1pub mod serializer;
2
3use anyhow::Result;
4use arc_swap::ArcSwap;
5use notify::{RecommendedWatcher, RecursiveMode, Watcher};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use std::path::{Path, PathBuf};
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::Arc;
11use tokio::sync::mpsc;
12use url::Url;
13
14#[async_trait::async_trait]
15pub trait ConfigManagerTrait: Send + Sync {
16    async fn reload(&self) -> Result<()>;
17    fn get_config(&self) -> Arc<Config>;
18    fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()>;
19    fn add_route(&self, rule: ProxyRule) -> Result<()>;
20    fn remove_route(&self, index: usize) -> Result<()>;
21}
22
23#[derive(Deserialize, Default, Clone, Debug)]
24pub struct TomlConfig {
25    #[serde(default)]
26    pub server: ServerConfig,
27    #[serde(default)]
28    pub tls: TlsConfig,
29    pub letsencrypt: Option<LetsEncryptConfig>,
30    pub scripting: Option<ScriptingTomlConfig>,
31    pub admin: Option<AdminConfig>,
32    pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
33}
34
35#[derive(Deserialize, Serialize, Clone, Debug)]
36pub struct CircuitBreakerTomlConfig {
37    pub failure_threshold: Option<u32>,
38    pub recovery_timeout_secs: Option<u64>,
39    pub success_threshold: Option<u32>,
40    pub failure_status_codes: Option<Vec<u16>>,
41}
42
43#[derive(Deserialize, Serialize, Clone, Debug)]
44pub struct AdminConfig {
45    pub enabled: bool,
46    pub bind: String,
47    pub api_key: Option<String>,
48}
49
50impl Default for AdminConfig {
51    fn default() -> Self {
52        Self {
53            enabled: false,
54            bind: "127.0.0.1:9090".to_string(),
55            api_key: None,
56        }
57    }
58}
59
60#[derive(Deserialize, Serialize, Clone, Debug, Default)]
61pub struct ScriptingTomlConfig {
62    pub enabled: bool,
63    pub scripts_dir: Option<String>,
64    pub hook_timeout_ms: Option<u64>,
65}
66
67#[derive(Deserialize, Serialize, Default, Clone, Debug)]
68pub struct ServerConfig {
69    pub bind: String,
70    pub https_port: u16,
71}
72
73#[derive(Deserialize, Serialize, Default, Clone, Debug)]
74pub struct TlsConfig {
75    pub mode: String,
76    pub cache_dir: String,
77}
78
79#[derive(Deserialize, Serialize, Clone, Debug)]
80pub struct LetsEncryptConfig {
81    pub staging: bool,
82    pub email: String,
83    pub terms_agreed: bool,
84}
85
86#[derive(Clone, Debug, Serialize)]
87pub struct Config {
88    pub server: ServerConfig,
89    pub tls: TlsConfig,
90    pub letsencrypt: Option<LetsEncryptConfig>,
91    pub scripting: ScriptingTomlConfig,
92    pub admin: AdminConfig,
93    pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
94    pub rules: Vec<ProxyRule>,
95    pub global_scripts: Vec<String>,
96}
97
98#[derive(Clone, Debug, Serialize, Deserialize)]
99pub struct ProxyRule {
100    pub matcher: RuleMatcher,
101    pub targets: Vec<Target>,
102    pub headers: Vec<HeaderRule>,
103    pub scripts: Vec<String>,
104}
105
106#[derive(Clone, Debug)]
107pub enum RuleMatcher {
108    Default,
109    Prefix(String),
110    Regex(RegexMatcher),
111    Exact(String),
112    Domain(String),
113    DomainPath(String, String),
114}
115
116/// Wrapper around Regex that stores the original pattern for serialization
117#[derive(Clone, Debug)]
118pub struct RegexMatcher {
119    pub pattern: String,
120    pub regex: Regex,
121}
122
123impl RegexMatcher {
124    pub fn new(pattern: &str) -> Result<Self> {
125        Ok(Self {
126            pattern: pattern.to_string(),
127            regex: Regex::new(pattern)?,
128        })
129    }
130
131    pub fn is_match(&self, text: &str) -> bool {
132        self.regex.is_match(text)
133    }
134}
135
136impl Serialize for RuleMatcher {
137    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
138    where
139        S: serde::Serializer,
140    {
141        use serde::ser::SerializeMap;
142        let mut map = serializer.serialize_map(Some(2))?;
143        match self {
144            RuleMatcher::Default => {
145                map.serialize_entry("type", "default")?;
146            }
147            RuleMatcher::Prefix(v) => {
148                map.serialize_entry("type", "prefix")?;
149                map.serialize_entry("value", v)?;
150            }
151            RuleMatcher::Regex(rm) => {
152                map.serialize_entry("type", "regex")?;
153                map.serialize_entry("value", &rm.pattern)?;
154            }
155            RuleMatcher::Exact(v) => {
156                map.serialize_entry("type", "exact")?;
157                map.serialize_entry("value", v)?;
158            }
159            RuleMatcher::Domain(v) => {
160                map.serialize_entry("type", "domain")?;
161                map.serialize_entry("value", v)?;
162            }
163            RuleMatcher::DomainPath(d, p) => {
164                map.serialize_entry("type", "domain_path")?;
165                map.serialize_entry("domain", d)?;
166                map.serialize_entry("path", p)?;
167            }
168        }
169        map.end()
170    }
171}
172
173impl<'de> Deserialize<'de> for RuleMatcher {
174    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
175    where
176        D: serde::Deserializer<'de>,
177    {
178        use serde::de::Error;
179        let value: serde_json::Value = Deserialize::deserialize(deserializer)?;
180        let obj = value
181            .as_object()
182            .ok_or_else(|| D::Error::custom("expected object"))?;
183        let matcher_type = obj
184            .get("type")
185            .and_then(|v| v.as_str())
186            .ok_or_else(|| D::Error::custom("missing 'type' field"))?;
187
188        match matcher_type {
189            "default" => Ok(RuleMatcher::Default),
190            "exact" => {
191                let v = obj
192                    .get("value")
193                    .and_then(|v| v.as_str())
194                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
195                Ok(RuleMatcher::Exact(v.to_string()))
196            }
197            "prefix" => {
198                let v = obj
199                    .get("value")
200                    .and_then(|v| v.as_str())
201                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
202                Ok(RuleMatcher::Prefix(v.to_string()))
203            }
204            "regex" => {
205                let v = obj
206                    .get("value")
207                    .and_then(|v| v.as_str())
208                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
209                let rm = RegexMatcher::new(v)
210                    .map_err(|e| D::Error::custom(format!("invalid regex: {}", e)))?;
211                Ok(RuleMatcher::Regex(rm))
212            }
213            "domain" => {
214                let v = obj
215                    .get("value")
216                    .and_then(|v| v.as_str())
217                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
218                Ok(RuleMatcher::Domain(v.to_string()))
219            }
220            "domain_path" => {
221                let d = obj
222                    .get("domain")
223                    .and_then(|v| v.as_str())
224                    .ok_or_else(|| D::Error::custom("missing 'domain'"))?;
225                let p = obj
226                    .get("path")
227                    .and_then(|v| v.as_str())
228                    .ok_or_else(|| D::Error::custom("missing 'path'"))?;
229                Ok(RuleMatcher::DomainPath(d.to_string(), p.to_string()))
230            }
231            other => Err(D::Error::custom(format!("unknown matcher type: {}", other))),
232        }
233    }
234}
235
236#[derive(Clone, Debug, Serialize, Deserialize)]
237pub struct Target {
238    pub url: Url,
239    pub weight: u8,
240}
241
242#[derive(Clone, Debug, Serialize, Deserialize)]
243pub struct HeaderRule {
244    pub name: String,
245    pub value: String,
246}
247
248impl Config {
249    /// Extract unique domain names from Domain and DomainPath rules,
250    /// filtering out IPs and "localhost".
251    pub fn acme_domains(&self) -> Vec<String> {
252        let mut domains = Vec::new();
253        let mut seen = std::collections::HashSet::new();
254
255        for rule in &self.rules {
256            let domain = match &rule.matcher {
257                RuleMatcher::Domain(d) => Some(d.as_str()),
258                RuleMatcher::DomainPath(d, _) => Some(d.as_str()),
259                _ => None,
260            };
261
262            if let Some(d) = domain {
263                if d == "localhost" || d.parse::<std::net::IpAddr>().is_ok() {
264                    continue;
265                }
266                if seen.insert(d.to_string()) {
267                    domains.push(d.to_string());
268                }
269            }
270        }
271
272        domains
273    }
274}
275
276pub struct ConfigManager {
277    config: ArcSwap<Config>,
278    config_path: PathBuf,
279    _watcher: Option<RecommendedWatcher>,
280    suppress_watch: Arc<AtomicBool>,
281}
282
283impl Clone for ConfigManager {
284    fn clone(&self) -> Self {
285        Self {
286            config: ArcSwap::new(self.config.load().clone()),
287            config_path: self.config_path.clone(),
288            _watcher: None,
289            suppress_watch: self.suppress_watch.clone(),
290        }
291    }
292}
293
294impl ConfigManager {
295    pub fn new(config_path: &str) -> Result<Self> {
296        let path = PathBuf::from(config_path);
297        let config = Self::load_config(&path, &path)?;
298        Ok(Self {
299            config: ArcSwap::new(Arc::new(config)),
300            config_path: path,
301            _watcher: None,
302            suppress_watch: Arc::new(AtomicBool::new(false)),
303        })
304    }
305
306    pub fn config_path(&self) -> &Path {
307        &self.config_path
308    }
309
310    pub fn suppress_watch(&self) -> &Arc<AtomicBool> {
311        &self.suppress_watch
312    }
313
314    fn load_config(proxy_conf_path: &Path, config_path: &Path) -> Result<Config> {
315        let content = std::fs::read_to_string(proxy_conf_path).unwrap_or_default();
316        let (rules, global_scripts) = parse_proxy_config(&content)?;
317        let toml_content = std::fs::read_to_string(
318            config_path
319                .parent()
320                .unwrap_or(Path::new("."))
321                .join("config.toml"),
322        )
323        .ok();
324        let toml_config: TomlConfig = toml_content
325            .as_ref()
326            .and_then(|c| toml::from_str(c).ok())
327            .unwrap_or_default();
328
329        Ok(Config {
330            server: toml_config.server,
331            tls: toml_config.tls,
332            letsencrypt: toml_config.letsencrypt,
333            scripting: toml_config.scripting.unwrap_or_default(),
334            admin: toml_config.admin.unwrap_or_default(),
335            circuit_breaker: toml_config.circuit_breaker,
336            rules,
337            global_scripts,
338        })
339    }
340
341    pub fn get_config(&self) -> Arc<Config> {
342        self.config.load().clone()
343    }
344
345    pub fn start_watcher(&self) -> Result<()> {
346        // Ensure the file exists so the watcher has something to watch
347        if !self.config_path.exists() {
348            if let Some(parent) = self.config_path.parent() {
349                std::fs::create_dir_all(parent).ok();
350            }
351            std::fs::write(&self.config_path, "")?;
352        }
353
354        let (tx, mut rx) = mpsc::channel(1);
355        let config_path = self.config_path.clone();
356        let suppress = self.suppress_watch.clone();
357
358        let mut watcher = RecommendedWatcher::new(
359            move |res| {
360                let _ = tx.blocking_send(res);
361            },
362            notify::Config::default(),
363        )?;
364
365        watcher.watch(&config_path, RecursiveMode::NonRecursive)?;
366
367        tracing::info!("Watching config file: {}", config_path.display());
368
369        std::thread::spawn(move || {
370            while let Some(res) = rx.blocking_recv() {
371                match res {
372                    Ok(event) => {
373                        if event.kind.is_modify() {
374                            if suppress.swap(false, Ordering::SeqCst) {
375                                tracing::debug!(
376                                    "Suppressing file watcher reload (admin API write)"
377                                );
378                                continue;
379                            }
380                            tracing::info!("Config file changed, reloading...");
381                        }
382                    }
383                    Err(e) => tracing::error!("Watch error: {}", e),
384                }
385            }
386        });
387
388        Ok(())
389    }
390
391    pub async fn reload(&self) -> Result<()> {
392        let new_config = Self::load_config(&self.config_path, &self.config_path)?;
393        self.config.store(Arc::new(new_config));
394        tracing::info!("Configuration reloaded successfully");
395        Ok(())
396    }
397
398    /// Persist current rules to proxy.conf and swap in-memory config
399    fn persist_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
400        let content = serializer::serialize_proxy_conf(&rules, &global_scripts);
401        self.suppress_watch.store(true, Ordering::SeqCst);
402        std::fs::write(&self.config_path, &content)?;
403        let mut config = (*self.config.load().as_ref()).clone();
404        config.rules = rules;
405        config.global_scripts = global_scripts;
406        self.config.store(Arc::new(config));
407        tracing::info!("Configuration persisted to {}", self.config_path.display());
408        Ok(())
409    }
410
411    pub fn add_route(&self, rule: ProxyRule) -> Result<()> {
412        let cfg = self.get_config();
413        let mut rules = cfg.rules.clone();
414        rules.push(rule);
415        self.persist_rules(rules, cfg.global_scripts.clone())
416    }
417
418    pub fn update_route(&self, index: usize, rule: ProxyRule) -> Result<()> {
419        let cfg = self.get_config();
420        let mut rules = cfg.rules.clone();
421        if index >= rules.len() {
422            anyhow::bail!(
423                "Route index {} out of range (have {} routes)",
424                index,
425                rules.len()
426            );
427        }
428        rules[index] = rule;
429        self.persist_rules(rules, cfg.global_scripts.clone())
430    }
431
432    pub fn remove_route(&self, index: usize) -> Result<()> {
433        let cfg = self.get_config();
434        let mut rules = cfg.rules.clone();
435        if index >= rules.len() {
436            anyhow::bail!(
437                "Route index {} out of range (have {} routes)",
438                index,
439                rules.len()
440            );
441        }
442        rules.remove(index);
443        self.persist_rules(rules, cfg.global_scripts.clone())
444    }
445
446    pub fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
447        self.persist_rules(rules, global_scripts)
448    }
449}
450
451#[async_trait::async_trait]
452impl ConfigManagerTrait for ConfigManager {
453    async fn reload(&self) -> Result<()> {
454        self.reload().await
455    }
456
457    fn get_config(&self) -> Arc<Config> {
458        self.get_config()
459    }
460
461    fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
462        self.update_rules(rules, global_scripts)
463    }
464
465    fn add_route(&self, rule: ProxyRule) -> Result<()> {
466        self.add_route(rule)
467    }
468
469    fn remove_route(&self, index: usize) -> Result<()> {
470        self.remove_route(index)
471    }
472}
473
474/// Extract `@script:a.lua,b.lua` from a string, returning (remaining_str, scripts_vec).
475fn extract_scripts(s: &str) -> (&str, Vec<String>) {
476    if let Some(idx) = s.find("@script:") {
477        let before = s[..idx].trim();
478        let after = &s[idx + "@script:".len()..];
479        // Scripts are comma-separated, ending at whitespace or end-of-string
480        let script_part = after.split_whitespace().next().unwrap_or(after);
481        let scripts: Vec<String> = script_part
482            .split(',')
483            .map(|s| s.trim().to_string())
484            .filter(|s| !s.is_empty())
485            .collect();
486        (before, scripts)
487    } else {
488        (s, Vec::new())
489    }
490}
491
492fn parse_proxy_config(content: &str) -> Result<(Vec<ProxyRule>, Vec<String>)> {
493    let mut rules = Vec::new();
494    let mut global_scripts = Vec::new();
495
496    // Join continuation lines (backslash at end of line)
497    let mut joined_lines: Vec<String> = Vec::new();
498    for line in content.lines() {
499        if let Some(current) = joined_lines.last_mut() {
500            if current.ends_with('\\') {
501                current.pop(); // remove the backslash
502                current.push_str(line.trim());
503                continue;
504            }
505        }
506        joined_lines.push(line.to_string());
507    }
508
509    for line in &joined_lines {
510        let trimmed = line.trim();
511        if trimmed.is_empty() || trimmed.starts_with('#') {
512            continue;
513        }
514
515        // Handle [global] @script:cors.lua,logging.lua
516        if trimmed.starts_with("[global]") {
517            let rest = trimmed.strip_prefix("[global]").unwrap().trim();
518            let (_, scripts) = extract_scripts(rest);
519            global_scripts.extend(scripts);
520            continue;
521        }
522
523        if let Some((source, target_str)) = trimmed.split_once("->") {
524            let source = source.trim();
525            // Extract @script: from the target side
526            let (target_str, route_scripts) = extract_scripts(target_str.trim());
527
528            let matcher = if source == "default" || source == "*" {
529                RuleMatcher::Default
530            } else if let Some(pattern) = source.strip_prefix("~") {
531                RuleMatcher::Regex(RegexMatcher::new(pattern)?)
532            } else if !source.starts_with('/')
533                && (source.contains('.') || source.parse::<std::net::IpAddr>().is_ok())
534            {
535                if let Some((domain, path)) = source.split_once('/') {
536                    if path.is_empty() || path == "*" {
537                        RuleMatcher::Domain(domain.to_string())
538                    } else if path.ends_with("/*") {
539                        RuleMatcher::DomainPath(
540                            domain.to_string(),
541                            path.trim_end_matches('*').to_string(),
542                        )
543                    } else {
544                        RuleMatcher::DomainPath(domain.to_string(), path.to_string())
545                    }
546                } else {
547                    RuleMatcher::Domain(source.to_string())
548                }
549            } else if source.ends_with("/*") {
550                RuleMatcher::Prefix(source.trim_end_matches('*').to_string())
551            } else {
552                RuleMatcher::Exact(source.to_string())
553            };
554
555            let targets: Vec<Target> = target_str
556                .split(',')
557                .map(|t| {
558                    Ok(Target {
559                        url: Url::parse(t.trim())?,
560                        weight: 100,
561                    })
562                })
563                .collect::<Result<Vec<_>>>()?;
564
565            rules.push(ProxyRule {
566                matcher,
567                targets,
568                headers: vec![],
569                scripts: route_scripts,
570            });
571        }
572    }
573
574    Ok((rules, global_scripts))
575}
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580
581    #[test]
582    fn test_backslash_continuation_joins_lines() {
583        let config = "/api/* -> http://backend1:8080, \\\n          http://backend2:8080\n";
584        let (rules, _) = parse_proxy_config(config).unwrap();
585        assert_eq!(rules.len(), 1);
586        assert_eq!(rules[0].targets.len(), 2);
587        assert_eq!(rules[0].targets[0].url.as_str(), "http://backend1:8080/");
588        assert_eq!(rules[0].targets[1].url.as_str(), "http://backend2:8080/");
589    }
590
591    #[test]
592    fn test_multiple_continuation_lines() {
593        let config = "/api/* -> http://backend1:8080, \\\n\
594                       http://backend2:8080, \\\n\
595                       http://backend3:8080\n";
596        let (rules, _) = parse_proxy_config(config).unwrap();
597        assert_eq!(rules.len(), 1);
598        assert_eq!(rules[0].targets.len(), 3);
599        assert_eq!(rules[0].targets[2].url.as_str(), "http://backend3:8080/");
600    }
601
602    #[test]
603    fn test_backslash_mid_line_not_continuation() {
604        let config = "/path -> http://localhost:8080\n\
605                       ~^/foo\\dbar$ -> http://localhost:9090\n";
606        let (rules, _) = parse_proxy_config(config).unwrap();
607        assert_eq!(rules.len(), 2);
608    }
609
610    #[test]
611    fn test_continuation_trims_whitespace() {
612        let config = "/api/* -> http://a:8080,   \\\n   http://b:8080,  \\\n   http://c:8080\n";
613        let (rules, _) = parse_proxy_config(config).unwrap();
614        assert_eq!(rules.len(), 1);
615        assert_eq!(rules[0].targets.len(), 3);
616    }
617
618    #[test]
619    fn test_continuation_with_scripts() {
620        let config = "/api/* -> http://a:8080, \\\n\
621                       http://b:8080 @script:auth.lua\n";
622        let (rules, _) = parse_proxy_config(config).unwrap();
623        assert_eq!(rules.len(), 1);
624        assert_eq!(rules[0].targets.len(), 2);
625        assert_eq!(rules[0].scripts, vec!["auth.lua"]);
626    }
627
628    #[test]
629    fn test_no_continuation_normal_config() {
630        let config = "/api/* -> http://backend:8080\ndefault -> http://localhost:3000\n";
631        let (rules, _) = parse_proxy_config(config).unwrap();
632        assert_eq!(rules.len(), 2);
633    }
634}