Skip to main content

uvb_storage_memory/
session.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime};
5use tokio::sync::RwLock;
6use uvb_core::TenantId;
7use uvb_storage_api::{SessionError, SessionRecord, SessionStore};
8
9pub struct InMemorySessionStore {
10    sessions: Arc<RwLock<HashMap<String, SessionRecord>>>,
11}
12
13impl InMemorySessionStore {
14    pub fn new() -> Self {
15        Self {
16            sessions: Arc::new(RwLock::new(HashMap::new())),
17        }
18    }
19}
20
21impl Default for InMemorySessionStore {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27#[async_trait]
28impl SessionStore for InMemorySessionStore {
29    async fn create(&self, record: SessionRecord) -> Result<String, SessionError> {
30        let id = record.id.clone();
31        self.sessions.write().await.insert(id.clone(), record);
32        Ok(id)
33    }
34
35    async fn get(&self, id: &str) -> Result<Option<SessionRecord>, SessionError> {
36        let sessions = self.sessions.read().await;
37        if let Some(session) = sessions.get(id) {
38            if session.expires_at <= SystemTime::now() {
39                return Err(SessionError::Expired);
40            }
41            Ok(Some(session.clone()))
42        } else {
43            Ok(None)
44        }
45    }
46
47    async fn update(&self, record: SessionRecord) -> Result<(), SessionError> {
48        let mut sessions = self.sessions.write().await;
49        if !sessions.contains_key(&record.id) {
50            return Err(SessionError::NotFound);
51        }
52        sessions.insert(record.id.clone(), record);
53        Ok(())
54    }
55
56    async fn delete(&self, id: &str) -> Result<(), SessionError> {
57        self.sessions
58            .write()
59            .await
60            .remove(id)
61            .ok_or(SessionError::NotFound)?;
62        Ok(())
63    }
64
65    async fn delete_by_user(
66        &self,
67        user_id: &str,
68        tenant_id: &TenantId,
69    ) -> Result<usize, SessionError> {
70        let mut sessions = self.sessions.write().await;
71        let to_delete: Vec<String> = sessions
72            .iter()
73            .filter(|(_, s)| s.user_id == user_id && &s.tenant_id == tenant_id)
74            .map(|(id, _)| id.clone())
75            .collect();
76
77        let count = to_delete.len();
78        for id in to_delete {
79            sessions.remove(&id);
80        }
81        Ok(count)
82    }
83
84    async fn extend(&self, id: &str, duration: Duration) -> Result<(), SessionError> {
85        let mut sessions = self.sessions.write().await;
86        if let Some(session) = sessions.get_mut(id) {
87            session.expires_at = session
88                .expires_at
89                .checked_add(duration)
90                .unwrap_or(session.expires_at);
91            Ok(())
92        } else {
93            Err(SessionError::NotFound)
94        }
95    }
96
97    async fn touch(&self, id: &str) -> Result<(), SessionError> {
98        let mut sessions = self.sessions.write().await;
99        if let Some(session) = sessions.get_mut(id) {
100            session.last_activity_at = SystemTime::now();
101            Ok(())
102        } else {
103            Err(SessionError::NotFound)
104        }
105    }
106
107    async fn cleanup_expired(&self) -> Result<usize, SessionError> {
108        let mut sessions = self.sessions.write().await;
109        let now = SystemTime::now();
110        let expired: Vec<String> = sessions
111            .iter()
112            .filter(|(_, s)| s.expires_at <= now)
113            .map(|(id, _)| id.clone())
114            .collect();
115
116        let count = expired.len();
117        for id in expired {
118            sessions.remove(&id);
119        }
120        Ok(count)
121    }
122
123    async fn list_by_user(
124        &self,
125        user_id: &str,
126        tenant_id: &TenantId,
127    ) -> Result<Vec<SessionRecord>, SessionError> {
128        let sessions = self.sessions.read().await;
129        let now = SystemTime::now();
130        Ok(sessions
131            .values()
132            .filter(|s| s.user_id == user_id && &s.tenant_id == tenant_id && s.expires_at > now)
133            .cloned()
134            .collect())
135    }
136}