1use crate::config::load_config;
2use crate::proxy::AuthGateway;
3use crate::state::{
4 CompiledRoute, FILE_WATCH_DEBOUNCE_MS, MAX_IP_ENTRIES, MAX_SESSION_ENTRIES, ProxyState,
5 RuntimeState,
6};
7use crate::utils::glob_to_regex;
8use arc_swap::ArcSwap;
9use ipnet::IpNet;
10use log::{error, info, warn};
11use moka::sync::Cache;
12use notify::{RecommendedWatcher, RecursiveMode, Watcher};
13use pingora::prelude::*;
14use std::fs;
15use std::path::{Path, PathBuf};
16use std::sync::Arc;
17use std::sync::atomic::AtomicU64;
18use std::sync::mpsc::channel;
19use std::time::Duration;
20
21pub struct App {
22 config_path: PathBuf,
23}
24
25impl App {
26 pub fn new<P: AsRef<Path>>(config_path: P) -> Self {
27 Self {
28 config_path: config_path.as_ref().to_path_buf(),
29 }
30 }
31
32 fn load_runtime_state(path: &Path) -> Result<RuntimeState, String> {
33 let config = load_config(path).map_err(|e| e.to_string())?;
34 let secret = config.auth.get_secret().map_err(|e| e.to_string())?;
35
36 let trusted_cidrs: Vec<(IpNet, String)> = config
37 .server
38 .trusted_proxies
39 .iter()
40 .filter_map(|(s, h)| {
41 s.parse::<IpNet>()
42 .map(|cidr| (cidr, h.clone()))
43 .map_err(|e| {
44 warn!("Failed to parse trusted proxy CIDR '{}': {}", s, e);
45 e
46 })
47 .ok()
48 })
49 .collect();
50
51 let routes = config
52 .routes
53 .iter()
54 .map(|r| {
55 let host = r.host.as_ref().and_then(|h| {
56 glob_to_regex(h)
57 .map_err(|e| {
58 warn!("Failed to compile host pattern '{}': {}", h, e);
59 e
60 })
61 .ok()
62 });
63
64 let path = r.path.as_ref().and_then(|p| {
65 glob_to_regex(p)
66 .map_err(|e| {
67 warn!("Failed to compile path pattern '{}': {}", p, e);
68 e
69 })
70 .ok()
71 });
72
73 CompiledRoute {
74 host,
75 path,
76 path_prefix: r.path_prefix.clone(),
77 upstream_addr: r.upstream_addr.clone(),
78 protect: r.protect,
79 }
80 })
81 .collect();
82
83 let login_page_html = match &config.auth.login_page_file {
84 Some(path) => fs::read_to_string(path)
85 .map_err(|e| format!("Failed to read login page file {}: {}", path, e))?,
86 None => include_str!("../login_page.html").to_string(),
87 };
88
89 let login_page_len = login_page_html.len().to_string();
90
91 Ok(RuntimeState {
92 config,
93 secret,
94 trusted_cidrs,
95 routes,
96 login_page_html: Arc::new(login_page_html),
97 login_page_len: Arc::new(login_page_len),
98 })
99 }
100
101 fn handle_config_reload(config_path: &Path, state: &Arc<ProxyState>) {
102 match Self::load_runtime_state(config_path) {
103 Ok(new_runtime) => {
104 let new_sec = &new_runtime.config.security;
105 let old_sec = &state.runtime.load().config.security;
106
107 if new_sec.blacklist_size != old_sec.blacklist_size
108 || new_sec.ban_duration != old_sec.ban_duration
109 {
110 info!(
111 "Blacklist config changed (Size: {}, Duration: {}s). Re-creating cache.",
112 new_sec.blacklist_size, new_sec.ban_duration
113 );
114 let new_blacklist = Cache::builder()
115 .time_to_live(Duration::from_secs(new_sec.ban_duration))
116 .max_capacity(new_sec.blacklist_size as u64)
117 .build();
118 state.blacklist.store(Arc::new(new_blacklist));
119 }
120
121 state.runtime.store(Arc::new(new_runtime));
122 info!("Configuration reloaded successfully.");
123 }
124 Err(e) => {
125 error!("Failed to reload configuration: {}", e);
126 }
127 }
128 }
129
130 pub fn run(self) {
131 let initial_runtime = match Self::load_runtime_state(&self.config_path) {
132 Ok(runtime) => runtime,
133 Err(e) => {
134 error!("Failed to load initial configuration: {}", e);
135 std::process::exit(1);
136 }
137 };
138
139 let bind_addr = initial_runtime.config.server.bind_addr.clone();
140 let tls_config = initial_runtime.config.tls.clone();
141
142 let security_config = &initial_runtime.config.security;
143 let auth_config = &initial_runtime.config.auth;
144
145 let blacklist_size = security_config.blacklist_size as u64;
146 let ban_duration = Duration::from_secs(security_config.ban_duration);
147 let whitelist_duration = Duration::from_secs(security_config.whitelist_duration);
148 let ip_limit_duration = Duration::from_secs(security_config.ip_limit_duration);
149 let session_duration = Duration::from_secs(auth_config.session_duration);
150
151 let initial_blacklist = Arc::new(
152 Cache::builder()
153 .time_to_live(ban_duration)
154 .max_capacity(blacklist_size)
155 .build(),
156 );
157
158 let state = Arc::new(ProxyState {
159 runtime: ArcSwap::new(Arc::new(initial_runtime)),
160 sessions: Cache::builder()
161 .time_to_live(session_duration)
162 .max_capacity(MAX_SESSION_ENTRIES)
163 .build(),
164 whitelist: Cache::builder()
165 .time_to_live(whitelist_duration)
166 .max_capacity(MAX_IP_ENTRIES)
167 .build(),
168 blacklist: ArcSwap::new(initial_blacklist),
169 ip_limits: Cache::builder()
170 .time_to_live(ip_limit_duration)
171 .max_capacity(MAX_IP_ENTRIES)
172 .build(),
173 last_verified_step: AtomicU64::new(0),
174 });
175
176 let state_for_watcher = state.clone();
177 let config_path = self.config_path.clone();
178
179 std::thread::spawn(move || {
180 let (tx, rx) = channel();
181 let mut watcher = match RecommendedWatcher::new(tx, notify::Config::default()) {
182 Ok(w) => w,
183 Err(e) => {
184 error!("Failed to create file watcher: {}", e);
185 return;
186 }
187 };
188
189 if let Err(e) = watcher.watch(&config_path, RecursiveMode::NonRecursive) {
190 error!("Failed to watch config file: {}", e);
191 return;
192 }
193
194 info!("Watching config file: {:?}", config_path);
195
196 for res in rx {
197 match res {
198 Ok(event) => {
199 if event.kind.is_modify() || event.kind.is_create() {
200 info!("Config file changed. Reloading...");
201 std::thread::sleep(Duration::from_millis(FILE_WATCH_DEBOUNCE_MS));
202 Self::handle_config_reload(&config_path, &state_for_watcher);
203 }
204 }
205 Err(e) => error!("Watch error: {}", e),
206 }
207 }
208 });
209
210 let mut server = match Server::new(None) {
211 Ok(s) => s,
212 Err(e) => {
213 error!("Failed to create server: {}", e);
214 std::process::exit(1);
215 }
216 };
217
218 server.bootstrap();
219
220 let mut my_gateway = http_proxy_service(&server.configuration, AuthGateway { state });
221
222 if let Some(tls) = tls_config {
223 if let Err(e) = my_gateway.add_tls(&bind_addr, &tls.cert_file, &tls.key_file) {
224 error!("Failed to add TLS: {}", e);
225 std::process::exit(1);
226 }
227 info!("Gateway Server running on {} (HTTPS)", bind_addr);
228 } else {
229 my_gateway.add_tcp(&bind_addr);
230 info!("Gateway Server running on {} (HTTP)", bind_addr);
231 }
232
233 server.add_service(my_gateway);
234 server.run_forever();
235 }
236}