1use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value, json};
12use std::collections::HashMap;
13use std::sync::OnceLock;
14
15use crate::Result;
16use crate::config::AuthConfig;
17
18static AUTO_JWT_SECRET: OnceLock<String> = OnceLock::new();
21
22fn get_or_generate_jwt_secret() -> &'static str {
24 AUTO_JWT_SECRET.get_or_init(|| {
25 use rand::RngCore;
26 let mut bytes = [0u8; 32];
27 rand::thread_rng().fill_bytes(&mut bytes);
28 let secret: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
29 tracing::warn!(
30 "No jwt_secret configured — generated a random secret. \
31 JWTs signed by external services will fail validation. \
32 Set [auth] jwt_secret in what.toml or WHAT_AUTH_JWT_SECRET env var."
33 );
34 secret
35 })
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct JwtClaims {
42 #[serde(default)]
44 pub exp: Option<u64>,
45 #[serde(default)]
46 pub iat: Option<u64>,
47 #[serde(default)]
48 pub sub: Option<String>,
49
50 #[serde(flatten)]
52 pub custom: HashMap<String, Value>,
53}
54
55impl JwtClaims {
56 pub fn to_context(&self, claim_names: &[String]) -> HashMap<String, Value> {
58 let mut context = HashMap::new();
59
60 if let Some(sub) = &self.sub {
62 context.insert("sub".to_string(), json!(sub));
63 }
64 if let Some(exp) = self.exp {
65 context.insert("exp".to_string(), json!(exp));
66 }
67
68 for name in claim_names {
70 if let Some(value) = self.custom.get(name) {
71 context.insert(name.clone(), value.clone());
72 }
73 }
74
75 context
76 }
77
78 pub fn is_expired(&self) -> bool {
80 if let Some(exp) = self.exp {
81 let now = std::time::SystemTime::now()
82 .duration_since(std::time::UNIX_EPOCH)
83 .map(|d| d.as_secs())
84 .unwrap_or(0);
85 exp < now
86 } else {
87 false }
89 }
90}
91
92#[derive(Clone)]
94pub struct AuthHandler {
95 config: AuthConfig,
96}
97
98impl AuthHandler {
99 pub fn new(config: AuthConfig) -> Self {
101 Self { config }
102 }
103
104 pub fn from_config_with_env(mut config: AuthConfig) -> Self {
106 if let Ok(val) = std::env::var("WHAT_AUTH_ENABLED") {
108 config.enabled = val.parse().unwrap_or(config.enabled);
109 }
110 if let Ok(val) = std::env::var("WHAT_AUTH_LOGIN_ENDPOINT") {
111 config.login_endpoint = Some(val);
112 }
113 if let Ok(val) = std::env::var("WHAT_AUTH_LOGOUT_ENDPOINT") {
114 config.logout_endpoint = Some(val);
115 }
116 if let Ok(val) = std::env::var("WHAT_AUTH_JWT_SECRET") {
117 config.jwt_secret = Some(val);
118 }
119 if let Ok(val) = std::env::var("WHAT_AUTH_JWT_COOKIE_NAME") {
120 config.jwt_cookie_name = val;
121 }
122 if let Ok(val) = std::env::var("WHAT_AUTH_LOGIN_PATH") {
123 config.login_path = val;
124 }
125 if let Ok(val) = std::env::var("WHAT_AUTH_AFTER_LOGIN") {
126 config.after_login = val;
127 }
128
129 Self { config }
130 }
131
132 pub fn is_enabled(&self) -> bool {
134 self.config.enabled
135 }
136
137 pub fn is_protected(&self, path: &str) -> bool {
139 if !self.config.enabled {
140 return false;
141 }
142
143 for pattern in &self.config.protected_paths {
144 if pattern_matches(pattern, path) {
145 return true;
146 }
147 }
148 false
149 }
150
151 pub fn login_path(&self) -> &str {
153 &self.config.login_path
154 }
155
156 pub fn after_login_path(&self) -> &str {
158 &self.config.after_login
159 }
160
161 pub fn login_endpoint(&self) -> Option<&str> {
163 self.config.login_endpoint.as_deref()
164 }
165
166 pub fn logout_endpoint(&self) -> Option<&str> {
168 self.config.logout_endpoint.as_deref()
169 }
170
171 pub fn jwt_cookie_name(&self) -> &str {
173 &self.config.jwt_cookie_name
174 }
175
176 pub fn jwt_claims(&self) -> &[String] {
178 &self.config.jwt_claims
179 }
180
181 pub fn parse_jwt_cookie(&self, cookie_header: Option<&str>) -> Option<String> {
183 cookie_header.and_then(|header| {
184 header
185 .split(';')
186 .map(|s| s.trim())
187 .find(|s| s.starts_with(&format!("{}=", self.config.jwt_cookie_name)))
188 .map(|s| s[self.config.jwt_cookie_name.len() + 1..].to_string())
189 })
190 }
191
192 pub fn decode_jwt(&self, token: &str) -> Result<JwtClaims> {
198 let secret = match self.config.jwt_secret {
199 Some(ref s) => s.as_str(),
200 None => get_or_generate_jwt_secret(),
201 };
202 let key = DecodingKey::from_secret(secret.as_bytes());
203 let validation = Validation::new(Algorithm::HS256);
204 let token_data = decode::<JwtClaims>(token, &key, &validation)?;
205 Ok(token_data.claims)
206 }
207
208 pub fn build_jwt_cookie(&self, token: &str, max_age: i64, secure: bool) -> String {
210 let mut cookie = format!(
211 "{}={}; HttpOnly; SameSite=Strict; Path=/; Max-Age={}",
212 self.config.jwt_cookie_name, token, max_age
213 );
214
215 if secure {
216 cookie.push_str("; Secure");
217 }
218
219 cookie
220 }
221
222 pub fn build_clear_cookie(&self) -> String {
224 format!(
225 "{}=; HttpOnly; SameSite=Strict; Path=/; Max-Age=0",
226 self.config.jwt_cookie_name
227 )
228 }
229}
230
231fn pattern_matches(pattern: &str, path: &str) -> bool {
237 if pattern.ends_with("/**") {
238 let prefix = &pattern[..pattern.len() - 3];
239 path.starts_with(prefix)
240 } else if pattern.ends_with("/*") {
241 let prefix = &pattern[..pattern.len() - 1];
244 path.starts_with(prefix) && !path[prefix.len()..].contains('/')
245 } else {
246 pattern == path
247 }
248}
249
250#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct UserContext {
254 pub authenticated: bool,
256 #[serde(flatten)]
258 pub claims: HashMap<String, Value>,
259}
260
261impl UserContext {
262 pub fn unauthenticated() -> Self {
264 Self {
265 authenticated: false,
266 claims: HashMap::new(),
267 }
268 }
269
270 pub fn from_claims(claims: HashMap<String, Value>) -> Self {
272 Self {
273 authenticated: true,
274 claims,
275 }
276 }
277
278 pub fn to_context(&self) -> Value {
280 let mut map = serde_json::Map::new();
281 map.insert("authenticated".to_string(), json!(self.authenticated));
282
283 for (key, value) in &self.claims {
285 map.insert(key.clone(), value.clone());
286 }
287
288 Value::Object(map)
289 }
290
291 pub fn roles(&self) -> Vec<String> {
294 self.claims
295 .get("roles")
296 .or_else(|| self.claims.get("role"))
297 .map(|v| match v {
298 Value::Array(arr) => arr
299 .iter()
300 .filter_map(|v| v.as_str().map(String::from))
301 .collect(),
302 Value::String(s) => s.split(',').map(|r| r.trim().to_string()).collect(),
303 _ => Vec::new(),
304 })
305 .unwrap_or_default()
306 }
307
308 pub fn sub(&self) -> Option<String> {
310 self.claims.get("sub").and_then(|v| v.as_str().map(String::from))
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn test_pattern_matches() {
320 assert!(pattern_matches("/admin", "/admin"));
322 assert!(!pattern_matches("/admin", "/admin/users"));
323
324 assert!(pattern_matches("/admin/*", "/admin/users"));
326 assert!(pattern_matches("/admin/*", "/admin/settings"));
327 assert!(!pattern_matches("/admin/*", "/admin/users/123"));
328 assert!(!pattern_matches("/admin/*", "/admin"));
329
330 assert!(pattern_matches("/api/**", "/api/v1"));
332 assert!(pattern_matches("/api/**", "/api/v1/users"));
333 assert!(pattern_matches("/api/**", "/api/v1/users/123"));
334 }
335
336 #[test]
337 fn test_jwt_claims_to_context() {
338 let claims = JwtClaims {
339 exp: Some(1234567890),
340 iat: Some(1234567800),
341 sub: Some("user123".to_string()),
342 custom: [
343 ("email".to_string(), json!("user@example.com")),
344 ("full_name".to_string(), json!("John Doe")),
345 ("role".to_string(), json!("admin")),
346 ]
347 .into_iter()
348 .collect(),
349 };
350
351 let context = claims.to_context(&["email".to_string(), "full_name".to_string()]);
352
353 assert_eq!(context.get("email"), Some(&json!("user@example.com")));
354 assert_eq!(context.get("full_name"), Some(&json!("John Doe")));
355 assert_eq!(context.get("sub"), Some(&json!("user123")));
356 assert!(!context.contains_key("role")); }
358
359 #[test]
360 fn test_user_context() {
361 let unauthenticated = UserContext::unauthenticated();
362 assert!(!unauthenticated.authenticated);
363
364 let authenticated = UserContext::from_claims(
365 [("email".to_string(), json!("user@example.com"))]
366 .into_iter()
367 .collect(),
368 );
369 assert!(authenticated.authenticated);
370
371 let context = authenticated.to_context();
372 assert_eq!(context.get("authenticated"), Some(&json!(true)));
373 assert_eq!(context.get("email"), Some(&json!("user@example.com")));
374 }
375
376 #[test]
377 fn test_auth_handler_parse_jwt_cookie() {
378 let config = AuthConfig {
379 enabled: true,
380 jwt_cookie_name: "w_token".to_string(),
381 ..Default::default()
382 };
383 let handler = AuthHandler::new(config);
384
385 let cookie_header = Some("w_token=abc123; other_cookie=xyz");
387 let result = handler.parse_jwt_cookie(cookie_header);
388 assert_eq!(result, Some("abc123".to_string()));
389
390 let cookie_header = Some("other=value; w_token=def456");
392 let result = handler.parse_jwt_cookie(cookie_header);
393 assert_eq!(result, Some("def456".to_string()));
394
395 let cookie_header = Some("other_cookie=xyz");
397 let result = handler.parse_jwt_cookie(cookie_header);
398 assert!(result.is_none());
399
400 let result = handler.parse_jwt_cookie(None);
402 assert!(result.is_none());
403 }
404
405 #[test]
406 fn test_auth_handler_is_protected() {
407 let config = AuthConfig {
408 enabled: true,
409 protected_paths: vec![
410 "/admin".to_string(),
411 "/admin/*".to_string(),
412 "/api/**".to_string(),
413 ],
414 ..Default::default()
415 };
416 let handler = AuthHandler::new(config);
417
418 assert!(handler.is_protected("/admin"));
420
421 assert!(handler.is_protected("/admin/users"));
423 assert!(handler.is_protected("/admin/settings"));
424
425 assert!(!handler.is_protected("/admin/users/123"));
427
428 assert!(handler.is_protected("/api/v1"));
430 assert!(handler.is_protected("/api/v1/users"));
431 assert!(handler.is_protected("/api/v1/users/123"));
432
433 assert!(!handler.is_protected("/"));
435 assert!(!handler.is_protected("/public"));
436 assert!(!handler.is_protected("/login"));
437 }
438
439 #[test]
440 fn test_auth_handler_disabled() {
441 let config = AuthConfig {
442 enabled: false,
443 protected_paths: vec!["/admin/**".to_string()],
444 ..Default::default()
445 };
446 let handler = AuthHandler::new(config);
447
448 assert!(!handler.is_protected("/admin"));
450 assert!(!handler.is_protected("/admin/users"));
451 assert!(!handler.is_enabled());
452 }
453
454 #[test]
455 fn test_build_jwt_cookie() {
456 let config = AuthConfig {
457 enabled: true,
458 jwt_cookie_name: "w_token".to_string(),
459 ..Default::default()
460 };
461 let handler = AuthHandler::new(config);
462
463 let cookie = handler.build_jwt_cookie("test_token_123", 3600, false);
465 assert!(cookie.contains("w_token=test_token_123"));
466 assert!(cookie.contains("HttpOnly"));
467 assert!(cookie.contains("SameSite=Strict"));
468 assert!(cookie.contains("Path=/"));
469 assert!(cookie.contains("Max-Age=3600"));
470 assert!(!cookie.contains("Secure"));
471
472 let cookie = handler.build_jwt_cookie("test_token_123", 3600, true);
474 assert!(cookie.contains("Secure"));
475 }
476
477 #[test]
478 fn test_build_clear_cookie() {
479 let config = AuthConfig {
480 enabled: true,
481 jwt_cookie_name: "w_token".to_string(),
482 ..Default::default()
483 };
484 let handler = AuthHandler::new(config);
485
486 let cookie = handler.build_clear_cookie();
487 assert!(cookie.contains("w_token="));
488 assert!(cookie.contains("Max-Age=0"));
489 assert!(cookie.contains("HttpOnly"));
490 assert!(cookie.contains("SameSite=Strict"));
491 assert!(cookie.contains("Path=/"));
492 }
493
494 #[test]
495 fn test_jwt_claims_is_expired() {
496 let future_exp = std::time::SystemTime::now()
498 .duration_since(std::time::UNIX_EPOCH)
499 .unwrap()
500 .as_secs()
501 + 3600; let claims = JwtClaims {
504 exp: Some(future_exp),
505 iat: None,
506 sub: None,
507 custom: HashMap::new(),
508 };
509 assert!(!claims.is_expired());
510
511 let past_exp = std::time::SystemTime::now()
513 .duration_since(std::time::UNIX_EPOCH)
514 .unwrap()
515 .as_secs()
516 - 3600; let expired_claims = JwtClaims {
519 exp: Some(past_exp),
520 iat: None,
521 sub: None,
522 custom: HashMap::new(),
523 };
524 assert!(expired_claims.is_expired());
525
526 let no_exp_claims = JwtClaims {
528 exp: None,
529 iat: None,
530 sub: None,
531 custom: HashMap::new(),
532 };
533 assert!(!no_exp_claims.is_expired());
534 }
535
536 #[test]
537 fn test_decode_jwt_with_configured_secret() {
538 use jsonwebtoken::{EncodingKey, Header, encode};
539
540 let secret = "test_secret_123";
541 let config = AuthConfig {
542 enabled: true,
543 jwt_secret: Some(secret.to_string()),
544 ..Default::default()
545 };
546 let handler = AuthHandler::new(config);
547
548 let exp = std::time::SystemTime::now()
550 .duration_since(std::time::UNIX_EPOCH)
551 .unwrap()
552 .as_secs()
553 + 3600;
554
555 let claims = JwtClaims {
556 exp: Some(exp),
557 iat: None,
558 sub: Some("user1".to_string()),
559 custom: [("email".to_string(), json!("a@b.com"))]
560 .into_iter()
561 .collect(),
562 };
563
564 let token = encode(
565 &Header::default(),
566 &claims,
567 &EncodingKey::from_secret(secret.as_bytes()),
568 )
569 .unwrap();
570 let decoded = handler.decode_jwt(&token).unwrap();
571 assert_eq!(decoded.sub, Some("user1".to_string()));
572 assert_eq!(decoded.custom.get("email"), Some(&json!("a@b.com")));
573 }
574
575 #[test]
576 fn test_decode_jwt_rejects_wrong_secret() {
577 use jsonwebtoken::{EncodingKey, Header, encode};
578
579 let config = AuthConfig {
580 enabled: true,
581 jwt_secret: Some("correct_secret".to_string()),
582 ..Default::default()
583 };
584 let handler = AuthHandler::new(config);
585
586 let exp = std::time::SystemTime::now()
587 .duration_since(std::time::UNIX_EPOCH)
588 .unwrap()
589 .as_secs()
590 + 3600;
591
592 let claims = JwtClaims {
593 exp: Some(exp),
594 iat: None,
595 sub: None,
596 custom: HashMap::new(),
597 };
598
599 let token = encode(
601 &Header::default(),
602 &claims,
603 &EncodingKey::from_secret(b"wrong_secret"),
604 )
605 .unwrap();
606 let result = handler.decode_jwt(&token);
607 assert!(
608 result.is_err(),
609 "Should reject JWT signed with wrong secret"
610 );
611 }
612
613 #[test]
614 fn test_decode_jwt_no_secret_uses_auto_generated() {
615 use jsonwebtoken::{EncodingKey, Header, encode};
618
619 let config = AuthConfig {
620 enabled: true,
621 jwt_secret: None, ..Default::default()
623 };
624 let handler = AuthHandler::new(config);
625
626 let exp = std::time::SystemTime::now()
627 .duration_since(std::time::UNIX_EPOCH)
628 .unwrap()
629 .as_secs()
630 + 3600;
631
632 let claims = JwtClaims {
633 exp: Some(exp),
634 iat: None,
635 sub: None,
636 custom: HashMap::new(),
637 };
638
639 let token = encode(
641 &Header::default(),
642 &claims,
643 &EncodingKey::from_secret(b"attacker_secret"),
644 )
645 .unwrap();
646 let result = handler.decode_jwt(&token);
647 assert!(
648 result.is_err(),
649 "Should reject JWT when no secret is configured (auto-generated secret won't match)"
650 );
651 }
652
653 #[test]
654 fn test_auth_handler_getters() {
655 let config = AuthConfig {
656 enabled: true,
657 login_path: "/login".to_string(),
658 after_login: "/dashboard".to_string(),
659 login_endpoint: Some("https://api.example.com/login".to_string()),
660 logout_endpoint: Some("https://api.example.com/logout".to_string()),
661 jwt_cookie_name: "auth_token".to_string(),
662 jwt_claims: vec!["email".to_string(), "name".to_string()],
663 ..Default::default()
664 };
665 let handler = AuthHandler::new(config);
666
667 assert!(handler.is_enabled());
668 assert_eq!(handler.login_path(), "/login");
669 assert_eq!(handler.after_login_path(), "/dashboard");
670 assert_eq!(
671 handler.login_endpoint(),
672 Some("https://api.example.com/login")
673 );
674 assert_eq!(
675 handler.logout_endpoint(),
676 Some("https://api.example.com/logout")
677 );
678 assert_eq!(handler.jwt_cookie_name(), "auth_token");
679 assert_eq!(
680 handler.jwt_claims(),
681 &["email".to_string(), "name".to_string()]
682 );
683 }
684}