totp_gateway/
app.rs

1use crate::config::load_config;
2use crate::proxy::AuthGateway;
3use crate::state::{
4    CompiledRoute, FILE_WATCH_DEBOUNCE_MS, MAX_IP_ENTRIES, MAX_SESSION_ENTRIES, ProxyState,
5    RuntimeState,
6};
7use crate::utils::glob_to_regex;
8use arc_swap::ArcSwap;
9use ipnet::IpNet;
10use log::{error, info, warn};
11use moka::sync::Cache;
12use notify::{RecommendedWatcher, RecursiveMode, Watcher};
13use pingora::prelude::*;
14use std::fs;
15use std::path::{Path, PathBuf};
16use std::sync::Arc;
17use std::sync::atomic::AtomicU64;
18use std::sync::mpsc::channel;
19use std::time::Duration;
20
21pub struct App {
22    config_path: PathBuf,
23}
24
25impl App {
26    pub fn new<P: AsRef<Path>>(config_path: P) -> Self {
27        Self {
28            config_path: config_path.as_ref().to_path_buf(),
29        }
30    }
31
32    fn load_runtime_state(path: &Path) -> Result<RuntimeState, String> {
33        let config = load_config(path).map_err(|e| e.to_string())?;
34        let secret = config.auth.get_secret().map_err(|e| e.to_string())?;
35
36        let trusted_cidrs: Vec<(IpNet, String)> = config
37            .server
38            .trusted_proxies
39            .iter()
40            .filter_map(|(s, h)| {
41                s.parse::<IpNet>()
42                    .map(|cidr| (cidr, h.clone()))
43                    .map_err(|e| {
44                        warn!("Failed to parse trusted proxy CIDR '{}': {}", s, e);
45                        e
46                    })
47                    .ok()
48            })
49            .collect();
50
51        let routes = config
52            .routes
53            .iter()
54            .map(|r| {
55                let host = r.host.as_ref().and_then(|h| {
56                    glob_to_regex(h)
57                        .map_err(|e| {
58                            warn!("Failed to compile host pattern '{}': {}", h, e);
59                            e
60                        })
61                        .ok()
62                });
63
64                let path = r.path.as_ref().and_then(|p| {
65                    glob_to_regex(p)
66                        .map_err(|e| {
67                            warn!("Failed to compile path pattern '{}': {}", p, e);
68                            e
69                        })
70                        .ok()
71                });
72
73                CompiledRoute {
74                    host,
75                    path,
76                    path_prefix: r.path_prefix.clone(),
77                    upstream_addr: r.upstream_addr.clone(),
78                    protect: r.protect,
79                }
80            })
81            .collect();
82
83        let login_page_html = match &config.auth.login_page_file {
84            Some(path) => fs::read_to_string(path)
85                .map_err(|e| format!("Failed to read login page file {}: {}", path, e))?,
86            None => include_str!("../login_page.html").to_string(),
87        };
88
89        let login_page_len = login_page_html.len().to_string();
90
91        Ok(RuntimeState {
92            config,
93            secret,
94            trusted_cidrs,
95            routes,
96            login_page_html: Arc::new(login_page_html),
97            login_page_len: Arc::new(login_page_len),
98        })
99    }
100
101    fn handle_config_reload(config_path: &Path, state: &Arc<ProxyState>) {
102        match Self::load_runtime_state(config_path) {
103            Ok(new_runtime) => {
104                let new_sec = &new_runtime.config.security;
105                let old_sec = &state.runtime.load().config.security;
106
107                if new_sec.blacklist_size != old_sec.blacklist_size
108                    || new_sec.ban_duration != old_sec.ban_duration
109                {
110                    info!(
111                        "Blacklist config changed (Size: {}, Duration: {}s). Re-creating cache.",
112                        new_sec.blacklist_size, new_sec.ban_duration
113                    );
114                    let new_blacklist = Cache::builder()
115                        .time_to_live(Duration::from_secs(new_sec.ban_duration))
116                        .max_capacity(new_sec.blacklist_size as u64)
117                        .build();
118                    state.blacklist.store(Arc::new(new_blacklist));
119                }
120
121                state.runtime.store(Arc::new(new_runtime));
122                info!("Configuration reloaded successfully.");
123            }
124            Err(e) => {
125                error!("Failed to reload configuration: {}", e);
126            }
127        }
128    }
129
130    pub fn run(self) {
131        let initial_runtime = match Self::load_runtime_state(&self.config_path) {
132            Ok(runtime) => runtime,
133            Err(e) => {
134                error!("Failed to load initial configuration: {}", e);
135                std::process::exit(1);
136            }
137        };
138
139        let bind_addr = initial_runtime.config.server.bind_addr.clone();
140        let tls_config = initial_runtime.config.tls.clone();
141
142        let security_config = &initial_runtime.config.security;
143        let auth_config = &initial_runtime.config.auth;
144
145        let blacklist_size = security_config.blacklist_size as u64;
146        let ban_duration = Duration::from_secs(security_config.ban_duration);
147        let whitelist_duration = Duration::from_secs(security_config.whitelist_duration);
148        let ip_limit_duration = Duration::from_secs(security_config.ip_limit_duration);
149        let session_duration = Duration::from_secs(auth_config.session_duration);
150
151        let initial_blacklist = Arc::new(
152            Cache::builder()
153                .time_to_live(ban_duration)
154                .max_capacity(blacklist_size)
155                .build(),
156        );
157
158        let state = Arc::new(ProxyState {
159            runtime: ArcSwap::new(Arc::new(initial_runtime)),
160            sessions: Cache::builder()
161                .time_to_live(session_duration)
162                .max_capacity(MAX_SESSION_ENTRIES)
163                .build(),
164            whitelist: Cache::builder()
165                .time_to_live(whitelist_duration)
166                .max_capacity(MAX_IP_ENTRIES)
167                .build(),
168            blacklist: ArcSwap::new(initial_blacklist),
169            ip_limits: Cache::builder()
170                .time_to_live(ip_limit_duration)
171                .max_capacity(MAX_IP_ENTRIES)
172                .build(),
173            last_verified_step: AtomicU64::new(0),
174        });
175
176        let state_for_watcher = state.clone();
177        let config_path = self.config_path.clone();
178
179        std::thread::spawn(move || {
180            let (tx, rx) = channel();
181            let mut watcher = match RecommendedWatcher::new(tx, notify::Config::default()) {
182                Ok(w) => w,
183                Err(e) => {
184                    error!("Failed to create file watcher: {}", e);
185                    return;
186                }
187            };
188
189            if let Err(e) = watcher.watch(&config_path, RecursiveMode::NonRecursive) {
190                error!("Failed to watch config file: {}", e);
191                return;
192            }
193
194            info!("Watching config file: {:?}", config_path);
195
196            for res in rx {
197                match res {
198                    Ok(event) => {
199                        if event.kind.is_modify() || event.kind.is_create() {
200                            info!("Config file changed. Reloading...");
201                            std::thread::sleep(Duration::from_millis(FILE_WATCH_DEBOUNCE_MS));
202                            Self::handle_config_reload(&config_path, &state_for_watcher);
203                        }
204                    }
205                    Err(e) => error!("Watch error: {}", e),
206                }
207            }
208        });
209
210        let mut server = match Server::new(None) {
211            Ok(s) => s,
212            Err(e) => {
213                error!("Failed to create server: {}", e);
214                std::process::exit(1);
215            }
216        };
217
218        server.bootstrap();
219
220        let mut my_gateway = http_proxy_service(&server.configuration, AuthGateway { state });
221
222        if let Some(tls) = tls_config {
223            if let Err(e) = my_gateway.add_tls(&bind_addr, &tls.cert_file, &tls.key_file) {
224                error!("Failed to add TLS: {}", e);
225                std::process::exit(1);
226            }
227            info!("Gateway Server running on {} (HTTPS)", bind_addr);
228        } else {
229            my_gateway.add_tcp(&bind_addr);
230            info!("Gateway Server running on {} (HTTP)", bind_addr);
231        }
232
233        server.add_service(my_gateway);
234        server.run_forever();
235    }
236}