1pub mod serializer;
2
3use crate::auth::BasicAuth;
4use anyhow::Result;
5use arc_swap::ArcSwap;
6use notify::{RecommendedWatcher, RecursiveMode, Watcher};
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use url::Url;
14
15#[async_trait::async_trait]
16pub trait ConfigManagerTrait: Send + Sync {
17 async fn reload(&self) -> Result<()>;
18 fn get_config(&self) -> Arc<Config>;
19 fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()>;
20 fn add_route(&self, rule: ProxyRule) -> Result<()>;
21 fn remove_route(&self, index: usize) -> Result<()>;
22}
23
24#[derive(Deserialize, Default, Clone, Debug)]
25pub struct TomlConfig {
26 #[serde(default)]
27 pub server: ServerConfig,
28 #[serde(default)]
29 pub tls: TlsConfig,
30 pub letsencrypt: Option<LetsEncryptConfig>,
31 pub scripting: Option<ScriptingTomlConfig>,
32 pub admin: Option<AdminConfig>,
33 pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
34}
35
36#[derive(Deserialize, Serialize, Clone, Debug)]
37pub struct CircuitBreakerTomlConfig {
38 pub failure_threshold: Option<u32>,
39 pub recovery_timeout_secs: Option<u64>,
40 pub success_threshold: Option<u32>,
41 pub failure_status_codes: Option<Vec<u16>>,
42}
43
44#[derive(Deserialize, Serialize, Clone, Debug)]
45pub struct AdminConfig {
46 pub enabled: bool,
47 pub bind: String,
48 pub api_key: Option<String>,
49}
50
51impl Default for AdminConfig {
52 fn default() -> Self {
53 Self {
54 enabled: false,
55 bind: "127.0.0.1:9090".to_string(),
56 api_key: None,
57 }
58 }
59}
60
61#[derive(Deserialize, Serialize, Clone, Debug, Default)]
62pub struct ScriptingTomlConfig {
63 pub enabled: bool,
64 pub scripts_dir: Option<String>,
65 pub hook_timeout_ms: Option<u64>,
66}
67
68#[derive(Deserialize, Serialize, Default, Clone, Debug)]
69pub struct ServerConfig {
70 pub bind: String,
71 pub https_port: u16,
72}
73
74#[derive(Deserialize, Serialize, Default, Clone, Debug)]
75pub struct TlsConfig {
76 pub mode: String,
77 pub cache_dir: String,
78}
79
80#[derive(Deserialize, Serialize, Clone, Debug)]
81pub struct LetsEncryptConfig {
82 pub staging: bool,
83 pub email: String,
84 pub terms_agreed: bool,
85}
86
87#[derive(Clone, Debug, Serialize)]
88pub struct Config {
89 pub server: ServerConfig,
90 pub tls: TlsConfig,
91 pub letsencrypt: Option<LetsEncryptConfig>,
92 pub scripting: ScriptingTomlConfig,
93 pub admin: AdminConfig,
94 pub circuit_breaker: Option<CircuitBreakerTomlConfig>,
95 pub rules: Vec<ProxyRule>,
96 pub global_scripts: Vec<String>,
97}
98
99#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
100pub enum LoadBalancingStrategy {
101 #[default]
102 #[serde(rename = "round-robin")]
103 RoundRobin,
104 #[serde(rename = "weighted")]
105 Weighted,
106 #[serde(rename = "failover")]
107 Failover,
108}
109
110#[derive(Clone, Debug, Serialize, Deserialize)]
111pub struct ProxyRule {
112 pub matcher: RuleMatcher,
113 pub targets: Vec<Target>,
114 pub headers: Vec<HeaderRule>,
115 pub scripts: Vec<String>,
116 #[serde(default)]
117 pub auth: Vec<BasicAuth>,
118 #[serde(default)]
119 pub load_balancing: LoadBalancingStrategy,
120}
121
122#[derive(Clone, Debug)]
123pub enum RuleMatcher {
124 Default,
125 Prefix(String),
126 Regex(RegexMatcher),
127 Exact(String),
128 Domain(String),
129 DomainPath(String, String),
130}
131
132#[derive(Clone, Debug)]
134pub struct RegexMatcher {
135 pub pattern: String,
136 pub regex: Regex,
137}
138
139impl RegexMatcher {
140 pub fn new(pattern: &str) -> Result<Self> {
141 Ok(Self {
142 pattern: pattern.to_string(),
143 regex: Regex::new(pattern)?,
144 })
145 }
146
147 pub fn is_match(&self, text: &str) -> bool {
148 self.regex.is_match(text)
149 }
150}
151
152impl Serialize for RuleMatcher {
153 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
154 where
155 S: serde::Serializer,
156 {
157 use serde::ser::SerializeMap;
158 let mut map = serializer.serialize_map(Some(2))?;
159 match self {
160 RuleMatcher::Default => {
161 map.serialize_entry("type", "default")?;
162 }
163 RuleMatcher::Prefix(v) => {
164 map.serialize_entry("type", "prefix")?;
165 map.serialize_entry("value", v)?;
166 }
167 RuleMatcher::Regex(rm) => {
168 map.serialize_entry("type", "regex")?;
169 map.serialize_entry("value", &rm.pattern)?;
170 }
171 RuleMatcher::Exact(v) => {
172 map.serialize_entry("type", "exact")?;
173 map.serialize_entry("value", v)?;
174 }
175 RuleMatcher::Domain(v) => {
176 map.serialize_entry("type", "domain")?;
177 map.serialize_entry("value", v)?;
178 }
179 RuleMatcher::DomainPath(d, p) => {
180 map.serialize_entry("type", "domain_path")?;
181 map.serialize_entry("domain", d)?;
182 map.serialize_entry("path", p)?;
183 }
184 }
185 map.end()
186 }
187}
188
189impl<'de> Deserialize<'de> for RuleMatcher {
190 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
191 where
192 D: serde::Deserializer<'de>,
193 {
194 use serde::de::Error;
195 let value: serde_json::Value = Deserialize::deserialize(deserializer)?;
196 let obj = value
197 .as_object()
198 .ok_or_else(|| D::Error::custom("expected object"))?;
199 let matcher_type = obj
200 .get("type")
201 .and_then(|v| v.as_str())
202 .ok_or_else(|| D::Error::custom("missing 'type' field"))?;
203
204 match matcher_type {
205 "default" => Ok(RuleMatcher::Default),
206 "exact" => {
207 let v = obj
208 .get("value")
209 .and_then(|v| v.as_str())
210 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
211 Ok(RuleMatcher::Exact(v.to_string()))
212 }
213 "prefix" => {
214 let v = obj
215 .get("value")
216 .and_then(|v| v.as_str())
217 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
218 Ok(RuleMatcher::Prefix(v.to_string()))
219 }
220 "regex" => {
221 let v = obj
222 .get("value")
223 .and_then(|v| v.as_str())
224 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
225 let rm = RegexMatcher::new(v)
226 .map_err(|e| D::Error::custom(format!("invalid regex: {}", e)))?;
227 Ok(RuleMatcher::Regex(rm))
228 }
229 "domain" => {
230 let v = obj
231 .get("value")
232 .and_then(|v| v.as_str())
233 .ok_or_else(|| D::Error::custom("missing 'value'"))?;
234 Ok(RuleMatcher::Domain(v.to_string()))
235 }
236 "domain_path" => {
237 let d = obj
238 .get("domain")
239 .and_then(|v| v.as_str())
240 .ok_or_else(|| D::Error::custom("missing 'domain'"))?;
241 let p = obj
242 .get("path")
243 .and_then(|v| v.as_str())
244 .ok_or_else(|| D::Error::custom("missing 'path'"))?;
245 Ok(RuleMatcher::DomainPath(d.to_string(), p.to_string()))
246 }
247 other => Err(D::Error::custom(format!("unknown matcher type: {}", other))),
248 }
249 }
250}
251
252#[derive(Clone, Debug, Serialize, Deserialize)]
253pub struct Target {
254 pub url: Url,
255 pub weight: u8,
256}
257
258#[derive(Clone, Debug, Serialize, Deserialize)]
259pub struct HeaderRule {
260 pub name: String,
261 pub value: String,
262}
263
264impl Config {
265 pub fn acme_domains(&self) -> Vec<String> {
268 let mut domains = Vec::new();
269 let mut seen = std::collections::HashSet::new();
270
271 for rule in &self.rules {
272 let domain = match &rule.matcher {
273 RuleMatcher::Domain(d) => Some(d.as_str()),
274 RuleMatcher::DomainPath(d, _) => Some(d.as_str()),
275 _ => None,
276 };
277
278 if let Some(d) = domain {
279 if d == "localhost" || d.parse::<std::net::IpAddr>().is_ok() {
280 continue;
281 }
282 if seen.insert(d.to_string()) {
283 domains.push(d.to_string());
284 }
285 }
286 }
287
288 domains
289 }
290}
291
292pub struct ConfigManager {
293 config: ArcSwap<Config>,
294 config_path: PathBuf,
295 _watcher: Option<RecommendedWatcher>,
296 suppress_watch: Arc<AtomicBool>,
297}
298
299impl Clone for ConfigManager {
300 fn clone(&self) -> Self {
301 Self {
302 config: ArcSwap::new(self.config.load().clone()),
303 config_path: self.config_path.clone(),
304 _watcher: None,
305 suppress_watch: self.suppress_watch.clone(),
306 }
307 }
308}
309
310impl ConfigManager {
311 pub fn new(config_path: &str) -> Result<Self> {
312 let path = PathBuf::from(config_path);
313 let config = Self::load_config(&path, &path)?;
314 Ok(Self {
315 config: ArcSwap::new(Arc::new(config)),
316 config_path: path,
317 _watcher: None,
318 suppress_watch: Arc::new(AtomicBool::new(false)),
319 })
320 }
321
322 pub fn config_path(&self) -> &Path {
323 &self.config_path
324 }
325
326 pub fn suppress_watch(&self) -> &Arc<AtomicBool> {
327 &self.suppress_watch
328 }
329
330 fn load_config(proxy_conf_path: &Path, config_path: &Path) -> Result<Config> {
331 let content = std::fs::read_to_string(proxy_conf_path).unwrap_or_default();
332 let (rules, global_scripts) = parse_proxy_config(&content)?;
333 let toml_content = std::fs::read_to_string(
334 config_path
335 .parent()
336 .unwrap_or(Path::new("."))
337 .join("config.toml"),
338 )
339 .ok();
340 let toml_config: TomlConfig = toml_content
341 .as_ref()
342 .and_then(|c| toml::from_str(c).ok())
343 .unwrap_or_default();
344
345 Ok(Config {
346 server: toml_config.server,
347 tls: toml_config.tls,
348 letsencrypt: toml_config.letsencrypt,
349 scripting: toml_config.scripting.unwrap_or_default(),
350 admin: toml_config.admin.unwrap_or_default(),
351 circuit_breaker: toml_config.circuit_breaker,
352 rules,
353 global_scripts,
354 })
355 }
356
357 pub fn get_config(&self) -> Arc<Config> {
358 self.config.load().clone()
359 }
360
361 pub fn start_watcher(&self) -> Result<()> {
362 if !self.config_path.exists() {
364 if let Some(parent) = self.config_path.parent() {
365 std::fs::create_dir_all(parent).ok();
366 }
367 std::fs::write(&self.config_path, "")?;
368 }
369
370 let (tx, mut rx) = mpsc::channel(1);
371 let config_path = self.config_path.clone();
372 let suppress = self.suppress_watch.clone();
373
374 let mut watcher = RecommendedWatcher::new(
375 move |res| {
376 let _ = tx.blocking_send(res);
377 },
378 notify::Config::default(),
379 )?;
380
381 watcher.watch(&config_path, RecursiveMode::NonRecursive)?;
382
383 tracing::info!("Watching config file: {}", config_path.display());
384
385 std::thread::spawn(move || {
386 while let Some(res) = rx.blocking_recv() {
387 match res {
388 Ok(event) => {
389 if event.kind.is_modify() {
390 if suppress.swap(false, Ordering::SeqCst) {
391 tracing::debug!(
392 "Suppressing file watcher reload (admin API write)"
393 );
394 continue;
395 }
396 tracing::info!("Config file changed, reloading...");
397 }
398 }
399 Err(e) => tracing::error!("Watch error: {}", e),
400 }
401 }
402 });
403
404 Ok(())
405 }
406
407 pub async fn reload(&self) -> Result<()> {
408 let new_config = Self::load_config(&self.config_path, &self.config_path)?;
409 self.config.store(Arc::new(new_config));
410 tracing::info!("Configuration reloaded successfully");
411 Ok(())
412 }
413
414 fn persist_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
416 let content = serializer::serialize_proxy_conf(&rules, &global_scripts);
417 self.suppress_watch.store(true, Ordering::SeqCst);
418 std::fs::write(&self.config_path, &content)?;
419 let mut config = (*self.config.load().as_ref()).clone();
420 config.rules = rules;
421 config.global_scripts = global_scripts;
422 self.config.store(Arc::new(config));
423 tracing::info!("Configuration persisted to {}", self.config_path.display());
424 Ok(())
425 }
426
427 pub fn add_route(&self, rule: ProxyRule) -> Result<()> {
428 let cfg = self.get_config();
429 let mut rules = cfg.rules.clone();
430 rules.push(rule);
431 self.persist_rules(rules, cfg.global_scripts.clone())
432 }
433
434 pub fn update_route(&self, index: usize, rule: ProxyRule) -> Result<()> {
435 let cfg = self.get_config();
436 let mut rules = cfg.rules.clone();
437 if index >= rules.len() {
438 anyhow::bail!(
439 "Route index {} out of range (have {} routes)",
440 index,
441 rules.len()
442 );
443 }
444 rules[index] = rule;
445 self.persist_rules(rules, cfg.global_scripts.clone())
446 }
447
448 pub fn remove_route(&self, index: usize) -> Result<()> {
449 let cfg = self.get_config();
450 let mut rules = cfg.rules.clone();
451 if index >= rules.len() {
452 anyhow::bail!(
453 "Route index {} out of range (have {} routes)",
454 index,
455 rules.len()
456 );
457 }
458 rules.remove(index);
459 self.persist_rules(rules, cfg.global_scripts.clone())
460 }
461
462 pub fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
463 self.persist_rules(rules, global_scripts)
464 }
465}
466
467#[async_trait::async_trait]
468impl ConfigManagerTrait for ConfigManager {
469 async fn reload(&self) -> Result<()> {
470 self.reload().await
471 }
472
473 fn get_config(&self) -> Arc<Config> {
474 self.get_config()
475 }
476
477 fn update_rules(&self, rules: Vec<ProxyRule>, global_scripts: Vec<String>) -> Result<()> {
478 self.update_rules(rules, global_scripts)
479 }
480
481 fn add_route(&self, rule: ProxyRule) -> Result<()> {
482 self.add_route(rule)
483 }
484
485 fn remove_route(&self, index: usize) -> Result<()> {
486 self.remove_route(index)
487 }
488}
489
490fn extract_scripts(s: &str) -> (&str, Vec<String>) {
492 if let Some(idx) = s.find("@script:") {
493 let before = s[..idx].trim();
494 let after = &s[idx + "@script:".len()..];
495 let script_part = after.split_whitespace().next().unwrap_or(after);
497 let scripts: Vec<String> = script_part
498 .split(',')
499 .map(|s| s.trim().to_string())
500 .filter(|s| !s.is_empty())
501 .collect();
502 (before, scripts)
503 } else {
504 (s, Vec::new())
505 }
506}
507
508fn extract_auth(s: &str) -> (String, Vec<BasicAuth>) {
512 let mut auth_entries = Vec::new();
513 let mut remaining = s.to_string();
514
515 while let Some(idx) = remaining.find("@auth:") {
516 let before = &remaining[..idx];
518 let after = &remaining[idx + "@auth:".len()..];
519
520 let end_idx = after
522 .find(|c: char| c.is_whitespace())
523 .unwrap_or(after.len());
524
525 let auth_part = &after[..end_idx];
526
527 if let Some((username, hash)) = auth_part.split_once(':') {
529 if !username.is_empty() && !hash.is_empty() {
530 auth_entries.push(BasicAuth {
531 username: username.to_string(),
532 hash: hash.to_string(),
533 });
534 }
535 }
536
537 let rest = &after[end_idx..];
539 remaining = if rest.is_empty() {
540 before.to_string()
541 } else {
542 format!("{}{}", before, rest)
543 };
544 }
545
546 (remaining.trim().to_string(), auth_entries)
547}
548
549fn extract_load_balancing(s: &str) -> (String, LoadBalancingStrategy) {
552 if let Some(idx) = s.find("@lb:") {
553 let before = &s[..idx];
554 let after = &s[idx + "@lb:".len()..];
555
556 let end_idx = after
557 .find(|c: char| c.is_whitespace())
558 .unwrap_or(after.len());
559
560 let strategy_str = &after[..end_idx];
561 let strategy = match strategy_str {
562 "round-robin" => LoadBalancingStrategy::RoundRobin,
563 "weighted" => LoadBalancingStrategy::Weighted,
564 "failover" => LoadBalancingStrategy::Failover,
565 _ => LoadBalancingStrategy::default(),
566 };
567
568 let rest = &after[end_idx..];
569 let remaining = if rest.is_empty() {
570 before.to_string()
571 } else {
572 format!("{}{}", before, rest)
573 };
574
575 (remaining.trim().to_string(), strategy)
576 } else {
577 (s.to_string(), LoadBalancingStrategy::default())
578 }
579}
580
581fn parse_proxy_config(content: &str) -> Result<(Vec<ProxyRule>, Vec<String>)> {
582 let mut rules = Vec::new();
583 let mut global_scripts = Vec::new();
584
585 let mut joined_lines: Vec<String> = Vec::new();
587 for line in content.lines() {
588 if let Some(current) = joined_lines.last_mut() {
589 if current.ends_with('\\') {
590 current.pop(); current.push_str(line.trim());
592 continue;
593 }
594 }
595 joined_lines.push(line.to_string());
596 }
597
598 for line in &joined_lines {
599 let trimmed = line.trim();
600 if trimmed.is_empty() || trimmed.starts_with('#') {
601 continue;
602 }
603
604 if trimmed.starts_with("[global]") {
606 let rest = trimmed.strip_prefix("[global]").unwrap().trim();
607 let (_, scripts) = extract_scripts(rest);
608 global_scripts.extend(scripts);
609 continue;
610 }
611
612 if let Some((source, target_str)) = trimmed.split_once("->") {
613 let source = source.trim();
614 let (target_str, route_scripts) = extract_scripts(target_str.trim());
616 let (target_str, auth_entries) = extract_auth(target_str);
618 let (target_str, load_balancing) = extract_load_balancing(&target_str);
620
621 let matcher = if source == "default" || source == "*" {
622 RuleMatcher::Default
623 } else if let Some(pattern) = source.strip_prefix("~") {
624 RuleMatcher::Regex(RegexMatcher::new(pattern)?)
625 } else if !source.starts_with('/')
626 && (source.contains('.') || source.parse::<std::net::IpAddr>().is_ok())
627 {
628 if let Some((domain, path)) = source.split_once('/') {
629 if path.is_empty() || path == "*" {
630 RuleMatcher::Domain(domain.to_string())
631 } else if path.ends_with("/*") {
632 RuleMatcher::DomainPath(
633 domain.to_string(),
634 path.trim_end_matches('*').to_string(),
635 )
636 } else {
637 RuleMatcher::DomainPath(domain.to_string(), path.to_string())
638 }
639 } else {
640 RuleMatcher::Domain(source.to_string())
641 }
642 } else if source.ends_with("/*") {
643 RuleMatcher::Prefix(source.trim_end_matches('*').to_string())
644 } else {
645 RuleMatcher::Exact(source.to_string())
646 };
647
648 let targets: Vec<Target> = target_str
649 .split(',')
650 .map(|t| {
651 Ok(Target {
652 url: Url::parse(t.trim())?,
653 weight: 100,
654 })
655 })
656 .collect::<Result<Vec<_>>>()?;
657
658 rules.push(ProxyRule {
659 matcher,
660 targets,
661 headers: vec![],
662 scripts: route_scripts,
663 auth: auth_entries,
664 load_balancing,
665 });
666 }
667 }
668
669 Ok((rules, global_scripts))
670}
671
672#[cfg(test)]
673mod tests {
674 use super::*;
675
676 #[test]
677 fn test_backslash_continuation_joins_lines() {
678 let config = "/api/* -> http://backend1:8080, \\\n http://backend2:8080\n";
679 let (rules, _) = parse_proxy_config(config).unwrap();
680 assert_eq!(rules.len(), 1);
681 assert_eq!(rules[0].targets.len(), 2);
682 assert_eq!(rules[0].targets[0].url.as_str(), "http://backend1:8080/");
683 assert_eq!(rules[0].targets[1].url.as_str(), "http://backend2:8080/");
684 }
685
686 #[test]
687 fn test_multiple_continuation_lines() {
688 let config = "/api/* -> http://backend1:8080, \\\n\
689 http://backend2:8080, \\\n\
690 http://backend3:8080\n";
691 let (rules, _) = parse_proxy_config(config).unwrap();
692 assert_eq!(rules.len(), 1);
693 assert_eq!(rules[0].targets.len(), 3);
694 assert_eq!(rules[0].targets[2].url.as_str(), "http://backend3:8080/");
695 }
696
697 #[test]
698 fn test_backslash_mid_line_not_continuation() {
699 let config = "/path -> http://localhost:8080\n\
700 ~^/foo\\dbar$ -> http://localhost:9090\n";
701 let (rules, _) = parse_proxy_config(config).unwrap();
702 assert_eq!(rules.len(), 2);
703 }
704
705 #[test]
706 fn test_continuation_trims_whitespace() {
707 let config = "/api/* -> http://a:8080, \\\n http://b:8080, \\\n http://c:8080\n";
708 let (rules, _) = parse_proxy_config(config).unwrap();
709 assert_eq!(rules.len(), 1);
710 assert_eq!(rules[0].targets.len(), 3);
711 }
712
713 #[test]
714 fn test_continuation_with_scripts() {
715 let config = "/api/* -> http://a:8080, \\\n\
716 http://b:8080 @script:auth.lua\n";
717 let (rules, _) = parse_proxy_config(config).unwrap();
718 assert_eq!(rules.len(), 1);
719 assert_eq!(rules[0].targets.len(), 2);
720 assert_eq!(rules[0].scripts, vec!["auth.lua"]);
721 }
722
723 #[test]
724 fn test_no_continuation_normal_config() {
725 let config = "/api/* -> http://backend:8080\ndefault -> http://localhost:3000\n";
726 let (rules, _) = parse_proxy_config(config).unwrap();
727 assert_eq!(rules.len(), 2);
728 }
729
730 #[test]
731 fn test_auth_parsing() {
732 let config = r#"
733/db/* -> http://localhost:8080/ @auth:demo:$2b$12$YFlnIiACnSaAcxDWQlYjeedxq/3GvhvoGhRTYHMqLifJrETSqOZQa
734"#;
735 let (rules, _) = parse_proxy_config(config).unwrap();
736 assert_eq!(rules.len(), 1);
737 assert_eq!(rules[0].auth.len(), 1);
738 assert_eq!(rules[0].auth[0].username, "demo");
739 assert!(rules[0].auth[0].hash.starts_with("$2b$"));
740 assert_eq!(rules[0].targets.len(), 1);
741 assert_eq!(rules[0].targets[0].url.as_str(), "http://localhost:8080/");
742 }
743
744 #[test]
745 fn test_multiple_auth_users() {
746 let config = r#"
747secure.example.com -> http://localhost:9000/ @auth:admin:$2b$12$hash1 @auth:user:$2b$12$hash2
748"#;
749 let (rules, _) = parse_proxy_config(config).unwrap();
750 assert_eq!(rules.len(), 1);
751 assert_eq!(rules[0].auth.len(), 2);
752 assert_eq!(rules[0].auth[0].username, "admin");
753 assert_eq!(rules[0].auth[1].username, "user");
754 }
755
756 #[test]
757 fn test_load_balancing_round_robin() {
758 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:round-robin";
759 let (rules, _) = parse_proxy_config(config).unwrap();
760 assert_eq!(rules.len(), 1);
761 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
762 assert_eq!(rules[0].targets.len(), 2);
763 }
764
765 #[test]
766 fn test_load_balancing_weighted() {
767 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:weighted";
768 let (rules, _) = parse_proxy_config(config).unwrap();
769 assert_eq!(rules.len(), 1);
770 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Weighted);
771 }
772
773 #[test]
774 fn test_load_balancing_failover() {
775 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover";
776 let (rules, _) = parse_proxy_config(config).unwrap();
777 assert_eq!(rules.len(), 1);
778 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
779 }
780
781 #[test]
782 fn test_load_balancing_default_is_round_robin() {
783 let config = "/api/* -> http://b1:8080, http://b2:8080";
784 let (rules, _) = parse_proxy_config(config).unwrap();
785 assert_eq!(rules.len(), 1);
786 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
787 }
788
789 #[test]
790 fn test_load_balancing_with_scripts() {
791 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:failover @script:auth.lua";
792 let (rules, _) = parse_proxy_config(config).unwrap();
793 assert_eq!(rules.len(), 1);
794 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::Failover);
795 assert_eq!(rules[0].scripts, vec!["auth.lua"]);
796 }
797
798 #[test]
799 fn test_load_balancing_unknown_strategy_defaults_to_round_robin() {
800 let config = "/api/* -> http://b1:8080, http://b2:8080 @lb:unknown";
801 let (rules, _) = parse_proxy_config(config).unwrap();
802 assert_eq!(rules.len(), 1);
803 assert_eq!(rules[0].load_balancing, LoadBalancingStrategy::RoundRobin);
804 }
805}