use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use time::OffsetDateTime;
use tokio::sync::Mutex;
use tower_sessions_core::{
session::{Id, Record},
session_store, SessionStore,
};
#[derive(Clone, Debug, Default)]
pub struct MemoryStore(Arc<Mutex<HashMap<Id, Record>>>);
#[async_trait]
impl SessionStore for MemoryStore {
async fn create(&self, record: &mut Record) -> session_store::Result<()> {
let mut store_guard = self.0.lock().await;
while store_guard.contains_key(&record.id) {
record.id = Id::default();
}
store_guard.insert(record.id, record.clone());
Ok(())
}
async fn save(&self, record: &Record) -> session_store::Result<()> {
self.0.lock().await.insert(record.id, record.clone());
Ok(())
}
async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>> {
Ok(self
.0
.lock()
.await
.get(session_id)
.filter(|Record { expiry_date, .. }| is_active(*expiry_date))
.cloned())
}
async fn delete(&self, session_id: &Id) -> session_store::Result<()> {
self.0.lock().await.remove(session_id);
Ok(())
}
}
fn is_active(expiry_date: OffsetDateTime) -> bool {
expiry_date > OffsetDateTime::now_utc()
}
#[cfg(test)]
mod tests {
use time::Duration;
use super::*;
#[tokio::test]
async fn test_create() {
let store = MemoryStore::default();
let mut record = Record {
id: Default::default(),
data: Default::default(),
expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
};
assert!(store.create(&mut record).await.is_ok());
}
#[tokio::test]
async fn test_save() {
let store = MemoryStore::default();
let record = Record {
id: Default::default(),
data: Default::default(),
expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
};
assert!(store.save(&record).await.is_ok());
}
#[tokio::test]
async fn test_load() {
let store = MemoryStore::default();
let mut record = Record {
id: Default::default(),
data: Default::default(),
expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
};
store.create(&mut record).await.unwrap();
let loaded_record = store.load(&record.id).await.unwrap();
assert_eq!(Some(record), loaded_record);
}
#[tokio::test]
async fn test_delete() {
let store = MemoryStore::default();
let mut record = Record {
id: Default::default(),
data: Default::default(),
expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
};
store.create(&mut record).await.unwrap();
assert!(store.delete(&record.id).await.is_ok());
assert_eq!(None, store.load(&record.id).await.unwrap());
}
#[tokio::test]
async fn test_create_id_collision() {
let store = MemoryStore::default();
let expiry_date = OffsetDateTime::now_utc() + Duration::minutes(30);
let mut record1 = Record {
id: Default::default(),
data: Default::default(),
expiry_date,
};
let mut record2 = Record {
id: Default::default(),
data: Default::default(),
expiry_date,
};
store.create(&mut record1).await.unwrap();
record2.id = record1.id; store.create(&mut record2).await.unwrap();
assert_ne!(record1.id, record2.id); }
}