scim_server/multi_tenant/
resolver.rs1use crate::resource::TenantContext;
8use std::collections::HashMap;
9use std::future::Future;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13pub trait TenantResolver: Send + Sync {
60 type Error: std::error::Error + Send + Sync + 'static;
62
63 fn resolve_tenant(
78 &self,
79 credential: &str,
80 ) -> impl Future<Output = Result<TenantContext, Self::Error>> + Send;
81
82 fn validate_tenant(
93 &self,
94 tenant_id: &str,
95 ) -> impl Future<Output = Result<bool, Self::Error>> + Send;
96
97 fn list_tenants(&self) -> impl Future<Output = Result<Vec<String>, Self::Error>> + Send {
105 async move {
106 Ok(vec![])
108 }
109 }
110
111 fn is_valid_credential(
121 &self,
122 credential: &str,
123 ) -> impl Future<Output = Result<bool, Self::Error>> + Send {
124 async move {
125 match self.resolve_tenant(credential).await {
126 Ok(_) => Ok(true),
127 Err(_) => Ok(false),
128 }
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
172pub struct StaticTenantResolver {
173 tenants: Arc<RwLock<HashMap<String, TenantContext>>>,
174}
175
176impl StaticTenantResolver {
177 pub fn new() -> Self {
179 Self {
180 tenants: Arc::new(RwLock::new(HashMap::new())),
181 }
182 }
183
184 pub async fn add_tenant(&self, credential: &str, tenant_context: TenantContext) {
202 let mut tenants = self.tenants.write().await;
203 tenants.insert(credential.to_string(), tenant_context);
204 }
205
206 pub async fn remove_tenant(&self, credential: &str) -> Option<TenantContext> {
214 let mut tenants = self.tenants.write().await;
215 tenants.remove(credential)
216 }
217
218 pub async fn tenant_count(&self) -> usize {
220 let tenants = self.tenants.read().await;
221 tenants.len()
222 }
223
224 pub async fn clear(&self) {
226 let mut tenants = self.tenants.write().await;
227 tenants.clear();
228 }
229
230 pub async fn get_all_credentials(&self) -> Vec<String> {
232 let tenants = self.tenants.read().await;
233 tenants.keys().cloned().collect()
234 }
235}
236
237impl Default for StaticTenantResolver {
238 fn default() -> Self {
239 Self::new()
240 }
241}
242
243#[derive(Debug, thiserror::Error)]
245pub enum StaticResolverError {
246 #[error("Invalid credentials: {credential}")]
247 InvalidCredentials { credential: String },
248 #[error("Tenant not found: {tenant_id}")]
249 TenantNotFound { tenant_id: String },
250}
251
252impl TenantResolver for StaticTenantResolver {
253 type Error = StaticResolverError;
254
255 async fn resolve_tenant(&self, credential: &str) -> Result<TenantContext, Self::Error> {
256 let tenants = self.tenants.read().await;
257 tenants
258 .get(credential)
259 .cloned()
260 .ok_or_else(|| StaticResolverError::InvalidCredentials {
261 credential: credential.to_string(),
262 })
263 }
264
265 async fn validate_tenant(&self, tenant_id: &str) -> Result<bool, Self::Error> {
266 let tenants = self.tenants.read().await;
267 Ok(tenants.values().any(|ctx| ctx.tenant_id == tenant_id))
268 }
269
270 async fn list_tenants(&self) -> Result<Vec<String>, Self::Error> {
271 let tenants = self.tenants.read().await;
272 Ok(tenants.values().map(|ctx| ctx.tenant_id.clone()).collect())
273 }
274}
275
276pub struct StaticTenantResolverBuilder {
306 tenants: Vec<(String, TenantContext)>,
307}
308
309impl StaticTenantResolverBuilder {
310 pub fn new() -> Self {
312 Self {
313 tenants: Vec::new(),
314 }
315 }
316
317 pub fn with_tenant(mut self, credential: &str, tenant_context: TenantContext) -> Self {
319 self.tenants.push((credential.to_string(), tenant_context));
320 self
321 }
322
323 pub async fn build(self) -> StaticTenantResolver {
325 let resolver = StaticTenantResolver::new();
326 for (credential, tenant_context) in self.tenants {
327 resolver.add_tenant(&credential, tenant_context).await;
328 }
329 resolver
330 }
331}
332
333impl Default for StaticTenantResolverBuilder {
334 fn default() -> Self {
335 Self::new()
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use crate::resource::{IsolationLevel, TenantPermissions};
343
344 #[tokio::test]
345 async fn test_static_resolver_basic_operations() {
346 let resolver = StaticTenantResolver::new();
347 assert_eq!(resolver.tenant_count().await, 0);
348
349 let tenant_context =
350 TenantContext::new("test-tenant".to_string(), "test-client".to_string());
351 resolver
352 .add_tenant("test-key", tenant_context.clone())
353 .await;
354
355 assert_eq!(resolver.tenant_count().await, 1);
356
357 let resolved = resolver.resolve_tenant("test-key").await.unwrap();
358 assert_eq!(resolved.tenant_id, "test-tenant");
359 assert_eq!(resolved.client_id, "test-client");
360 }
361
362 #[tokio::test]
363 async fn test_static_resolver_invalid_credentials() {
364 let resolver = StaticTenantResolver::new();
365 let result = resolver.resolve_tenant("invalid-key").await;
366 assert!(result.is_err());
367 assert!(matches!(
368 result.unwrap_err(),
369 StaticResolverError::InvalidCredentials { .. }
370 ));
371 }
372
373 #[tokio::test]
374 async fn test_static_resolver_tenant_validation() {
375 let resolver = StaticTenantResolver::new();
376 let tenant_context = TenantContext::new("valid-tenant".to_string(), "client".to_string());
377 resolver.add_tenant("key", tenant_context).await;
378
379 assert!(resolver.validate_tenant("valid-tenant").await.unwrap());
380 assert!(!resolver.validate_tenant("invalid-tenant").await.unwrap());
381 }
382
383 #[tokio::test]
384 async fn test_static_resolver_list_tenants() {
385 let resolver = StaticTenantResolver::new();
386
387 resolver
388 .add_tenant(
389 "key1",
390 TenantContext::new("tenant1".to_string(), "client1".to_string()),
391 )
392 .await;
393 resolver
394 .add_tenant(
395 "key2",
396 TenantContext::new("tenant2".to_string(), "client2".to_string()),
397 )
398 .await;
399
400 let tenants = resolver.list_tenants().await.unwrap();
401 assert_eq!(tenants.len(), 2);
402 assert!(tenants.contains(&"tenant1".to_string()));
403 assert!(tenants.contains(&"tenant2".to_string()));
404 }
405
406 #[tokio::test]
407 async fn test_static_resolver_remove_tenant() {
408 let resolver = StaticTenantResolver::new();
409 let tenant_context = TenantContext::new("test".to_string(), "client".to_string());
410 resolver.add_tenant("key", tenant_context.clone()).await;
411
412 assert_eq!(resolver.tenant_count().await, 1);
413
414 let removed = resolver.remove_tenant("key").await;
415 assert!(removed.is_some());
416 assert_eq!(removed.unwrap().tenant_id, "test");
417 assert_eq!(resolver.tenant_count().await, 0);
418
419 let not_found = resolver.remove_tenant("nonexistent").await;
420 assert!(not_found.is_none());
421 }
422
423 #[tokio::test]
424 async fn test_static_resolver_clear() {
425 let resolver = StaticTenantResolver::new();
426 resolver
427 .add_tenant(
428 "key1",
429 TenantContext::new("tenant1".to_string(), "client1".to_string()),
430 )
431 .await;
432 resolver
433 .add_tenant(
434 "key2",
435 TenantContext::new("tenant2".to_string(), "client2".to_string()),
436 )
437 .await;
438
439 assert_eq!(resolver.tenant_count().await, 2);
440 resolver.clear().await;
441 assert_eq!(resolver.tenant_count().await, 0);
442 }
443
444 #[tokio::test]
445 async fn test_static_resolver_is_valid_credential() {
446 let resolver = StaticTenantResolver::new();
447 resolver
448 .add_tenant(
449 "valid-key",
450 TenantContext::new("tenant".to_string(), "client".to_string()),
451 )
452 .await;
453
454 assert!(resolver.is_valid_credential("valid-key").await.unwrap());
455 assert!(!resolver.is_valid_credential("invalid-key").await.unwrap());
456 }
457
458 #[tokio::test]
459 async fn test_static_resolver_builder() {
460 let resolver = StaticTenantResolverBuilder::new()
461 .with_tenant(
462 "key1",
463 TenantContext::new("tenant1".to_string(), "client1".to_string()),
464 )
465 .with_tenant(
466 "key2",
467 TenantContext::new("tenant2".to_string(), "client2".to_string())
468 .with_isolation_level(IsolationLevel::Strict),
469 )
470 .build()
471 .await;
472
473 assert_eq!(resolver.tenant_count().await, 2);
474
475 let tenant1 = resolver.resolve_tenant("key1").await.unwrap();
476 assert_eq!(tenant1.tenant_id, "tenant1");
477 assert_eq!(tenant1.isolation_level, IsolationLevel::Standard);
478
479 let tenant2 = resolver.resolve_tenant("key2").await.unwrap();
480 assert_eq!(tenant2.tenant_id, "tenant2");
481 assert_eq!(tenant2.isolation_level, IsolationLevel::Strict);
482 }
483
484 #[tokio::test]
485 async fn test_static_resolver_get_all_credentials() {
486 let resolver = StaticTenantResolver::new();
487 resolver
488 .add_tenant(
489 "key1",
490 TenantContext::new("tenant1".to_string(), "client1".to_string()),
491 )
492 .await;
493 resolver
494 .add_tenant(
495 "key2",
496 TenantContext::new("tenant2".to_string(), "client2".to_string()),
497 )
498 .await;
499
500 let credentials = resolver.get_all_credentials().await;
501 assert_eq!(credentials.len(), 2);
502 assert!(credentials.contains(&"key1".to_string()));
503 assert!(credentials.contains(&"key2".to_string()));
504 }
505
506 #[tokio::test]
507 async fn test_complex_tenant_context() {
508 let mut permissions = TenantPermissions::default();
509 permissions.max_users = Some(100);
510 permissions.can_delete = false;
511
512 let tenant_context =
513 TenantContext::new("complex-tenant".to_string(), "complex-client".to_string())
514 .with_isolation_level(IsolationLevel::Strict)
515 .with_permissions(permissions);
516
517 let resolver = StaticTenantResolver::new();
518 resolver.add_tenant("complex-key", tenant_context).await;
519
520 let resolved = resolver.resolve_tenant("complex-key").await.unwrap();
521 assert_eq!(resolved.isolation_level, IsolationLevel::Strict);
522 assert_eq!(resolved.permissions.max_users, Some(100));
523 assert!(!resolved.permissions.can_delete);
524 assert!(resolved.check_user_limit(50));
525 assert!(!resolved.check_user_limit(100));
526 }
527}