Skip to main content

soli_proxy/config/
mod.rs

1pub mod serializer;
2
3use crate::auth::BasicAuth;
4use anyhow::Result;
5use arc_swap::ArcSwap;
6use notify::{RecommendedWatcher, RecursiveMode, Watcher};
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use url::Url;
14
15#[async_trait::async_trait]
16pub trait ConfigManagerTrait: Send + Sync {
17    async fn reload(&self) -> Result<()>;
18    fn get_config(&self) -> Arc<Config>;
19    fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()>;
20    fn add_route(&self, rule: ProxyRule) -> Result<()>;
21    fn remove_route(&self, index: usize) -> Result<()>;
22}
23
24#[derive(Deserialize, Default, Clone, Debug)]
25pub struct TomlConfig {
26    #[serde(default)]
27    pub server: ServerConfig,
28    #[serde(default)]
29    pub tls: TlsConfig,
30    pub letsencrypt: Option<LetsEncryptConfig>,
31    pub scripting: Option<ScriptingTomlConfig>,
32    pub admin: Option<AdminConfig>,
33    pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
34}
35
36#[derive(Deserialize, Serialize, Clone, Debug)]
37pub struct CircuitBreakerTomlConfig {
38    pub failure_threshold: Option<u32>,
39    pub recovery_timeout_secs: Option<u64>,
40    pub success_threshold: Option<u32>,
41    pub failure_status_codes: Option<Vec<u16>>,
42}
43
44#[derive(Deserialize, Serialize, Clone, Debug)]
45pub struct AdminConfig {
46    pub enabled: bool,
47    pub bind: String,
48    pub api_key: Option<String>,
49}
50
51impl Default for AdminConfig {
52    fn default() -> Self {
53        Self {
54            enabled: false,
55            bind: "127.0.0.1:9090".to_string(),
56            api_key: None,
57        }
58    }
59}
60
61#[derive(Deserialize, Serialize, Clone, Debug, Default)]
62pub struct ScriptingTomlConfig {
63    pub enabled: bool,
64    pub scripts_dir: Option<String>,
65    pub hook_timeout_ms: Option<u64>,
66}
67
68#[derive(Deserialize, Serialize, Clone, Debug)]
69pub struct ServerConfig {
70    pub bind: String,
71    pub https_port: u16,
72}
73
74impl Default for ServerConfig {
75    fn default() -> Self {
76        Self {
77            bind: "0.0.0.0:8080".to_string(),
78            https_port: 443,
79        }
80    }
81}
82
83#[derive(Deserialize, Serialize, Default, Clone, Debug)]
84pub struct TlsConfig {
85    pub mode: String,
86    pub cache_dir: String,
87}
88
89#[derive(Deserialize, Serialize, Clone, Debug)]
90pub struct LetsEncryptConfig {
91    pub staging: bool,
92    pub email: String,
93    pub terms_agreed: bool,
94}
95
96#[derive(Clone, Debug, Serialize)]
97pub struct Config {
98    pub server: ServerConfig,
99    pub tls: TlsConfig,
100    pub letsencrypt: Option<LetsEncryptConfig>,
101    pub scripting: ScriptingTomlConfig,
102    pub admin: AdminConfig,
103    pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
104    pub rules: Vec<ProxyRule>,
105    pub global_scripts: Vec<String>,
106}
107
108#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
109pub enum LoadBalancingStrategy {
110    #[default]
111    #[serde(rename = "round-robin")]
112    RoundRobin,
113    #[serde(rename = "weighted")]
114    Weighted,
115    #[serde(rename = "failover")]
116    Failover,
117}
118
119#[derive(Clone, Debug, Serialize, Deserialize)]
120pub struct ProxyRule {
121    pub matcher: RuleMatcher,
122    pub targets: Vec<Target>,
123    pub headers: Vec<HeaderRule>,
124    pub scripts: Vec<String>,
125    #[serde(default)]
126    pub auth: Vec<BasicAuth>,
127    #[serde(default)]
128    pub load_balancing: LoadBalancingStrategy,
129}
130
131#[derive(Clone, Debug)]
132pub enum RuleMatcher {
133    Default,
134    Prefix(String),
135    Regex(RegexMatcher),
136    Exact(String),
137    Domain(String),
138    DomainPath(String, String),
139}
140
141/// Wrapper around Regex that stores the original pattern for serialization
142#[derive(Clone, Debug)]
143pub struct RegexMatcher {
144    pub pattern: String,
145    pub regex: Regex,
146}
147
148impl RegexMatcher {
149    pub fn new(pattern: &str) -> Result<Self> {
150        Ok(Self {
151            pattern: pattern.to_string(),
152            regex: Regex::new(pattern)?,
153        })
154    }
155
156    pub fn is_match(&self, text: &str) -> bool {
157        self.regex.is_match(text)
158    }
159}
160
161impl Serialize for RuleMatcher {
162    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
163    where
164        S: serde::Serializer,
165    {
166        use serde::ser::SerializeMap;
167        let mut map = serializer.serialize_map(Some(2))?;
168        match self {
169            RuleMatcher::Default => {
170                map.serialize_entry("type", "default")?;
171            }
172            RuleMatcher::Prefix(v) => {
173                map.serialize_entry("type", "prefix")?;
174                map.serialize_entry("value", v)?;
175            }
176            RuleMatcher::Regex(rm) => {
177                map.serialize_entry("type", "regex")?;
178                map.serialize_entry("value", &rm.pattern)?;
179            }
180            RuleMatcher::Exact(v) => {
181                map.serialize_entry("type", "exact")?;
182                map.serialize_entry("value", v)?;
183            }
184            RuleMatcher::Domain(v) => {
185                map.serialize_entry("type", "domain")?;
186                map.serialize_entry("value", v)?;
187            }
188            RuleMatcher::DomainPath(d, p) => {
189                map.serialize_entry("type", "domain_path")?;
190                map.serialize_entry("domain", d)?;
191                map.serialize_entry("path", p)?;
192            }
193        }
194        map.end()
195    }
196}
197
198impl<'de> Deserialize<'de> for RuleMatcher {
199    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
200    where
201        D: serde::Deserializer<'de>,
202    {
203        use serde::de::Error;
204        let value: serde_json::Value = Deserialize::deserialize(deserializer)?;
205        let obj = value
206            .as_object()
207            .ok_or_else(|| D::Error::custom("expected object"))?;
208        let matcher_type = obj
209            .get("type")
210            .and_then(|v| v.as_str())
211            .ok_or_else(|| D::Error::custom("missing 'type' field"))?;
212
213        match matcher_type {
214            "default" => Ok(RuleMatcher::Default),
215            "exact" => {
216                let v = obj
217                    .get("value")
218                    .and_then(|v| v.as_str())
219                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
220                Ok(RuleMatcher::Exact(v.to_string()))
221            }
222            "prefix" => {
223                let v = obj
224                    .get("value")
225                    .and_then(|v| v.as_str())
226                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
227                Ok(RuleMatcher::Prefix(v.to_string()))
228            }
229            "regex" => {
230                let v = obj
231                    .get("value")
232                    .and_then(|v| v.as_str())
233                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
234                let rm = RegexMatcher::new(v)
235                    .map_err(|e| D::Error::custom(format!("invalid regex: {}", e)))?;
236                Ok(RuleMatcher::Regex(rm))
237            }
238            "domain" => {
239                let v = obj
240                    .get("value")
241                    .and_then(|v| v.as_str())
242                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
243                Ok(RuleMatcher::Domain(v.to_string()))
244            }
245            "domain_path" => {
246                let d = obj
247                    .get("domain")
248                    .and_then(|v| v.as_str())
249                    .ok_or_else(|| D::Error::custom("missing 'domain'"))?;
250                let p = obj
251                    .get("path")
252                    .and_then(|v| v.as_str())
253                    .ok_or_else(|| D::Error::custom("missing 'path'"))?;
254                Ok(RuleMatcher::DomainPath(d.to_string(), p.to_string()))
255            }
256            other => Err(D::Error::custom(format!("unknown matcher type: {}", other))),
257        }
258    }
259}
260
261#[derive(Clone, Debug, Serialize, Deserialize)]
262pub struct Target {
263    pub url: Url,
264    pub weight: u8,
265}
266
267#[derive(Clone, Debug, Serialize, Deserialize)]
268pub struct HeaderRule {
269    pub name: String,
270    pub value: String,
271}
272
273impl Config {
274    /// Extract unique domain names from Domain and DomainPath rules,
275    /// filtering out IPs and "localhost".
276    pub fn acme_domains(&self) -> Vec<String> {
277        let mut domains = Vec::new();
278        let mut seen = std::collections::HashSet::new();
279
280        for rule in &self.rules {
281            let domain = match &rule.matcher {
282                RuleMatcher::Domain(d) => Some(d.as_str()),
283                RuleMatcher::DomainPath(d, _) => Some(d.as_str()),
284                _ => None,
285            };
286
287            if let Some(d) = domain {
288                if d == "localhost" || d.parse::<std::net::IpAddr>().is_ok() {
289                    continue;
290                }
291                if seen.insert(d.to_string()) {
292                    domains.push(d.to_string());
293                }
294            }
295        }
296
297        domains
298    }
299}
300
301pub struct ConfigManager {
302    config: ArcSwap<Config>,
303    config_path: PathBuf,
304    _watcher: Option<RecommendedWatcher>,
305    suppress_watch: Arc<AtomicBool>,
306}
307
308impl Clone for ConfigManager {
309    fn clone(&self) -> Self {
310        Self {
311            config: ArcSwap::new(self.config.load().clone()),
312            config_path: self.config_path.clone(),
313            _watcher: None,
314            suppress_watch: self.suppress_watch.clone(),
315        }
316    }
317}
318
319impl ConfigManager {
320    pub fn new(config_path: &str) -> Result<Self> {
321        let path = PathBuf::from(config_path);
322        let config = Self::load_config(&path, &path)?;
323        Ok(Self {
324            config: ArcSwap::new(Arc::new(config)),
325            config_path: path,
326            _watcher: None,
327            suppress_watch: Arc::new(AtomicBool::new(false)),
328        })
329    }
330
331    pub fn config_path(&self) -> &Path {
332        &self.config_path
333    }
334
335    pub fn suppress_watch(&self) -> &Arc<AtomicBool> {
336        &self.suppress_watch
337    }
338
339    fn load_config(proxy_conf_path: &Path, config_path: &Path) -> Result<Config> {
340        let content = std::fs::read_to_string(proxy_conf_path).unwrap_or_default();
341        let (rules, global_scripts) = parse_proxy_config(&content)?;
342        let toml_content = std::fs::read_to_string(
343            config_path
344                .parent()
345                .unwrap_or(Path::new("."))
346                .join("config.toml"),
347        )
348        .ok();
349        let toml_config: TomlConfig = toml_content
350            .as_ref()
351            .and_then(|c| toml::from_str(c).ok())
352            .unwrap_or_default();
353
354        Ok(Config {
355            server: toml_config.server,
356            tls: toml_config.tls,
357            letsencrypt: toml_config.letsencrypt,
358            scripting: toml_config.scripting.unwrap_or_default(),
359            admin: toml_config.admin.unwrap_or_default(),
360            circuit_breaker: toml_config.circuit_breaker,
361            rules,
362            global_scripts,
363        })
364    }
365
366    pub fn get_config(&self) -> Arc<Config> {
367        self.config.load().clone()
368    }
369
370    pub fn start_watcher(&self) -> Result<()> {
371        // Ensure the file exists so the watcher has something to watch
372        if !self.config_path.exists() {
373            if let Some(parent) = self.config_path.parent() {
374                std::fs::create_dir_all(parent).ok();
375            }
376            std::fs::write(&self.config_path, "")?;
377        }
378
379        let (tx, mut rx) = mpsc::channel(1);
380        let config_path = self.config_path.clone();
381        let suppress = self.suppress_watch.clone();
382
383        let mut watcher = RecommendedWatcher::new(
384            move |res| {
385                let _ = tx.blocking_send(res);
386            },
387            notify::Config::default(),
388        )?;
389
390        watcher.watch(&config_path, RecursiveMode::NonRecursive)?;
391
392        tracing::info!("Watching config file: {}", config_path.display());
393
394        std::thread::spawn(move || {
395            while let Some(res) = rx.blocking_recv() {
396                match res {
397                    Ok(event) => {
398                        if event.kind.is_modify() {
399                            if suppress.swap(false, Ordering::SeqCst) {
400                                tracing::debug!(
401                                    "Suppressing file watcher reload (admin API write)"
402                                );
403                                continue;
404                            }
405                            tracing::info!("Config file changed, reloading...");
406                        }
407                    }
408                    Err(e) => tracing::error!("Watch error: {}", e),
409                }
410            }
411        });
412
413        Ok(())
414    }
415
416    pub async fn reload(&self) -> Result<()> {
417        let new_config = Self::load_config(&self.config_path, &self.config_path)?;
418        self.config.store(Arc::new(new_config));
419        tracing::info!("Configuration reloaded successfully");
420        Ok(())
421    }
422
423    /// Persist current rules to proxy.conf and swap in-memory config
424    fn persist_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
425        let content = serializer::serialize_proxy_conf(&rules, &global_scripts);
426        self.suppress_watch.store(true, Ordering::SeqCst);
427        std::fs::write(&self.config_path, &content)?;
428        let mut config = (*self.config.load().as_ref()).clone();
429        config.rules = rules;
430        config.global_scripts = global_scripts;
431        self.config.store(Arc::new(config));
432        tracing::info!("Configuration persisted to {}", self.config_path.display());
433        Ok(())
434    }
435
436    pub fn add_route(&self, rule: ProxyRule) -> Result<()> {
437        let cfg = self.get_config();
438        let mut rules = cfg.rules.clone();
439        rules.push(rule);
440        self.persist_rules(rules, cfg.global_scripts.clone())
441    }
442
443    pub fn update_route(&self, index: usize, rule: ProxyRule) -> Result<()> {
444        let cfg = self.get_config();
445        let mut rules = cfg.rules.clone();
446        if index >= rules.len() {
447            anyhow::bail!(
448                "Route index {} out of range (have {} routes)",
449                index,
450                rules.len()
451            );
452        }
453        rules[index] = rule;
454        self.persist_rules(rules, cfg.global_scripts.clone())
455    }
456
457    pub fn remove_route(&self, index: usize) -> Result<()> {
458        let cfg = self.get_config();
459        let mut rules = cfg.rules.clone();
460        if index >= rules.len() {
461            anyhow::bail!(
462                "Route index {} out of range (have {} routes)",
463                index,
464                rules.len()
465            );
466        }
467        rules.remove(index);
468        self.persist_rules(rules, cfg.global_scripts.clone())
469    }
470
471    pub fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
472        self.persist_rules(rules, global_scripts)
473    }
474}
475
476#[async_trait::async_trait]
477impl ConfigManagerTrait for ConfigManager {
478    async fn reload(&self) -> Result<()> {
479        self.reload().await
480    }
481
482    fn get_config(&self) -> Arc<Config> {
483        self.get_config()
484    }
485
486    fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
487        self.update_rules(rules, global_scripts)
488    }
489
490    fn add_route(&self, rule: ProxyRule) -> Result<()> {
491        self.add_route(rule)
492    }
493
494    fn remove_route(&self, index: usize) -> Result<()> {
495        self.remove_route(index)
496    }
497}
498
499/// Extract `@script:a.lua,b.lua` from a string, returning (remaining_str, scripts_vec).
500fn extract_scripts(s: &str) -> (&str, Vec<String>) {
501    if let Some(idx) = s.find("@script:") {
502        let before = s[..idx].trim();
503        let after = &s[idx + "@script:".len()..];
504        // Scripts are comma-separated, ending at whitespace or end-of-string
505        let script_part = after.split_whitespace().next().unwrap_or(after);
506        let scripts: Vec<String> = script_part
507            .split(',')
508            .map(|s| s.trim().to_string())
509            .filter(|s| !s.is_empty())
510            .collect();
511        (before, scripts)
512    } else {
513        (s, Vec::new())
514    }
515}
516
517/// Extract `@auth:user:hash` entries from a string, returning (remaining_str, auth_vec).
518/// Multiple @auth entries can appear: `@auth:user1:hash1 @auth:user2:hash2`
519/// Hash is everything after the second colon (bcrypt hashes start with $2a$, $2b$, $2y$)
520fn extract_auth(s: &str) -> (String, Vec<BasicAuth>) {
521    let mut auth_entries = Vec::new();
522    let mut remaining = s.to_string();
523
524    while let Some(idx) = remaining.find("@auth:") {
525        // Keep the part BEFORE @auth:
526        let before = &remaining[..idx];
527        let after = &remaining[idx + "@auth:".len()..];
528
529        // Find the end of this auth entry (whitespace or end of string)
530        let end_idx = after
531            .find(|c: char| c.is_whitespace())
532            .unwrap_or(after.len());
533
534        let auth_part = &after[..end_idx];
535
536        // Parse username:hash - hash is everything after the first colon
537        if let Some((username, hash)) = auth_part.split_once(':') {
538            if !username.is_empty() && !hash.is_empty() {
539                auth_entries.push(BasicAuth {
540                    username: username.to_string(),
541                    hash: hash.to_string(),
542                });
543            }
544        }
545
546        // Continue with the part BEFORE this @auth, plus any remaining after it
547        let rest = &after[end_idx..];
548        remaining = if rest.is_empty() {
549            before.to_string()
550        } else {
551            format!("{}{}", before, rest)
552        };
553    }
554
555    (remaining.trim().to_string(), auth_entries)
556}
557
558/// Extract `@lb:strategy` from a string, returning (remaining_str, strategy).
559/// Example: `@lb:round-robin` or `@lb:weighted` or `@lb:failover`
560fn extract_load_balancing(s: &str) -> (String, LoadBalancingStrategy) {
561    if let Some(idx) = s.find("@lb:") {
562        let before = &s[..idx];
563        let after = &s[idx + "@lb:".len()..];
564
565        let end_idx = after
566            .find(|c: char| c.is_whitespace())
567            .unwrap_or(after.len());
568
569        let strategy_str = &after[..end_idx];
570        let strategy = match strategy_str {
571            "round-robin" => LoadBalancingStrategy::RoundRobin,
572            "weighted" => LoadBalancingStrategy::Weighted,
573            "failover" => LoadBalancingStrategy::Failover,
574            _ => LoadBalancingStrategy::default(),
575        };
576
577        let rest = &after[end_idx..];
578        let remaining = if rest.is_empty() {
579            before.to_string()
580        } else {
581            format!("{}{}", before, rest)
582        };
583
584        (remaining.trim().to_string(), strategy)
585    } else {
586        (s.to_string(), LoadBalancingStrategy::default())
587    }
588}
589
590fn parse_proxy_config(content: &str) -> Result<(Vec<ProxyRule>, Vec<String>)> {
591    let mut rules = Vec::new();
592    let mut global_scripts = Vec::new();
593
594    // Join continuation lines (backslash at end of line)
595    let mut joined_lines: Vec<String> = Vec::new();
596    for line in content.lines() {
597        if let Some(current) = joined_lines.last_mut() {
598            if current.ends_with('\\') {
599                current.pop(); // remove the backslash
600                current.push_str(line.trim());
601                continue;
602            }
603        }
604        joined_lines.push(line.to_string());
605    }
606
607    for line in &joined_lines {
608        let trimmed = line.trim();
609        if trimmed.is_empty() || trimmed.starts_with('#') {
610            continue;
611        }
612
613        // Handle [global] @script:cors.lua,logging.lua
614        if trimmed.starts_with("[global]") {
615            let rest = trimmed.strip_prefix("[global]").unwrap().trim();
616            let (_, scripts) = extract_scripts(rest);
617            global_scripts.extend(scripts);
618            continue;
619        }
620
621        if let Some((source, target_str)) = trimmed.split_once("->") {
622            let source = source.trim();
623            // Extract @script: from the target side
624            let (target_str, route_scripts) = extract_scripts(target_str.trim());
625            // Extract @auth: entries from the target side
626            let (target_str, auth_entries) = extract_auth(target_str);
627            // Extract @lb: load balancing strategy
628            let (target_str, load_balancing) = extract_load_balancing(&target_str);
629
630            let matcher = if source == "default" || source == "*" {
631                RuleMatcher::Default
632            } else if let Some(pattern) = source.strip_prefix("~") {
633                RuleMatcher::Regex(RegexMatcher::new(pattern)?)
634            } else if !source.starts_with('/')
635                && (source.contains('.') || source.parse::<std::net::IpAddr>().is_ok())
636            {
637                if let Some((domain, path)) = source.split_once('/') {
638                    if path.is_empty() || path == "*" {
639                        RuleMatcher::Domain(domain.to_string())
640                    } else if path.ends_with("/*") {
641                        RuleMatcher::DomainPath(
642                            domain.to_string(),
643                            path.trim_end_matches('*').to_string(),
644                        )
645                    } else {
646                        RuleMatcher::DomainPath(domain.to_string(), path.to_string())
647                    }
648                } else {
649                    RuleMatcher::Domain(source.to_string())
650                }
651            } else if source.ends_with("/*") {
652                RuleMatcher::Prefix(source.trim_end_matches('*').to_string())
653            } else {
654                RuleMatcher::Exact(source.to_string())
655            };
656
657            let targets: Vec<Target> = target_str
658                .split(',')
659                .map(|t| {
660                    Ok(Target {
661                        url: Url::parse(t.trim())?,
662                        weight: 100,
663                    })
664                })
665                .collect::<Result<Vec<_>>>()?;
666
667            rules.push(ProxyRule {
668                matcher,
669                targets,
670                headers: vec![],
671                scripts: route_scripts,
672                auth: auth_entries,
673                load_balancing,
674            });
675        }
676    }
677
678    Ok((rules, global_scripts))
679}
680
681#[cfg(test)]
682mod tests {
683    use super::*;
684
685    #[test]
686    fn test_backslash_continuation_joins_lines() {
687        let config = "/api/* -> http://backend1:8080, \\\n          http://backend2:8080\n";
688        let (rules, _) = parse_proxy_config(config).unwrap();
689        assert_eq!(rules.len(), 1);
690        assert_eq!(rules[0].targets.len(), 2);
691        assert_eq!(rules[0].targets[0].url.as_str(), "http://backend1:8080/");
692        assert_eq!(rules[0].targets[1].url.as_str(), "http://backend2:8080/");
693    }
694
695    #[test]
696    fn test_multiple_continuation_lines() {
697        let config = "/api/* -> http://backend1:8080, \\\n\
698                       http://backend2:8080, \\\n\
699                       http://backend3:8080\n";
700        let (rules, _) = parse_proxy_config(config).unwrap();
701        assert_eq!(rules.len(), 1);
702        assert_eq!(rules[0].targets.len(), 3);
703        assert_eq!(rules[0].targets[2].url.as_str(), "http://backend3:8080/");
704    }
705
706    #[test]
707    fn test_backslash_mid_line_not_continuation() {
708        let config = "/path -> http://localhost:8080\n\
709                       ~^/foo\\dbar$ -> http://localhost:9090\n";
710        let (rules, _) = parse_proxy_config(config).unwrap();
711        assert_eq!(rules.len(), 2);
712    }
713
714    #[test]
715    fn test_continuation_trims_whitespace() {
716        let config = "/api/* -> http://a:8080,   \\\n   http://b:8080,  \\\n   http://c:8080\n";
717        let (rules, _) = parse_proxy_config(config).unwrap();
718        assert_eq!(rules.len(), 1);
719        assert_eq!(rules[0].targets.len(), 3);
720    }
721
722    #[test]
723    fn test_continuation_with_scripts() {
724        let config = "/api/* -> http://a:8080, \\\n\
725                       http://b:8080 @script:auth.lua\n";
726        let (rules, _) = parse_proxy_config(config).unwrap();
727        assert_eq!(rules.len(), 1);
728        assert_eq!(rules[0].targets.len(), 2);
729        assert_eq!(rules[0].scripts, vec!["auth.lua"]);
730    }
731
732    #[test]
733    fn test_no_continuation_normal_config() {
734        let config = "/api/* -> http://backend:8080\ndefault -> http://localhost:3000\n";
735        let (rules, _) = parse_proxy_config(config).unwrap();
736        assert_eq!(rules.len(), 2);
737    }
738
739    #[test]
740    fn test_auth_parsing() {
741        let config = r#"
742/db/* -> http://localhost:8080/ @auth:demo:$2b$12$YFlnIiACnSaAcxDWQlYjeedxq/3GvhvoGhRTYHMqLifJrETSqOZQa
743"#;
744        let (rules, _) = parse_proxy_config(config).unwrap();
745        assert_eq!(rules.len(), 1);
746        assert_eq!(rules[0].auth.len(), 1);
747        assert_eq!(rules[0].auth[0].username, "demo");
748        assert!(rules[0].auth[0].hash.starts_with("$2b$"));
749        assert_eq!(rules[0].targets.len(), 1);
750        assert_eq!(rules[0].targets[0].url.as_str(), "http://localhost:8080/");
751    }
752
753    #[test]
754    fn test_multiple_auth_users() {
755        let config = r#"
756secure.example.com -> http://localhost:9000/ @auth:admin:$2b$12$hash1 @auth:user:$2b$12$hash2
757"#;
758        let (rules, _) = parse_proxy_config(config).unwrap();
759        assert_eq!(rules.len(), 1);
760        assert_eq!(rules[0].auth.len(), 2);
761        assert_eq!(rules[0].auth[0].username, "admin");
762        assert_eq!(rules[0].auth[1].username, "user");
763    }
764
765    #[test]
766    fn test_load_balancing_round_robin() {
767        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:round-robin";
768        let (rules, _) = parse_proxy_config(config).unwrap();
769        assert_eq!(rules.len(), 1);
770        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
771        assert_eq!(rules[0].targets.len(), 2);
772    }
773
774    #[test]
775    fn test_load_balancing_weighted() {
776        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:weighted";
777        let (rules, _) = parse_proxy_config(config).unwrap();
778        assert_eq!(rules.len(), 1);
779        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Weighted);
780    }
781
782    #[test]
783    fn test_load_balancing_failover() {
784        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover";
785        let (rules, _) = parse_proxy_config(config).unwrap();
786        assert_eq!(rules.len(), 1);
787        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
788    }
789
790    #[test]
791    fn test_load_balancing_default_is_round_robin() {
792        let config = "/api/* -> http://b1:8080, http://b2:8080";
793        let (rules, _) = parse_proxy_config(config).unwrap();
794        assert_eq!(rules.len(), 1);
795        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
796    }
797
798    #[test]
799    fn test_load_balancing_with_scripts() {
800        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover @script:auth.lua";
801        let (rules, _) = parse_proxy_config(config).unwrap();
802        assert_eq!(rules.len(), 1);
803        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
804        assert_eq!(rules[0].scripts, vec!["auth.lua"]);
805    }
806
807    #[test]
808    fn test_load_balancing_unknown_strategy_defaults_to_round_robin() {
809        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:unknown";
810        let (rules, _) = parse_proxy_config(config).unwrap();
811        assert_eq!(rules.len(), 1);
812        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
813    }
814}