1use std::collections::HashMap;
14use std::time::{Duration, SystemTime, UNIX_EPOCH};
15
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18
19pub use crate::types::{TokenInfo, UserInfo};
21
22#[derive(Debug, Clone)]
24pub struct ValidationConfig {
25 pub issuer: Option<String>,
27 pub audience: Option<String>,
29 pub leeway: Duration,
31 pub validate_exp: bool,
33 pub validate_nbf: bool,
35}
36
37impl Default for ValidationConfig {
38 fn default() -> Self {
39 Self {
40 issuer: None,
41 audience: None,
42 leeway: Duration::from_secs(60), validate_exp: true,
44 validate_nbf: true,
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct AuthContext {
94 pub sub: String,
99
100 #[serde(skip_serializing_if = "Option::is_none")]
102 pub iss: Option<String>,
103
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub aud: Option<String>,
107
108 #[serde(skip_serializing_if = "Option::is_none")]
110 pub exp: Option<u64>,
111
112 #[serde(skip_serializing_if = "Option::is_none")]
114 pub iat: Option<u64>,
115
116 #[serde(skip_serializing_if = "Option::is_none")]
118 pub nbf: Option<u64>,
119
120 #[serde(skip_serializing_if = "Option::is_none")]
122 pub jti: Option<String>,
123
124 pub user: UserInfo,
129
130 #[serde(default)]
132 pub roles: Vec<String>,
133
134 #[serde(default)]
136 pub permissions: Vec<String>,
137
138 #[serde(default)]
140 pub scopes: Vec<String>,
141
142 #[serde(skip_serializing_if = "Option::is_none")]
151 pub request_id: Option<String>,
152
153 #[serde(with = "systemtime_serde")]
155 pub authenticated_at: SystemTime,
156
157 #[serde(
159 default,
160 skip_serializing_if = "Option::is_none",
161 with = "systemtime_serde_opt"
162 )]
163 pub expires_at: Option<SystemTime>,
164
165 #[serde(skip_serializing_if = "Option::is_none")]
167 pub token: Option<TokenInfo>,
168
169 pub provider: String,
171
172 #[cfg(feature = "dpop")]
176 #[serde(skip_serializing_if = "Option::is_none")]
177 pub dpop_jkt: Option<String>,
179
180 #[serde(flatten)]
185 pub metadata: HashMap<String, Value>,
186}
187
188mod systemtime_serde {
193 use super::*;
194 use serde::{Deserializer, Serializer};
195
196 pub fn serialize<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
197 where
198 S: Serializer,
199 {
200 let since_epoch = time
201 .duration_since(UNIX_EPOCH)
202 .map_err(serde::ser::Error::custom)?;
203 serializer.serialize_u64(since_epoch.as_secs())
204 }
205
206 pub fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
207 where
208 D: Deserializer<'de>,
209 {
210 let secs = u64::deserialize(deserializer)?;
211 Ok(UNIX_EPOCH + Duration::from_secs(secs))
212 }
213}
214
215mod systemtime_serde_opt {
216 use super::*;
217 use serde::{Deserializer, Serializer};
218
219 pub fn serialize<S>(time: &Option<SystemTime>, serializer: S) -> Result<S::Ok, S::Error>
220 where
221 S: Serializer,
222 {
223 match time {
224 Some(t) => {
225 let since_epoch = t
226 .duration_since(UNIX_EPOCH)
227 .map_err(serde::ser::Error::custom)?;
228 serializer.serialize_some(&since_epoch.as_secs())
229 }
230 None => serializer.serialize_none(),
231 }
232 }
233
234 pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<SystemTime>, D::Error>
235 where
236 D: Deserializer<'de>,
237 {
238 let opt: Option<u64> = Option::deserialize(deserializer)?;
239 Ok(opt.map(|secs| UNIX_EPOCH + Duration::from_secs(secs)))
240 }
241}
242
243impl AuthContext {
248 pub fn builder() -> AuthContextBuilder {
250 AuthContextBuilder::default()
251 }
252
253 pub fn to_jwt_claims(&self) -> Value {
269 serde_json::to_value(self).expect("AuthContext serialization should never fail")
270 }
271
272 pub fn from_jwt_claims(claims: Value) -> Result<Self, AuthError> {
283 serde_json::from_value(claims).map_err(|e| AuthError::InvalidClaims(e.to_string()))
284 }
285
286 pub fn is_expired(&self) -> bool {
294 if let Some(expires_at) = self.expires_at
296 && SystemTime::now() > expires_at
297 {
298 return true;
299 }
300
301 if let Some(exp) = self.exp {
303 let exp_time = UNIX_EPOCH + Duration::from_secs(exp);
304 if SystemTime::now() > exp_time {
305 return true;
306 }
307 }
308
309 false
310 }
311
312 pub fn validate(&self, config: &ValidationConfig) -> Result<(), AuthError> {
324 let now = SystemTime::now();
325
326 if config.validate_exp
328 && let Some(exp) = self.exp
329 {
330 let exp_time = UNIX_EPOCH + Duration::from_secs(exp);
331 let exp_with_leeway = exp_time + config.leeway;
332 if now > exp_with_leeway {
333 return Err(AuthError::TokenExpired);
334 }
335 }
336
337 if config.validate_nbf
339 && let Some(nbf) = self.nbf
340 {
341 let nbf_time = UNIX_EPOCH + Duration::from_secs(nbf);
342 if nbf_time > now + config.leeway {
343 return Err(AuthError::TokenNotYetValid);
344 }
345 }
346
347 if let Some(ref expected_aud) = config.audience {
349 match &self.aud {
350 Some(aud) if aud == expected_aud => {}
351 _ => return Err(AuthError::InvalidAudience),
352 }
353 }
354
355 if let Some(ref expected_iss) = config.issuer {
357 match &self.iss {
358 Some(iss) if iss == expected_iss => {}
359 _ => return Err(AuthError::InvalidIssuer),
360 }
361 }
362
363 Ok(())
364 }
365
366 pub fn has_role(&self, role: &str) -> bool {
372 self.roles.iter().any(|r| r == role)
373 }
374
375 pub fn has_any_role(&self, roles: &[&str]) -> bool {
377 roles.iter().any(|r| self.has_role(r))
378 }
379
380 pub fn has_all_roles(&self, roles: &[&str]) -> bool {
382 roles.iter().all(|r| self.has_role(r))
383 }
384
385 pub fn has_permission(&self, perm: &str) -> bool {
387 self.permissions.iter().any(|p| p == perm)
388 }
389
390 pub fn has_any_permission(&self, perms: &[&str]) -> bool {
392 perms.iter().any(|p| self.has_permission(p))
393 }
394
395 pub fn has_all_permissions(&self, perms: &[&str]) -> bool {
397 perms.iter().all(|p| self.has_permission(p))
398 }
399
400 pub fn has_scope(&self, scope: &str) -> bool {
402 self.scopes.iter().any(|s| s == scope)
403 }
404
405 pub fn has_any_scope(&self, scopes: &[&str]) -> bool {
407 scopes.iter().any(|s| self.has_scope(s))
408 }
409
410 pub fn has_all_scopes(&self, scopes: &[&str]) -> bool {
412 scopes.iter().all(|s| self.has_scope(s))
413 }
414
415 pub fn get_metadata<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
427 self.metadata
428 .get(key)
429 .and_then(|v| serde_json::from_value(v.clone()).ok())
430 }
431
432 #[cfg(feature = "dpop")]
437 pub fn validate_dpop_proof(&self, proof: &DpopProof) -> Result<(), AuthError> {
441 match &self.dpop_jkt {
442 Some(jkt) if jkt == &proof.jkt => Ok(()),
443 Some(_) => Err(AuthError::DpopMismatch),
444 None => Err(AuthError::DpopRequired),
445 }
446 }
447}
448
449#[derive(Default)]
471pub struct AuthContextBuilder {
472 sub: Option<String>,
473 iss: Option<String>,
474 aud: Option<String>,
475 exp: Option<u64>,
476 iat: Option<u64>,
477 nbf: Option<u64>,
478 jti: Option<String>,
479 user: Option<UserInfo>,
480 roles: Vec<String>,
481 permissions: Vec<String>,
482 scopes: Vec<String>,
483 request_id: Option<String>,
484 authenticated_at: Option<SystemTime>,
485 expires_at: Option<SystemTime>,
486 token: Option<TokenInfo>,
487 provider: Option<String>,
488 #[cfg(feature = "dpop")]
489 dpop_jkt: Option<String>,
490 metadata: HashMap<String, Value>,
491}
492
493impl AuthContextBuilder {
494 pub fn subject(mut self, sub: impl Into<String>) -> Self {
496 self.sub = Some(sub.into());
497 self
498 }
499
500 #[deprecated(
502 since = "2.0.5",
503 note = "Use `.subject()` instead to avoid confusion with std::ops::Sub trait"
504 )]
505 #[allow(clippy::should_implement_trait)]
506 pub fn sub(self, sub: impl Into<String>) -> Self {
507 self.subject(sub)
508 }
509
510 pub fn iss(mut self, iss: impl Into<String>) -> Self {
512 self.iss = Some(iss.into());
513 self
514 }
515
516 pub fn aud(mut self, aud: impl Into<String>) -> Self {
518 self.aud = Some(aud.into());
519 self
520 }
521
522 pub fn exp(mut self, exp: u64) -> Self {
524 self.exp = Some(exp);
525 self
526 }
527
528 pub fn iat(mut self, iat: u64) -> Self {
530 self.iat = Some(iat);
531 self
532 }
533
534 pub fn nbf(mut self, nbf: u64) -> Self {
536 self.nbf = Some(nbf);
537 self
538 }
539
540 pub fn jti(mut self, jti: impl Into<String>) -> Self {
542 self.jti = Some(jti.into());
543 self
544 }
545
546 pub fn user(mut self, user: UserInfo) -> Self {
548 self.user = Some(user);
549 self
550 }
551
552 pub fn roles(mut self, roles: Vec<String>) -> Self {
554 self.roles = roles;
555 self
556 }
557
558 pub fn role(mut self, role: impl Into<String>) -> Self {
560 self.roles.push(role.into());
561 self
562 }
563
564 pub fn permissions(mut self, permissions: Vec<String>) -> Self {
566 self.permissions = permissions;
567 self
568 }
569
570 pub fn permission(mut self, permission: impl Into<String>) -> Self {
572 self.permissions.push(permission.into());
573 self
574 }
575
576 pub fn scopes(mut self, scopes: Vec<String>) -> Self {
578 self.scopes = scopes;
579 self
580 }
581
582 pub fn scope(mut self, scope: impl Into<String>) -> Self {
584 self.scopes.push(scope.into());
585 self
586 }
587
588 pub fn request_id(mut self, request_id: impl Into<String>) -> Self {
594 self.request_id = Some(request_id.into());
595 self
596 }
597
598 pub fn authenticated_at(mut self, authenticated_at: SystemTime) -> Self {
600 self.authenticated_at = Some(authenticated_at);
601 self
602 }
603
604 pub fn expires_at(mut self, expires_at: SystemTime) -> Self {
606 self.expires_at = Some(expires_at);
607 self
608 }
609
610 pub fn token(mut self, token: TokenInfo) -> Self {
612 self.token = Some(token);
613 self
614 }
615
616 pub fn provider(mut self, provider: impl Into<String>) -> Self {
618 self.provider = Some(provider.into());
619 self
620 }
621
622 #[cfg(feature = "dpop")]
624 pub fn dpop_jkt(mut self, jkt: impl Into<String>) -> Self {
625 self.dpop_jkt = Some(jkt.into());
626 self
627 }
628
629 pub fn metadata(mut self, key: impl Into<String>, value: Value) -> Self {
631 self.metadata.insert(key.into(), value);
632 self
633 }
634
635 pub fn build(self) -> Result<AuthContext, AuthError> {
644 let sub = self.sub.ok_or(AuthError::MissingField("sub"))?;
645 let user = self.user.ok_or(AuthError::MissingField("user"))?;
646 let provider = self.provider.ok_or(AuthError::MissingField("provider"))?;
647 let authenticated_at = self.authenticated_at.unwrap_or_else(SystemTime::now);
648
649 Ok(AuthContext {
650 sub,
651 iss: self.iss,
652 aud: self.aud,
653 exp: self.exp,
654 iat: self.iat,
655 nbf: self.nbf,
656 jti: self.jti,
657 user,
658 roles: self.roles,
659 permissions: self.permissions,
660 scopes: self.scopes,
661 request_id: self.request_id,
662 authenticated_at,
663 expires_at: self.expires_at,
664 token: self.token,
665 provider,
666 #[cfg(feature = "dpop")]
667 dpop_jkt: self.dpop_jkt,
668 metadata: self.metadata,
669 })
670 }
671}
672
673#[derive(Debug, thiserror::Error)]
679pub enum AuthError {
680 #[error("Invalid claims: {0}")]
681 InvalidClaims(String),
682
683 #[error("Token expired")]
684 TokenExpired,
685
686 #[error("Token not yet valid")]
687 TokenNotYetValid,
688
689 #[error("Invalid audience")]
690 InvalidAudience,
691
692 #[error("Invalid issuer")]
693 InvalidIssuer,
694
695 #[error("Missing required field: {0}")]
696 MissingField(&'static str),
697
698 #[cfg(feature = "dpop")]
699 #[error("DPoP proof mismatch")]
700 DpopMismatch,
701
702 #[cfg(feature = "dpop")]
703 #[error("DPoP proof required but not provided")]
704 DpopRequired,
705}
706
707#[cfg(feature = "dpop")]
712pub struct DpopProof {
714 pub jkt: String,
716}
717
718#[cfg(test)]
723mod tests {
724 use super::*;
725
726 fn create_test_user() -> UserInfo {
727 UserInfo {
728 id: "user123".to_string(),
729 username: "testuser".to_string(),
730 email: Some("test@example.com".to_string()),
731 display_name: Some("Test User".to_string()),
732 avatar_url: None,
733 metadata: HashMap::new(),
734 }
735 }
736
737 #[test]
738 fn test_builder_minimal() {
739 let user = create_test_user();
740 let ctx = AuthContext::builder()
741 .subject("user123")
742 .user(user)
743 .provider("test")
744 .build()
745 .unwrap();
746
747 assert_eq!(ctx.sub, "user123");
748 assert_eq!(ctx.provider, "test");
749 assert!(ctx.roles.is_empty());
750 assert!(ctx.permissions.is_empty());
751 }
752
753 #[test]
754 fn test_builder_full() {
755 let user = create_test_user();
756 let ctx = AuthContext::builder()
757 .subject("user123")
758 .iss("test-issuer")
759 .aud("test-audience")
760 .user(user)
761 .roles(vec!["admin".to_string(), "user".to_string()])
762 .permissions(vec!["read:posts".to_string()])
763 .scopes(vec!["openid".to_string(), "email".to_string()])
764 .provider("oauth2:test")
765 .build()
766 .unwrap();
767
768 assert_eq!(ctx.sub, "user123");
769 assert_eq!(ctx.iss, Some("test-issuer".to_string()));
770 assert_eq!(ctx.aud, Some("test-audience".to_string()));
771 assert_eq!(ctx.roles.len(), 2);
772 assert_eq!(ctx.permissions.len(), 1);
773 assert_eq!(ctx.scopes.len(), 2);
774 }
775
776 #[test]
777 fn test_is_expired() {
778 let user = create_test_user();
779
780 let future = SystemTime::now() + Duration::from_secs(3600);
782 let ctx = AuthContext::builder()
783 .subject("user123")
784 .user(user.clone())
785 .provider("test")
786 .expires_at(future)
787 .build()
788 .unwrap();
789 assert!(!ctx.is_expired());
790
791 let past = SystemTime::now() - Duration::from_secs(3600);
793 let ctx = AuthContext::builder()
794 .subject("user123")
795 .user(user)
796 .provider("test")
797 .expires_at(past)
798 .build()
799 .unwrap();
800 assert!(ctx.is_expired());
801 }
802
803 #[test]
804 fn test_has_role() {
805 let user = create_test_user();
806 let ctx = AuthContext::builder()
807 .subject("user123")
808 .user(user)
809 .provider("test")
810 .roles(vec!["admin".to_string(), "user".to_string()])
811 .build()
812 .unwrap();
813
814 assert!(ctx.has_role("admin"));
815 assert!(ctx.has_role("user"));
816 assert!(!ctx.has_role("superuser"));
817 }
818
819 #[test]
820 fn test_has_any_role() {
821 let user = create_test_user();
822 let ctx = AuthContext::builder()
823 .subject("user123")
824 .user(user)
825 .provider("test")
826 .roles(vec!["admin".to_string(), "user".to_string()])
827 .build()
828 .unwrap();
829
830 assert!(ctx.has_any_role(&["admin", "superuser"]));
831 assert!(ctx.has_any_role(&["user", "guest"]));
832 assert!(!ctx.has_any_role(&["superuser", "guest"]));
833 }
834
835 #[test]
836 fn test_has_all_roles() {
837 let user = create_test_user();
838 let ctx = AuthContext::builder()
839 .subject("user123")
840 .user(user)
841 .provider("test")
842 .roles(vec!["admin".to_string(), "user".to_string()])
843 .build()
844 .unwrap();
845
846 assert!(ctx.has_all_roles(&["admin", "user"]));
847 assert!(ctx.has_all_roles(&["admin"]));
848 assert!(!ctx.has_all_roles(&["admin", "user", "superuser"]));
849 }
850
851 #[test]
852 fn test_has_permission() {
853 let user = create_test_user();
854 let ctx = AuthContext::builder()
855 .subject("user123")
856 .user(user)
857 .provider("test")
858 .permissions(vec!["read:posts".to_string(), "write:posts".to_string()])
859 .build()
860 .unwrap();
861
862 assert!(ctx.has_permission("read:posts"));
863 assert!(ctx.has_permission("write:posts"));
864 assert!(!ctx.has_permission("delete:posts"));
865 }
866
867 #[test]
868 fn test_has_scope() {
869 let user = create_test_user();
870 let ctx = AuthContext::builder()
871 .subject("user123")
872 .user(user)
873 .provider("test")
874 .scopes(vec!["openid".to_string(), "email".to_string()])
875 .build()
876 .unwrap();
877
878 assert!(ctx.has_scope("openid"));
879 assert!(ctx.has_scope("email"));
880 assert!(!ctx.has_scope("profile"));
881 }
882
883 #[test]
884 fn test_jwt_serialization() {
885 let user = create_test_user();
886 let ctx = AuthContext::builder()
887 .subject("user123")
888 .iss("test-issuer")
889 .user(user)
890 .provider("test")
891 .roles(vec!["admin".to_string()])
892 .build()
893 .unwrap();
894
895 let claims = ctx.to_jwt_claims();
897 assert!(claims.is_object());
898
899 let ctx2 = AuthContext::from_jwt_claims(claims).unwrap();
901 assert_eq!(ctx2.sub, ctx.sub);
902 assert_eq!(ctx2.iss, ctx.iss);
903 assert_eq!(ctx2.roles, ctx.roles);
904 }
905
906 #[test]
907 fn test_validation_expired() {
908 let user = create_test_user();
909 let past_timestamp = SystemTime::now()
910 .duration_since(UNIX_EPOCH)
911 .unwrap()
912 .as_secs()
913 - 3600;
914
915 let ctx = AuthContext::builder()
916 .subject("user123")
917 .user(user)
918 .provider("test")
919 .exp(past_timestamp)
920 .build()
921 .unwrap();
922
923 let config = ValidationConfig::default();
924 let result = ctx.validate(&config);
925 assert!(matches!(result, Err(AuthError::TokenExpired)));
926 }
927
928 #[test]
929 fn test_validation_not_yet_valid() {
930 let user = create_test_user();
931 let future_timestamp = SystemTime::now()
932 .duration_since(UNIX_EPOCH)
933 .unwrap()
934 .as_secs()
935 + 3600;
936
937 let ctx = AuthContext::builder()
938 .subject("user123")
939 .user(user)
940 .provider("test")
941 .nbf(future_timestamp)
942 .build()
943 .unwrap();
944
945 let config = ValidationConfig::default();
946 let result = ctx.validate(&config);
947 assert!(matches!(result, Err(AuthError::TokenNotYetValid)));
948 }
949
950 #[test]
951 fn test_validation_audience() {
952 let user = create_test_user();
953 let ctx = AuthContext::builder()
954 .subject("user123")
955 .user(user)
956 .provider("test")
957 .aud("wrong-audience")
958 .build()
959 .unwrap();
960
961 let config = ValidationConfig {
962 audience: Some("expected-audience".to_string()),
963 ..Default::default()
964 };
965
966 let result = ctx.validate(&config);
967 assert!(matches!(result, Err(AuthError::InvalidAudience)));
968 }
969
970 #[test]
971 fn test_metadata() {
972 let user = create_test_user();
973 let ctx = AuthContext::builder()
974 .subject("user123")
975 .user(user)
976 .provider("test")
977 .metadata("tenant_id", Value::String("tenant123".to_string()))
978 .metadata("org_id", Value::Number(42.into()))
979 .build()
980 .unwrap();
981
982 let tenant_id: Option<String> = ctx.get_metadata("tenant_id");
983 assert_eq!(tenant_id, Some("tenant123".to_string()));
984
985 let org_id: Option<i64> = ctx.get_metadata("org_id");
986 assert_eq!(org_id, Some(42));
987
988 let missing: Option<String> = ctx.get_metadata("missing");
989 assert_eq!(missing, None);
990 }
991}