1use crate::config::BlacklistStrategy;
2use crate::state::{
3 CompiledRoute, DEFAULT_HTTP_PORT, MAX_BODY_SIZE, MAX_IP_ENTRIES, MAX_SESSION_ENTRIES,
4 ProxyState, TOTP_DIGITS, TOTP_SKEW, TOTP_STEP_SECS,
5};
6use crate::utils::{ProxyError, SessionId, UpstreamAddr};
7use async_trait::async_trait;
8use bytes::Bytes;
9use log::{info, warn};
10use pingora::http::ResponseHeader;
11use pingora::prelude::*;
12use std::net::IpAddr;
13use std::ops::Deref;
14use std::str::FromStr;
15use std::sync::Arc;
16use std::sync::atomic::Ordering;
17use std::time::{SystemTime, UNIX_EPOCH};
18use totp_rs::{Algorithm, Secret, TOTP};
19use url::form_urlencoded;
20use uuid::Uuid;
21
22pub struct AuthGateway {
23 pub state: Arc<ProxyState>,
24}
25
26impl AuthGateway {
27 fn get_real_ip(&self, session: &Session) -> Option<IpAddr> {
28 let client_addr = session
29 .client_addr()
30 .and_then(|addr| addr.as_inet())
31 .map(|inet| inet.ip())?;
32
33 let runtime = self.state.runtime.load();
34
35 for (cidr, header_name) in &runtime.trusted_cidrs {
36 if cidr.contains(&client_addr)
37 && let Some(ip) = session
38 .req_header()
39 .headers
40 .get(header_name)
41 .and_then(|h| h.to_str().ok())
42 .and_then(|s| IpAddr::from_str(s).ok())
43 {
44 return Some(ip);
45 }
46 }
47 Some(client_addr)
48 }
49
50 fn verify_totp(&self, code: &str) -> Result<bool, ProxyError> {
51 let runtime = self.state.runtime.load();
52 let secret = Secret::Encoded(runtime.secret.clone());
53 let secret_bytes = secret
54 .to_bytes()
55 .map_err(|_| ProxyError::TotpSecretInvalid)?;
56
57 let totp = TOTP::new(
58 Algorithm::SHA1,
59 TOTP_DIGITS,
60 TOTP_SKEW as u8,
61 TOTP_STEP_SECS,
62 secret_bytes,
63 )
64 .map_err(|_| ProxyError::TotpCreationFailed)?;
65
66 let now = SystemTime::now()
67 .duration_since(UNIX_EPOCH)
68 .unwrap_or_default()
69 .as_secs();
70
71 let is_valid_format = totp.check(code, now)
72 || totp.check(code, now.saturating_sub(TOTP_STEP_SECS))
73 || totp.check(code, now.saturating_add(TOTP_STEP_SECS));
74
75 if !is_valid_format {
76 warn!("Invalid TOTP Format: {}", code);
77 return Ok(false);
78 }
79
80 let current_step = now / TOTP_STEP_SECS;
81 let last_step = self.state.last_verified_step.load(Ordering::Relaxed);
82
83 if current_step > last_step {
84 let result = self.state.last_verified_step.compare_exchange(
85 last_step,
86 current_step,
87 Ordering::SeqCst,
88 Ordering::Relaxed,
89 );
90
91 return Ok(result.is_ok());
92 }
93
94 warn!("Replay Attack Detected! Code used within same step.");
95 Ok(false)
96 }
97
98 fn get_session_cookie(&self, session: &Session) -> Option<SessionId> {
99 if let Some(header) = session.req_header().headers.get("Cookie")
100 && let Ok(cookie_str) = header.to_str()
101 {
102 for part in cookie_str.split(';') {
103 let part = part.trim();
104 if let Some(sid) = part.strip_prefix("SID=") {
105 return Some(SessionId::new(sid.to_string()));
106 }
107 }
108 }
109 None
110 }
111
112 fn is_blacklisted(&self, ip: IpAddr) -> bool {
113 let runtime = self.state.runtime.load();
114 if !runtime.config.security.enabled {
115 return false;
116 }
117 self.state.blacklist.load().contains_key(&ip)
118 }
119
120 fn register_failure(&self, ip: IpAddr) {
121 let runtime = self.state.runtime.load();
122 let security_config = &runtime.config.security;
123
124 if !security_config.enabled {
125 return;
126 }
127
128 if !self.state.ip_limits.contains_key(&ip)
129 && self.state.ip_limits.entry_count() >= MAX_IP_ENTRIES
130 {
131 warn!("IP Limit Table Full. Dropping failure tracking for: {}", ip);
132 return;
133 }
134
135 let entry = self.state.ip_limits.entry(ip).or_insert(0);
136 let mut val = *entry.value();
137 val += 1;
138
139 if val >= security_config.max_retries as u8 {
140 let blacklist = self.state.blacklist.load();
141
142 if security_config.blacklist_strategy == BlacklistStrategy::Block
143 && blacklist.iter().count() as u64 >= security_config.blacklist_size as u64
144 {
145 warn!(
146 "Blacklist is full and strategy is 'block', not adding new IP: {}",
147 ip
148 );
149 return;
150 }
151
152 warn!("IP {} added to blacklist due to repeated failures.", ip);
153 blacklist.insert(ip, ());
154 self.state.ip_limits.invalidate(&ip);
155 return;
156 }
157
158 self.state.ip_limits.insert(ip, val);
159 }
160
161 fn reset_failure(&self, ip: IpAddr) {
162 self.state.ip_limits.invalidate(&ip);
163 }
164
165 fn parse_upstream_addr(addr: &str) -> UpstreamAddr {
166 addr.parse()
167 .unwrap_or_else(|_| UpstreamAddr::new("127.0.0.1".to_string(), DEFAULT_HTTP_PORT))
168 }
169
170 fn check_route(host: &str, path: &str, r: &&&CompiledRoute) -> bool {
171 if let Some(prefix) = &r.path_prefix {
172 return path.starts_with(prefix);
173 }
174
175 let host_match = r.host.as_ref().map(|re| re.is_match(host)).unwrap_or(true);
176 let path_match = r.path.as_ref().map(|re| re.is_match(path)).unwrap_or(true);
177
178 host_match && path_match
179 }
180}
181
182#[async_trait]
183impl ProxyHttp for AuthGateway {
184 type CTX = ();
185
186 fn new_ctx(&self) -> Self::CTX {}
187
188 async fn upstream_peer(
189 &self,
190 session: &mut Session,
191 _ctx: &mut Self::CTX,
192 ) -> Result<Box<HttpPeer>> {
193 let runtime = self.state.runtime.load();
194
195 let host = session
196 .req_header()
197 .headers
198 .get("Host")
199 .and_then(|v| v.to_str().ok())
200 .unwrap_or("");
201
202 let host = host.split(':').next().unwrap_or(host);
203 let path = session.req_header().uri.path();
204
205 let upstream_addr = runtime
206 .routes
207 .iter()
208 .find(|r| Self::check_route(host, path, &r))
209 .map(|r| &r.upstream_addr)
210 .unwrap_or(&runtime.config.server.default_upstream);
211
212 let parsed = Self::parse_upstream_addr(upstream_addr);
213
214 Ok(Box::new(HttpPeer::new(
215 (parsed.host.as_str(), parsed.port),
216 false,
217 parsed.host.clone(),
218 )))
219 }
220
221 async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result<bool> {
222 let runtime_for_route = self.state.runtime.load();
223 let host_hdr = session
224 .req_header()
225 .headers
226 .get("Host")
227 .and_then(|v| v.to_str().ok())
228 .unwrap_or("");
229 let host_only = host_hdr.split(':').next().unwrap_or(host_hdr);
230 let path = session.req_header().uri.path();
231
232 let matched_route = runtime_for_route
233 .routes
234 .iter()
235 .find(|r| Self::check_route(host_only, path, &r));
236
237 if let Some(route) = matched_route {
238 if !route.protect {
239 return Ok(false);
240 }
241 }
242
243 let client_ip = match self.get_real_ip(session) {
244 Some(ip) => ip,
245 None => {
246 let mut header = ResponseHeader::build(400, Some(1))?;
247 header.insert_header("Content-Length", "0")?;
248 session
249 .write_response_header(Box::new(header), false)
250 .await?;
251 return Ok(true);
252 }
253 };
254
255 let runtime = self.state.runtime.load();
256 let auth_config = &runtime.config.auth;
257
258 if self.is_blacklisted(client_ip) {
259 warn!("Blocked Request from {}", client_ip);
260 let mut header = ResponseHeader::build(429, Some(1))?;
261 header.insert_header("Content-Length", "0")?;
262 session
263 .write_response_header(Box::new(header), false)
264 .await?;
265 return Ok(true);
266 }
267
268 if let Some(sid) = self.get_session_cookie(session)
269 && self.state.sessions.get(sid.as_str()).is_some()
270 {
271 return Ok(false);
272 }
273
274 if session.req_header().method == "POST" && session.req_header().uri.path() == "/auth" {
275 let content_len = session.req_header().deref().headers.get("Content-Length");
276 let content_len = content_len
277 .and_then(|v| v.to_str().ok())
278 .and_then(|v| v.parse().ok());
279 let content_len = content_len.unwrap_or(0);
280
281 if content_len > MAX_BODY_SIZE {
282 warn!("Payload too large: {} bytes", content_len);
283 let mut header = ResponseHeader::build(413, Some(1))?;
284 header.insert_header("Content-Length", "0")?;
285
286 session
287 .write_response_header(Box::new(header), true)
288 .await?;
289 return Ok(true);
290 }
291
292 let body_bytes = session.read_request_body().await?.unwrap_or_default();
293
294 if body_bytes.len() > MAX_BODY_SIZE {
295 let mut header = ResponseHeader::build(413, Some(1))?;
296 header.insert_header("Content-Length", "0")?;
297
298 session
299 .write_response_header(Box::new(header), true)
300 .await?;
301 return Ok(true);
302 }
303
304 let params: std::collections::HashMap<String, String> =
305 form_urlencoded::parse(&body_bytes).into_owned().collect();
306
307 if let Some(code) = params.get("code") {
308 match self.verify_totp(code) {
309 Ok(true) => {
310 self.reset_failure(client_ip);
311
312 if self.state.sessions.entry_count() >= MAX_SESSION_ENTRIES {
313 warn!("Session Table Full. Rejecting login.");
314 let mut header = ResponseHeader::build(503, Some(2))?;
315 header.insert_header("Retry-After", "60")?;
316 header.insert_header("Content-Length", "0")?;
317 session
318 .write_response_header(Box::new(header), true)
319 .await?;
320 return Ok(true);
321 }
322
323 let new_sid = SessionId::new(Uuid::new_v4().to_string());
324 info!("Login Success: (IP: {:?})", client_ip);
325
326 self.state.sessions.insert(new_sid.as_str().to_string(), ());
327
328 let cookie_val = format!(
329 "SID={}; Path=/; HttpOnly; Secure; SameSite=Strict; Max-Age={}",
330 new_sid, auth_config.session_duration
331 );
332 let mut header = ResponseHeader::build(302, Some(3))?;
333 header.insert_header("Set-Cookie", cookie_val)?;
334 header.insert_header("Location", "/")?;
335 header.insert_header("Content-Length", "0")?;
336
337 session
338 .write_response_header(Box::new(header), true)
339 .await?;
340 return Ok(true);
341 }
342 Ok(false) => {
343 warn!("Login Failed (Invalid TOTP). IP: {:?}", client_ip);
344 self.register_failure(client_ip);
345
346 if self.is_blacklisted(client_ip) {
347 let mut header = ResponseHeader::build(429, Some(1))?;
348 header.insert_header("Content-Length", "0")?;
349 session
350 .write_response_header(Box::new(header), true)
351 .await?;
352 return Ok(true);
353 }
354 }
355 Err(e) => {
356 warn!("TOTP verification error: {}", e);
357 self.register_failure(client_ip);
358 }
359 }
360 }
361
362 let mut header = ResponseHeader::build(302, Some(2))?;
363 header.insert_header("Location", "/?error=1")?;
364 header.insert_header("Content-Length", "0")?;
365 session
366 .write_response_header(Box::new(header), true)
367 .await?;
368 return Ok(true);
369 }
370
371 let mut header = ResponseHeader::build(200, Some(8))?;
372 header.insert_header("Content-Type", "text/html; charset=utf-8")?;
373 header.insert_header("Content-Length", runtime.login_page_len.as_str())?;
374 header.insert_header("X-Content-Type-Options", "nosniff")?;
375 header.insert_header("X-Frame-Options", "DENY")?;
376
377 header.insert_header(
378 "Cache-Control",
379 "no-store, no-cache, must-revalidate, private",
380 )?;
381 header.insert_header("Pragma", "no-cache")?;
382 header.insert_header("Expires", "0")?;
383 header.insert_header("CDN-Cache-Control", "no-store")?;
384
385 session
386 .write_response_header(Box::new(header), false)
387 .await?;
388 session
389 .write_response_body(
390 Some(Bytes::from(runtime.login_page_html.as_bytes().to_vec())),
391 true,
392 )
393 .await?;
394
395 Ok(true)
396 }
397
398 async fn response_filter(
399 &self,
400 session: &mut Session,
401 upstream_response: &mut ResponseHeader,
402 _ctx: &mut Self::CTX,
403 ) -> Result<()> {
404 let runtime = self.state.runtime.load();
405
406 let host_hdr = session
407 .req_header()
408 .headers
409 .get("Host")
410 .and_then(|v| v.to_str().ok())
411 .unwrap_or("");
412 let host_only = host_hdr.split(':').next().unwrap_or(host_hdr);
413 let path = session.req_header().uri.path();
414
415 let matched_route = runtime
416 .routes
417 .iter()
418 .find(|r| Self::check_route(host_only, path, &r));
419
420 let is_unprotected = matched_route.map(|r| !r.protect).unwrap_or(false);
421
422 if !is_unprotected {
423 upstream_response
424 .insert_header(
425 "Cache-Control",
426 "no-store, no-cache, must-revalidate, private",
427 )
428 .ok();
429 upstream_response.insert_header("Pragma", "no-cache").ok();
430 upstream_response.insert_header("Expires", "0").ok();
431
432 upstream_response
433 .insert_header("CDN-Cache-Control", "no-store")
434 .ok();
435 upstream_response
436 .insert_header("Cloudflare-CDN-Cache-Control", "no-store")
437 .ok();
438 }
439
440 Ok(())
441 }
442}