Skip to main content

uvb_storage_memory/
enrollment.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5use uvb_core::TenantId;
6use uvb_storage_api::{EnrollmentError, EnrollmentRecord, EnrollmentStore};
7
8pub struct InMemoryEnrollmentStore {
9    enrollments: Arc<RwLock<HashMap<String, EnrollmentRecord>>>,
10}
11
12impl InMemoryEnrollmentStore {
13    pub fn new() -> Self {
14        Self {
15            enrollments: Arc::new(RwLock::new(HashMap::new())),
16        }
17    }
18}
19
20impl Default for InMemoryEnrollmentStore {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26#[async_trait]
27impl EnrollmentStore for InMemoryEnrollmentStore {
28    async fn create(&self, record: EnrollmentRecord) -> Result<String, EnrollmentError> {
29        let id = record.id.clone();
30        tracing::info!(
31            enrollment_id = %id,
32            user_id = %record.user_id,
33            tenant_id = %record.tenant_id,
34            factor_id = %record.factor_id,
35            "Storing enrollment in memory"
36        );
37        self.enrollments
38            .write()
39            .await
40            .insert(id.clone(), record.clone());
41        Ok(id)
42    }
43
44    async fn get(&self, id: &str) -> Result<Option<EnrollmentRecord>, EnrollmentError> {
45        Ok(self.enrollments.read().await.get(id).cloned())
46    }
47
48    async fn is_enrolled(
49        &self,
50        user_id: &str,
51        tenant_id: &TenantId,
52        factor_id: &str,
53    ) -> Result<bool, EnrollmentError> {
54        let enrollments = self.enrollments.read().await;
55        let count = enrollments.len();
56        let enrolled = enrollments.values().any(|e| {
57            e.user_id == user_id
58                && &e.tenant_id == tenant_id
59                && e.factor_id == factor_id
60                && e.status == uvb_storage_api::EnrollmentStatus::Active
61        });
62        tracing::info!(
63            user_id = %user_id,
64            tenant_id = %tenant_id,
65            factor_id = %factor_id,
66            enrolled = enrolled,
67            total_enrollments = count,
68            "Checking enrollment status"
69        );
70        Ok(enrolled)
71    }
72
73    async fn list_by_user(
74        &self,
75        user_id: &str,
76        tenant_id: &TenantId,
77    ) -> Result<Vec<EnrollmentRecord>, EnrollmentError> {
78        let enrollments = self.enrollments.read().await;
79        Ok(enrollments
80            .values()
81            .filter(|e| e.user_id == user_id && &e.tenant_id == tenant_id)
82            .cloned()
83            .collect())
84    }
85
86    async fn update(&self, record: EnrollmentRecord) -> Result<(), EnrollmentError> {
87        let mut enrollments = self.enrollments.write().await;
88        if !enrollments.contains_key(&record.id) {
89            return Err(EnrollmentError::NotFound);
90        }
91        enrollments.insert(record.id.clone(), record);
92        Ok(())
93    }
94
95    async fn delete(&self, id: &str) -> Result<(), EnrollmentError> {
96        self.enrollments
97            .write()
98            .await
99            .remove(id)
100            .ok_or(EnrollmentError::NotFound)?;
101        Ok(())
102    }
103
104    async fn record_usage(&self, id: &str) -> Result<(), EnrollmentError> {
105        let mut enrollments = self.enrollments.write().await;
106        if let Some(enrollment) = enrollments.get_mut(id) {
107            enrollment.use_count += 1;
108            enrollment.last_used_at = Some(std::time::SystemTime::now());
109            Ok(())
110        } else {
111            Err(EnrollmentError::NotFound)
112        }
113    }
114}