1use crate::AuthBackend;
4use async_trait::async_trait;
5use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
6use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
7use rusmes_proto::Username;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12use tokio::sync::RwLock;
13
14#[derive(Debug, Clone)]
16pub enum OidcProvider {
17 Google {
19 client_id: String,
20 client_secret: String,
21 },
22 Microsoft {
24 tenant_id: String,
25 client_id: String,
26 client_secret: String,
27 },
28 Generic {
30 issuer_url: String,
31 client_id: String,
32 client_secret: String,
33 jwks_url: String,
34 },
35}
36
37#[derive(Debug, Serialize, Deserialize)]
39struct Claims {
40 sub: String,
41 email: Option<String>,
42 exp: u64,
43 iat: u64,
44 iss: String,
45 aud: String,
46}
47
48#[derive(Debug, Deserialize)]
50#[allow(dead_code)]
51struct IntrospectionResponse {
52 active: bool,
53 #[serde(default)]
54 username: Option<String>,
55 #[serde(default)]
56 email: Option<String>,
57 #[serde(default)]
58 exp: Option<u64>,
59}
60
61#[derive(Debug, Clone, Deserialize)]
63struct Jwks {
64 keys: Vec<Jwk>,
65}
66
67#[derive(Debug, Clone, Deserialize)]
69#[allow(dead_code)]
70struct Jwk {
71 kid: String,
72 kty: String,
73 #[serde(rename = "use")]
74 key_use: Option<String>,
75 alg: Option<String>,
76 n: Option<String>,
77 e: Option<String>,
78}
79
80#[derive(Debug, Clone)]
82#[allow(dead_code)]
83struct TokenCacheEntry {
84 username: String,
85 expires_at: SystemTime,
86}
87
88#[derive(Debug, Clone)]
90pub struct OAuth2Config {
91 pub provider: OidcProvider,
93 pub introspection_endpoint: Option<String>,
95 pub jwks_cache_ttl: u64,
97 pub enable_refresh_tokens: bool,
99 pub allowed_algorithms: Vec<Algorithm>,
101}
102
103impl Default for OAuth2Config {
104 fn default() -> Self {
105 Self {
106 provider: OidcProvider::Generic {
107 issuer_url: "https://example.com".to_string(),
108 client_id: "client-id".to_string(),
109 client_secret: "client-secret".to_string(),
110 jwks_url: "https://example.com/.well-known/jwks.json".to_string(),
111 },
112 introspection_endpoint: None,
113 jwks_cache_ttl: 3600,
114 enable_refresh_tokens: true,
115 allowed_algorithms: vec![Algorithm::RS256],
116 }
117 }
118}
119
120pub struct OAuth2Backend {
122 config: OAuth2Config,
123 token_cache: Arc<RwLock<HashMap<String, TokenCacheEntry>>>,
124 jwks_cache: Arc<RwLock<Option<(Jwks, SystemTime)>>>,
125 client: reqwest::Client,
126}
127
128impl OAuth2Backend {
129 pub fn new(config: OAuth2Config) -> Self {
131 Self {
132 config,
133 token_cache: Arc::new(RwLock::new(HashMap::new())),
134 jwks_cache: Arc::new(RwLock::new(None)),
135 client: reqwest::Client::new(),
136 }
137 }
138
139 pub fn parse_xoauth2_response(response: &str) -> anyhow::Result<(String, String)> {
143 let decoded = BASE64
145 .decode(response.as_bytes())
146 .map_err(|e| anyhow::anyhow!("Failed to decode XOAUTH2 response: {}", e))?;
147
148 let decoded_str = String::from_utf8(decoded)
149 .map_err(|e| anyhow::anyhow!("Invalid UTF-8 in XOAUTH2 response: {}", e))?;
150
151 let parts: Vec<&str> = decoded_str.split('\x01').collect();
153
154 let mut username = None;
156 let mut token = None;
157
158 for part in &parts {
159 if part.starts_with("user=") {
160 username = part.strip_prefix("user=").map(|s| s.to_string());
161 } else if part.starts_with("auth=Bearer ") {
162 token = part.strip_prefix("auth=Bearer ").map(|s| s.to_string());
163 }
164 }
165
166 let username = username.ok_or_else(|| anyhow::anyhow!("Missing username in XOAUTH2"))?;
167 let token = token.ok_or_else(|| anyhow::anyhow!("Missing token in XOAUTH2"))?;
168
169 Ok((username, token))
170 }
171
172 #[allow(dead_code)]
176 pub fn encode_xoauth2_response(username: &str, token: &str) -> String {
177 let response = format!("user={}\x01auth=Bearer {}\x01\x01", username, token);
178 BASE64.encode(response.as_bytes())
179 }
180
181 pub async fn cleanup_expired_tokens(&self) {
183 let mut cache = self.token_cache.write().await;
184 let now = SystemTime::now();
185 cache.retain(|_, entry| entry.expires_at > now);
186 }
187
188 #[allow(dead_code)]
190 pub async fn token_cache_size(&self) -> usize {
191 let cache = self.token_cache.read().await;
192 cache.len()
193 }
194
195 #[allow(dead_code)]
197 pub async fn invalidate_token(&self, username: &str) {
198 let mut cache = self.token_cache.write().await;
199 cache.remove(username);
200 }
201
202 #[allow(dead_code)]
204 pub async fn clear_jwks_cache(&self) {
205 let mut cache = self.jwks_cache.write().await;
206 *cache = None;
207 }
208
209 async fn get_jwks(&self) -> anyhow::Result<Jwks> {
211 {
213 let cache = self.jwks_cache.read().await;
214 if let Some((jwks, cached_at)) = &*cache {
215 if cached_at.elapsed().unwrap_or(Duration::MAX).as_secs()
216 < self.config.jwks_cache_ttl
217 {
218 return Ok(jwks.clone());
219 }
220 }
221 }
222
223 let jwks_url = match &self.config.provider {
225 OidcProvider::Google { .. } => "https://www.googleapis.com/oauth2/v3/certs",
226 OidcProvider::Microsoft { tenant_id, .. } => &format!(
227 "https://login.microsoftonline.com/{}/discovery/v2.0/keys",
228 tenant_id
229 ),
230 OidcProvider::Generic { jwks_url, .. } => jwks_url.as_str(),
231 };
232
233 let jwks: Jwks = self.client.get(jwks_url).send().await?.json().await?;
234
235 {
237 let mut cache = self.jwks_cache.write().await;
238 *cache = Some((jwks.clone(), SystemTime::now()));
239 }
240
241 Ok(jwks)
242 }
243
244 async fn validate_jwt(&self, token: &str) -> anyhow::Result<Claims> {
246 let header = decode_header(token)?;
248 let kid = header
249 .kid
250 .ok_or_else(|| anyhow::anyhow!("No kid in JWT header"))?;
251
252 let jwks = self.get_jwks().await?;
254
255 let jwk = jwks
257 .keys
258 .iter()
259 .find(|k| k.kid == kid)
260 .ok_or_else(|| anyhow::anyhow!("No matching key found in JWKS"))?;
261
262 let n = jwk
264 .n
265 .as_ref()
266 .ok_or_else(|| anyhow::anyhow!("Missing n in JWK"))?;
267 let e = jwk
268 .e
269 .as_ref()
270 .ok_or_else(|| anyhow::anyhow!("Missing e in JWK"))?;
271
272 let n_bytes = BASE64.decode(n)?;
273 let e_bytes = BASE64.decode(e)?;
274
275 let decoding_key = DecodingKey::from_rsa_raw_components(&n_bytes, &e_bytes);
278
279 let mut validation = Validation::new(Algorithm::RS256);
281 validation.algorithms = self.config.allowed_algorithms.clone();
282
283 let expected_aud = match &self.config.provider {
285 OidcProvider::Google { client_id, .. } => client_id.clone(),
286 OidcProvider::Microsoft { client_id, .. } => client_id.clone(),
287 OidcProvider::Generic { client_id, .. } => client_id.clone(),
288 };
289 validation.set_audience(&[&expected_aud]);
290
291 let token_data = decode::<Claims>(token, &decoding_key, &validation)?;
292
293 Ok(token_data.claims)
294 }
295
296 async fn introspect_token(&self, token: &str) -> anyhow::Result<IntrospectionResponse> {
298 let endpoint = self
299 .config
300 .introspection_endpoint
301 .as_ref()
302 .ok_or_else(|| anyhow::anyhow!("Token introspection endpoint not configured"))?;
303
304 let (client_id, client_secret) = match &self.config.provider {
305 OidcProvider::Google {
306 client_id,
307 client_secret,
308 } => (client_id, client_secret),
309 OidcProvider::Microsoft {
310 client_id,
311 client_secret,
312 ..
313 } => (client_id, client_secret),
314 OidcProvider::Generic {
315 client_id,
316 client_secret,
317 ..
318 } => (client_id, client_secret),
319 };
320
321 let mut params = HashMap::new();
322 params.insert("token", token);
323 params.insert("client_id", client_id);
324 params.insert("client_secret", client_secret);
325
326 let response = self
327 .client
328 .post(endpoint)
329 .form(¶ms)
330 .send()
331 .await?
332 .json::<IntrospectionResponse>()
333 .await?;
334
335 Ok(response)
336 }
337
338 async fn xoauth2_authenticate(&self, token: &str) -> anyhow::Result<String> {
340 if let Ok(claims) = self.validate_jwt(token).await {
342 return Ok(claims.email.or(Some(claims.sub)).unwrap_or_default());
343 }
344
345 let introspection = self.introspect_token(token).await?;
347
348 if !introspection.active {
349 return Err(anyhow::anyhow!("Token is not active"));
350 }
351
352 introspection
353 .email
354 .or(introspection.username)
355 .ok_or_else(|| anyhow::anyhow!("No username in token"))
356 }
357
358 #[allow(dead_code)]
360 async fn refresh_token(&self, refresh_token: &str) -> anyhow::Result<String> {
361 if !self.config.enable_refresh_tokens {
362 return Err(anyhow::anyhow!("Refresh tokens not enabled"));
363 }
364
365 let token_endpoint = match &self.config.provider {
366 OidcProvider::Google { .. } => "https://oauth2.googleapis.com/token",
367 OidcProvider::Microsoft { tenant_id, .. } => &format!(
368 "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
369 tenant_id
370 ),
371 OidcProvider::Generic { issuer_url, .. } => &format!("{}/token", issuer_url),
372 };
373
374 let (client_id, client_secret) = match &self.config.provider {
375 OidcProvider::Google {
376 client_id,
377 client_secret,
378 } => (client_id, client_secret),
379 OidcProvider::Microsoft {
380 client_id,
381 client_secret,
382 ..
383 } => (client_id, client_secret),
384 OidcProvider::Generic {
385 client_id,
386 client_secret,
387 ..
388 } => (client_id, client_secret),
389 };
390
391 let mut params = HashMap::new();
392 params.insert("grant_type", "refresh_token");
393 params.insert("refresh_token", refresh_token);
394 params.insert("client_id", client_id);
395 params.insert("client_secret", client_secret);
396
397 #[derive(Deserialize)]
398 struct TokenResponse {
399 access_token: String,
400 }
401
402 let response = self
403 .client
404 .post(token_endpoint)
405 .form(¶ms)
406 .send()
407 .await?
408 .json::<TokenResponse>()
409 .await?;
410
411 Ok(response.access_token)
412 }
413}
414
415#[async_trait]
416impl AuthBackend for OAuth2Backend {
417 async fn authenticate(&self, username: &Username, password: &str) -> anyhow::Result<bool> {
418 let token = password;
420
421 {
423 let cache = self.token_cache.read().await;
424 if let Some(entry) = cache.get(&username.to_string()) {
425 if SystemTime::now() < entry.expires_at {
426 return Ok(true);
427 }
428 }
429 }
430
431 match self.xoauth2_authenticate(token).await {
433 Ok(token_username) => {
434 if token_username == username.to_string() {
435 let mut cache = self.token_cache.write().await;
437 cache.insert(
438 username.to_string(),
439 TokenCacheEntry {
440 username: token_username,
441 expires_at: SystemTime::now() + Duration::from_secs(300),
442 },
443 );
444 Ok(true)
445 } else {
446 Ok(false)
447 }
448 }
449 Err(_) => Ok(false),
450 }
451 }
452
453 async fn verify_identity(&self, username: &Username) -> anyhow::Result<bool> {
454 let cache = self.token_cache.read().await;
456 Ok(cache.contains_key(&username.to_string()))
457 }
458
459 async fn list_users(&self) -> anyhow::Result<Vec<Username>> {
460 let cache = self.token_cache.read().await;
462 Ok(cache
463 .keys()
464 .filter_map(|k| Username::new(k.clone()).ok())
465 .collect())
466 }
467
468 async fn create_user(&self, _username: &Username, _password: &str) -> anyhow::Result<()> {
469 Err(anyhow::anyhow!(
470 "OAuth2 backend does not support user creation (external provider)"
471 ))
472 }
473
474 async fn delete_user(&self, _username: &Username) -> anyhow::Result<()> {
475 Err(anyhow::anyhow!(
476 "OAuth2 backend does not support user deletion (external provider)"
477 ))
478 }
479
480 async fn change_password(
481 &self,
482 _username: &Username,
483 _new_password: &str,
484 ) -> anyhow::Result<()> {
485 Err(anyhow::anyhow!(
486 "OAuth2 backend does not support password changes (external provider)"
487 ))
488 }
489
490 async fn verify_bearer_token(&self, token: &str) -> anyhow::Result<Username> {
496 let raw = self.xoauth2_authenticate(token).await?;
497 let username = Username::new(raw)
498 .map_err(|e| anyhow::anyhow!("Bearer token resolved to invalid username: {}", e))?;
499 Ok(username)
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
512 fn test_oauth2_config_default() {
513 let config = OAuth2Config::default();
514 assert_eq!(config.jwks_cache_ttl, 3600);
515 assert!(config.enable_refresh_tokens);
516 assert_eq!(config.allowed_algorithms.len(), 1);
517 }
518
519 #[test]
520 fn test_oauth2_config_google() {
521 let config = OAuth2Config {
522 provider: OidcProvider::Google {
523 client_id: "test-client-id".to_string(),
524 client_secret: "test-secret".to_string(),
525 },
526 ..Default::default()
527 };
528 assert!(matches!(config.provider, OidcProvider::Google { .. }));
529 }
530
531 #[test]
532 fn test_oauth2_config_microsoft() {
533 let config = OAuth2Config {
534 provider: OidcProvider::Microsoft {
535 tenant_id: "test-tenant".to_string(),
536 client_id: "test-client".to_string(),
537 client_secret: "test-secret".to_string(),
538 },
539 ..Default::default()
540 };
541 assert!(matches!(config.provider, OidcProvider::Microsoft { .. }));
542 }
543
544 #[test]
545 fn test_oauth2_config_generic() {
546 let config = OAuth2Config {
547 provider: OidcProvider::Generic {
548 issuer_url: "https://oidc.example.com".to_string(),
549 client_id: "client".to_string(),
550 client_secret: "secret".to_string(),
551 jwks_url: "https://oidc.example.com/jwks".to_string(),
552 },
553 ..Default::default()
554 };
555 assert!(matches!(config.provider, OidcProvider::Generic { .. }));
556 }
557
558 #[test]
559 fn test_allowed_algorithms() {
560 let config = OAuth2Config {
561 allowed_algorithms: vec![Algorithm::RS256, Algorithm::RS384, Algorithm::RS512],
562 ..Default::default()
563 };
564 assert_eq!(config.allowed_algorithms.len(), 3);
565 }
566
567 #[test]
568 fn test_introspection_endpoint_optional() {
569 let config = OAuth2Config::default();
570 assert!(config.introspection_endpoint.is_none());
571
572 let config_with_introspection = OAuth2Config {
573 introspection_endpoint: Some("https://example.com/introspect".to_string()),
574 ..Default::default()
575 };
576 assert!(config_with_introspection.introspection_endpoint.is_some());
577 }
578
579 #[test]
580 fn test_refresh_tokens_enabled() {
581 let config = OAuth2Config {
582 enable_refresh_tokens: true,
583 ..Default::default()
584 };
585 assert!(config.enable_refresh_tokens);
586
587 let config_disabled = OAuth2Config {
588 enable_refresh_tokens: false,
589 ..Default::default()
590 };
591 assert!(!config_disabled.enable_refresh_tokens);
592 }
593
594 #[test]
595 fn test_jwks_cache_ttl() {
596 let config = OAuth2Config {
597 jwks_cache_ttl: 7200,
598 ..Default::default()
599 };
600 assert_eq!(config.jwks_cache_ttl, 7200);
601 }
602
603 #[test]
604 fn test_config_clone() {
605 let config = OAuth2Config::default();
606 let cloned = config.clone();
607 assert_eq!(config.jwks_cache_ttl, cloned.jwks_cache_ttl);
608 }
609
610 #[tokio::test]
615 async fn test_oauth2_backend_creation() {
616 let config = OAuth2Config::default();
617 let backend = OAuth2Backend::new(config);
618 let cache = backend.token_cache.read().await;
619 assert_eq!(cache.len(), 0);
620 }
621
622 #[tokio::test]
623 async fn test_token_cache_empty_on_creation() {
624 let backend = OAuth2Backend::new(OAuth2Config::default());
625 let cache = backend.token_cache.read().await;
626 assert!(cache.is_empty());
627 }
628
629 #[tokio::test]
630 async fn test_jwks_cache_empty_on_creation() {
631 let backend = OAuth2Backend::new(OAuth2Config::default());
632 let cache = backend.jwks_cache.read().await;
633 assert!(cache.is_none());
634 }
635
636 #[tokio::test]
641 async fn test_create_user_not_supported() {
642 let backend = OAuth2Backend::new(OAuth2Config::default());
643 let username = Username::new("user@example.com".to_string()).unwrap();
644 let result = backend.create_user(&username, "token").await;
645 assert!(result.is_err());
646 assert!(result
647 .unwrap_err()
648 .to_string()
649 .contains("external provider"));
650 }
651
652 #[tokio::test]
653 async fn test_delete_user_not_supported() {
654 let backend = OAuth2Backend::new(OAuth2Config::default());
655 let username = Username::new("user@example.com".to_string()).unwrap();
656 let result = backend.delete_user(&username).await;
657 assert!(result.is_err());
658 assert!(result
659 .unwrap_err()
660 .to_string()
661 .contains("external provider"));
662 }
663
664 #[tokio::test]
665 async fn test_change_password_not_supported() {
666 let backend = OAuth2Backend::new(OAuth2Config::default());
667 let username = Username::new("user@example.com".to_string()).unwrap();
668 let result = backend.change_password(&username, "newtoken").await;
669 assert!(result.is_err());
670 assert!(result
671 .unwrap_err()
672 .to_string()
673 .contains("external provider"));
674 }
675
676 #[tokio::test]
677 async fn test_list_users_empty() {
678 let backend = OAuth2Backend::new(OAuth2Config::default());
679 let users = backend.list_users().await.unwrap();
680 assert_eq!(users.len(), 0);
681 }
682
683 #[tokio::test]
684 async fn test_verify_identity_not_cached() {
685 let backend = OAuth2Backend::new(OAuth2Config::default());
686 let username = Username::new("user@example.com".to_string()).unwrap();
687 let verified = backend.verify_identity(&username).await.unwrap();
688 assert!(!verified);
689 }
690
691 #[tokio::test]
692 async fn test_verify_identity_cached() {
693 let backend = OAuth2Backend::new(OAuth2Config::default());
694 let username = Username::new("cached@example.com".to_string()).unwrap();
695
696 {
697 let mut cache = backend.token_cache.write().await;
698 cache.insert(
699 username.to_string(),
700 TokenCacheEntry {
701 username: username.to_string(),
702 expires_at: SystemTime::now() + Duration::from_secs(300),
703 },
704 );
705 }
706
707 let verified = backend.verify_identity(&username).await.unwrap();
708 assert!(verified);
709 }
710
711 #[tokio::test]
716 async fn test_token_cache_insertion() {
717 let backend = OAuth2Backend::new(OAuth2Config::default());
718
719 {
720 let mut cache = backend.token_cache.write().await;
721 cache.insert(
722 "user@example.com".to_string(),
723 TokenCacheEntry {
724 username: "user@example.com".to_string(),
725 expires_at: SystemTime::now() + Duration::from_secs(300),
726 },
727 );
728 }
729
730 let cache = backend.token_cache.read().await;
731 assert_eq!(cache.len(), 1);
732 assert!(cache.contains_key("user@example.com"));
733 }
734
735 #[tokio::test]
736 async fn test_token_cache_expiration() {
737 let backend = OAuth2Backend::new(OAuth2Config::default());
738
739 {
740 let mut cache = backend.token_cache.write().await;
741 cache.insert(
742 "expired@example.com".to_string(),
743 TokenCacheEntry {
744 username: "expired@example.com".to_string(),
745 expires_at: SystemTime::now() - Duration::from_secs(1),
746 },
747 );
748 }
749
750 let cache = backend.token_cache.read().await;
752 let entry = cache.get("expired@example.com").unwrap();
753 assert!(entry.expires_at < SystemTime::now());
754 }
755
756 #[tokio::test]
757 async fn test_token_cache_multiple_users() {
758 let backend = OAuth2Backend::new(OAuth2Config::default());
759
760 {
761 let mut cache = backend.token_cache.write().await;
762 for i in 1..=5 {
763 cache.insert(
764 format!("user{}@example.com", i),
765 TokenCacheEntry {
766 username: format!("user{}@example.com", i),
767 expires_at: SystemTime::now() + Duration::from_secs(300),
768 },
769 );
770 }
771 }
772
773 let cache = backend.token_cache.read().await;
774 assert_eq!(cache.len(), 5);
775 }
776
777 #[tokio::test]
778 async fn test_list_users_with_cached_tokens() {
779 let backend = OAuth2Backend::new(OAuth2Config::default());
780
781 {
782 let mut cache = backend.token_cache.write().await;
783 cache.insert(
784 "user1@example.com".to_string(),
785 TokenCacheEntry {
786 username: "user1@example.com".to_string(),
787 expires_at: SystemTime::now() + Duration::from_secs(300),
788 },
789 );
790 cache.insert(
791 "user2@example.com".to_string(),
792 TokenCacheEntry {
793 username: "user2@example.com".to_string(),
794 expires_at: SystemTime::now() + Duration::from_secs(300),
795 },
796 );
797 }
798
799 let users = backend.list_users().await.unwrap();
800 assert_eq!(users.len(), 2);
801 }
802
803 #[test]
808 fn test_claims_structure() {
809 let claims = Claims {
810 sub: "user123".to_string(),
811 email: Some("user@example.com".to_string()),
812 exp: 1234567890,
813 iat: 1234567800,
814 iss: "https://accounts.google.com".to_string(),
815 aud: "client-id".to_string(),
816 };
817 assert_eq!(claims.sub, "user123");
818 assert_eq!(claims.email.unwrap(), "user@example.com");
819 }
820
821 #[test]
822 fn test_claims_without_email() {
823 let claims = Claims {
824 sub: "user123".to_string(),
825 email: None,
826 exp: 1234567890,
827 iat: 1234567800,
828 iss: "https://accounts.google.com".to_string(),
829 aud: "client-id".to_string(),
830 };
831 assert_eq!(claims.sub, "user123");
832 assert!(claims.email.is_none());
833 }
834
835 #[test]
840 fn test_token_cache_entry() {
841 let entry = TokenCacheEntry {
842 username: "user@example.com".to_string(),
843 expires_at: SystemTime::now() + Duration::from_secs(300),
844 };
845 assert_eq!(entry.username, "user@example.com");
846 assert!(entry.expires_at > SystemTime::now());
847 }
848
849 #[test]
850 fn test_token_cache_entry_expired() {
851 let entry = TokenCacheEntry {
852 username: "user@example.com".to_string(),
853 expires_at: SystemTime::now() - Duration::from_secs(10),
854 };
855 assert!(entry.expires_at < SystemTime::now());
856 }
857
858 #[test]
863 fn test_google_provider_config() {
864 let provider = OidcProvider::Google {
865 client_id: "google-client-id".to_string(),
866 client_secret: "google-secret".to_string(),
867 };
868
869 if let OidcProvider::Google { client_id, .. } = &provider {
870 assert_eq!(client_id, "google-client-id");
871 } else {
872 panic!("Expected Google provider");
873 }
874 }
875
876 #[test]
877 fn test_microsoft_provider_config() {
878 let provider = OidcProvider::Microsoft {
879 tenant_id: "tenant-123".to_string(),
880 client_id: "ms-client-id".to_string(),
881 client_secret: "ms-secret".to_string(),
882 };
883
884 if let OidcProvider::Microsoft { tenant_id, .. } = &provider {
885 assert_eq!(tenant_id, "tenant-123");
886 } else {
887 panic!("Expected Microsoft provider");
888 }
889 }
890
891 #[test]
892 fn test_generic_provider_config() {
893 let provider = OidcProvider::Generic {
894 issuer_url: "https://auth.example.com".to_string(),
895 client_id: "generic-client".to_string(),
896 client_secret: "generic-secret".to_string(),
897 jwks_url: "https://auth.example.com/.well-known/jwks.json".to_string(),
898 };
899
900 if let OidcProvider::Generic { issuer_url, .. } = &provider {
901 assert_eq!(issuer_url, "https://auth.example.com");
902 } else {
903 panic!("Expected Generic provider");
904 }
905 }
906
907 #[test]
912 fn test_multiple_allowed_algorithms() {
913 let config = OAuth2Config {
914 allowed_algorithms: vec![
915 Algorithm::RS256,
916 Algorithm::RS384,
917 Algorithm::RS512,
918 Algorithm::ES256,
919 ],
920 ..Default::default()
921 };
922 assert_eq!(config.allowed_algorithms.len(), 4);
923 assert!(config.allowed_algorithms.contains(&Algorithm::RS256));
924 assert!(config.allowed_algorithms.contains(&Algorithm::ES256));
925 }
926
927 #[test]
928 fn test_single_algorithm_rs256() {
929 let config = OAuth2Config {
930 allowed_algorithms: vec![Algorithm::RS256],
931 ..Default::default()
932 };
933 assert_eq!(config.allowed_algorithms.len(), 1);
934 assert_eq!(config.allowed_algorithms[0], Algorithm::RS256);
935 }
936
937 #[test]
942 fn test_jwks_structure() {
943 let jwks = Jwks { keys: vec![] };
944 assert_eq!(jwks.keys.len(), 0);
945 }
946
947 #[test]
948 fn test_jwk_structure() {
949 let jwk = Jwk {
950 kid: "key-1".to_string(),
951 kty: "RSA".to_string(),
952 key_use: Some("sig".to_string()),
953 alg: Some("RS256".to_string()),
954 n: Some("modulus".to_string()),
955 e: Some("AQAB".to_string()),
956 };
957 assert_eq!(jwk.kid, "key-1");
958 assert_eq!(jwk.kty, "RSA");
959 }
960
961 #[tokio::test]
966 async fn test_authenticate_empty_token() {
967 let backend = OAuth2Backend::new(OAuth2Config::default());
968 let username = Username::new("user@example.com".to_string()).unwrap();
969 let result = backend.authenticate(&username, "").await;
970 assert!(result.is_ok());
971 assert!(!result.unwrap());
972 }
973
974 #[tokio::test]
975 async fn test_authenticate_invalid_token() {
976 let backend = OAuth2Backend::new(OAuth2Config::default());
977 let username = Username::new("user@example.com".to_string()).unwrap();
978 let result = backend.authenticate(&username, "invalid-token").await;
979 assert!(result.is_ok());
980 assert!(!result.unwrap());
981 }
982
983 #[test]
984 fn test_config_with_all_options() {
985 let config = OAuth2Config {
986 provider: OidcProvider::Google {
987 client_id: "client".to_string(),
988 client_secret: "secret".to_string(),
989 },
990 introspection_endpoint: Some("https://oauth.example.com/introspect".to_string()),
991 jwks_cache_ttl: 1800,
992 enable_refresh_tokens: false,
993 allowed_algorithms: vec![Algorithm::RS256, Algorithm::RS384],
994 };
995
996 assert!(config.introspection_endpoint.is_some());
997 assert_eq!(config.jwks_cache_ttl, 1800);
998 assert!(!config.enable_refresh_tokens);
999 assert_eq!(config.allowed_algorithms.len(), 2);
1000 }
1001
1002 #[tokio::test]
1007 async fn test_verify_identity_invalid_username() {
1008 let backend = OAuth2Backend::new(OAuth2Config::default());
1009 let username = Username::new("nonexistent@example.com".to_string()).unwrap();
1011 let result = backend.verify_identity(&username).await;
1012 assert!(result.is_ok());
1013 assert!(!result.unwrap());
1014 }
1015
1016 #[tokio::test]
1021 async fn test_concurrent_cache_access() {
1022 let backend = Arc::new(OAuth2Backend::new(OAuth2Config::default()));
1023
1024 let mut handles = vec![];
1025 for i in 0..10 {
1026 let backend = Arc::clone(&backend);
1027 let handle = tokio::spawn(async move {
1028 let mut cache = backend.token_cache.write().await;
1029 cache.insert(
1030 format!("user{}@example.com", i),
1031 TokenCacheEntry {
1032 username: format!("user{}@example.com", i),
1033 expires_at: SystemTime::now() + Duration::from_secs(300),
1034 },
1035 );
1036 });
1037 handles.push(handle);
1038 }
1039
1040 for handle in handles {
1041 handle.await.unwrap();
1042 }
1043
1044 let cache = backend.token_cache.read().await;
1045 assert_eq!(cache.len(), 10);
1046 }
1047
1048 #[tokio::test]
1053 async fn test_introspect_without_endpoint() {
1054 let backend = OAuth2Backend::new(OAuth2Config::default());
1055 let result = backend.introspect_token("test-token").await;
1056 assert!(result.is_err());
1057 assert!(result.unwrap_err().to_string().contains("not configured"));
1058 }
1059
1060 #[tokio::test]
1061 async fn test_refresh_token_disabled() {
1062 let config = OAuth2Config {
1063 enable_refresh_tokens: false,
1064 ..Default::default()
1065 };
1066 let backend = OAuth2Backend::new(config);
1067 let result = backend.refresh_token("refresh-token").await;
1068 assert!(result.is_err());
1069 assert!(result.unwrap_err().to_string().contains("not enabled"));
1070 }
1071
1072 #[test]
1077 fn test_parse_xoauth2_response_valid() {
1078 let response =
1079 OAuth2Backend::encode_xoauth2_response("user@example.com", "ya29.a0AfH6SMBx...");
1080 let result = OAuth2Backend::parse_xoauth2_response(&response);
1081 assert!(result.is_ok());
1082 let (username, token) = result.unwrap();
1083 assert_eq!(username, "user@example.com");
1084 assert_eq!(token, "ya29.a0AfH6SMBx...");
1085 }
1086
1087 #[test]
1088 fn test_encode_xoauth2_response() {
1089 let encoded = OAuth2Backend::encode_xoauth2_response("test@example.com", "token123");
1090 assert!(!encoded.is_empty());
1091
1092 let (username, token) = OAuth2Backend::parse_xoauth2_response(&encoded).unwrap();
1094 assert_eq!(username, "test@example.com");
1095 assert_eq!(token, "token123");
1096 }
1097
1098 #[test]
1099 fn test_parse_xoauth2_response_invalid_base64() {
1100 let result = OAuth2Backend::parse_xoauth2_response("not-valid-base64!");
1101 assert!(result.is_err());
1102 assert!(result.unwrap_err().to_string().contains("decode"));
1103 }
1104
1105 #[test]
1106 fn test_parse_xoauth2_response_missing_username() {
1107 let invalid = BASE64.encode(b"auth=Bearer token123\x01\x01");
1109 let result = OAuth2Backend::parse_xoauth2_response(&invalid);
1110 assert!(result.is_err());
1111 assert!(result.unwrap_err().to_string().contains("username"));
1112 }
1113
1114 #[test]
1115 fn test_parse_xoauth2_response_missing_token() {
1116 let invalid = BASE64.encode(b"user=test@example.com\x01\x01");
1118 let result = OAuth2Backend::parse_xoauth2_response(&invalid);
1119 assert!(result.is_err());
1120 assert!(result.unwrap_err().to_string().contains("token"));
1121 }
1122
1123 #[test]
1124 fn test_xoauth2_round_trip() {
1125 let original_username = "roundtrip@example.com";
1126 let original_token = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9...";
1127
1128 let encoded = OAuth2Backend::encode_xoauth2_response(original_username, original_token);
1129 let (decoded_username, decoded_token) =
1130 OAuth2Backend::parse_xoauth2_response(&encoded).unwrap();
1131
1132 assert_eq!(decoded_username, original_username);
1133 assert_eq!(decoded_token, original_token);
1134 }
1135
1136 #[test]
1137 fn test_xoauth2_special_characters() {
1138 let username = "user+tag@example.com";
1139 let token = "token-with-special_chars.123";
1140
1141 let encoded = OAuth2Backend::encode_xoauth2_response(username, token);
1142 let (decoded_username, decoded_token) =
1143 OAuth2Backend::parse_xoauth2_response(&encoded).unwrap();
1144
1145 assert_eq!(decoded_username, username);
1146 assert_eq!(decoded_token, token);
1147 }
1148
1149 #[tokio::test]
1154 async fn test_cleanup_expired_tokens() {
1155 let backend = OAuth2Backend::new(OAuth2Config::default());
1156
1157 {
1158 let mut cache = backend.token_cache.write().await;
1159 cache.insert(
1161 "expired@example.com".to_string(),
1162 TokenCacheEntry {
1163 username: "expired@example.com".to_string(),
1164 expires_at: SystemTime::now() - Duration::from_secs(10),
1165 },
1166 );
1167 cache.insert(
1169 "valid@example.com".to_string(),
1170 TokenCacheEntry {
1171 username: "valid@example.com".to_string(),
1172 expires_at: SystemTime::now() + Duration::from_secs(300),
1173 },
1174 );
1175 }
1176
1177 backend.cleanup_expired_tokens().await;
1178
1179 let cache = backend.token_cache.read().await;
1180 assert_eq!(cache.len(), 1);
1181 assert!(cache.contains_key("valid@example.com"));
1182 assert!(!cache.contains_key("expired@example.com"));
1183 }
1184
1185 #[tokio::test]
1186 async fn test_token_cache_size() {
1187 let backend = OAuth2Backend::new(OAuth2Config::default());
1188
1189 {
1190 let mut cache = backend.token_cache.write().await;
1191 for i in 1..=3 {
1192 cache.insert(
1193 format!("user{}@example.com", i),
1194 TokenCacheEntry {
1195 username: format!("user{}@example.com", i),
1196 expires_at: SystemTime::now() + Duration::from_secs(300),
1197 },
1198 );
1199 }
1200 }
1201
1202 let size = backend.token_cache_size().await;
1203 assert_eq!(size, 3);
1204 }
1205
1206 #[tokio::test]
1207 async fn test_invalidate_token() {
1208 let backend = OAuth2Backend::new(OAuth2Config::default());
1209
1210 {
1211 let mut cache = backend.token_cache.write().await;
1212 cache.insert(
1213 "user@example.com".to_string(),
1214 TokenCacheEntry {
1215 username: "user@example.com".to_string(),
1216 expires_at: SystemTime::now() + Duration::from_secs(300),
1217 },
1218 );
1219 }
1220
1221 assert_eq!(backend.token_cache_size().await, 1);
1222
1223 backend.invalidate_token("user@example.com").await;
1224
1225 assert_eq!(backend.token_cache_size().await, 0);
1226 }
1227
1228 #[tokio::test]
1229 async fn test_clear_jwks_cache() {
1230 let backend = OAuth2Backend::new(OAuth2Config::default());
1231
1232 {
1233 let mut cache = backend.jwks_cache.write().await;
1234 *cache = Some((Jwks { keys: vec![] }, SystemTime::now()));
1235 }
1236
1237 backend.clear_jwks_cache().await;
1238
1239 let cache = backend.jwks_cache.read().await;
1240 assert!(cache.is_none());
1241 }
1242
1243 #[test]
1248 fn test_google_jwks_url() {
1249 let config = OAuth2Config {
1250 provider: OidcProvider::Google {
1251 client_id: "client".to_string(),
1252 client_secret: "secret".to_string(),
1253 },
1254 ..Default::default()
1255 };
1256 let backend = OAuth2Backend::new(config);
1257
1258 assert!(matches!(
1260 backend.config.provider,
1261 OidcProvider::Google { .. }
1262 ));
1263 }
1264
1265 #[test]
1266 fn test_microsoft_urls() {
1267 let tenant_id = "tenant-abc-123";
1268 let provider = OidcProvider::Microsoft {
1269 tenant_id: tenant_id.to_string(),
1270 client_id: "client".to_string(),
1271 client_secret: "secret".to_string(),
1272 };
1273
1274 if let OidcProvider::Microsoft { tenant_id: tid, .. } = &provider {
1275 let expected_jwks = format!(
1276 "https://login.microsoftonline.com/{}/discovery/v2.0/keys",
1277 tid
1278 );
1279 assert!(expected_jwks.contains(tenant_id));
1280 }
1281 }
1282
1283 #[test]
1284 fn test_generic_provider_urls() {
1285 let issuer = "https://auth.company.com";
1286 let jwks_url = "https://auth.company.com/.well-known/jwks.json";
1287
1288 let provider = OidcProvider::Generic {
1289 issuer_url: issuer.to_string(),
1290 client_id: "client".to_string(),
1291 client_secret: "secret".to_string(),
1292 jwks_url: jwks_url.to_string(),
1293 };
1294
1295 if let OidcProvider::Generic {
1296 issuer_url,
1297 jwks_url: jwks,
1298 ..
1299 } = &provider
1300 {
1301 assert_eq!(issuer_url, issuer);
1302 assert_eq!(jwks, jwks_url);
1303 }
1304 }
1305
1306 #[tokio::test]
1311 async fn test_multiple_cleanup_calls() {
1312 let backend = OAuth2Backend::new(OAuth2Config::default());
1313
1314 {
1315 let mut cache = backend.token_cache.write().await;
1316 cache.insert(
1317 "expired@example.com".to_string(),
1318 TokenCacheEntry {
1319 username: "expired@example.com".to_string(),
1320 expires_at: SystemTime::now() - Duration::from_secs(10),
1321 },
1322 );
1323 }
1324
1325 backend.cleanup_expired_tokens().await;
1327 backend.cleanup_expired_tokens().await;
1328 backend.cleanup_expired_tokens().await;
1329
1330 let cache = backend.token_cache.read().await;
1331 assert_eq!(cache.len(), 0);
1332 }
1333
1334 #[tokio::test]
1335 async fn test_invalidate_nonexistent_token() {
1336 let backend = OAuth2Backend::new(OAuth2Config::default());
1337 backend.invalidate_token("nonexistent@example.com").await;
1339 assert_eq!(backend.token_cache_size().await, 0);
1340 }
1341
1342 #[test]
1343 fn test_xoauth2_empty_username() {
1344 let encoded = OAuth2Backend::encode_xoauth2_response("", "token");
1345 let result = OAuth2Backend::parse_xoauth2_response(&encoded);
1346 assert!(result.is_ok());
1347 let (username, _) = result.unwrap();
1348 assert_eq!(username, "");
1349 }
1350
1351 #[test]
1352 fn test_xoauth2_empty_token() {
1353 let encoded = OAuth2Backend::encode_xoauth2_response("user@example.com", "");
1354 let result = OAuth2Backend::parse_xoauth2_response(&encoded);
1355 assert!(result.is_ok());
1356 let (_, token) = result.unwrap();
1357 assert_eq!(token, "");
1358 }
1359
1360 #[test]
1361 fn test_xoauth2_long_token() {
1362 let long_token = "a".repeat(1000);
1363 let encoded = OAuth2Backend::encode_xoauth2_response("user@example.com", &long_token);
1364 let result = OAuth2Backend::parse_xoauth2_response(&encoded);
1365 assert!(result.is_ok());
1366 let (_, token) = result.unwrap();
1367 assert_eq!(token.len(), 1000);
1368 }
1369
1370 #[test]
1375 fn test_config_validation_minimal() {
1376 let config = OAuth2Config {
1377 provider: OidcProvider::Generic {
1378 issuer_url: "https://minimal.example.com".to_string(),
1379 client_id: "c".to_string(),
1380 client_secret: "s".to_string(),
1381 jwks_url: "https://minimal.example.com/jwks".to_string(),
1382 },
1383 introspection_endpoint: None,
1384 jwks_cache_ttl: 60,
1385 enable_refresh_tokens: false,
1386 allowed_algorithms: vec![Algorithm::RS256],
1387 };
1388
1389 let backend = OAuth2Backend::new(config);
1390 assert!(backend.config.jwks_cache_ttl >= 60);
1391 }
1392
1393 #[test]
1394 fn test_config_validation_maximal() {
1395 let config = OAuth2Config {
1396 provider: OidcProvider::Google {
1397 client_id: "very-long-client-id-with-many-characters".to_string(),
1398 client_secret: "very-long-secret-with-special-chars!@#$%".to_string(),
1399 },
1400 introspection_endpoint: Some(
1401 "https://oauth.googleapis.com/token/introspect".to_string(),
1402 ),
1403 jwks_cache_ttl: 86400,
1404 enable_refresh_tokens: true,
1405 allowed_algorithms: vec![
1406 Algorithm::RS256,
1407 Algorithm::RS384,
1408 Algorithm::RS512,
1409 Algorithm::ES256,
1410 Algorithm::ES384,
1411 ],
1412 };
1413
1414 let backend = OAuth2Backend::new(config);
1415 assert_eq!(backend.config.allowed_algorithms.len(), 5);
1416 assert!(backend.config.enable_refresh_tokens);
1417 }
1418
1419 #[tokio::test]
1424 async fn test_concurrent_jwks_cache_access() {
1425 let backend = Arc::new(OAuth2Backend::new(OAuth2Config::default()));
1426
1427 let mut handles = vec![];
1428 for _ in 0..5 {
1429 let backend = Arc::clone(&backend);
1430 let handle = tokio::spawn(async move {
1431 backend.clear_jwks_cache().await;
1432 });
1433 handles.push(handle);
1434 }
1435
1436 for handle in handles {
1437 handle.await.unwrap();
1438 }
1439
1440 let cache = backend.jwks_cache.read().await;
1441 assert!(cache.is_none());
1442 }
1443
1444 #[tokio::test]
1445 async fn test_concurrent_cleanup() {
1446 let backend = Arc::new(OAuth2Backend::new(OAuth2Config::default()));
1447
1448 {
1449 let mut cache = backend.token_cache.write().await;
1450 for i in 0..100 {
1451 cache.insert(
1452 format!("user{}@example.com", i),
1453 TokenCacheEntry {
1454 username: format!("user{}@example.com", i),
1455 expires_at: if i % 2 == 0 {
1456 SystemTime::now() + Duration::from_secs(300)
1457 } else {
1458 SystemTime::now() - Duration::from_secs(10)
1459 },
1460 },
1461 );
1462 }
1463 }
1464
1465 let mut handles = vec![];
1466 for _ in 0..10 {
1467 let backend = Arc::clone(&backend);
1468 let handle = tokio::spawn(async move {
1469 backend.cleanup_expired_tokens().await;
1470 });
1471 handles.push(handle);
1472 }
1473
1474 for handle in handles {
1475 handle.await.unwrap();
1476 }
1477
1478 let cache = backend.token_cache.read().await;
1479 assert_eq!(cache.len(), 50);
1480 }
1481
1482 #[cfg(test)]
1489 fn make_signed_jwt_and_jwks(
1490 client_id: &str,
1491 email: &str,
1492 ) -> anyhow::Result<(String, Jwks, String)> {
1493 use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
1494 use rsa::pkcs8::EncodePrivateKey;
1495 use rsa::RsaPrivateKey;
1496
1497 let bits = 512_usize;
1499 let mut rng = rand_core::OsRng;
1500 let private_key = RsaPrivateKey::new(&mut rng, bits)?;
1501 let public_key = private_key.to_public_key();
1502
1503 let private_pem = private_key
1505 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
1506 .map_err(|e| anyhow::anyhow!("pkcs8 pem error: {}", e))?;
1507
1508 let encoding_key = EncodingKey::from_rsa_pem(private_pem.as_bytes())
1509 .map_err(|e| anyhow::anyhow!("encoding key error: {}", e))?;
1510
1511 let kid = "test-key-1".to_string();
1512
1513 let mut header = Header::new(Algorithm::RS256);
1514 header.kid = Some(kid.clone());
1515
1516 let now = std::time::SystemTime::now()
1517 .duration_since(std::time::UNIX_EPOCH)
1518 .unwrap_or_default()
1519 .as_secs();
1520
1521 let claims = Claims {
1522 sub: email.to_string(),
1523 email: Some(email.to_string()),
1524 exp: now + 3600,
1525 iat: now,
1526 iss: "https://test.example.com".to_string(),
1527 aud: client_id.to_string(),
1528 };
1529
1530 let token = encode(&header, &claims, &encoding_key)
1531 .map_err(|e| anyhow::anyhow!("jwt encode error: {}", e))?;
1532
1533 use rsa::traits::PublicKeyParts;
1535 let n_bytes = public_key.n().to_bytes_be();
1536 let e_bytes = public_key.e().to_bytes_be();
1537
1538 let jwk = Jwk {
1539 kid: kid.clone(),
1540 kty: "RSA".to_string(),
1541 key_use: Some("sig".to_string()),
1542 alg: Some("RS256".to_string()),
1543 n: Some(BASE64.encode(&n_bytes)),
1544 e: Some(BASE64.encode(&e_bytes)),
1545 };
1546
1547 let jwks = Jwks { keys: vec![jwk] };
1548
1549 Ok((token, jwks, kid))
1550 }
1551
1552 #[tokio::test]
1553 async fn test_oauth2_bearer_valid_token() {
1554 let client_id = "test-client";
1555 let email = "alice@example.com";
1556
1557 let (token, jwks, _kid) = make_signed_jwt_and_jwks(client_id, email).unwrap();
1558
1559 let config = OAuth2Config {
1560 provider: OidcProvider::Generic {
1561 issuer_url: "https://test.example.com".to_string(),
1562 client_id: client_id.to_string(),
1563 client_secret: "secret".to_string(),
1564 jwks_url: "https://test.example.com/jwks".to_string(),
1565 },
1566 introspection_endpoint: None,
1567 jwks_cache_ttl: 3600,
1568 enable_refresh_tokens: false,
1569 allowed_algorithms: vec![Algorithm::RS256],
1570 };
1571
1572 let backend = OAuth2Backend::new(config);
1573
1574 {
1576 let mut cache = backend.jwks_cache.write().await;
1577 *cache = Some((jwks, SystemTime::now()));
1578 }
1579
1580 let result = backend.verify_bearer_token(&token).await;
1581 assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
1582 let username = result.unwrap();
1583 assert_eq!(username.to_string(), email);
1584 }
1585
1586 #[tokio::test]
1587 async fn test_bearer_malformed_token() {
1588 let backend = OAuth2Backend::new(OAuth2Config::default());
1590 let result = backend.verify_bearer_token("not.a.jwt").await;
1591 assert!(
1592 result.is_err(),
1593 "malformed token should be rejected, got Ok"
1594 );
1595 }
1596
1597 #[tokio::test]
1598 async fn test_file_backend_bearer_rejected() {
1599 use crate::file::FileAuthBackend;
1600 use std::env::temp_dir;
1601
1602 let path = temp_dir().join("test_file_backend_bearer.passwd");
1604 tokio::fs::write(&path, b"").await.unwrap();
1605
1606 let backend = FileAuthBackend::new(path.to_str().unwrap()).await.unwrap();
1607
1608 let result = backend.verify_bearer_token("some-token").await;
1609 assert!(
1610 result.is_err(),
1611 "FileAuthBackend should reject Bearer tokens"
1612 );
1613 }
1614
1615 #[tokio::test]
1616 async fn test_sql_backend_bearer_rejected() {
1617 use crate::backends::sql::{SqlBackend, SqlConfig};
1618
1619 let url = format!(
1621 "sqlite:file:rusmes_sql_bearer_test_{}?mode=memory&cache=shared",
1622 std::time::SystemTime::now()
1623 .duration_since(std::time::UNIX_EPOCH)
1624 .map(|d| d.as_nanos())
1625 .unwrap_or(0)
1626 );
1627
1628 let config = SqlConfig {
1629 database_url: url,
1630 ..Default::default()
1631 };
1632
1633 match SqlBackend::new(config).await {
1634 Ok(backend) => {
1635 let result = backend.verify_bearer_token("some-token").await;
1636 assert!(result.is_err(), "SqlBackend should reject Bearer tokens");
1637 }
1638 Err(_) => {
1639 }
1643 }
1644 }
1645}