1pub mod serializer;
2
3use crate::auth::BasicAuth;
4use anyhow::Result;
5use arc_swap::ArcSwap;
6use notify::{RecommendedWatcher, RecursiveMode, Watcher};
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use url::Url;
14
15#[async_trait::async_trait]
16pub trait ConfigManagerTrait: Send + Sync {
17 async fn reload(&self) -> Result<()>;
18 fn get_config(&self) -> Arc<Config>;
19 fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()>;
20 fn add_route(&self, rule: ProxyRule) -> Result<()>;
21 fn remove_route(&self, index: usize) -> Result<()>;
22}
23
24#[derive(Deserialize, Default, Clone, Debug)]
25pub struct TomlConfig {
26 #[serde(default)]
27 pub server: ServerConfig,
28 #[serde(default)]
29 pub tls: TlsConfig,
30 pub letsencrypt: Option<LetsEncryptConfig>,
31 pub scripting: Option<ScriptingTomlConfig>,
32 pub admin: Option<AdminConfig>,
33 pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
34}
35
36#[derive(Deserialize, Serialize, Clone, Debug)]
37pub struct CircuitBreakerTomlConfig {
38 pub failure_threshold: Option<u32>,
39 pub recovery_timeout_secs: Option<u64>,
40 pub success_threshold: Option<u32>,
41 pub failure_status_codes: Option<Vec<u16>>,
42}
43
44#[derive(Deserialize, Serialize, Clone, Debug)]
45pub struct AdminConfig {
46 pub enabled: bool,
47 pub bind: String,
48 pub api_key: Option<String>,
49 #[serde(default)]
50 pub username: Option<String>,
51 #[serde(default)]
52 pub password_hash: Option<String>,
53}
54
55impl Default for AdminConfig {
56 fn default() -> Self {
57 let _ = dotenv::dotenv();
58 let username = std::env::var("ADMIN_USER").ok();
59 let password_hash = std::env::var("ADMIN_PASSWORD").ok();
60 Self {
61 enabled: false,
62 bind: "127.0.0.1:9090".to_string(),
63 api_key: None,
64 username,
65 password_hash,
66 }
67 }
68}
69
70#[derive(Deserialize, Serialize, Clone, Debug, Default)]
71pub struct ScriptingTomlConfig {
72 pub enabled: bool,
73 pub scripts_dir: Option<String>,
74 pub hook_timeout_ms: Option<u64>,
75}
76
77#[derive(Deserialize, Serialize, Clone, Debug)]
78pub struct ServerConfig {
79 pub bind: String,
80 pub https_port: u16,
81}
82
83impl Default for ServerConfig {
84 fn default() -> Self {
85 Self {
86 bind: "0.0.0.0:8080".to_string(),
87 https_port: 443,
88 }
89 }
90}
91
92#[derive(Deserialize, Serialize, Default, Clone, Debug)]
93pub struct TlsConfig {
94 pub mode: String,
95 pub cache_dir: String,
96}
97
98#[derive(Deserialize, Serialize, Clone, Debug)]
99pub struct LetsEncryptConfig {
100 pub staging: bool,
101 pub email: String,
102 pub terms_agreed: bool,
103}
104
105#[derive(Clone, Debug, Serialize)]
106pub struct Config {
107 pub server: ServerConfig,
108 pub tls: TlsConfig,
109 pub letsencrypt: Option<LetsEncryptConfig>,
110 pub scripting: ScriptingTomlConfig,
111 pub admin: AdminConfig,
112 pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
113 pub rules: Vec<ProxyRule>,
114 pub global_scripts: Vec<String>,
115}
116
117#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
118pub enum LoadBalancingStrategy {
119 #[default]
120 #[serde(rename = "round-robin")]
121 RoundRobin,
122 #[serde(rename = "weighted")]
123 Weighted,
124 #[serde(rename = "failover")]
125 Failover,
126}
127
128#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct ProxyRule {
130 pub matcher: RuleMatcher,
131 pub targets: Vec<Target>,
132 pub headers: Vec<HeaderRule>,
133 pub scripts: Vec<String>,
134 #[serde(default)]
135 pub auth: Vec<BasicAuth>,
136 #[serde(default)]
137 pub load_balancing: LoadBalancingStrategy,
138}
139
140#[derive(Clone, Debug)]
141pub enum RuleMatcher {
142 Default,
143 Prefix(String),
144 Regex(RegexMatcher),
145 Exact(String),
146 Domain(String),
147 DomainPath(String, String),
148}
149
150#[derive(Clone, Debug)]
152pub struct RegexMatcher {
153 pub pattern: String,
154 pub regex: Regex,
155}
156
157impl RegexMatcher {
158 pub fn new(pattern: &str) -> Result<Self> {
159 Ok(Self {
160 pattern: pattern.to_string(),
161 regex: Regex::new(pattern)?,
162 })
163 }
164
165 pub fn is_match(&self, text: &str) -> bool {
166 self.regex.is_match(text)
167 }
168}
169
170impl Serialize for RuleMatcher {
171 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
172 where
173 S: serde::Serializer,
174 {
175 use serde::ser::SerializeMap;
176 let mut map = serializer.serialize_map(Some(2))?;
177 match self {
178 RuleMatcher::Default => {
179 map.serialize_entry("type", "default")?;
180 }
181 RuleMatcher::Prefix(v) => {
182 map.serialize_entry("type", "prefix")?;
183 map.serialize_entry("value", v)?;
184 }
185 RuleMatcher::Regex(rm) => {
186 map.serialize_entry("type", "regex")?;
187 map.serialize_entry("value", &rm.pattern)?;
188 }
189 RuleMatcher::Exact(v) => {
190 map.serialize_entry("type", "exact")?;
191 map.serialize_entry("value", v)?;
192 }
193 RuleMatcher::Domain(v) => {
194 map.serialize_entry("type", "domain")?;
195 map.serialize_entry("value", v)?;
196 }
197 RuleMatcher::DomainPath(d, p) => {
198 map.serialize_entry("type", "domain_path")?;
199 map.serialize_entry("domain", d)?;
200 map.serialize_entry("path", p)?;
201 }
202 }
203 map.end()
204 }
205}
206
207impl<'de> Deserialize<'de> for RuleMatcher {
208 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
209 where
210 D: serde::Deserializer<'de>,
211 {
212 use serde::de::Error;
213 let value: serde_json::Value = Deserialize::deserialize(deserializer)?;
214 let obj = value
215 .as_object()
216 .ok_or_else(|| D::Error::custom("expected object"))?;
217 let matcher_type = obj
218 .get("type")
219 .and_then(|v| v.as_str())
220 .ok_or_else(|| D::Error::custom("missing 'type' field"))?;
221
222 match matcher_type {
223 "default" => Ok(RuleMatcher::Default),
224 "exact" => {
225 let v = obj
226 .get("value")
227 .and_then(|v| v.as_str())
228 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
229 Ok(RuleMatcher::Exact(v.to_string()))
230 }
231 "prefix" => {
232 let v = obj
233 .get("value")
234 .and_then(|v| v.as_str())
235 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
236 Ok(RuleMatcher::Prefix(v.to_string()))
237 }
238 "regex" => {
239 let v = obj
240 .get("value")
241 .and_then(|v| v.as_str())
242 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
243 let rm = RegexMatcher::new(v)
244 .map_err(|e| D::Error::custom(format!("invalid regex: {}", e)))?;
245 Ok(RuleMatcher::Regex(rm))
246 }
247 "domain" => {
248 let v = obj
249 .get("value")
250 .and_then(|v| v.as_str())
251 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
252 Ok(RuleMatcher::Domain(v.to_string()))
253 }
254 "domain_path" => {
255 let d = obj
256 .get("domain")
257 .and_then(|v| v.as_str())
258 .ok_or_else(|| D::Error::custom("missing 'domain'"))?;
259 let p = obj
260 .get("path")
261 .and_then(|v| v.as_str())
262 .ok_or_else(|| D::Error::custom("missing 'path'"))?;
263 Ok(RuleMatcher::DomainPath(d.to_string(), p.to_string()))
264 }
265 other => Err(D::Error::custom(format!("unknown matcher type: {}", other))),
266 }
267 }
268}
269
270#[derive(Clone, Debug, Serialize, Deserialize)]
271pub struct Target {
272 pub url: Url,
273 pub weight: u8,
274}
275
276#[derive(Clone, Debug, Serialize, Deserialize)]
277pub struct HeaderRule {
278 pub name: String,
279 pub value: String,
280}
281
282impl Config {
283 pub fn acme_domains(&self) -> Vec<String> {
286 let mut domains = Vec::new();
287 let mut seen = std::collections::HashSet::new();
288
289 for rule in &self.rules {
290 let domain = match &rule.matcher {
291 RuleMatcher::Domain(d) => Some(d.as_str()),
292 RuleMatcher::DomainPath(d, _) => Some(d.as_str()),
293 _ => None,
294 };
295
296 if let Some(d) = domain {
297 if d == "localhost" || d.parse::<std::net::IpAddr>().is_ok() {
298 continue;
299 }
300 if seen.insert(d.to_string()) {
301 domains.push(d.to_string());
302 }
303 }
304 }
305
306 domains
307 }
308}
309
310pub struct ConfigManager {
311 config: ArcSwap<Config>,
312 config_path: PathBuf,
313 _watcher: Option<RecommendedWatcher>,
314 suppress_watch: Arc<AtomicBool>,
315}
316
317impl Clone for ConfigManager {
318 fn clone(&self) -> Self {
319 Self {
320 config: ArcSwap::new(self.config.load().clone()),
321 config_path: self.config_path.clone(),
322 _watcher: None,
323 suppress_watch: self.suppress_watch.clone(),
324 }
325 }
326}
327
328impl ConfigManager {
329 pub fn new(config_path: &str) -> Result<Self> {
330 let path = PathBuf::from(config_path);
331 let config = Self::load_config(&path, &path)?;
332 Ok(Self {
333 config: ArcSwap::new(Arc::new(config)),
334 config_path: path,
335 _watcher: None,
336 suppress_watch: Arc::new(AtomicBool::new(false)),
337 })
338 }
339
340 pub fn config_path(&self) -> &Path {
341 &self.config_path
342 }
343
344 pub fn suppress_watch(&self) -> &Arc<AtomicBool> {
345 &self.suppress_watch
346 }
347
348 fn load_config(proxy_conf_path: &Path, config_path: &Path) -> Result<Config> {
349 let content = std::fs::read_to_string(proxy_conf_path).unwrap_or_default();
350 let (rules, global_scripts) = parse_proxy_config(&content)?;
351 let toml_content = std::fs::read_to_string(
352 config_path
353 .parent()
354 .unwrap_or(Path::new("."))
355 .join("config.toml"),
356 )
357 .ok();
358 let toml_config: TomlConfig = toml_content
359 .as_ref()
360 .and_then(|c| toml::from_str(c).ok())
361 .unwrap_or_default();
362
363 Ok(Config {
364 server: toml_config.server,
365 tls: toml_config.tls,
366 letsencrypt: toml_config.letsencrypt,
367 scripting: toml_config.scripting.unwrap_or_default(),
368 admin: toml_config.admin.unwrap_or_default(),
369 circuit_breaker: toml_config.circuit_breaker,
370 rules,
371 global_scripts,
372 })
373 }
374
375 pub fn get_config(&self) -> Arc<Config> {
376 self.config.load().clone()
377 }
378
379 pub fn start_watcher(&self) -> Result<()> {
380 if !self.config_path.exists() {
382 if let Some(parent) = self.config_path.parent() {
383 std::fs::create_dir_all(parent).ok();
384 }
385 std::fs::write(&self.config_path, "")?;
386 }
387
388 let (tx, mut rx) = mpsc::channel(1);
389 let config_path = self.config_path.clone();
390 let suppress = self.suppress_watch.clone();
391
392 let mut watcher = RecommendedWatcher::new(
393 move |res| {
394 let _ = tx.blocking_send(res);
395 },
396 notify::Config::default(),
397 )?;
398
399 watcher.watch(&config_path, RecursiveMode::NonRecursive)?;
400
401 tracing::info!("Watching config file: {}", config_path.display());
402
403 std::thread::spawn(move || {
404 while let Some(res) = rx.blocking_recv() {
405 match res {
406 Ok(event) => {
407 if event.kind.is_modify() {
408 if suppress.swap(false, Ordering::SeqCst) {
409 tracing::debug!(
410 "Suppressing file watcher reload (admin API write)"
411 );
412 continue;
413 }
414 tracing::info!("Config file changed, reloading...");
415 }
416 }
417 Err(e) => tracing::error!("Watch error: {}", e),
418 }
419 }
420 });
421
422 Ok(())
423 }
424
425 pub async fn reload(&self) -> Result<()> {
426 let new_config = Self::load_config(&self.config_path, &self.config_path)?;
427 self.config.store(Arc::new(new_config));
428 tracing::info!("Configuration reloaded successfully");
429 Ok(())
430 }
431
432 fn persist_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
434 let content = serializer::serialize_proxy_conf(&rules, &global_scripts);
435 self.suppress_watch.store(true, Ordering::SeqCst);
436 std::fs::write(&self.config_path, &content)?;
437 let mut config = (*self.config.load().as_ref()).clone();
438 config.rules = rules;
439 config.global_scripts = global_scripts;
440 self.config.store(Arc::new(config));
441 tracing::info!("Configuration persisted to {}", self.config_path.display());
442 Ok(())
443 }
444
445 pub fn add_route(&self, rule: ProxyRule) -> Result<()> {
446 let cfg = self.get_config();
447 let mut rules = cfg.rules.clone();
448 rules.push(rule);
449 self.persist_rules(rules, cfg.global_scripts.clone())
450 }
451
452 pub fn update_route(&self, index: usize, rule: ProxyRule) -> Result<()> {
453 let cfg = self.get_config();
454 let mut rules = cfg.rules.clone();
455 if index >= rules.len() {
456 anyhow::bail!(
457 "Route index {} out of range (have {} routes)",
458 index,
459 rules.len()
460 );
461 }
462 rules[index] = rule;
463 self.persist_rules(rules, cfg.global_scripts.clone())
464 }
465
466 pub fn remove_route(&self, index: usize) -> Result<()> {
467 let cfg = self.get_config();
468 let mut rules = cfg.rules.clone();
469 if index >= rules.len() {
470 anyhow::bail!(
471 "Route index {} out of range (have {} routes)",
472 index,
473 rules.len()
474 );
475 }
476 rules.remove(index);
477 self.persist_rules(rules, cfg.global_scripts.clone())
478 }
479
480 pub fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
481 self.persist_rules(rules, global_scripts)
482 }
483}
484
485#[async_trait::async_trait]
486impl ConfigManagerTrait for ConfigManager {
487 async fn reload(&self) -> Result<()> {
488 self.reload().await
489 }
490
491 fn get_config(&self) -> Arc<Config> {
492 self.get_config()
493 }
494
495 fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
496 self.update_rules(rules, global_scripts)
497 }
498
499 fn add_route(&self, rule: ProxyRule) -> Result<()> {
500 self.add_route(rule)
501 }
502
503 fn remove_route(&self, index: usize) -> Result<()> {
504 self.remove_route(index)
505 }
506}
507
508fn extract_scripts(s: &str) -> (&str, Vec<String>) {
510 if let Some(idx) = s.find("@script:") {
511 let before = s[..idx].trim();
512 let after = &s[idx + "@script:".len()..];
513 let script_part = after.split_whitespace().next().unwrap_or(after);
515 let scripts: Vec<String> = script_part
516 .split(',')
517 .map(|s| s.trim().to_string())
518 .filter(|s| !s.is_empty())
519 .collect();
520 (before, scripts)
521 } else {
522 (s, Vec::new())
523 }
524}
525
526fn extract_auth(s: &str) -> (String, Vec<BasicAuth>) {
530 let mut auth_entries = Vec::new();
531 let mut remaining = s.to_string();
532
533 while let Some(idx) = remaining.find("@auth:") {
534 let before = &remaining[..idx];
536 let after = &remaining[idx + "@auth:".len()..];
537
538 let end_idx = after
540 .find(|c: char| c.is_whitespace())
541 .unwrap_or(after.len());
542
543 let auth_part = &after[..end_idx];
544
545 if let Some((username, hash)) = auth_part.split_once(':') {
547 if !username.is_empty() && !hash.is_empty() {
548 auth_entries.push(BasicAuth {
549 username: username.to_string(),
550 hash: hash.to_string(),
551 });
552 }
553 }
554
555 let rest = &after[end_idx..];
557 remaining = if rest.is_empty() {
558 before.to_string()
559 } else {
560 format!("{}{}", before, rest)
561 };
562 }
563
564 (remaining.trim().to_string(), auth_entries)
565}
566
567fn extract_load_balancing(s: &str) -> (String, LoadBalancingStrategy) {
570 if let Some(idx) = s.find("@lb:") {
571 let before = &s[..idx];
572 let after = &s[idx + "@lb:".len()..];
573
574 let end_idx = after
575 .find(|c: char| c.is_whitespace())
576 .unwrap_or(after.len());
577
578 let strategy_str = &after[..end_idx];
579 let strategy = match strategy_str {
580 "round-robin" => LoadBalancingStrategy::RoundRobin,
581 "weighted" => LoadBalancingStrategy::Weighted,
582 "failover" => LoadBalancingStrategy::Failover,
583 _ => LoadBalancingStrategy::default(),
584 };
585
586 let rest = &after[end_idx..];
587 let remaining = if rest.is_empty() {
588 before.to_string()
589 } else {
590 format!("{}{}", before, rest)
591 };
592
593 (remaining.trim().to_string(), strategy)
594 } else {
595 (s.to_string(), LoadBalancingStrategy::default())
596 }
597}
598
599fn parse_proxy_config(content: &str) -> Result<(Vec<ProxyRule>, Vec<String>)> {
600 let mut rules = Vec::new();
601 let mut global_scripts = Vec::new();
602
603 let mut joined_lines: Vec<String> = Vec::new();
605 for line in content.lines() {
606 if let Some(current) = joined_lines.last_mut() {
607 if current.ends_with('\\') {
608 current.pop(); current.push_str(line.trim());
610 continue;
611 }
612 }
613 joined_lines.push(line.to_string());
614 }
615
616 for line in &joined_lines {
617 let trimmed = line.trim();
618 if trimmed.is_empty() || trimmed.starts_with('#') {
619 continue;
620 }
621
622 if trimmed.starts_with("[global]") {
624 let rest = trimmed.strip_prefix("[global]").unwrap().trim();
625 let (_, scripts) = extract_scripts(rest);
626 global_scripts.extend(scripts);
627 continue;
628 }
629
630 if let Some((source, target_str)) = trimmed.split_once("->") {
631 let source = source.trim();
632 let (target_str, route_scripts) = extract_scripts(target_str.trim());
634 let (target_str, auth_entries) = extract_auth(target_str);
636 let (target_str, load_balancing) = extract_load_balancing(&target_str);
638
639 let matcher = if source == "default" || source == "*" {
640 RuleMatcher::Default
641 } else if let Some(pattern) = source.strip_prefix("~") {
642 RuleMatcher::Regex(RegexMatcher::new(pattern)?)
643 } else if !source.starts_with('/')
644 && (source.contains('.') || source.parse::<std::net::IpAddr>().is_ok())
645 {
646 if let Some((domain, path)) = source.split_once('/') {
647 if path.is_empty() || path == "*" {
648 RuleMatcher::Domain(domain.to_string())
649 } else if path.ends_with("/*") {
650 RuleMatcher::DomainPath(
651 domain.to_string(),
652 path.trim_end_matches('*').to_string(),
653 )
654 } else {
655 RuleMatcher::DomainPath(domain.to_string(), path.to_string())
656 }
657 } else {
658 RuleMatcher::Domain(source.to_string())
659 }
660 } else if source.ends_with("/*") {
661 RuleMatcher::Prefix(source.trim_end_matches('*').to_string())
662 } else {
663 RuleMatcher::Exact(source.to_string())
664 };
665
666 let targets: Vec<Target> = target_str
667 .split(',')
668 .map(|t| {
669 Ok(Target {
670 url: Url::parse(t.trim())?,
671 weight: 100,
672 })
673 })
674 .collect::<Result<Vec<_>>>()?;
675
676 rules.push(ProxyRule {
677 matcher,
678 targets,
679 headers: vec![],
680 scripts: route_scripts,
681 auth: auth_entries,
682 load_balancing,
683 });
684 }
685 }
686
687 Ok((rules, global_scripts))
688}
689
690#[cfg(test)]
691mod tests {
692 use super::*;
693
694 #[test]
695 fn test_backslash_continuation_joins_lines() {
696 let config = "/api/* -> http://backend1:8080, \\\n http://backend2:8080\n";
697 let (rules, _) = parse_proxy_config(config).unwrap();
698 assert_eq!(rules.len(), 1);
699 assert_eq!(rules[0].targets.len(), 2);
700 assert_eq!(rules[0].targets[0].url.as_str(), "http://backend1:8080/");
701 assert_eq!(rules[0].targets[1].url.as_str(), "http://backend2:8080/");
702 }
703
704 #[test]
705 fn test_multiple_continuation_lines() {
706 let config = "/api/* -> http://backend1:8080, \\\n\
707 http://backend2:8080, \\\n\
708 http://backend3:8080\n";
709 let (rules, _) = parse_proxy_config(config).unwrap();
710 assert_eq!(rules.len(), 1);
711 assert_eq!(rules[0].targets.len(), 3);
712 assert_eq!(rules[0].targets[2].url.as_str(), "http://backend3:8080/");
713 }
714
715 #[test]
716 fn test_backslash_mid_line_not_continuation() {
717 let config = "/path -> http://localhost:8080\n\
718 ~^/foo\\dbar$ -> http://localhost:9090\n";
719 let (rules, _) = parse_proxy_config(config).unwrap();
720 assert_eq!(rules.len(), 2);
721 }
722
723 #[test]
724 fn test_continuation_trims_whitespace() {
725 let config = "/api/* -> http://a:8080, \\\n http://b:8080, \\\n http://c:8080\n";
726 let (rules, _) = parse_proxy_config(config).unwrap();
727 assert_eq!(rules.len(), 1);
728 assert_eq!(rules[0].targets.len(), 3);
729 }
730
731 #[test]
732 fn test_continuation_with_scripts() {
733 let config = "/api/* -> http://a:8080, \\\n\
734 http://b:8080 @script:auth.lua\n";
735 let (rules, _) = parse_proxy_config(config).unwrap();
736 assert_eq!(rules.len(), 1);
737 assert_eq!(rules[0].targets.len(), 2);
738 assert_eq!(rules[0].scripts, vec!["auth.lua"]);
739 }
740
741 #[test]
742 fn test_no_continuation_normal_config() {
743 let config = "/api/* -> http://backend:8080\ndefault -> http://localhost:3000\n";
744 let (rules, _) = parse_proxy_config(config).unwrap();
745 assert_eq!(rules.len(), 2);
746 }
747
748 #[test]
749 fn test_auth_parsing() {
750 let config = r#"
751/db/* -> http://localhost:8080/ @auth:demo:$2b$12$YFlnIiACnSaAcxDWQlYjeedxq/3GvhvoGhRTYHMqLifJrETSqOZQa
752"#;
753 let (rules, _) = parse_proxy_config(config).unwrap();
754 assert_eq!(rules.len(), 1);
755 assert_eq!(rules[0].auth.len(), 1);
756 assert_eq!(rules[0].auth[0].username, "demo");
757 assert!(rules[0].auth[0].hash.starts_with("$2b$"));
758 assert_eq!(rules[0].targets.len(), 1);
759 assert_eq!(rules[0].targets[0].url.as_str(), "http://localhost:8080/");
760 }
761
762 #[test]
763 fn test_multiple_auth_users() {
764 let config = r#"
765secure.example.com -> http://localhost:9000/ @auth:admin:$2b$12$hash1 @auth:user:$2b$12$hash2
766"#;
767 let (rules, _) = parse_proxy_config(config).unwrap();
768 assert_eq!(rules.len(), 1);
769 assert_eq!(rules[0].auth.len(), 2);
770 assert_eq!(rules[0].auth[0].username, "admin");
771 assert_eq!(rules[0].auth[1].username, "user");
772 }
773
774 #[test]
775 fn test_load_balancing_round_robin() {
776 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:round-robin";
777 let (rules, _) = parse_proxy_config(config).unwrap();
778 assert_eq!(rules.len(), 1);
779 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
780 assert_eq!(rules[0].targets.len(), 2);
781 }
782
783 #[test]
784 fn test_load_balancing_weighted() {
785 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:weighted";
786 let (rules, _) = parse_proxy_config(config).unwrap();
787 assert_eq!(rules.len(), 1);
788 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Weighted);
789 }
790
791 #[test]
792 fn test_load_balancing_failover() {
793 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover";
794 let (rules, _) = parse_proxy_config(config).unwrap();
795 assert_eq!(rules.len(), 1);
796 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
797 }
798
799 #[test]
800 fn test_load_balancing_default_is_round_robin() {
801 let config = "/api/* -> http://b1:8080, http://b2:8080";
802 let (rules, _) = parse_proxy_config(config).unwrap();
803 assert_eq!(rules.len(), 1);
804 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
805 }
806
807 #[test]
808 fn test_load_balancing_with_scripts() {
809 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover @script:auth.lua";
810 let (rules, _) = parse_proxy_config(config).unwrap();
811 assert_eq!(rules.len(), 1);
812 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
813 assert_eq!(rules[0].scripts, vec!["auth.lua"]);
814 }
815
816 #[test]
817 fn test_load_balancing_unknown_strategy_defaults_to_round_robin() {
818 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:unknown";
819 let (rules, _) = parse_proxy_config(config).unwrap();
820 assert_eq!(rules.len(), 1);
821 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
822 }
823}