tower_sessions_memory_store/
lib.rs

1use std::{collections::HashMap, sync::Arc};
2
3use async_trait::async_trait;
4use time::OffsetDateTime;
5use tokio::sync::Mutex;
6use tower_sessions_core::{
7    session::{Id, Record},
8    session_store, SessionStore,
9};
10
11/// A session store that lives only in memory.
12///
13/// This is useful for testing but not recommended for real applications.
14///
15/// # Examples
16///
17/// ```rust
18/// use tower_sessions::MemoryStore;
19/// MemoryStore::default();
20/// ```
21#[derive(Clone, Debug, Default)]
22pub struct MemoryStore(Arc<Mutex<HashMap<Id, Record>>>);
23
24#[async_trait]
25impl SessionStore for MemoryStore {
26    async fn create(&self, record: &mut Record) -> session_store::Result<()> {
27        let mut store_guard = self.0.lock().await;
28        while store_guard.contains_key(&record.id) {
29            // Session ID collision mitigation.
30            record.id = Id::default();
31        }
32        store_guard.insert(record.id, record.clone());
33        Ok(())
34    }
35
36    async fn save(&self, record: &Record) -> session_store::Result<()> {
37        self.0.lock().await.insert(record.id, record.clone());
38        Ok(())
39    }
40
41    async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>> {
42        Ok(self
43            .0
44            .lock()
45            .await
46            .get(session_id)
47            .filter(|Record { expiry_date, .. }| is_active(*expiry_date))
48            .cloned())
49    }
50
51    async fn delete(&self, session_id: &Id) -> session_store::Result<()> {
52        self.0.lock().await.remove(session_id);
53        Ok(())
54    }
55}
56
57fn is_active(expiry_date: OffsetDateTime) -> bool {
58    expiry_date > OffsetDateTime::now_utc()
59}
60
61#[cfg(test)]
62mod tests {
63    use time::Duration;
64
65    use super::*;
66
67    #[tokio::test]
68    async fn test_create() {
69        let store = MemoryStore::default();
70        let mut record = Record {
71            id: Default::default(),
72            data: Default::default(),
73            expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
74        };
75        assert!(store.create(&mut record).await.is_ok());
76    }
77
78    #[tokio::test]
79    async fn test_save() {
80        let store = MemoryStore::default();
81        let record = Record {
82            id: Default::default(),
83            data: Default::default(),
84            expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
85        };
86        assert!(store.save(&record).await.is_ok());
87    }
88
89    #[tokio::test]
90    async fn test_load() {
91        let store = MemoryStore::default();
92        let mut record = Record {
93            id: Default::default(),
94            data: Default::default(),
95            expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
96        };
97        store.create(&mut record).await.unwrap();
98        let loaded_record = store.load(&record.id).await.unwrap();
99        assert_eq!(Some(record), loaded_record);
100    }
101
102    #[tokio::test]
103    async fn test_delete() {
104        let store = MemoryStore::default();
105        let mut record = Record {
106            id: Default::default(),
107            data: Default::default(),
108            expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
109        };
110        store.create(&mut record).await.unwrap();
111        assert!(store.delete(&record.id).await.is_ok());
112        assert_eq!(None, store.load(&record.id).await.unwrap());
113    }
114
115    #[tokio::test]
116    async fn test_create_id_collision() {
117        let store = MemoryStore::default();
118        let expiry_date = OffsetDateTime::now_utc() + Duration::minutes(30);
119        let mut record1 = Record {
120            id: Default::default(),
121            data: Default::default(),
122            expiry_date,
123        };
124        let mut record2 = Record {
125            id: Default::default(),
126            data: Default::default(),
127            expiry_date,
128        };
129        store.create(&mut record1).await.unwrap();
130        record2.id = record1.id; // Set the same ID for record2
131        store.create(&mut record2).await.unwrap();
132        assert_ne!(record1.id, record2.id); // IDs should be different
133    }
134}