1use std::env;
17use std::path::PathBuf;
18use std::time::Duration;
19
20pub const KB: usize = 1024;
22pub const MB: usize = 1024 * KB;
23
24const ENV_PREFIX: &str = "WARPDRIVE_";
26
27const DEFAULT_TARGET_HOST: &str = "127.0.0.1";
29
30const DEFAULT_TARGET_PORT: u16 = 3000;
32
33const DEFAULT_CACHE_SIZE_BYTES: usize = 64 * MB;
35
36const DEFAULT_MAX_CACHE_ITEM_SIZE_BYTES: usize = MB;
38
39const DEFAULT_ACME_DIRECTORY_URL: &str = "https://acme-v02.api.letsencrypt.org/directory";
41
42const DEFAULT_STORAGE_PATH: &str = "./storage/warpdrive";
44
45const DEFAULT_BAD_GATEWAY_PAGE: &str = "./public/502.html";
47
48const DEFAULT_HTTP_PORT: u16 = 8080;
50
51const DEFAULT_HTTPS_PORT: u16 = 8443;
53
54const DEFAULT_HTTP_IDLE_TIMEOUT: Duration = Duration::from_secs(60);
56
57const DEFAULT_HTTP_READ_TIMEOUT: Duration = Duration::from_secs(30);
59
60const DEFAULT_HTTP_WRITE_TIMEOUT: Duration = Duration::from_secs(30);
62
63const DEFAULT_LOG_LEVEL: LogLevel = LogLevel::Info;
65
66const DEFAULT_SHUTDOWN_TIMEOUT_SECS: u64 = 30;
68
69const DEFAULT_UPSTREAM_TIMEOUT: Duration = Duration::from_secs(30);
71
72const DEFAULT_RATE_LIMIT_RPS: u32 = 100;
74
75const DEFAULT_RATE_LIMIT_BURST: u32 = 200;
77
78const DEFAULT_CB_FAILURE_THRESHOLD: u32 = 5;
80
81const DEFAULT_CB_TIMEOUT_SECS: u64 = 60;
83
84const DEFAULT_METRICS_PORT: u16 = 9090;
86
87const DEFAULT_STATIC_PATHS: &[&str] = &["/assets", "/packs", "/images", "/favicon.ico"];
89
90pub const DEFAULT_STATIC_INLINE_SIZE_LIMIT: u64 = 4 * MB as u64;
92
93const DEFAULT_STATIC_ROOT: &str = "./public";
95
96const DEFAULT_STATIC_CACHE_CONTROL: &str = "public, max-age=31536000, immutable";
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum LogLevel {
102 Error,
103 Warn,
104 Info,
105 Debug,
106}
107
108impl LogLevel {
109 fn from_str(s: &str) -> Option<Self> {
111 match s.to_lowercase().as_str() {
112 "error" => Some(LogLevel::Error),
113 "warn" => Some(LogLevel::Warn),
114 "info" => Some(LogLevel::Info),
115 "debug" => Some(LogLevel::Debug),
116 _ => None,
117 }
118 }
119}
120
121#[derive(Debug, Clone)]
126pub struct Config {
127 pub target_host: String,
130
131 pub target_port: u16,
133
134 pub upstream_command: Option<String>,
136
137 pub upstream_args: Vec<String>,
139
140 pub cache_size_bytes: usize,
143
144 pub max_cache_item_size_bytes: usize,
146
147 pub x_sendfile_enabled: bool,
149
150 pub gzip_compression_enabled: bool,
152
153 pub max_request_body: usize,
155
156 pub tls_domains: Vec<String>,
159
160 pub tls_cert_path: Option<String>,
162
163 pub tls_key_path: Option<String>,
165
166 pub acme_directory_url: String,
168
169 pub eab_kid: Option<String>,
171
172 pub eab_hmac_key: Option<String>,
174
175 pub storage_path: String,
177
178 pub bad_gateway_page: String,
180
181 pub http_port: u16,
184
185 pub https_port: u16,
187
188 pub http_idle_timeout: Duration,
190
191 pub http_read_timeout: Duration,
193
194 pub http_write_timeout: Duration,
196
197 pub h2c_enabled: bool,
199
200 pub forward_headers: bool,
203
204 pub log_requests: bool,
206
207 pub log_level: LogLevel,
209
210 pub shutdown_timeout_secs: u64,
213
214 pub database_url: Option<String>,
217
218 pub pg_channel_cache_invalidation: String,
220
221 pub pg_channel_config_update: String,
223
224 pub pg_channel_health: String,
226
227 pub redis_url: Option<String>,
230
231 pub toml_config: Option<crate::config::toml::TomlConfig>,
234
235 pub rate_limit_enabled: bool,
238
239 pub rate_limit_requests_per_sec: u32,
241
242 pub rate_limit_burst_size: u32,
244
245 pub upstream_timeout: Duration,
247
248 pub circuit_breaker_enabled: bool,
250
251 pub circuit_breaker_failure_threshold: u32,
253
254 pub circuit_breaker_timeout_secs: u64,
256
257 pub max_concurrent_requests: usize,
259
260 pub metrics_enabled: bool,
263
264 pub metrics_port: u16,
266
267 pub static_enabled: bool,
270
271 pub static_root: PathBuf,
273
274 pub static_paths: Vec<String>,
276
277 pub static_cache_control: String,
279
280 pub static_gzip_enabled: bool,
282
283 pub static_index_files: Vec<String>,
285
286 pub static_fallthrough: bool,
288
289 pub static_inline_size_limit: u64,
291
292 pub client_ip_header: Option<String>,
295
296 pub trusted_ranges_file: Option<PathBuf>,
298}
299
300impl Config {
301 pub fn from_env() -> Result<Self, String> {
319 let _ = dotenvy::dotenv();
321
322 let config = Config {
323 target_host: get_env_string("TARGET_HOST")
324 .unwrap_or_else(|| DEFAULT_TARGET_HOST.to_string()),
325 target_port: get_env_u16("TARGET_PORT", DEFAULT_TARGET_PORT),
326 upstream_command: get_env_string("UPSTREAM_COMMAND"),
327 upstream_args: get_env_strings("UPSTREAM_ARGS", vec![]),
328
329 cache_size_bytes: get_env_usize("CACHE_SIZE", DEFAULT_CACHE_SIZE_BYTES),
330 max_cache_item_size_bytes: get_env_usize(
331 "MAX_CACHE_ITEM_SIZE",
332 DEFAULT_MAX_CACHE_ITEM_SIZE_BYTES,
333 ),
334 x_sendfile_enabled: get_env_bool("X_SENDFILE_ENABLED", true),
335 gzip_compression_enabled: get_env_bool("GZIP_COMPRESSION_ENABLED", true),
336 max_request_body: get_env_usize("MAX_REQUEST_BODY", 0),
337
338 tls_domains: get_env_strings("TLS_DOMAIN", vec![]),
339 tls_cert_path: get_env_string("TLS_CERT_PATH"),
340 tls_key_path: get_env_string("TLS_KEY_PATH"),
341 acme_directory_url: get_env_string("ACME_DIRECTORY")
342 .unwrap_or_else(|| DEFAULT_ACME_DIRECTORY_URL.to_string()),
343 eab_kid: get_env_string("EAB_KID"),
344 eab_hmac_key: get_env_string("EAB_HMAC_KEY"),
345 storage_path: get_env_string("STORAGE_PATH")
346 .unwrap_or_else(|| DEFAULT_STORAGE_PATH.to_string()),
347 bad_gateway_page: get_env_string("BAD_GATEWAY_PAGE")
348 .unwrap_or_else(|| DEFAULT_BAD_GATEWAY_PAGE.to_string()),
349
350 http_port: get_env_u16("HTTP_PORT", DEFAULT_HTTP_PORT),
351 https_port: get_env_u16("HTTPS_PORT", DEFAULT_HTTPS_PORT),
352 http_idle_timeout: get_env_duration("HTTP_IDLE_TIMEOUT", DEFAULT_HTTP_IDLE_TIMEOUT),
353 http_read_timeout: get_env_duration("HTTP_READ_TIMEOUT", DEFAULT_HTTP_READ_TIMEOUT),
354 http_write_timeout: get_env_duration("HTTP_WRITE_TIMEOUT", DEFAULT_HTTP_WRITE_TIMEOUT),
355 h2c_enabled: get_env_bool("H2C_ENABLED", false),
356
357 forward_headers: get_env_bool("FORWARD_HEADERS", false), log_requests: get_env_bool("LOG_REQUESTS", true),
359 log_level: get_env_log_level("LOG_LEVEL", DEFAULT_LOG_LEVEL),
360
361 shutdown_timeout_secs: get_env_u64(
362 "SHUTDOWN_TIMEOUT_SECS",
363 DEFAULT_SHUTDOWN_TIMEOUT_SECS,
364 ),
365
366 database_url: get_env_string("DATABASE_URL"),
367 pg_channel_cache_invalidation: get_env_string("PG_CHANNEL_CACHE_INVALIDATION")
368 .unwrap_or_else(|| "warpdrive:cache:invalidate".to_string()),
369 pg_channel_config_update: get_env_string("PG_CHANNEL_CONFIG_UPDATE")
370 .unwrap_or_else(|| "warpdrive:config:update".to_string()),
371 pg_channel_health: get_env_string("PG_CHANNEL_HEALTH")
372 .unwrap_or_else(|| "warpdrive:health".to_string()),
373
374 redis_url: get_env_string("REDIS_URL"),
375
376 toml_config: get_env_string("CONFIG")
378 .map(|path| {
379 crate::config::toml::TomlConfig::from_file(&path)
380 .map_err(|e| format!("Failed to load TOML config from {}: {}", path, e))
381 })
382 .transpose()?,
383
384 rate_limit_enabled: get_env_bool("RATE_LIMIT_ENABLED", false),
386 rate_limit_requests_per_sec: get_env_u32("RATE_LIMIT_RPS", DEFAULT_RATE_LIMIT_RPS),
387 rate_limit_burst_size: get_env_u32("RATE_LIMIT_BURST", DEFAULT_RATE_LIMIT_BURST),
388 upstream_timeout: get_env_duration("UPSTREAM_TIMEOUT", DEFAULT_UPSTREAM_TIMEOUT),
389 circuit_breaker_enabled: get_env_bool("CIRCUIT_BREAKER_ENABLED", false),
390 circuit_breaker_failure_threshold: get_env_u32(
391 "CIRCUIT_BREAKER_FAILURE_THRESHOLD",
392 DEFAULT_CB_FAILURE_THRESHOLD,
393 ),
394 circuit_breaker_timeout_secs: get_env_u64(
395 "CIRCUIT_BREAKER_TIMEOUT_SECS",
396 DEFAULT_CB_TIMEOUT_SECS,
397 ),
398 max_concurrent_requests: get_env_usize("MAX_CONCURRENT_REQUESTS", 0),
399
400 metrics_enabled: get_env_bool("METRICS_ENABLED", false),
402 metrics_port: get_env_u16("METRICS_PORT", DEFAULT_METRICS_PORT),
403
404 static_enabled: get_env_bool("STATIC_ENABLED", true),
406 static_root: PathBuf::from(
407 get_env_string("STATIC_ROOT").unwrap_or_else(|| DEFAULT_STATIC_ROOT.to_string()),
408 ),
409 static_paths: get_env_strings(
410 "STATIC_PATHS",
411 DEFAULT_STATIC_PATHS.iter().map(|s| s.to_string()).collect(),
412 ),
413 static_cache_control: get_env_string("STATIC_CACHE_CONTROL")
414 .unwrap_or_else(|| DEFAULT_STATIC_CACHE_CONTROL.to_string()),
415 static_gzip_enabled: get_env_bool("STATIC_GZIP", true),
416 static_index_files: get_env_strings(
417 "STATIC_INDEX_FILES",
418 vec!["index.html".to_string()],
419 ),
420 static_fallthrough: get_env_bool("STATIC_FALLTHROUGH", true),
421 static_inline_size_limit: get_env_u64(
422 "STATIC_INLINE_SIZE_LIMIT",
423 DEFAULT_STATIC_INLINE_SIZE_LIMIT,
424 ),
425
426 client_ip_header: get_env_string("CLIENT_IP_HEADER"),
428 trusted_ranges_file: get_env_string("TRUSTED_RANGES_FILE").map(PathBuf::from),
429 };
430
431 let config = if find_env("FORWARD_HEADERS").is_none() {
433 Config {
434 forward_headers: !config.has_tls(),
435 ..config
436 }
437 } else {
438 config
439 };
440
441 Ok(config)
442 }
443
444 pub fn has_tls(&self) -> bool {
450 (self.tls_cert_path.is_some() && self.tls_key_path.is_some())
451 || !self.tls_domains.is_empty()
452 }
453
454 pub fn has_manual_tls(&self) -> bool {
456 self.tls_cert_path.is_some() && self.tls_key_path.is_some()
457 }
458
459 pub fn has_acme_domains(&self) -> bool {
463 !self.tls_domains.is_empty() && !self.has_manual_tls()
464 }
465
466 pub fn validate(&self) -> Result<(), String> {
480 if self.tls_cert_path.is_some() != self.tls_key_path.is_some() {
482 return Err(
483 "Both TLS_CERT_PATH and TLS_KEY_PATH must be specified together".to_string(),
484 );
485 }
486
487 if !self.tls_domains.is_empty() && !self.has_manual_tls() {
489 if self.acme_directory_url.is_empty() {
490 return Err(
491 "ACME directory URL required when TLS domains are specified".to_string()
492 );
493 }
494
495 if !self.acme_directory_url.starts_with("http://")
497 && !self.acme_directory_url.starts_with("https://")
498 {
499 return Err("ACME directory URL must be a valid HTTP(S) URL".to_string());
500 }
501 }
502
503 if self.http_port == 0 {
505 return Err("HTTP port cannot be 0".to_string());
506 }
507 if self.https_port == 0 {
508 return Err("HTTPS port cannot be 0".to_string());
509 }
510 if self.target_port == 0 {
511 return Err("Target port cannot be 0".to_string());
512 }
513
514 if self.target_host.trim().is_empty() {
515 return Err("Target host cannot be empty".to_string());
516 }
517
518 if self.max_cache_item_size_bytes > self.cache_size_bytes {
520 return Err("Maximum cache item size cannot exceed total cache size".to_string());
521 }
522
523 if let Some(ref db_url) = self.database_url {
525 if !db_url.starts_with("postgres://") && !db_url.starts_with("postgresql://") {
526 return Err("Database URL must be a valid PostgreSQL connection string".to_string());
527 }
528 }
529
530 if let Some(ref redis_url) = self.redis_url {
532 if !redis_url.starts_with("redis://") && !redis_url.starts_with("rediss://") {
533 return Err("Redis URL must be a valid Redis connection string".to_string());
534 }
535 }
536
537 Ok(())
538 }
539}
540
541impl Default for Config {
542 fn default() -> Self {
543 Config {
544 target_host: DEFAULT_TARGET_HOST.to_string(),
545 target_port: DEFAULT_TARGET_PORT,
546 upstream_command: None,
547 upstream_args: vec![],
548
549 cache_size_bytes: DEFAULT_CACHE_SIZE_BYTES,
550 max_cache_item_size_bytes: DEFAULT_MAX_CACHE_ITEM_SIZE_BYTES,
551 x_sendfile_enabled: true,
552 gzip_compression_enabled: true,
553 max_request_body: 0,
554
555 tls_domains: vec![],
556 tls_cert_path: None,
557 tls_key_path: None,
558 acme_directory_url: DEFAULT_ACME_DIRECTORY_URL.to_string(),
559 eab_kid: None,
560 eab_hmac_key: None,
561 storage_path: DEFAULT_STORAGE_PATH.to_string(),
562 bad_gateway_page: DEFAULT_BAD_GATEWAY_PAGE.to_string(),
563
564 http_port: DEFAULT_HTTP_PORT,
565 https_port: DEFAULT_HTTPS_PORT,
566 http_idle_timeout: DEFAULT_HTTP_IDLE_TIMEOUT,
567 http_read_timeout: DEFAULT_HTTP_READ_TIMEOUT,
568 http_write_timeout: DEFAULT_HTTP_WRITE_TIMEOUT,
569 h2c_enabled: false,
570
571 forward_headers: true,
572 log_requests: true,
573 log_level: DEFAULT_LOG_LEVEL,
574
575 shutdown_timeout_secs: DEFAULT_SHUTDOWN_TIMEOUT_SECS,
576
577 database_url: None,
578 pg_channel_cache_invalidation: "warpdrive:cache:invalidate".to_string(),
579 pg_channel_config_update: "warpdrive:config:update".to_string(),
580 pg_channel_health: "warpdrive:health".to_string(),
581
582 redis_url: None,
583
584 toml_config: None,
585
586 rate_limit_enabled: false,
588 rate_limit_requests_per_sec: DEFAULT_RATE_LIMIT_RPS,
589 rate_limit_burst_size: DEFAULT_RATE_LIMIT_BURST,
590 upstream_timeout: DEFAULT_UPSTREAM_TIMEOUT,
591 circuit_breaker_enabled: false,
592 circuit_breaker_failure_threshold: DEFAULT_CB_FAILURE_THRESHOLD,
593 circuit_breaker_timeout_secs: DEFAULT_CB_TIMEOUT_SECS,
594 max_concurrent_requests: 0,
595
596 metrics_enabled: false,
598 metrics_port: DEFAULT_METRICS_PORT,
599
600 static_enabled: true,
602 static_root: PathBuf::from(DEFAULT_STATIC_ROOT),
603 static_paths: DEFAULT_STATIC_PATHS.iter().map(|s| s.to_string()).collect(),
604 static_cache_control: DEFAULT_STATIC_CACHE_CONTROL.to_string(),
605 static_gzip_enabled: true,
606 static_index_files: vec!["index.html".to_string()],
607 static_fallthrough: true,
608 static_inline_size_limit: DEFAULT_STATIC_INLINE_SIZE_LIMIT,
609
610 client_ip_header: None,
612 trusted_ranges_file: None,
613 }
614 }
615}
616
617fn find_env(key: &str) -> Option<String> {
624 if let Ok(value) = env::var(format!("{}{}", ENV_PREFIX, key)) {
626 return Some(value);
627 }
628
629 env::var(key).ok()
631}
632
633fn get_env_string(key: &str) -> Option<String> {
635 find_env(key)
636}
637
638fn get_env_strings(key: &str, default: Vec<String>) -> Vec<String> {
640 match find_env(key) {
641 Some(value) => value
642 .split(',')
643 .map(|s| s.trim().to_string())
644 .filter(|s| !s.is_empty())
645 .collect(),
646 None => default,
647 }
648}
649
650fn get_env_u16(key: &str, default: u16) -> u16 {
652 find_env(key)
653 .and_then(|v| v.parse().ok())
654 .unwrap_or(default)
655}
656
657fn get_env_u32(key: &str, default: u32) -> u32 {
659 find_env(key)
660 .and_then(|v| v.parse().ok())
661 .unwrap_or(default)
662}
663
664fn get_env_usize(key: &str, default: usize) -> usize {
666 find_env(key)
667 .and_then(|v| v.parse().ok())
668 .unwrap_or(default)
669}
670
671fn get_env_u64(key: &str, default: u64) -> u64 {
673 find_env(key)
674 .and_then(|v| v.parse().ok())
675 .unwrap_or(default)
676}
677
678fn get_env_bool(key: &str, default: bool) -> bool {
680 find_env(key)
681 .and_then(|v| match v.to_lowercase().as_str() {
682 "true" | "1" | "yes" | "on" => Some(true),
683 "false" | "0" | "no" | "off" => Some(false),
684 _ => None,
685 })
686 .unwrap_or(default)
687}
688
689fn get_env_duration(key: &str, default: Duration) -> Duration {
691 find_env(key)
692 .and_then(|v| v.parse::<u64>().ok())
693 .map(Duration::from_secs)
694 .unwrap_or(default)
695}
696
697fn get_env_log_level(key: &str, default: LogLevel) -> LogLevel {
699 find_env(key)
700 .and_then(|v| LogLevel::from_str(&v))
701 .unwrap_or(default)
702}
703
704#[cfg(test)]
705mod tests {
706 use super::*;
707
708 #[test]
709 fn test_default_config() {
710 let config = Config::default();
711 assert_eq!(config.target_port, 3000);
712 assert_eq!(config.cache_size_bytes, 64 * MB);
713 assert_eq!(config.max_cache_item_size_bytes, MB);
714 assert!(config.x_sendfile_enabled);
715 assert!(config.gzip_compression_enabled);
716 assert_eq!(config.max_request_body, 0);
717 assert!(config.tls_domains.is_empty());
718 assert_eq!(config.http_port, 8080);
719 assert_eq!(config.https_port, 8443);
720 assert!(!config.h2c_enabled); assert!(config.log_requests);
722 }
723
724 #[test]
725 fn test_has_tls() {
726 let mut config = Config::default();
727 assert!(!config.has_tls());
728
729 config.tls_domains = vec!["example.com".to_string()];
730 assert!(config.has_tls());
731 }
732
733 #[test]
734 fn test_validate_valid_config() {
735 let config = Config::default();
736 assert!(config.validate().is_ok());
737 }
738
739 #[test]
740 fn test_validate_zero_ports() {
741 let config = Config {
742 http_port: 0,
743 ..Default::default()
744 };
745 assert!(config.validate().is_err());
746 }
747
748 #[test]
749 fn test_validate_cache_size() {
750 let config = Config {
751 max_cache_item_size_bytes: DEFAULT_CACHE_SIZE_BYTES + 1,
752 ..Default::default()
753 };
754 assert!(config.validate().is_err());
755 }
756
757 #[test]
758 fn test_validate_tls_without_acme() {
759 let config = Config {
760 tls_domains: vec!["example.com".to_string()],
761 acme_directory_url: String::new(),
762 ..Default::default()
763 };
764 assert!(config.validate().is_err());
765 }
766
767 #[test]
768 fn test_validate_invalid_acme_url() {
769 let config = Config {
770 tls_domains: vec!["example.com".to_string()],
771 acme_directory_url: "invalid-url".to_string(),
772 ..Default::default()
773 };
774 assert!(config.validate().is_err());
775 }
776
777 #[test]
778 fn test_validate_invalid_database_url() {
779 let config = Config {
780 database_url: Some("invalid-db-url".to_string()),
781 ..Default::default()
782 };
783 assert!(config.validate().is_err());
784 }
785
786 #[test]
787 fn test_validate_invalid_redis_url() {
788 let config = Config {
789 redis_url: Some("invalid-redis-url".to_string()),
790 ..Default::default()
791 };
792 assert!(config.validate().is_err());
793 }
794
795 #[test]
796 fn test_log_level_from_str() {
797 assert_eq!(LogLevel::from_str("error"), Some(LogLevel::Error));
798 assert_eq!(LogLevel::from_str("ERROR"), Some(LogLevel::Error));
799 assert_eq!(LogLevel::from_str("warn"), Some(LogLevel::Warn));
800 assert_eq!(LogLevel::from_str("info"), Some(LogLevel::Info));
801 assert_eq!(LogLevel::from_str("debug"), Some(LogLevel::Debug));
802 assert_eq!(LogLevel::from_str("invalid"), None);
803 }
804}