uvb_storage_memory/
enrollment.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5use uvb_core::TenantId;
6use uvb_storage_api::{EnrollmentError, EnrollmentRecord, EnrollmentStore};
7
8pub struct InMemoryEnrollmentStore {
9 enrollments: Arc<RwLock<HashMap<String, EnrollmentRecord>>>,
10}
11
12impl InMemoryEnrollmentStore {
13 pub fn new() -> Self {
14 Self {
15 enrollments: Arc::new(RwLock::new(HashMap::new())),
16 }
17 }
18}
19
20impl Default for InMemoryEnrollmentStore {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26#[async_trait]
27impl EnrollmentStore for InMemoryEnrollmentStore {
28 async fn create(&self, record: EnrollmentRecord) -> Result<String, EnrollmentError> {
29 let id = record.id.clone();
30 tracing::info!(
31 enrollment_id = %id,
32 user_id = %record.user_id,
33 tenant_id = %record.tenant_id,
34 factor_id = %record.factor_id,
35 "Storing enrollment in memory"
36 );
37 self.enrollments
38 .write()
39 .await
40 .insert(id.clone(), record.clone());
41 Ok(id)
42 }
43
44 async fn get(&self, id: &str) -> Result<Option<EnrollmentRecord>, EnrollmentError> {
45 Ok(self.enrollments.read().await.get(id).cloned())
46 }
47
48 async fn is_enrolled(
49 &self,
50 user_id: &str,
51 tenant_id: &TenantId,
52 factor_id: &str,
53 ) -> Result<bool, EnrollmentError> {
54 let enrollments = self.enrollments.read().await;
55 let count = enrollments.len();
56 let enrolled = enrollments.values().any(|e| {
57 e.user_id == user_id
58 && &e.tenant_id == tenant_id
59 && e.factor_id == factor_id
60 && e.status == uvb_storage_api::EnrollmentStatus::Active
61 });
62 tracing::info!(
63 user_id = %user_id,
64 tenant_id = %tenant_id,
65 factor_id = %factor_id,
66 enrolled = enrolled,
67 total_enrollments = count,
68 "Checking enrollment status"
69 );
70 Ok(enrolled)
71 }
72
73 async fn list_by_user(
74 &self,
75 user_id: &str,
76 tenant_id: &TenantId,
77 ) -> Result<Vec<EnrollmentRecord>, EnrollmentError> {
78 let enrollments = self.enrollments.read().await;
79 Ok(enrollments
80 .values()
81 .filter(|e| e.user_id == user_id && &e.tenant_id == tenant_id)
82 .cloned()
83 .collect())
84 }
85
86 async fn update(&self, record: EnrollmentRecord) -> Result<(), EnrollmentError> {
87 let mut enrollments = self.enrollments.write().await;
88 if !enrollments.contains_key(&record.id) {
89 return Err(EnrollmentError::NotFound);
90 }
91 enrollments.insert(record.id.clone(), record);
92 Ok(())
93 }
94
95 async fn delete(&self, id: &str) -> Result<(), EnrollmentError> {
96 self.enrollments
97 .write()
98 .await
99 .remove(id)
100 .ok_or(EnrollmentError::NotFound)?;
101 Ok(())
102 }
103
104 async fn record_usage(&self, id: &str) -> Result<(), EnrollmentError> {
105 let mut enrollments = self.enrollments.write().await;
106 if let Some(enrollment) = enrollments.get_mut(id) {
107 enrollment.use_count += 1;
108 enrollment.last_used_at = Some(std::time::SystemTime::now());
109 Ok(())
110 } else {
111 Err(EnrollmentError::NotFound)
112 }
113 }
114}