Skip to main content

soli_proxy/config/
mod.rs

1pub mod serializer;
2
3use crate::auth::BasicAuth;
4use anyhow::Result;
5use arc_swap::ArcSwap;
6use notify::{RecommendedWatcher, RecursiveMode, Watcher};
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use url::Url;
14
15#[async_trait::async_trait]
16pub trait ConfigManagerTrait: Send + Sync {
17    async fn reload(&self) -> Result<()>;
18    fn get_config(&self) -> Arc<Config>;
19    fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()>;
20    fn add_route(&self, rule: ProxyRule) -> Result<()>;
21    fn remove_route(&self, index: usize) -> Result<()>;
22}
23
24#[derive(Deserialize, Default, Clone, Debug)]
25pub struct TomlConfig {
26    #[serde(default)]
27    pub server: ServerConfig,
28    #[serde(default)]
29    pub tls: TlsConfig,
30    pub letsencrypt: Option<LetsEncryptConfig>,
31    pub scripting: Option<ScriptingTomlConfig>,
32    pub admin: Option<AdminConfig>,
33    pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
34}
35
36#[derive(Deserialize, Serialize, Clone, Debug)]
37pub struct CircuitBreakerTomlConfig {
38    pub failure_threshold: Option<u32>,
39    pub recovery_timeout_secs: Option<u64>,
40    pub success_threshold: Option<u32>,
41    pub failure_status_codes: Option<Vec<u16>>,
42}
43
44#[derive(Deserialize, Serialize, Clone, Debug)]
45pub struct AdminConfig {
46    pub enabled: bool,
47    pub bind: String,
48    pub api_key: Option<String>,
49    #[serde(default)]
50    pub username: Option<String>,
51    #[serde(default)]
52    pub password_hash: Option<String>,
53}
54
55impl Default for AdminConfig {
56    fn default() -> Self {
57        let _ = dotenv::dotenv();
58        let username = std::env::var("ADMIN_USER").ok();
59        let password_hash = std::env::var("ADMIN_PASSWORD").ok();
60        Self {
61            enabled: false,
62            bind: "127.0.0.1:9090".to_string(),
63            api_key: None,
64            username,
65            password_hash,
66        }
67    }
68}
69
70#[derive(Deserialize, Serialize, Clone, Debug, Default)]
71pub struct ScriptingTomlConfig {
72    pub enabled: bool,
73    pub scripts_dir: Option<String>,
74    pub hook_timeout_ms: Option<u64>,
75}
76
77#[derive(Deserialize, Serialize, Clone, Debug)]
78pub struct ServerConfig {
79    pub bind: String,
80    pub https_port: u16,
81}
82
83impl Default for ServerConfig {
84    fn default() -> Self {
85        Self {
86            bind: "0.0.0.0:8080".to_string(),
87            https_port: 443,
88        }
89    }
90}
91
92#[derive(Deserialize, Serialize, Default, Clone, Debug)]
93pub struct TlsConfig {
94    pub mode: String,
95    pub cache_dir: String,
96}
97
98#[derive(Deserialize, Serialize, Clone, Debug)]
99pub struct LetsEncryptConfig {
100    pub staging: bool,
101    pub email: String,
102    pub terms_agreed: bool,
103}
104
105#[derive(Clone, Debug, Serialize)]
106pub struct Config {
107    pub server: ServerConfig,
108    pub tls: TlsConfig,
109    pub letsencrypt: Option<LetsEncryptConfig>,
110    pub scripting: ScriptingTomlConfig,
111    pub admin: AdminConfig,
112    pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
113    pub rules: Vec<ProxyRule>,
114    pub global_scripts: Vec<String>,
115}
116
117#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
118pub enum LoadBalancingStrategy {
119    #[default]
120    #[serde(rename = "round-robin")]
121    RoundRobin,
122    #[serde(rename = "weighted")]
123    Weighted,
124    #[serde(rename = "failover")]
125    Failover,
126}
127
128#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct ProxyRule {
130    pub matcher: RuleMatcher,
131    pub targets: Vec<Target>,
132    pub headers: Vec<HeaderRule>,
133    pub scripts: Vec<String>,
134    #[serde(default)]
135    pub auth: Vec<BasicAuth>,
136    #[serde(default)]
137    pub load_balancing: LoadBalancingStrategy,
138}
139
140#[derive(Clone, Debug)]
141pub enum RuleMatcher {
142    Default,
143    Prefix(String),
144    Regex(RegexMatcher),
145    Exact(String),
146    Domain(String),
147    DomainPath(String, String),
148}
149
150/// Wrapper around Regex that stores the original pattern for serialization
151#[derive(Clone, Debug)]
152pub struct RegexMatcher {
153    pub pattern: String,
154    pub regex: Regex,
155}
156
157impl RegexMatcher {
158    pub fn new(pattern: &str) -> Result<Self> {
159        Ok(Self {
160            pattern: pattern.to_string(),
161            regex: Regex::new(pattern)?,
162        })
163    }
164
165    pub fn is_match(&self, text: &str) -> bool {
166        self.regex.is_match(text)
167    }
168}
169
170impl Serialize for RuleMatcher {
171    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
172    where
173        S: serde::Serializer,
174    {
175        use serde::ser::SerializeMap;
176        let mut map = serializer.serialize_map(Some(2))?;
177        match self {
178            RuleMatcher::Default => {
179                map.serialize_entry("type", "default")?;
180            }
181            RuleMatcher::Prefix(v) => {
182                map.serialize_entry("type", "prefix")?;
183                map.serialize_entry("value", v)?;
184            }
185            RuleMatcher::Regex(rm) => {
186                map.serialize_entry("type", "regex")?;
187                map.serialize_entry("value", &rm.pattern)?;
188            }
189            RuleMatcher::Exact(v) => {
190                map.serialize_entry("type", "exact")?;
191                map.serialize_entry("value", v)?;
192            }
193            RuleMatcher::Domain(v) => {
194                map.serialize_entry("type", "domain")?;
195                map.serialize_entry("value", v)?;
196            }
197            RuleMatcher::DomainPath(d, p) => {
198                map.serialize_entry("type", "domain_path")?;
199                map.serialize_entry("domain", d)?;
200                map.serialize_entry("path", p)?;
201            }
202        }
203        map.end()
204    }
205}
206
207impl<'de> Deserialize<'de> for RuleMatcher {
208    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
209    where
210        D: serde::Deserializer<'de>,
211    {
212        use serde::de::Error;
213        let value: serde_json::Value = Deserialize::deserialize(deserializer)?;
214        let obj = value
215            .as_object()
216            .ok_or_else(|| D::Error::custom("expected object"))?;
217        let matcher_type = obj
218            .get("type")
219            .and_then(|v| v.as_str())
220            .ok_or_else(|| D::Error::custom("missing 'type' field"))?;
221
222        match matcher_type {
223            "default" => Ok(RuleMatcher::Default),
224            "exact" => {
225                let v = obj
226                    .get("value")
227                    .and_then(|v| v.as_str())
228                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
229                Ok(RuleMatcher::Exact(v.to_string()))
230            }
231            "prefix" => {
232                let v = obj
233                    .get("value")
234                    .and_then(|v| v.as_str())
235                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
236                Ok(RuleMatcher::Prefix(v.to_string()))
237            }
238            "regex" => {
239                let v = obj
240                    .get("value")
241                    .and_then(|v| v.as_str())
242                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
243                let rm = RegexMatcher::new(v)
244                    .map_err(|e| D::Error::custom(format!("invalid regex: {}", e)))?;
245                Ok(RuleMatcher::Regex(rm))
246            }
247            "domain" => {
248                let v = obj
249                    .get("value")
250                    .and_then(|v| v.as_str())
251                    .ok_or_else(|| D::Error::custom("missing 'value'"))?;
252                Ok(RuleMatcher::Domain(v.to_string()))
253            }
254            "domain_path" => {
255                let d = obj
256                    .get("domain")
257                    .and_then(|v| v.as_str())
258                    .ok_or_else(|| D::Error::custom("missing 'domain'"))?;
259                let p = obj
260                    .get("path")
261                    .and_then(|v| v.as_str())
262                    .ok_or_else(|| D::Error::custom("missing 'path'"))?;
263                Ok(RuleMatcher::DomainPath(d.to_string(), p.to_string()))
264            }
265            other => Err(D::Error::custom(format!("unknown matcher type: {}", other))),
266        }
267    }
268}
269
270#[derive(Clone, Debug, Serialize, Deserialize)]
271pub struct Target {
272    pub url: Url,
273    pub weight: u8,
274}
275
276#[derive(Clone, Debug, Serialize, Deserialize)]
277pub struct HeaderRule {
278    pub name: String,
279    pub value: String,
280}
281
282impl Config {
283    /// Extract unique domain names from Domain and DomainPath rules,
284    /// filtering out IPs and "localhost".
285    pub fn acme_domains(&self) -> Vec<String> {
286        let mut domains = Vec::new();
287        let mut seen = std::collections::HashSet::new();
288
289        for rule in &self.rules {
290            let domain = match &rule.matcher {
291                RuleMatcher::Domain(d) => Some(d.as_str()),
292                RuleMatcher::DomainPath(d, _) => Some(d.as_str()),
293                _ => None,
294            };
295
296            if let Some(d) = domain {
297                if d == "localhost" || d.parse::<std::net::IpAddr>().is_ok() {
298                    continue;
299                }
300                if seen.insert(d.to_string()) {
301                    domains.push(d.to_string());
302                }
303            }
304        }
305
306        domains
307    }
308}
309
310pub struct ConfigManager {
311    config: ArcSwap<Config>,
312    config_path: PathBuf,
313    _watcher: Option<RecommendedWatcher>,
314    suppress_watch: Arc<AtomicBool>,
315}
316
317impl Clone for ConfigManager {
318    fn clone(&self) -> Self {
319        Self {
320            config: ArcSwap::new(self.config.load().clone()),
321            config_path: self.config_path.clone(),
322            _watcher: None,
323            suppress_watch: self.suppress_watch.clone(),
324        }
325    }
326}
327
328impl ConfigManager {
329    pub fn new(config_path: &str) -> Result<Self> {
330        let path = PathBuf::from(config_path);
331        let config = Self::load_config(&path, &path)?;
332        Ok(Self {
333            config: ArcSwap::new(Arc::new(config)),
334            config_path: path,
335            _watcher: None,
336            suppress_watch: Arc::new(AtomicBool::new(false)),
337        })
338    }
339
340    pub fn config_path(&self) -> &Path {
341        &self.config_path
342    }
343
344    pub fn suppress_watch(&self) -> &Arc<AtomicBool> {
345        &self.suppress_watch
346    }
347
348    fn load_config(proxy_conf_path: &Path, config_path: &Path) -> Result<Config> {
349        let content = std::fs::read_to_string(proxy_conf_path).unwrap_or_default();
350        let (rules, global_scripts) = parse_proxy_config(&content)?;
351        let toml_content = std::fs::read_to_string(
352            config_path
353                .parent()
354                .unwrap_or(Path::new("."))
355                .join("config.toml"),
356        )
357        .ok();
358        let toml_config: TomlConfig = toml_content
359            .as_ref()
360            .and_then(|c| toml::from_str(c).ok())
361            .unwrap_or_default();
362
363        Ok(Config {
364            server: toml_config.server,
365            tls: toml_config.tls,
366            letsencrypt: toml_config.letsencrypt,
367            scripting: toml_config.scripting.unwrap_or_default(),
368            admin: toml_config.admin.unwrap_or_default(),
369            circuit_breaker: toml_config.circuit_breaker,
370            rules,
371            global_scripts,
372        })
373    }
374
375    pub fn get_config(&self) -> Arc<Config> {
376        self.config.load().clone()
377    }
378
379    pub fn start_watcher(&self) -> Result<()> {
380        // Ensure the file exists so the watcher has something to watch
381        if !self.config_path.exists() {
382            if let Some(parent) = self.config_path.parent() {
383                std::fs::create_dir_all(parent).ok();
384            }
385            std::fs::write(&self.config_path, "")?;
386        }
387
388        let (tx, mut rx) = mpsc::channel(1);
389        let config_path = self.config_path.clone();
390        let suppress = self.suppress_watch.clone();
391
392        let mut watcher = RecommendedWatcher::new(
393            move |res| {
394                let _ = tx.blocking_send(res);
395            },
396            notify::Config::default(),
397        )?;
398
399        watcher.watch(&config_path, RecursiveMode::NonRecursive)?;
400
401        tracing::info!("Watching config file: {}", config_path.display());
402
403        std::thread::spawn(move || {
404            while let Some(res) = rx.blocking_recv() {
405                match res {
406                    Ok(event) => {
407                        if event.kind.is_modify() {
408                            if suppress.swap(false, Ordering::SeqCst) {
409                                tracing::debug!(
410                                    "Suppressing file watcher reload (admin API write)"
411                                );
412                                continue;
413                            }
414                            tracing::info!("Config file changed, reloading...");
415                        }
416                    }
417                    Err(e) => tracing::error!("Watch error: {}", e),
418                }
419            }
420        });
421
422        Ok(())
423    }
424
425    pub async fn reload(&self) -> Result<()> {
426        let new_config = Self::load_config(&self.config_path, &self.config_path)?;
427        self.config.store(Arc::new(new_config));
428        tracing::info!("Configuration reloaded successfully");
429        Ok(())
430    }
431
432    /// Persist current rules to proxy.conf and swap in-memory config
433    fn persist_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
434        let content = serializer::serialize_proxy_conf(&rules, &global_scripts);
435        self.suppress_watch.store(true, Ordering::SeqCst);
436        std::fs::write(&self.config_path, &content)?;
437        let mut config = (*self.config.load().as_ref()).clone();
438        config.rules = rules;
439        config.global_scripts = global_scripts;
440        self.config.store(Arc::new(config));
441        tracing::info!("Configuration persisted to {}", self.config_path.display());
442        Ok(())
443    }
444
445    pub fn add_route(&self, rule: ProxyRule) -> Result<()> {
446        let cfg = self.get_config();
447        let mut rules = cfg.rules.clone();
448        rules.push(rule);
449        self.persist_rules(rules, cfg.global_scripts.clone())
450    }
451
452    pub fn update_route(&self, index: usize, rule: ProxyRule) -> Result<()> {
453        let cfg = self.get_config();
454        let mut rules = cfg.rules.clone();
455        if index >= rules.len() {
456            anyhow::bail!(
457                "Route index {} out of range (have {} routes)",
458                index,
459                rules.len()
460            );
461        }
462        rules[index] = rule;
463        self.persist_rules(rules, cfg.global_scripts.clone())
464    }
465
466    pub fn remove_route(&self, index: usize) -> Result<()> {
467        let cfg = self.get_config();
468        let mut rules = cfg.rules.clone();
469        if index >= rules.len() {
470            anyhow::bail!(
471                "Route index {} out of range (have {} routes)",
472                index,
473                rules.len()
474            );
475        }
476        rules.remove(index);
477        self.persist_rules(rules, cfg.global_scripts.clone())
478    }
479
480    pub fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
481        self.persist_rules(rules, global_scripts)
482    }
483}
484
485#[async_trait::async_trait]
486impl ConfigManagerTrait for ConfigManager {
487    async fn reload(&self) -> Result<()> {
488        self.reload().await
489    }
490
491    fn get_config(&self) -> Arc<Config> {
492        self.get_config()
493    }
494
495    fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
496        self.update_rules(rules, global_scripts)
497    }
498
499    fn add_route(&self, rule: ProxyRule) -> Result<()> {
500        self.add_route(rule)
501    }
502
503    fn remove_route(&self, index: usize) -> Result<()> {
504        self.remove_route(index)
505    }
506}
507
508/// Extract `@script:a.lua,b.lua` from a string, returning (remaining_str, scripts_vec).
509fn extract_scripts(s: &str) -> (&str, Vec<String>) {
510    if let Some(idx) = s.find("@script:") {
511        let before = s[..idx].trim();
512        let after = &s[idx + "@script:".len()..];
513        // Scripts are comma-separated, ending at whitespace or end-of-string
514        let script_part = after.split_whitespace().next().unwrap_or(after);
515        let scripts: Vec<String> = script_part
516            .split(',')
517            .map(|s| s.trim().to_string())
518            .filter(|s| !s.is_empty())
519            .collect();
520        (before, scripts)
521    } else {
522        (s, Vec::new())
523    }
524}
525
526/// Extract `@auth:user:hash` entries from a string, returning (remaining_str, auth_vec).
527/// Multiple @auth entries can appear: `@auth:user1:hash1 @auth:user2:hash2`
528/// Hash is everything after the second colon (bcrypt hashes start with $2a$, $2b$, $2y$)
529fn extract_auth(s: &str) -> (String, Vec<BasicAuth>) {
530    let mut auth_entries = Vec::new();
531    let mut remaining = s.to_string();
532
533    while let Some(idx) = remaining.find("@auth:") {
534        // Keep the part BEFORE @auth:
535        let before = &remaining[..idx];
536        let after = &remaining[idx + "@auth:".len()..];
537
538        // Find the end of this auth entry (whitespace or end of string)
539        let end_idx = after
540            .find(|c: char| c.is_whitespace())
541            .unwrap_or(after.len());
542
543        let auth_part = &after[..end_idx];
544
545        // Parse username:hash - hash is everything after the first colon
546        if let Some((username, hash)) = auth_part.split_once(':') {
547            if !username.is_empty() && !hash.is_empty() {
548                auth_entries.push(BasicAuth {
549                    username: username.to_string(),
550                    hash: hash.to_string(),
551                });
552            }
553        }
554
555        // Continue with the part BEFORE this @auth, plus any remaining after it
556        let rest = &after[end_idx..];
557        remaining = if rest.is_empty() {
558            before.to_string()
559        } else {
560            format!("{}{}", before, rest)
561        };
562    }
563
564    (remaining.trim().to_string(), auth_entries)
565}
566
567/// Extract `@lb:strategy` from a string, returning (remaining_str, strategy).
568/// Example: `@lb:round-robin` or `@lb:weighted` or `@lb:failover`
569fn extract_load_balancing(s: &str) -> (String, LoadBalancingStrategy) {
570    if let Some(idx) = s.find("@lb:") {
571        let before = &s[..idx];
572        let after = &s[idx + "@lb:".len()..];
573
574        let end_idx = after
575            .find(|c: char| c.is_whitespace())
576            .unwrap_or(after.len());
577
578        let strategy_str = &after[..end_idx];
579        let strategy = match strategy_str {
580            "round-robin" => LoadBalancingStrategy::RoundRobin,
581            "weighted" => LoadBalancingStrategy::Weighted,
582            "failover" => LoadBalancingStrategy::Failover,
583            _ => LoadBalancingStrategy::default(),
584        };
585
586        let rest = &after[end_idx..];
587        let remaining = if rest.is_empty() {
588            before.to_string()
589        } else {
590            format!("{}{}", before, rest)
591        };
592
593        (remaining.trim().to_string(), strategy)
594    } else {
595        (s.to_string(), LoadBalancingStrategy::default())
596    }
597}
598
599fn parse_proxy_config(content: &str) -> Result<(Vec<ProxyRule>, Vec<String>)> {
600    let mut rules = Vec::new();
601    let mut global_scripts = Vec::new();
602
603    // Join continuation lines (backslash at end of line)
604    let mut joined_lines: Vec<String> = Vec::new();
605    for line in content.lines() {
606        if let Some(current) = joined_lines.last_mut() {
607            if current.ends_with('\\') {
608                current.pop(); // remove the backslash
609                current.push_str(line.trim());
610                continue;
611            }
612        }
613        joined_lines.push(line.to_string());
614    }
615
616    for line in &joined_lines {
617        let trimmed = line.trim();
618        if trimmed.is_empty() || trimmed.starts_with('#') {
619            continue;
620        }
621
622        // Handle [global] @script:cors.lua,logging.lua
623        if trimmed.starts_with("[global]") {
624            let rest = trimmed.strip_prefix("[global]").unwrap().trim();
625            let (_, scripts) = extract_scripts(rest);
626            global_scripts.extend(scripts);
627            continue;
628        }
629
630        if let Some((source, target_str)) = trimmed.split_once("->") {
631            let source = source.trim();
632            // Extract @script: from the target side
633            let (target_str, route_scripts) = extract_scripts(target_str.trim());
634            // Extract @auth: entries from the target side
635            let (target_str, auth_entries) = extract_auth(target_str);
636            // Extract @lb: load balancing strategy
637            let (target_str, load_balancing) = extract_load_balancing(&target_str);
638
639            let matcher = if source == "default" || source == "*" {
640                RuleMatcher::Default
641            } else if let Some(pattern) = source.strip_prefix("~") {
642                RuleMatcher::Regex(RegexMatcher::new(pattern)?)
643            } else if !source.starts_with('/')
644                && (source.contains('.') || source.parse::<std::net::IpAddr>().is_ok())
645            {
646                if let Some((domain, path)) = source.split_once('/') {
647                    if path.is_empty() || path == "*" {
648                        RuleMatcher::Domain(domain.to_string())
649                    } else if path.ends_with("/*") {
650                        RuleMatcher::DomainPath(
651                            domain.to_string(),
652                            path.trim_end_matches('*').to_string(),
653                        )
654                    } else {
655                        RuleMatcher::DomainPath(domain.to_string(), path.to_string())
656                    }
657                } else {
658                    RuleMatcher::Domain(source.to_string())
659                }
660            } else if source.ends_with("/*") {
661                RuleMatcher::Prefix(source.trim_end_matches('*').to_string())
662            } else {
663                RuleMatcher::Exact(source.to_string())
664            };
665
666            let targets: Vec<Target> = target_str
667                .split(',')
668                .map(|t| {
669                    Ok(Target {
670                        url: Url::parse(t.trim())?,
671                        weight: 100,
672                    })
673                })
674                .collect::<Result<Vec<_>>>()?;
675
676            rules.push(ProxyRule {
677                matcher,
678                targets,
679                headers: vec![],
680                scripts: route_scripts,
681                auth: auth_entries,
682                load_balancing,
683            });
684        }
685    }
686
687    Ok((rules, global_scripts))
688}
689
690#[cfg(test)]
691mod tests {
692    use super::*;
693
694    #[test]
695    fn test_backslash_continuation_joins_lines() {
696        let config = "/api/* -> http://backend1:8080, \\\n          http://backend2:8080\n";
697        let (rules, _) = parse_proxy_config(config).unwrap();
698        assert_eq!(rules.len(), 1);
699        assert_eq!(rules[0].targets.len(), 2);
700        assert_eq!(rules[0].targets[0].url.as_str(), "http://backend1:8080/");
701        assert_eq!(rules[0].targets[1].url.as_str(), "http://backend2:8080/");
702    }
703
704    #[test]
705    fn test_multiple_continuation_lines() {
706        let config = "/api/* -> http://backend1:8080, \\\n\
707                       http://backend2:8080, \\\n\
708                       http://backend3:8080\n";
709        let (rules, _) = parse_proxy_config(config).unwrap();
710        assert_eq!(rules.len(), 1);
711        assert_eq!(rules[0].targets.len(), 3);
712        assert_eq!(rules[0].targets[2].url.as_str(), "http://backend3:8080/");
713    }
714
715    #[test]
716    fn test_backslash_mid_line_not_continuation() {
717        let config = "/path -> http://localhost:8080\n\
718                       ~^/foo\\dbar$ -> http://localhost:9090\n";
719        let (rules, _) = parse_proxy_config(config).unwrap();
720        assert_eq!(rules.len(), 2);
721    }
722
723    #[test]
724    fn test_continuation_trims_whitespace() {
725        let config = "/api/* -> http://a:8080,   \\\n   http://b:8080,  \\\n   http://c:8080\n";
726        let (rules, _) = parse_proxy_config(config).unwrap();
727        assert_eq!(rules.len(), 1);
728        assert_eq!(rules[0].targets.len(), 3);
729    }
730
731    #[test]
732    fn test_continuation_with_scripts() {
733        let config = "/api/* -> http://a:8080, \\\n\
734                       http://b:8080 @script:auth.lua\n";
735        let (rules, _) = parse_proxy_config(config).unwrap();
736        assert_eq!(rules.len(), 1);
737        assert_eq!(rules[0].targets.len(), 2);
738        assert_eq!(rules[0].scripts, vec!["auth.lua"]);
739    }
740
741    #[test]
742    fn test_no_continuation_normal_config() {
743        let config = "/api/* -> http://backend:8080\ndefault -> http://localhost:3000\n";
744        let (rules, _) = parse_proxy_config(config).unwrap();
745        assert_eq!(rules.len(), 2);
746    }
747
748    #[test]
749    fn test_auth_parsing() {
750        let config = r#"
751/db/* -> http://localhost:8080/ @auth:demo:$2b$12$YFlnIiACnSaAcxDWQlYjeedxq/3GvhvoGhRTYHMqLifJrETSqOZQa
752"#;
753        let (rules, _) = parse_proxy_config(config).unwrap();
754        assert_eq!(rules.len(), 1);
755        assert_eq!(rules[0].auth.len(), 1);
756        assert_eq!(rules[0].auth[0].username, "demo");
757        assert!(rules[0].auth[0].hash.starts_with("$2b$"));
758        assert_eq!(rules[0].targets.len(), 1);
759        assert_eq!(rules[0].targets[0].url.as_str(), "http://localhost:8080/");
760    }
761
762    #[test]
763    fn test_multiple_auth_users() {
764        let config = r#"
765secure.example.com -> http://localhost:9000/ @auth:admin:$2b$12$hash1 @auth:user:$2b$12$hash2
766"#;
767        let (rules, _) = parse_proxy_config(config).unwrap();
768        assert_eq!(rules.len(), 1);
769        assert_eq!(rules[0].auth.len(), 2);
770        assert_eq!(rules[0].auth[0].username, "admin");
771        assert_eq!(rules[0].auth[1].username, "user");
772    }
773
774    #[test]
775    fn test_load_balancing_round_robin() {
776        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:round-robin";
777        let (rules, _) = parse_proxy_config(config).unwrap();
778        assert_eq!(rules.len(), 1);
779        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
780        assert_eq!(rules[0].targets.len(), 2);
781    }
782
783    #[test]
784    fn test_load_balancing_weighted() {
785        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:weighted";
786        let (rules, _) = parse_proxy_config(config).unwrap();
787        assert_eq!(rules.len(), 1);
788        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Weighted);
789    }
790
791    #[test]
792    fn test_load_balancing_failover() {
793        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover";
794        let (rules, _) = parse_proxy_config(config).unwrap();
795        assert_eq!(rules.len(), 1);
796        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
797    }
798
799    #[test]
800    fn test_load_balancing_default_is_round_robin() {
801        let config = "/api/* -> http://b1:8080, http://b2:8080";
802        let (rules, _) = parse_proxy_config(config).unwrap();
803        assert_eq!(rules.len(), 1);
804        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
805    }
806
807    #[test]
808    fn test_load_balancing_with_scripts() {
809        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover @script:auth.lua";
810        let (rules, _) = parse_proxy_config(config).unwrap();
811        assert_eq!(rules.len(), 1);
812        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
813        assert_eq!(rules[0].scripts, vec!["auth.lua"]);
814    }
815
816    #[test]
817    fn test_load_balancing_unknown_strategy_defaults_to_round_robin() {
818        let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:unknown";
819        let (rules, _) = parse_proxy_config(config).unwrap();
820        assert_eq!(rules.len(), 1);
821        assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
822    }
823}