totp_gateway/
proxy.rs

1use crate::config::BlacklistStrategy;
2use crate::state::{
3    CompiledRoute, DEFAULT_HTTP_PORT, MAX_BODY_SIZE, MAX_IP_ENTRIES, MAX_SESSION_ENTRIES,
4    ProxyState, TOTP_DIGITS, TOTP_SKEW, TOTP_STEP_SECS,
5};
6use crate::utils::{ProxyError, SessionId, UpstreamAddr};
7use async_trait::async_trait;
8use bytes::Bytes;
9use log::{info, warn};
10use pingora::http::ResponseHeader;
11use pingora::prelude::*;
12use std::net::IpAddr;
13use std::ops::Deref;
14use std::str::FromStr;
15use std::sync::Arc;
16use std::sync::atomic::Ordering;
17use std::time::{SystemTime, UNIX_EPOCH};
18use totp_rs::{Algorithm, Secret, TOTP};
19use url::form_urlencoded;
20use uuid::Uuid;
21
22pub struct AuthGateway {
23    pub state: Arc<ProxyState>,
24}
25
26impl AuthGateway {
27    fn get_real_ip(&self, session: &Session) -> Option<IpAddr> {
28        let client_addr = session
29            .client_addr()
30            .and_then(|addr| addr.as_inet())
31            .map(|inet| inet.ip())?;
32
33        let runtime = self.state.runtime.load();
34
35        for (cidr, header_name) in &runtime.trusted_cidrs {
36            if cidr.contains(&client_addr)
37                && let Some(ip) = session
38                    .req_header()
39                    .headers
40                    .get(header_name)
41                    .and_then(|h| h.to_str().ok())
42                    .and_then(|s| IpAddr::from_str(s).ok())
43            {
44                return Some(ip);
45            }
46        }
47        Some(client_addr)
48    }
49
50    fn verify_totp(&self, code: &str) -> Result<bool, ProxyError> {
51        let runtime = self.state.runtime.load();
52        let secret = Secret::Encoded(runtime.secret.clone());
53        let secret_bytes = secret
54            .to_bytes()
55            .map_err(|_| ProxyError::TotpSecretInvalid)?;
56
57        let totp = TOTP::new(
58            Algorithm::SHA1,
59            TOTP_DIGITS,
60            TOTP_SKEW as u8,
61            TOTP_STEP_SECS,
62            secret_bytes,
63        )
64        .map_err(|_| ProxyError::TotpCreationFailed)?;
65
66        let now = SystemTime::now()
67            .duration_since(UNIX_EPOCH)
68            .unwrap_or_default()
69            .as_secs();
70
71        let is_valid_format = totp.check(code, now)
72            || totp.check(code, now.saturating_sub(TOTP_STEP_SECS))
73            || totp.check(code, now.saturating_add(TOTP_STEP_SECS));
74
75        if !is_valid_format {
76            warn!("Invalid TOTP Format: {}", code);
77            return Ok(false);
78        }
79
80        let current_step = now / TOTP_STEP_SECS;
81        let last_step = self.state.last_verified_step.load(Ordering::Relaxed);
82
83        if current_step > last_step {
84            let result = self.state.last_verified_step.compare_exchange(
85                last_step,
86                current_step,
87                Ordering::SeqCst,
88                Ordering::Relaxed,
89            );
90
91            return Ok(result.is_ok());
92        }
93
94        warn!("Replay Attack Detected! Code used within same step.");
95        Ok(false)
96    }
97
98    fn get_session_cookie(&self, session: &Session) -> Option<SessionId> {
99        if let Some(header) = session.req_header().headers.get("Cookie")
100            && let Ok(cookie_str) = header.to_str()
101        {
102            for part in cookie_str.split(';') {
103                let part = part.trim();
104                if let Some(sid) = part.strip_prefix("SID=") {
105                    return Some(SessionId::new(sid.to_string()));
106                }
107            }
108        }
109        None
110    }
111
112    fn is_blacklisted(&self, ip: IpAddr) -> bool {
113        let runtime = self.state.runtime.load();
114        if !runtime.config.security.enabled {
115            return false;
116        }
117        self.state.blacklist.load().contains_key(&ip)
118    }
119
120    fn register_failure(&self, ip: IpAddr) {
121        let runtime = self.state.runtime.load();
122        let security_config = &runtime.config.security;
123
124        if !security_config.enabled {
125            return;
126        }
127
128        if !self.state.ip_limits.contains_key(&ip)
129            && self.state.ip_limits.entry_count() >= MAX_IP_ENTRIES
130        {
131            warn!("IP Limit Table Full. Dropping failure tracking for: {}", ip);
132            return;
133        }
134
135        let entry = self.state.ip_limits.entry(ip).or_insert(0);
136        let mut val = *entry.value();
137        val += 1;
138
139        if val >= security_config.max_retries as u8 {
140            let blacklist = self.state.blacklist.load();
141
142            if security_config.blacklist_strategy == BlacklistStrategy::Block
143                && blacklist.iter().count() as u64 >= security_config.blacklist_size as u64
144            {
145                warn!(
146                    "Blacklist is full and strategy is 'block', not adding new IP: {}",
147                    ip
148                );
149                return;
150            }
151
152            warn!("IP {} added to blacklist due to repeated failures.", ip);
153            blacklist.insert(ip, ());
154            self.state.ip_limits.invalidate(&ip);
155            return;
156        }
157
158        self.state.ip_limits.insert(ip, val);
159    }
160
161    fn reset_failure(&self, ip: IpAddr) {
162        self.state.ip_limits.invalidate(&ip);
163    }
164
165    fn parse_upstream_addr(addr: &str) -> UpstreamAddr {
166        addr.parse()
167            .unwrap_or_else(|_| UpstreamAddr::new("127.0.0.1".to_string(), DEFAULT_HTTP_PORT))
168    }
169
170    fn check_route(host: &str, path: &str, r: &&&CompiledRoute) -> bool {
171        if let Some(prefix) = &r.path_prefix {
172            return path.starts_with(prefix);
173        }
174
175        let host_match = r.host.as_ref().map(|re| re.is_match(host)).unwrap_or(true);
176        let path_match = r.path.as_ref().map(|re| re.is_match(path)).unwrap_or(true);
177
178        host_match && path_match
179    }
180}
181
182#[async_trait]
183impl ProxyHttp for AuthGateway {
184    type CTX = ();
185
186    fn new_ctx(&self) -> Self::CTX {}
187
188    async fn upstream_peer(
189        &self,
190        session: &mut Session,
191        _ctx: &mut Self::CTX,
192    ) -> Result<Box<HttpPeer>> {
193        let runtime = self.state.runtime.load();
194
195        let host = session
196            .req_header()
197            .headers
198            .get("Host")
199            .and_then(|v| v.to_str().ok())
200            .unwrap_or("");
201
202        let host = host.split(':').next().unwrap_or(host);
203        let path = session.req_header().uri.path();
204
205        let upstream_addr = runtime
206            .routes
207            .iter()
208            .find(|r| Self::check_route(host, path, &r))
209            .map(|r| &r.upstream_addr)
210            .unwrap_or(&runtime.config.server.default_upstream);
211
212        let parsed = Self::parse_upstream_addr(upstream_addr);
213
214        Ok(Box::new(HttpPeer::new(
215            (parsed.host.as_str(), parsed.port),
216            false,
217            parsed.host.clone(),
218        )))
219    }
220
221    async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result<bool> {
222        let runtime_for_route = self.state.runtime.load();
223        let host_hdr = session
224            .req_header()
225            .headers
226            .get("Host")
227            .and_then(|v| v.to_str().ok())
228            .unwrap_or("");
229        let host_only = host_hdr.split(':').next().unwrap_or(host_hdr);
230        let path = session.req_header().uri.path();
231
232        let matched_route = runtime_for_route
233            .routes
234            .iter()
235            .find(|r| Self::check_route(host_only, path, &r));
236
237        if let Some(route) = matched_route {
238            if !route.protect {
239                return Ok(false);
240            }
241        }
242
243        let client_ip = match self.get_real_ip(session) {
244            Some(ip) => ip,
245            None => {
246                let mut header = ResponseHeader::build(400, Some(1))?;
247                header.insert_header("Content-Length", "0")?;
248                session
249                    .write_response_header(Box::new(header), false)
250                    .await?;
251                return Ok(true);
252            }
253        };
254
255        let runtime = self.state.runtime.load();
256        let auth_config = &runtime.config.auth;
257
258        if self.is_blacklisted(client_ip) {
259            warn!("Blocked Request from {}", client_ip);
260            let mut header = ResponseHeader::build(429, Some(1))?;
261            header.insert_header("Content-Length", "0")?;
262            session
263                .write_response_header(Box::new(header), false)
264                .await?;
265            return Ok(true);
266        }
267
268        if let Some(sid) = self.get_session_cookie(session)
269            && self.state.sessions.get(sid.as_str()).is_some()
270        {
271            return Ok(false);
272        }
273
274        if session.req_header().method == "POST" && session.req_header().uri.path() == "/auth" {
275            let content_len = session.req_header().deref().headers.get("Content-Length");
276            let content_len = content_len
277                .and_then(|v| v.to_str().ok())
278                .and_then(|v| v.parse().ok());
279            let content_len = content_len.unwrap_or(0);
280
281            if content_len > MAX_BODY_SIZE {
282                warn!("Payload too large: {} bytes", content_len);
283                let mut header = ResponseHeader::build(413, Some(1))?;
284                header.insert_header("Content-Length", "0")?;
285
286                session
287                    .write_response_header(Box::new(header), true)
288                    .await?;
289                return Ok(true);
290            }
291
292            let body_bytes = session.read_request_body().await?.unwrap_or_default();
293
294            if body_bytes.len() > MAX_BODY_SIZE {
295                let mut header = ResponseHeader::build(413, Some(1))?;
296                header.insert_header("Content-Length", "0")?;
297
298                session
299                    .write_response_header(Box::new(header), true)
300                    .await?;
301                return Ok(true);
302            }
303
304            let params: std::collections::HashMap<String, String> =
305                form_urlencoded::parse(&body_bytes).into_owned().collect();
306
307            if let Some(code) = params.get("code") {
308                match self.verify_totp(code) {
309                    Ok(true) => {
310                        self.reset_failure(client_ip);
311
312                        if self.state.sessions.entry_count() >= MAX_SESSION_ENTRIES {
313                            warn!("Session Table Full. Rejecting login.");
314                            let mut header = ResponseHeader::build(503, Some(2))?;
315                            header.insert_header("Retry-After", "60")?;
316                            header.insert_header("Content-Length", "0")?;
317                            session
318                                .write_response_header(Box::new(header), true)
319                                .await?;
320                            return Ok(true);
321                        }
322
323                        let new_sid = SessionId::new(Uuid::new_v4().to_string());
324                        info!("Login Success: (IP: {:?})", client_ip);
325
326                        self.state.sessions.insert(new_sid.as_str().to_string(), ());
327
328                        let cookie_val = format!(
329                            "SID={}; Path=/; HttpOnly; Secure; SameSite=Strict; Max-Age={}",
330                            new_sid, auth_config.session_duration
331                        );
332                        let mut header = ResponseHeader::build(302, Some(3))?;
333                        header.insert_header("Set-Cookie", cookie_val)?;
334                        header.insert_header("Location", "/")?;
335                        header.insert_header("Content-Length", "0")?;
336
337                        session
338                            .write_response_header(Box::new(header), true)
339                            .await?;
340                        return Ok(true);
341                    }
342                    Ok(false) => {
343                        warn!("Login Failed (Invalid TOTP). IP: {:?}", client_ip);
344                        self.register_failure(client_ip);
345
346                        if self.is_blacklisted(client_ip) {
347                            let mut header = ResponseHeader::build(429, Some(1))?;
348                            header.insert_header("Content-Length", "0")?;
349                            session
350                                .write_response_header(Box::new(header), true)
351                                .await?;
352                            return Ok(true);
353                        }
354                    }
355                    Err(e) => {
356                        warn!("TOTP verification error: {}", e);
357                        self.register_failure(client_ip);
358                    }
359                }
360            }
361
362            let mut header = ResponseHeader::build(302, Some(2))?;
363            header.insert_header("Location", "/?error=1")?;
364            header.insert_header("Content-Length", "0")?;
365            session
366                .write_response_header(Box::new(header), true)
367                .await?;
368            return Ok(true);
369        }
370
371        let mut header = ResponseHeader::build(200, Some(8))?;
372        header.insert_header("Content-Type", "text/html; charset=utf-8")?;
373        header.insert_header("Content-Length", runtime.login_page_len.as_str())?;
374        header.insert_header("X-Content-Type-Options", "nosniff")?;
375        header.insert_header("X-Frame-Options", "DENY")?;
376
377        header.insert_header(
378            "Cache-Control",
379            "no-store, no-cache, must-revalidate, private",
380        )?;
381        header.insert_header("Pragma", "no-cache")?;
382        header.insert_header("Expires", "0")?;
383        header.insert_header("CDN-Cache-Control", "no-store")?;
384
385        session
386            .write_response_header(Box::new(header), false)
387            .await?;
388        session
389            .write_response_body(
390                Some(Bytes::from(runtime.login_page_html.as_bytes().to_vec())),
391                true,
392            )
393            .await?;
394
395        Ok(true)
396    }
397
398    async fn response_filter(
399        &self,
400        session: &mut Session,
401        upstream_response: &mut ResponseHeader,
402        _ctx: &mut Self::CTX,
403    ) -> Result<()> {
404        let runtime = self.state.runtime.load();
405
406        let host_hdr = session
407            .req_header()
408            .headers
409            .get("Host")
410            .and_then(|v| v.to_str().ok())
411            .unwrap_or("");
412        let host_only = host_hdr.split(':').next().unwrap_or(host_hdr);
413        let path = session.req_header().uri.path();
414
415        let matched_route = runtime
416            .routes
417            .iter()
418            .find(|r| Self::check_route(host_only, path, &r));
419
420        let is_unprotected = matched_route.map(|r| !r.protect).unwrap_or(false);
421
422        if !is_unprotected {
423            upstream_response
424                .insert_header(
425                    "Cache-Control",
426                    "no-store, no-cache, must-revalidate, private",
427                )
428                .ok();
429            upstream_response.insert_header("Pragma", "no-cache").ok();
430            upstream_response.insert_header("Expires", "0").ok();
431
432            upstream_response
433                .insert_header("CDN-Cache-Control", "no-store")
434                .ok();
435            upstream_response
436                .insert_header("Cloudflare-CDN-Cache-Control", "no-store")
437                .ok();
438        }
439
440        Ok(())
441    }
442}