totp_gateway/
config.rs

1use serde::{Deserialize, Serialize};
2use std::env;
3use std::fs;
4use std::path::Path;
5use totp_rs::Secret;
6
7#[derive(Debug, Deserialize, Serialize, Clone, Default)]
8pub struct Config {
9    pub server: ServerConfig,
10    #[serde(default)]
11    pub tls: Option<TlsConfig>,
12    pub auth: AuthConfig,
13    #[serde(default)]
14    pub security: SecurityConfig,
15    #[serde(default)]
16    pub routes: Vec<RouteConfig>,
17}
18
19#[derive(Debug, Deserialize, Serialize, Clone)]
20pub struct TlsConfig {
21    pub cert_file: String,
22    pub key_file: String,
23}
24
25#[derive(Debug, Deserialize, Serialize, Clone)]
26pub struct ServerConfig {
27    pub bind_addr: String,
28    pub default_upstream: String,
29    #[serde(default)]
30    pub trusted_proxies: Vec<(String, String)>,
31}
32
33impl Default for ServerConfig {
34    fn default() -> Self {
35        Self {
36            bind_addr: "0.0.0.0:25000".to_string(),
37            default_upstream: "127.0.0.1:25001".to_string(),
38            trusted_proxies: vec![],
39        }
40    }
41}
42
43#[derive(Debug, Deserialize, Serialize, Clone)]
44pub struct AuthConfig {
45    pub totp_secret: Option<String>,
46    pub totp_secret_file: Option<String>,
47    pub totp_secret_env: Option<String>,
48    pub login_page_file: Option<String>,
49    #[serde(default = "default_session_duration")]
50    pub session_duration: u64,
51}
52
53impl Default for AuthConfig {
54    fn default() -> Self {
55        let secret = Secret::generate_secret();
56        let encoded = secret.to_encoded().to_string();
57        Self {
58            totp_secret: Some(encoded),
59            totp_secret_file: None,
60            totp_secret_env: None,
61            login_page_file: None,
62            session_duration: default_session_duration(),
63        }
64    }
65}
66
67#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
68#[serde(rename_all = "lowercase")]
69pub enum BlacklistStrategy {
70    Overwrite,
71    Block,
72}
73
74#[derive(Debug, Deserialize, Serialize, Clone)]
75pub struct SecurityConfig {
76    #[serde(default = "default_security_enabled")]
77    pub enabled: bool,
78    #[serde(default = "default_blacklist_size")]
79    pub blacklist_size: usize,
80    #[serde(default = "default_blacklist_strategy")]
81    pub blacklist_strategy: BlacklistStrategy,
82    #[serde(default = "default_max_retries")]
83    pub max_retries: u32,
84    #[serde(default = "default_ip_limit_duration")]
85    pub ip_limit_duration: u64,
86    #[serde(default = "default_ban_duration")]
87    pub ban_duration: u64,
88    #[serde(default = "default_whitelist_duration")]
89    pub whitelist_duration: u64,
90}
91
92fn default_session_duration() -> u64 {
93    1800
94}
95
96fn default_security_enabled() -> bool {
97    true
98}
99
100fn default_blacklist_size() -> usize {
101    1000
102}
103
104fn default_blacklist_strategy() -> BlacklistStrategy {
105    BlacklistStrategy::Overwrite
106}
107
108fn default_max_retries() -> u32 {
109    5
110}
111
112fn default_ip_limit_duration() -> u64 {
113    3600
114}
115
116fn default_ban_duration() -> u64 {
117    3600
118}
119
120fn default_whitelist_duration() -> u64 {
121    604800
122}
123
124fn default_route_protect() -> bool {
125    true
126}
127
128impl Default for SecurityConfig {
129    fn default() -> Self {
130        Self {
131            enabled: default_security_enabled(),
132            blacklist_size: default_blacklist_size(),
133            blacklist_strategy: default_blacklist_strategy(),
134            max_retries: default_max_retries(),
135            ip_limit_duration: default_ip_limit_duration(),
136            ban_duration: default_ban_duration(),
137            whitelist_duration: default_whitelist_duration(),
138        }
139    }
140}
141
142impl AuthConfig {
143    pub fn get_secret(&self) -> Result<String, String> {
144        let secret = if let Some(s) = &self.totp_secret {
145            s.clone()
146        } else if let Some(path) = &self.totp_secret_file {
147            fs::read_to_string(path)
148                .map_err(|e| format!("Failed to read secret file {}: {}", path, e))?
149                .trim()
150                .to_string()
151        } else if let Some(env_var) = &self.totp_secret_env {
152            env::var(env_var).map_err(|_| format!("Environment variable {} not found", env_var))?
153        } else {
154            return Err(
155                "No TOTP secret configured. Provide totp_secret, totp_secret_file, or totp_secret_env"
156                    .to_string(),
157            );
158        };
159
160        if secret.is_empty() {
161            return Err("TOTP secret is empty".to_string());
162        }
163
164        Ok(secret.chars().filter(|c| !c.is_whitespace()).collect())
165    }
166}
167
168#[derive(Debug, Deserialize, Serialize, Clone)]
169pub struct RouteConfig {
170    #[serde(default)]
171    pub host: Option<String>,
172    #[serde(default)]
173    pub path: Option<String>,
174    #[serde(default)]
175    pub path_prefix: Option<String>,
176    pub upstream_addr: String,
177    #[serde(default = "default_route_protect")]
178    pub protect: bool,
179}
180
181pub fn load_config<P: AsRef<Path>>(path: P) -> Result<Config, Box<dyn std::error::Error>> {
182    if !path.as_ref().exists() {
183        let example = include_str!("../example_config.toml");
184        fs::write(path.as_ref(), example)?;
185    }
186
187    let content = fs::read_to_string(path)?;
188    let config: Config = toml::from_str(&content)?;
189    Ok(config)
190}