Skip to main content

what_core/auth/
mod.rs

1//! Authentication module for the What framework
2//!
3//! Provides JWT-based authentication with:
4//! - Cookie-based JWT storage (HttpOnly, Secure, SameSite)
5//! - Configurable protected routes
6//! - Session integration for user data
7//! - Backend API integration for login/logout
8
9use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value, json};
12use std::collections::HashMap;
13use std::sync::OnceLock;
14
15use crate::Result;
16use crate::config::AuthConfig;
17
18/// Auto-generated JWT secret used when no secret is configured.
19/// Generated once per process lifetime using cryptographically secure random bytes.
20static AUTO_JWT_SECRET: OnceLock<String> = OnceLock::new();
21
22/// Get or generate a fallback JWT secret (32 random bytes, hex-encoded).
23fn get_or_generate_jwt_secret() -> &'static str {
24    AUTO_JWT_SECRET.get_or_init(|| {
25        use rand::RngCore;
26        let mut bytes = [0u8; 32];
27        rand::thread_rng().fill_bytes(&mut bytes);
28        let secret: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
29        tracing::warn!(
30            "No jwt_secret configured — generated a random secret. \
31             JWTs signed by external services will fail validation. \
32             Set [auth] jwt_secret in what.toml or WHAT_AUTH_JWT_SECRET env var."
33        );
34        secret
35    })
36}
37
38/// JWT claims structure
39/// Uses a flexible HashMap to support any claims from the backend
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct JwtClaims {
42    /// Standard JWT claims
43    #[serde(default)]
44    pub exp: Option<u64>,
45    #[serde(default)]
46    pub iat: Option<u64>,
47    #[serde(default)]
48    pub sub: Option<String>,
49
50    /// Custom claims (user_id, email, full_name, etc.)
51    #[serde(flatten)]
52    pub custom: HashMap<String, Value>,
53}
54
55impl JwtClaims {
56    /// Extract specified claims into a context map for templates
57    pub fn to_context(&self, claim_names: &[String]) -> HashMap<String, Value> {
58        let mut context = HashMap::new();
59
60        // Add standard claims if present
61        if let Some(sub) = &self.sub {
62            context.insert("sub".to_string(), json!(sub));
63        }
64        if let Some(exp) = self.exp {
65            context.insert("exp".to_string(), json!(exp));
66        }
67
68        // Add requested custom claims
69        for name in claim_names {
70            if let Some(value) = self.custom.get(name) {
71                context.insert(name.clone(), value.clone());
72            }
73        }
74
75        context
76    }
77
78    /// Check if the token is expired
79    pub fn is_expired(&self) -> bool {
80        if let Some(exp) = self.exp {
81            let now = std::time::SystemTime::now()
82                .duration_since(std::time::UNIX_EPOCH)
83                .map(|d| d.as_secs())
84                .unwrap_or(0);
85            exp < now
86        } else {
87            false // No expiration = not expired
88        }
89    }
90}
91
92/// Authentication handler
93#[derive(Clone)]
94pub struct AuthHandler {
95    config: AuthConfig,
96}
97
98impl AuthHandler {
99    /// Create a new auth handler with the given configuration
100    pub fn new(config: AuthConfig) -> Self {
101        Self { config }
102    }
103
104    /// Load configuration with environment variable overrides
105    pub fn from_config_with_env(mut config: AuthConfig) -> Self {
106        // Override with environment variables
107        if let Ok(val) = std::env::var("WHAT_AUTH_ENABLED") {
108            config.enabled = val.parse().unwrap_or(config.enabled);
109        }
110        if let Ok(val) = std::env::var("WHAT_AUTH_LOGIN_ENDPOINT") {
111            config.login_endpoint = Some(val);
112        }
113        if let Ok(val) = std::env::var("WHAT_AUTH_LOGOUT_ENDPOINT") {
114            config.logout_endpoint = Some(val);
115        }
116        if let Ok(val) = std::env::var("WHAT_AUTH_JWT_SECRET") {
117            config.jwt_secret = Some(val);
118        }
119        if let Ok(val) = std::env::var("WHAT_AUTH_JWT_COOKIE_NAME") {
120            config.jwt_cookie_name = val;
121        }
122        if let Ok(val) = std::env::var("WHAT_AUTH_LOGIN_PATH") {
123            config.login_path = val;
124        }
125        if let Ok(val) = std::env::var("WHAT_AUTH_AFTER_LOGIN") {
126            config.after_login = val;
127        }
128
129        Self { config }
130    }
131
132    /// Check if authentication is enabled
133    pub fn is_enabled(&self) -> bool {
134        self.config.enabled
135    }
136
137    /// Check if a path is protected
138    pub fn is_protected(&self, path: &str) -> bool {
139        if !self.config.enabled {
140            return false;
141        }
142
143        for pattern in &self.config.protected_paths {
144            if pattern_matches(pattern, path) {
145                return true;
146            }
147        }
148        false
149    }
150
151    /// Get the login path
152    pub fn login_path(&self) -> &str {
153        &self.config.login_path
154    }
155
156    /// Get the after-login redirect path
157    pub fn after_login_path(&self) -> &str {
158        &self.config.after_login
159    }
160
161    /// Get the login endpoint URL
162    pub fn login_endpoint(&self) -> Option<&str> {
163        self.config.login_endpoint.as_deref()
164    }
165
166    /// Get the logout endpoint URL
167    pub fn logout_endpoint(&self) -> Option<&str> {
168        self.config.logout_endpoint.as_deref()
169    }
170
171    /// Get the JWT cookie name
172    pub fn jwt_cookie_name(&self) -> &str {
173        &self.config.jwt_cookie_name
174    }
175
176    /// Get the list of claims to extract
177    pub fn jwt_claims(&self) -> &[String] {
178        &self.config.jwt_claims
179    }
180
181    /// Parse JWT from cookie header
182    pub fn parse_jwt_cookie(&self, cookie_header: Option<&str>) -> Option<String> {
183        cookie_header.and_then(|header| {
184            header
185                .split(';')
186                .map(|s| s.trim())
187                .find(|s| s.starts_with(&format!("{}=", self.config.jwt_cookie_name)))
188                .map(|s| s[self.config.jwt_cookie_name.len() + 1..].to_string())
189        })
190    }
191
192    /// Decode and validate a JWT token.
193    ///
194    /// Always validates the signature. If no jwt_secret is configured, a random
195    /// secret is generated at startup (logged as WARN). This means externally-signed
196    /// tokens will be rejected unless the correct secret is provided.
197    pub fn decode_jwt(&self, token: &str) -> Result<JwtClaims> {
198        let secret = match self.config.jwt_secret {
199            Some(ref s) => s.as_str(),
200            None => get_or_generate_jwt_secret(),
201        };
202        let key = DecodingKey::from_secret(secret.as_bytes());
203        let validation = Validation::new(Algorithm::HS256);
204        let token_data = decode::<JwtClaims>(token, &key, &validation)?;
205        Ok(token_data.claims)
206    }
207
208    /// Build Set-Cookie header for JWT token
209    pub fn build_jwt_cookie(&self, token: &str, max_age: i64, secure: bool) -> String {
210        let mut cookie = format!(
211            "{}={}; HttpOnly; SameSite=Strict; Path=/; Max-Age={}",
212            self.config.jwt_cookie_name, token, max_age
213        );
214
215        if secure {
216            cookie.push_str("; Secure");
217        }
218
219        cookie
220    }
221
222    /// Build Set-Cookie header to clear the JWT cookie
223    pub fn build_clear_cookie(&self) -> String {
224        format!(
225            "{}=; HttpOnly; SameSite=Strict; Path=/; Max-Age=0",
226            self.config.jwt_cookie_name
227        )
228    }
229}
230
231/// Simple glob-style pattern matching
232/// Supports:
233/// - Exact match: "/admin" matches "/admin"
234/// - Wildcard suffix: "/admin/*" matches "/admin/users", "/admin/settings"
235/// - Double wildcard: "/api/**" matches "/api/v1/users", "/api/v1/users/123"
236fn pattern_matches(pattern: &str, path: &str) -> bool {
237    if pattern.ends_with("/**") {
238        let prefix = &pattern[..pattern.len() - 3];
239        path.starts_with(prefix)
240    } else if pattern.ends_with("/*") {
241        // Include the trailing slash in prefix for proper matching
242        // "/admin/*" → prefix "/admin/" should match "/admin/users" but not "/admin/users/edit"
243        let prefix = &pattern[..pattern.len() - 1];
244        path.starts_with(prefix) && !path[prefix.len()..].contains('/')
245    } else {
246        pattern == path
247    }
248}
249
250/// User context for templates
251/// Contains authenticated user information extracted from JWT
252#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct UserContext {
254    /// Whether the user is authenticated
255    pub authenticated: bool,
256    /// User claims from JWT
257    #[serde(flatten)]
258    pub claims: HashMap<String, Value>,
259}
260
261impl UserContext {
262    /// Create an unauthenticated context
263    pub fn unauthenticated() -> Self {
264        Self {
265            authenticated: false,
266            claims: HashMap::new(),
267        }
268    }
269
270    /// Create an authenticated context from JWT claims
271    pub fn from_claims(claims: HashMap<String, Value>) -> Self {
272        Self {
273            authenticated: true,
274            claims,
275        }
276    }
277
278    /// Convert to JSON Value for template context
279    pub fn to_context(&self) -> Value {
280        let mut map = serde_json::Map::new();
281        map.insert("authenticated".to_string(), json!(self.authenticated));
282
283        // Add all claims
284        for (key, value) in &self.claims {
285            map.insert(key.clone(), value.clone());
286        }
287
288        Value::Object(map)
289    }
290
291    /// Extract the user's roles from the JWT claims. Accepts a `roles` claim
292    /// (JSON array or comma-separated string) and falls back to a `role` claim.
293    pub fn roles(&self) -> Vec<String> {
294        self.claims
295            .get("roles")
296            .or_else(|| self.claims.get("role"))
297            .map(|v| match v {
298                Value::Array(arr) => arr
299                    .iter()
300                    .filter_map(|v| v.as_str().map(String::from))
301                    .collect(),
302                Value::String(s) => s.split(',').map(|r| r.trim().to_string()).collect(),
303                _ => Vec::new(),
304            })
305            .unwrap_or_default()
306    }
307
308    /// The `sub` (subject) claim, if present — the stable user identifier.
309    pub fn sub(&self) -> Option<String> {
310        self.claims.get("sub").and_then(|v| v.as_str().map(String::from))
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn test_pattern_matches() {
320        // Exact match
321        assert!(pattern_matches("/admin", "/admin"));
322        assert!(!pattern_matches("/admin", "/admin/users"));
323
324        // Single wildcard
325        assert!(pattern_matches("/admin/*", "/admin/users"));
326        assert!(pattern_matches("/admin/*", "/admin/settings"));
327        assert!(!pattern_matches("/admin/*", "/admin/users/123"));
328        assert!(!pattern_matches("/admin/*", "/admin"));
329
330        // Double wildcard
331        assert!(pattern_matches("/api/**", "/api/v1"));
332        assert!(pattern_matches("/api/**", "/api/v1/users"));
333        assert!(pattern_matches("/api/**", "/api/v1/users/123"));
334    }
335
336    #[test]
337    fn test_jwt_claims_to_context() {
338        let claims = JwtClaims {
339            exp: Some(1234567890),
340            iat: Some(1234567800),
341            sub: Some("user123".to_string()),
342            custom: [
343                ("email".to_string(), json!("user@example.com")),
344                ("full_name".to_string(), json!("John Doe")),
345                ("role".to_string(), json!("admin")),
346            ]
347            .into_iter()
348            .collect(),
349        };
350
351        let context = claims.to_context(&["email".to_string(), "full_name".to_string()]);
352
353        assert_eq!(context.get("email"), Some(&json!("user@example.com")));
354        assert_eq!(context.get("full_name"), Some(&json!("John Doe")));
355        assert_eq!(context.get("sub"), Some(&json!("user123")));
356        assert!(!context.contains_key("role")); // Not requested
357    }
358
359    #[test]
360    fn test_user_context() {
361        let unauthenticated = UserContext::unauthenticated();
362        assert!(!unauthenticated.authenticated);
363
364        let authenticated = UserContext::from_claims(
365            [("email".to_string(), json!("user@example.com"))]
366                .into_iter()
367                .collect(),
368        );
369        assert!(authenticated.authenticated);
370
371        let context = authenticated.to_context();
372        assert_eq!(context.get("authenticated"), Some(&json!(true)));
373        assert_eq!(context.get("email"), Some(&json!("user@example.com")));
374    }
375
376    #[test]
377    fn test_auth_handler_parse_jwt_cookie() {
378        let config = AuthConfig {
379            enabled: true,
380            jwt_cookie_name: "w_token".to_string(),
381            ..Default::default()
382        };
383        let handler = AuthHandler::new(config);
384
385        // Test valid cookie header
386        let cookie_header = Some("w_token=abc123; other_cookie=xyz");
387        let result = handler.parse_jwt_cookie(cookie_header);
388        assert_eq!(result, Some("abc123".to_string()));
389
390        // Test cookie at end of header
391        let cookie_header = Some("other=value; w_token=def456");
392        let result = handler.parse_jwt_cookie(cookie_header);
393        assert_eq!(result, Some("def456".to_string()));
394
395        // Test missing cookie
396        let cookie_header = Some("other_cookie=xyz");
397        let result = handler.parse_jwt_cookie(cookie_header);
398        assert!(result.is_none());
399
400        // Test None header
401        let result = handler.parse_jwt_cookie(None);
402        assert!(result.is_none());
403    }
404
405    #[test]
406    fn test_auth_handler_is_protected() {
407        let config = AuthConfig {
408            enabled: true,
409            protected_paths: vec![
410                "/admin".to_string(),
411                "/admin/*".to_string(),
412                "/api/**".to_string(),
413            ],
414            ..Default::default()
415        };
416        let handler = AuthHandler::new(config);
417
418        // Exact match
419        assert!(handler.is_protected("/admin"));
420
421        // Single wildcard match
422        assert!(handler.is_protected("/admin/users"));
423        assert!(handler.is_protected("/admin/settings"));
424
425        // Single wildcard should NOT match deeper paths
426        assert!(!handler.is_protected("/admin/users/123"));
427
428        // Double wildcard matches all depths
429        assert!(handler.is_protected("/api/v1"));
430        assert!(handler.is_protected("/api/v1/users"));
431        assert!(handler.is_protected("/api/v1/users/123"));
432
433        // Non-matching paths
434        assert!(!handler.is_protected("/"));
435        assert!(!handler.is_protected("/public"));
436        assert!(!handler.is_protected("/login"));
437    }
438
439    #[test]
440    fn test_auth_handler_disabled() {
441        let config = AuthConfig {
442            enabled: false,
443            protected_paths: vec!["/admin/**".to_string()],
444            ..Default::default()
445        };
446        let handler = AuthHandler::new(config);
447
448        // When auth is disabled, nothing should be protected
449        assert!(!handler.is_protected("/admin"));
450        assert!(!handler.is_protected("/admin/users"));
451        assert!(!handler.is_enabled());
452    }
453
454    #[test]
455    fn test_build_jwt_cookie() {
456        let config = AuthConfig {
457            enabled: true,
458            jwt_cookie_name: "w_token".to_string(),
459            ..Default::default()
460        };
461        let handler = AuthHandler::new(config);
462
463        // Build cookie without Secure flag
464        let cookie = handler.build_jwt_cookie("test_token_123", 3600, false);
465        assert!(cookie.contains("w_token=test_token_123"));
466        assert!(cookie.contains("HttpOnly"));
467        assert!(cookie.contains("SameSite=Strict"));
468        assert!(cookie.contains("Path=/"));
469        assert!(cookie.contains("Max-Age=3600"));
470        assert!(!cookie.contains("Secure"));
471
472        // Build cookie with Secure flag
473        let cookie = handler.build_jwt_cookie("test_token_123", 3600, true);
474        assert!(cookie.contains("Secure"));
475    }
476
477    #[test]
478    fn test_build_clear_cookie() {
479        let config = AuthConfig {
480            enabled: true,
481            jwt_cookie_name: "w_token".to_string(),
482            ..Default::default()
483        };
484        let handler = AuthHandler::new(config);
485
486        let cookie = handler.build_clear_cookie();
487        assert!(cookie.contains("w_token="));
488        assert!(cookie.contains("Max-Age=0"));
489        assert!(cookie.contains("HttpOnly"));
490        assert!(cookie.contains("SameSite=Strict"));
491        assert!(cookie.contains("Path=/"));
492    }
493
494    #[test]
495    fn test_jwt_claims_is_expired() {
496        // Not expired (future timestamp)
497        let future_exp = std::time::SystemTime::now()
498            .duration_since(std::time::UNIX_EPOCH)
499            .unwrap()
500            .as_secs()
501            + 3600; // 1 hour in the future
502
503        let claims = JwtClaims {
504            exp: Some(future_exp),
505            iat: None,
506            sub: None,
507            custom: HashMap::new(),
508        };
509        assert!(!claims.is_expired());
510
511        // Expired (past timestamp)
512        let past_exp = std::time::SystemTime::now()
513            .duration_since(std::time::UNIX_EPOCH)
514            .unwrap()
515            .as_secs()
516            - 3600; // 1 hour in the past
517
518        let expired_claims = JwtClaims {
519            exp: Some(past_exp),
520            iat: None,
521            sub: None,
522            custom: HashMap::new(),
523        };
524        assert!(expired_claims.is_expired());
525
526        // No expiration = not expired
527        let no_exp_claims = JwtClaims {
528            exp: None,
529            iat: None,
530            sub: None,
531            custom: HashMap::new(),
532        };
533        assert!(!no_exp_claims.is_expired());
534    }
535
536    #[test]
537    fn test_decode_jwt_with_configured_secret() {
538        use jsonwebtoken::{EncodingKey, Header, encode};
539
540        let secret = "test_secret_123";
541        let config = AuthConfig {
542            enabled: true,
543            jwt_secret: Some(secret.to_string()),
544            ..Default::default()
545        };
546        let handler = AuthHandler::new(config);
547
548        // Create a valid token with the same secret
549        let exp = std::time::SystemTime::now()
550            .duration_since(std::time::UNIX_EPOCH)
551            .unwrap()
552            .as_secs()
553            + 3600;
554
555        let claims = JwtClaims {
556            exp: Some(exp),
557            iat: None,
558            sub: Some("user1".to_string()),
559            custom: [("email".to_string(), json!("a@b.com"))]
560                .into_iter()
561                .collect(),
562        };
563
564        let token = encode(
565            &Header::default(),
566            &claims,
567            &EncodingKey::from_secret(secret.as_bytes()),
568        )
569        .unwrap();
570        let decoded = handler.decode_jwt(&token).unwrap();
571        assert_eq!(decoded.sub, Some("user1".to_string()));
572        assert_eq!(decoded.custom.get("email"), Some(&json!("a@b.com")));
573    }
574
575    #[test]
576    fn test_decode_jwt_rejects_wrong_secret() {
577        use jsonwebtoken::{EncodingKey, Header, encode};
578
579        let config = AuthConfig {
580            enabled: true,
581            jwt_secret: Some("correct_secret".to_string()),
582            ..Default::default()
583        };
584        let handler = AuthHandler::new(config);
585
586        let exp = std::time::SystemTime::now()
587            .duration_since(std::time::UNIX_EPOCH)
588            .unwrap()
589            .as_secs()
590            + 3600;
591
592        let claims = JwtClaims {
593            exp: Some(exp),
594            iat: None,
595            sub: None,
596            custom: HashMap::new(),
597        };
598
599        // Sign with a different secret
600        let token = encode(
601            &Header::default(),
602            &claims,
603            &EncodingKey::from_secret(b"wrong_secret"),
604        )
605        .unwrap();
606        let result = handler.decode_jwt(&token);
607        assert!(
608            result.is_err(),
609            "Should reject JWT signed with wrong secret"
610        );
611    }
612
613    #[test]
614    fn test_decode_jwt_no_secret_uses_auto_generated() {
615        // When no secret is configured, decode_jwt should use the auto-generated secret.
616        // A token signed with an arbitrary secret should be rejected.
617        use jsonwebtoken::{EncodingKey, Header, encode};
618
619        let config = AuthConfig {
620            enabled: true,
621            jwt_secret: None, // No secret configured
622            ..Default::default()
623        };
624        let handler = AuthHandler::new(config);
625
626        let exp = std::time::SystemTime::now()
627            .duration_since(std::time::UNIX_EPOCH)
628            .unwrap()
629            .as_secs()
630            + 3600;
631
632        let claims = JwtClaims {
633            exp: Some(exp),
634            iat: None,
635            sub: None,
636            custom: HashMap::new(),
637        };
638
639        // Sign with an arbitrary secret — this should NOT match the auto-generated one
640        let token = encode(
641            &Header::default(),
642            &claims,
643            &EncodingKey::from_secret(b"attacker_secret"),
644        )
645        .unwrap();
646        let result = handler.decode_jwt(&token);
647        assert!(
648            result.is_err(),
649            "Should reject JWT when no secret is configured (auto-generated secret won't match)"
650        );
651    }
652
653    #[test]
654    fn test_auth_handler_getters() {
655        let config = AuthConfig {
656            enabled: true,
657            login_path: "/login".to_string(),
658            after_login: "/dashboard".to_string(),
659            login_endpoint: Some("https://api.example.com/login".to_string()),
660            logout_endpoint: Some("https://api.example.com/logout".to_string()),
661            jwt_cookie_name: "auth_token".to_string(),
662            jwt_claims: vec!["email".to_string(), "name".to_string()],
663            ..Default::default()
664        };
665        let handler = AuthHandler::new(config);
666
667        assert!(handler.is_enabled());
668        assert_eq!(handler.login_path(), "/login");
669        assert_eq!(handler.after_login_path(), "/dashboard");
670        assert_eq!(
671            handler.login_endpoint(),
672            Some("https://api.example.com/login")
673        );
674        assert_eq!(
675            handler.logout_endpoint(),
676            Some("https://api.example.com/logout")
677        );
678        assert_eq!(handler.jwt_cookie_name(), "auth_token");
679        assert_eq!(
680            handler.jwt_claims(),
681            &["email".to_string(), "name".to_string()]
682        );
683    }
684}