Skip to main content

soli_proxy/config/
mod.rs

1pub mod serializer;
2
3use crate::auth::BasicAuth;
4use anyhow::Result;
5use arc_swap::ArcSwap;
6use notify::{RecommendedWatcher, RecursiveMode, Watcher};
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use url::Url;
14
15#[async_trait::async_trait]
16pub trait ConfigManagerTrait: Send + Sync {
17    async fn reload(&self) -> Result<()>;
18    fn get_config(&self) -> Arc<Config>;
19    fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()>;
20    fn add_route(&self, rule: ProxyRule) -> Result<()>;
21    fn remove_route(&self, index: usize) -> Result<()>;
22}
23
24#[derive(Deserialize, Default, Clone, Debug)]
25pub struct TomlConfig {
26    #[serde(default)]
27    pub server: ServerConfig,
28    #[serde(default)]
29    pub tls: TlsConfig,
30    pub letsencrypt: Option<LetsEncryptConfig>,
31    pub scripting: Option<ScriptingTomlConfig>,
32    pub admin: Option<AdminConfig>,
33    pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
34}
35
36#[derive(Deserialize, Serialize, Clone, Debug)]
37pub struct CircuitBreakerTomlConfig {
38    pub failure_threshold: Option<u32>,
39    pub recovery_timeout_secs: Option<u64>,
40    pub success_threshold: Option<u32>,
41    pub failure_status_codes: Option<Vec<u16>>,
42}
43
44#[derive(Deserialize, Serialize, Clone, Debug)]
45pub struct AdminConfig {
46    pub enabled: bool,
47    pub bind: String,
48    pub api_key: Option<String>,
49}
50
51impl Default for AdminConfig {
52    fn default() -> Self {
53        Self {
54            enabled: false,
55            bind: "127.0.0.1:9090".to_string(),
56            api_key: None,
57        }
58    }
59}
60
61#[derive(Deserialize, Serialize, Clone, Debug, Default)]
62pub struct ScriptingTomlConfig {
63    pub enabled: bool,
64    pub scripts_dir: Option<String>,
65    pub hook_timeout_ms: Option<u64>,
66}
67
68#[derive(Deserialize, Serialize, Default, Clone, Debug)]
69pub struct ServerConfig {
70    pub bind: String,
71    pub https_port: u16,
72}
73
74#[derive(Deserialize, Serialize, Default, Clone, Debug)]
75pub struct TlsConfig {
76    pub mode: String,
77    pub cache_dir: String,
78}
79
80#[derive(Deserialize, Serialize, Clone, Debug)]
81pub struct LetsEncryptConfig {
82    pub staging: bool,
83    pub email: String,
84    pub terms_agreed: bool,
85}
86
87#[derive(Clone, Debug, Serialize)]
88pub struct Config {
89    pub server: ServerConfig,
90    pub tls: TlsConfig,
91    pub letsencrypt: Option<LetsEncryptConfig>,
92    pub scripting: ScriptingTomlConfig,
93    pub admin: AdminConfig,
94    pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
95    pub rules: Vec<ProxyRule>,
96    pub global_scripts: Vec<String>,
97}
98
99#[derive(Clone, Debug, Serialize, Deserialize)]
100pub struct ProxyRule {
101    pub matcher: RuleMatcher,
102    pub targets: Vec<Target>,
103    pub headers: Vec<HeaderRule>,
104    pub scripts: Vec<String>,
105    #[serde(default)]
106    pub auth: Vec<BasicAuth>,
107}
108
109#[derive(Clone, Debug)]
110pub enum RuleMatcher {
111    Default,
112    Prefix(String),
113    Regex(RegexMatcher),
114    Exact(String),
115    Domain(String),
116    DomainPath(String, String),
117}
118
119/// Wrapper around Regex that stores the original pattern for serialization
120#[derive(Clone, Debug)]
121pub struct RegexMatcher {
122    pub pattern: String,
123    pub regex: Regex,
124}
125
126impl RegexMatcher {
127    pub fn new(pattern: &str) -> Result<Self> {
128        Ok(Self {
129            pattern: pattern.to_string(),
130            regex: Regex::new(pattern)?,
131        })
132    }
133
134    pub fn is_match(&self, text: &str) -> bool {
135        self.regex.is_match(text)
136    }
137}
138
139impl Serialize for RuleMatcher {
140    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
141    where
142        S: serde::Serializer,
143    {
144        use serde::ser::SerializeMap;
145        let mut map = serializer.serialize_map(Some(2))?;
146        match self {
147            RuleMatcher::Default => {
148                map.serialize_entry("type", "default")?;
149            }
150            RuleMatcher::Prefix(v) => {
151                map.serialize_entry("type", "prefix")?;
152                map.serialize_entry("value", v)?;
153            }
154            RuleMatcher::Regex(rm) => {
155                map.serialize_entry("type", "regex")?;
156                map.serialize_entry("value", &rm.pattern)?;
157            }
158            RuleMatcher::Exact(v) => {
159                map.serialize_entry("type", "exact")?;
160                map.serialize_entry("value", v)?;
161            }
162            RuleMatcher::Domain(v) => {
163                map.serialize_entry("type", "domain")?;
164                map.serialize_entry("value", v)?;
165            }
166            RuleMatcher::DomainPath(d, p) => {
167                map.serialize_entry("type", "domain_path")?;
168                map.serialize_entry("domain", d)?;
169                map.serialize_entry("path", p)?;
170            }
171        }
172        map.end()
173    }
174}
175
176impl<'de> Deserialize<'de> for RuleMatcher {
177    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
178    where
179        D: serde::Deserializer<'de>,
180    {
181        use serde::de::Error;
182        let value: serde_json::Value = Deserialize::deserialize(deserializer)?;
183        let obj = value
184            .as_object()
185            .ok_or_else(|| D::Error::custom("expected object"))?;
186        let matcher_type = obj
187            .get("type")
188            .and_then(|v| v.as_str())
189            .ok_or_else(|| D::Error::custom("missing 'type' field"))?;
190
191        match matcher_type {
192            "default" => Ok(RuleMatcher::Default),
193            "exact" => {
194                let v = obj
195                    .get("value")
196                    .and_then(|v| v.as_str())
197                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
198                Ok(RuleMatcher::Exact(v.to_string()))
199            }
200            "prefix" => {
201                let v = obj
202                    .get("value")
203                    .and_then(|v| v.as_str())
204                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
205                Ok(RuleMatcher::Prefix(v.to_string()))
206            }
207            "regex" => {
208                let v = obj
209                    .get("value")
210                    .and_then(|v| v.as_str())
211                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
212                let rm = RegexMatcher::new(v)
213                    .map_err(|e| D::Error::custom(format!("invalid regex: {}", e)))?;
214                Ok(RuleMatcher::Regex(rm))
215            }
216            "domain" => {
217                let v = obj
218                    .get("value")
219                    .and_then(|v| v.as_str())
220                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
221                Ok(RuleMatcher::Domain(v.to_string()))
222            }
223            "domain_path" => {
224                let d = obj
225                    .get("domain")
226                    .and_then(|v| v.as_str())
227                    .ok_or_else(|| D::Error::custom("missing 'domain'"))?;
228                let p = obj
229                    .get("path")
230                    .and_then(|v| v.as_str())
231                    .ok_or_else(|| D::Error::custom("missing 'path'"))?;
232                Ok(RuleMatcher::DomainPath(d.to_string(), p.to_string()))
233            }
234            other => Err(D::Error::custom(format!("unknown matcher type: {}", other))),
235        }
236    }
237}
238
239#[derive(Clone, Debug, Serialize, Deserialize)]
240pub struct Target {
241    pub url: Url,
242    pub weight: u8,
243}
244
245#[derive(Clone, Debug, Serialize, Deserialize)]
246pub struct HeaderRule {
247    pub name: String,
248    pub value: String,
249}
250
251impl Config {
252    /// Extract unique domain names from Domain and DomainPath rules,
253    /// filtering out IPs and "localhost".
254    pub fn acme_domains(&self) -> Vec<String> {
255        let mut domains = Vec::new();
256        let mut seen = std::collections::HashSet::new();
257
258        for rule in &self.rules {
259            let domain = match &rule.matcher {
260                RuleMatcher::Domain(d) => Some(d.as_str()),
261                RuleMatcher::DomainPath(d, _) => Some(d.as_str()),
262                _ => None,
263            };
264
265            if let Some(d) = domain {
266                if d == "localhost" || d.parse::<std::net::IpAddr>().is_ok() {
267                    continue;
268                }
269                if seen.insert(d.to_string()) {
270                    domains.push(d.to_string());
271                }
272            }
273        }
274
275        domains
276    }
277}
278
279pub struct ConfigManager {
280    config: ArcSwap<Config>,
281    config_path: PathBuf,
282    _watcher: Option<RecommendedWatcher>,
283    suppress_watch: Arc<AtomicBool>,
284}
285
286impl Clone for ConfigManager {
287    fn clone(&self) -> Self {
288        Self {
289            config: ArcSwap::new(self.config.load().clone()),
290            config_path: self.config_path.clone(),
291            _watcher: None,
292            suppress_watch: self.suppress_watch.clone(),
293        }
294    }
295}
296
297impl ConfigManager {
298    pub fn new(config_path: &str) -> Result<Self> {
299        let path = PathBuf::from(config_path);
300        let config = Self::load_config(&path, &path)?;
301        Ok(Self {
302            config: ArcSwap::new(Arc::new(config)),
303            config_path: path,
304            _watcher: None,
305            suppress_watch: Arc::new(AtomicBool::new(false)),
306        })
307    }
308
309    pub fn config_path(&self) -> &Path {
310        &self.config_path
311    }
312
313    pub fn suppress_watch(&self) -> &Arc<AtomicBool> {
314        &self.suppress_watch
315    }
316
317    fn load_config(proxy_conf_path: &Path, config_path: &Path) -> Result<Config> {
318        let content = std::fs::read_to_string(proxy_conf_path).unwrap_or_default();
319        let (rules, global_scripts) = parse_proxy_config(&content)?;
320        let toml_content = std::fs::read_to_string(
321            config_path
322                .parent()
323                .unwrap_or(Path::new("."))
324                .join("config.toml"),
325        )
326        .ok();
327        let toml_config: TomlConfig = toml_content
328            .as_ref()
329            .and_then(|c| toml::from_str(c).ok())
330            .unwrap_or_default();
331
332        Ok(Config {
333            server: toml_config.server,
334            tls: toml_config.tls,
335            letsencrypt: toml_config.letsencrypt,
336            scripting: toml_config.scripting.unwrap_or_default(),
337            admin: toml_config.admin.unwrap_or_default(),
338            circuit_breaker: toml_config.circuit_breaker,
339            rules,
340            global_scripts,
341        })
342    }
343
344    pub fn get_config(&self) -> Arc<Config> {
345        self.config.load().clone()
346    }
347
348    pub fn start_watcher(&self) -> Result<()> {
349        // Ensure the file exists so the watcher has something to watch
350        if !self.config_path.exists() {
351            if let Some(parent) = self.config_path.parent() {
352                std::fs::create_dir_all(parent).ok();
353            }
354            std::fs::write(&self.config_path, "")?;
355        }
356
357        let (tx, mut rx) = mpsc::channel(1);
358        let config_path = self.config_path.clone();
359        let suppress = self.suppress_watch.clone();
360
361        let mut watcher = RecommendedWatcher::new(
362            move |res| {
363                let _ = tx.blocking_send(res);
364            },
365            notify::Config::default(),
366        )?;
367
368        watcher.watch(&config_path, RecursiveMode::NonRecursive)?;
369
370        tracing::info!("Watching config file: {}", config_path.display());
371
372        std::thread::spawn(move || {
373            while let Some(res) = rx.blocking_recv() {
374                match res {
375                    Ok(event) => {
376                        if event.kind.is_modify() {
377                            if suppress.swap(false, Ordering::SeqCst) {
378                                tracing::debug!(
379                                    "Suppressing file watcher reload (admin API write)"
380                                );
381                                continue;
382                            }
383                            tracing::info!("Config file changed, reloading...");
384                        }
385                    }
386                    Err(e) => tracing::error!("Watch error: {}", e),
387                }
388            }
389        });
390
391        Ok(())
392    }
393
394    pub async fn reload(&self) -> Result<()> {
395        let new_config = Self::load_config(&self.config_path, &self.config_path)?;
396        self.config.store(Arc::new(new_config));
397        tracing::info!("Configuration reloaded successfully");
398        Ok(())
399    }
400
401    /// Persist current rules to proxy.conf and swap in-memory config
402    fn persist_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
403        let content = serializer::serialize_proxy_conf(&rules, &global_scripts);
404        self.suppress_watch.store(true, Ordering::SeqCst);
405        std::fs::write(&self.config_path, &content)?;
406        let mut config = (*self.config.load().as_ref()).clone();
407        config.rules = rules;
408        config.global_scripts = global_scripts;
409        self.config.store(Arc::new(config));
410        tracing::info!("Configuration persisted to {}", self.config_path.display());
411        Ok(())
412    }
413
414    pub fn add_route(&self, rule: ProxyRule) -> Result<()> {
415        let cfg = self.get_config();
416        let mut rules = cfg.rules.clone();
417        rules.push(rule);
418        self.persist_rules(rules, cfg.global_scripts.clone())
419    }
420
421    pub fn update_route(&self, index: usize, rule: ProxyRule) -> Result<()> {
422        let cfg = self.get_config();
423        let mut rules = cfg.rules.clone();
424        if index >= rules.len() {
425            anyhow::bail!(
426                "Route index {} out of range (have {} routes)",
427                index,
428                rules.len()
429            );
430        }
431        rules[index] = rule;
432        self.persist_rules(rules, cfg.global_scripts.clone())
433    }
434
435    pub fn remove_route(&self, index: usize) -> Result<()> {
436        let cfg = self.get_config();
437        let mut rules = cfg.rules.clone();
438        if index >= rules.len() {
439            anyhow::bail!(
440                "Route index {} out of range (have {} routes)",
441                index,
442                rules.len()
443            );
444        }
445        rules.remove(index);
446        self.persist_rules(rules, cfg.global_scripts.clone())
447    }
448
449    pub fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
450        self.persist_rules(rules, global_scripts)
451    }
452}
453
454#[async_trait::async_trait]
455impl ConfigManagerTrait for ConfigManager {
456    async fn reload(&self) -> Result<()> {
457        self.reload().await
458    }
459
460    fn get_config(&self) -> Arc<Config> {
461        self.get_config()
462    }
463
464    fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
465        self.update_rules(rules, global_scripts)
466    }
467
468    fn add_route(&self, rule: ProxyRule) -> Result<()> {
469        self.add_route(rule)
470    }
471
472    fn remove_route(&self, index: usize) -> Result<()> {
473        self.remove_route(index)
474    }
475}
476
477/// Extract `@script:a.lua,b.lua` from a string, returning (remaining_str, scripts_vec).
478fn extract_scripts(s: &str) -> (&str, Vec<String>) {
479    if let Some(idx) = s.find("@script:") {
480        let before = s[..idx].trim();
481        let after = &s[idx + "@script:".len()..];
482        // Scripts are comma-separated, ending at whitespace or end-of-string
483        let script_part = after.split_whitespace().next().unwrap_or(after);
484        let scripts: Vec<String> = script_part
485            .split(',')
486            .map(|s| s.trim().to_string())
487            .filter(|s| !s.is_empty())
488            .collect();
489        (before, scripts)
490    } else {
491        (s, Vec::new())
492    }
493}
494
495/// Extract `@auth:user:hash` entries from a string, returning (remaining_str, auth_vec).
496/// Multiple @auth entries can appear: `@auth:user1:hash1 @auth:user2:hash2`
497/// Hash is everything after the second colon (bcrypt hashes start with $2a$, $2b$, $2y$)
498fn extract_auth(s: &str) -> (String, Vec<BasicAuth>) {
499    let mut auth_entries = Vec::new();
500    let mut remaining = s.to_string();
501
502    while let Some(idx) = remaining.find("@auth:") {
503        // Keep the part BEFORE @auth:
504        let before = &remaining[..idx];
505        let after = &remaining[idx + "@auth:".len()..];
506
507        // Find the end of this auth entry (whitespace or end of string)
508        let end_idx = after
509            .find(|c: char| c.is_whitespace())
510            .unwrap_or(after.len());
511
512        let auth_part = &after[..end_idx];
513
514        // Parse username:hash - hash is everything after the first colon
515        if let Some((username, hash)) = auth_part.split_once(':') {
516            if !username.is_empty() && !hash.is_empty() {
517                auth_entries.push(BasicAuth {
518                    username: username.to_string(),
519                    hash: hash.to_string(),
520                });
521            }
522        }
523
524        // Continue with the part BEFORE this @auth, plus any remaining after it
525        let rest = &after[end_idx..];
526        remaining = if rest.is_empty() {
527            before.to_string()
528        } else {
529            format!("{}{}", before, rest)
530        };
531    }
532
533    (remaining.trim().to_string(), auth_entries)
534}
535
536fn parse_proxy_config(content: &str) -> Result<(Vec<ProxyRule>, Vec<String>)> {
537    let mut rules = Vec::new();
538    let mut global_scripts = Vec::new();
539
540    // Join continuation lines (backslash at end of line)
541    let mut joined_lines: Vec<String> = Vec::new();
542    for line in content.lines() {
543        if let Some(current) = joined_lines.last_mut() {
544            if current.ends_with('\\') {
545                current.pop(); // remove the backslash
546                current.push_str(line.trim());
547                continue;
548            }
549        }
550        joined_lines.push(line.to_string());
551    }
552
553    for line in &joined_lines {
554        let trimmed = line.trim();
555        if trimmed.is_empty() || trimmed.starts_with('#') {
556            continue;
557        }
558
559        // Handle [global] @script:cors.lua,logging.lua
560        if trimmed.starts_with("[global]") {
561            let rest = trimmed.strip_prefix("[global]").unwrap().trim();
562            let (_, scripts) = extract_scripts(rest);
563            global_scripts.extend(scripts);
564            continue;
565        }
566
567        if let Some((source, target_str)) = trimmed.split_once("->") {
568            let source = source.trim();
569            // Extract @script: from the target side
570            let (target_str, route_scripts) = extract_scripts(target_str.trim());
571            // Extract @auth: entries from the target side
572            let (target_str, auth_entries) = extract_auth(target_str);
573
574            let matcher = if source == "default" || source == "*" {
575                RuleMatcher::Default
576            } else if let Some(pattern) = source.strip_prefix("~") {
577                RuleMatcher::Regex(RegexMatcher::new(pattern)?)
578            } else if !source.starts_with('/')
579                && (source.contains('.') || source.parse::<std::net::IpAddr>().is_ok())
580            {
581                if let Some((domain, path)) = source.split_once('/') {
582                    if path.is_empty() || path == "*" {
583                        RuleMatcher::Domain(domain.to_string())
584                    } else if path.ends_with("/*") {
585                        RuleMatcher::DomainPath(
586                            domain.to_string(),
587                            path.trim_end_matches('*').to_string(),
588                        )
589                    } else {
590                        RuleMatcher::DomainPath(domain.to_string(), path.to_string())
591                    }
592                } else {
593                    RuleMatcher::Domain(source.to_string())
594                }
595            } else if source.ends_with("/*") {
596                RuleMatcher::Prefix(source.trim_end_matches('*').to_string())
597            } else {
598                RuleMatcher::Exact(source.to_string())
599            };
600
601            let targets: Vec<Target> = target_str
602                .split(',')
603                .map(|t| {
604                    Ok(Target {
605                        url: Url::parse(t.trim())?,
606                        weight: 100,
607                    })
608                })
609                .collect::<Result<Vec<_>>>()?;
610
611            rules.push(ProxyRule {
612                matcher,
613                targets,
614                headers: vec![],
615                scripts: route_scripts,
616                auth: auth_entries,
617            });
618        }
619    }
620
621    Ok((rules, global_scripts))
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627
628    #[test]
629    fn test_backslash_continuation_joins_lines() {
630        let config = "/api/* -> http://backend1:8080, \\\n          http://backend2:8080\n";
631        let (rules, _) = parse_proxy_config(config).unwrap();
632        assert_eq!(rules.len(), 1);
633        assert_eq!(rules[0].targets.len(), 2);
634        assert_eq!(rules[0].targets[0].url.as_str(), "http://backend1:8080/");
635        assert_eq!(rules[0].targets[1].url.as_str(), "http://backend2:8080/");
636    }
637
638    #[test]
639    fn test_multiple_continuation_lines() {
640        let config = "/api/* -> http://backend1:8080, \\\n\
641                       http://backend2:8080, \\\n\
642                       http://backend3:8080\n";
643        let (rules, _) = parse_proxy_config(config).unwrap();
644        assert_eq!(rules.len(), 1);
645        assert_eq!(rules[0].targets.len(), 3);
646        assert_eq!(rules[0].targets[2].url.as_str(), "http://backend3:8080/");
647    }
648
649    #[test]
650    fn test_backslash_mid_line_not_continuation() {
651        let config = "/path -> http://localhost:8080\n\
652                       ~^/foo\\dbar$ -> http://localhost:9090\n";
653        let (rules, _) = parse_proxy_config(config).unwrap();
654        assert_eq!(rules.len(), 2);
655    }
656
657    #[test]
658    fn test_continuation_trims_whitespace() {
659        let config = "/api/* -> http://a:8080,   \\\n   http://b:8080,  \\\n   http://c:8080\n";
660        let (rules, _) = parse_proxy_config(config).unwrap();
661        assert_eq!(rules.len(), 1);
662        assert_eq!(rules[0].targets.len(), 3);
663    }
664
665    #[test]
666    fn test_continuation_with_scripts() {
667        let config = "/api/* -> http://a:8080, \\\n\
668                       http://b:8080 @script:auth.lua\n";
669        let (rules, _) = parse_proxy_config(config).unwrap();
670        assert_eq!(rules.len(), 1);
671        assert_eq!(rules[0].targets.len(), 2);
672        assert_eq!(rules[0].scripts, vec!["auth.lua"]);
673    }
674
675    #[test]
676    fn test_no_continuation_normal_config() {
677        let config = "/api/* -> http://backend:8080\ndefault -> http://localhost:3000\n";
678        let (rules, _) = parse_proxy_config(config).unwrap();
679        assert_eq!(rules.len(), 2);
680    }
681
682    #[test]
683    fn test_auth_parsing() {
684        let config = r#"
685/db/* -> http://localhost:8080/ @auth:demo:$2b$12$YFlnIiACnSaAcxDWQlYjeedxq/3GvhvoGhRTYHMqLifJrETSqOZQa
686"#;
687        let (rules, _) = parse_proxy_config(config).unwrap();
688        assert_eq!(rules.len(), 1);
689        assert_eq!(rules[0].auth.len(), 1);
690        assert_eq!(rules[0].auth[0].username, "demo");
691        assert!(rules[0].auth[0].hash.starts_with("$2b$"));
692        assert_eq!(rules[0].targets.len(), 1);
693        assert_eq!(rules[0].targets[0].url.as_str(), "http://localhost:8080/");
694    }
695
696    #[test]
697    fn test_multiple_auth_users() {
698        let config = r#"
699secure.example.com -> http://localhost:9000/ @auth:admin:$2b$12$hash1 @auth:user:$2b$12$hash2
700"#;
701        let (rules, _) = parse_proxy_config(config).unwrap();
702        assert_eq!(rules.len(), 1);
703        assert_eq!(rules[0].auth.len(), 2);
704        assert_eq!(rules[0].auth[0].username, "admin");
705        assert_eq!(rules[0].auth[1].username, "user");
706    }
707}