1use async_trait::async_trait;
6use dashmap::DashMap;
7use serde::{Serialize, de::DeserializeOwned};
8use std::{sync::Arc, time::Duration};
9use tokio::time::Instant;
10
11#[async_trait]
16pub trait SessionStore: Clone + Send + Sync + 'static {
17 async fn get(&self, session_id: &str) -> Option<String>;
19
20 async fn set(&self, session_id: &str, data: &str, ttl: Duration);
22
23 async fn remove(&self, session_id: &str);
25
26 async fn exists(&self, session_id: &str) -> bool;
28
29 async fn refresh(&self, session_id: &str, ttl: Duration);
31
32 async fn get_typed<T: DeserializeOwned + Send + Sync>(&self, session_id: &str) -> Option<T> {
34 let data = self.get(session_id).await?;
35 serde_json::from_str(&data).ok()
36 }
37
38 async fn set_typed<T: Serialize + Send + Sync>(&self, session_id: &str, data: &T, ttl: Duration) -> bool {
40 match serde_json::to_string(data) {
41 Ok(json) => {
42 self.set(session_id, &json, ttl).await;
43 true
44 }
45 Err(_) => false,
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
52struct SessionEntry {
53 data: String,
55 expires_at: Instant,
57}
58
59#[derive(Debug, Clone)]
64pub struct MemorySessionStore {
65 storage: Arc<DashMap<String, SessionEntry>>,
67 cleanup_interval: Duration,
69 #[allow(dead_code)]
71 auto_cleanup: bool,
72}
73
74impl MemorySessionStore {
75 pub fn new() -> Self {
77 Self::with_config(Duration::from_secs(60), true)
78 }
79
80 pub fn with_config(cleanup_interval: Duration, auto_cleanup: bool) -> Self {
82 let store = Self { storage: Arc::new(DashMap::new()), cleanup_interval, auto_cleanup };
83
84 if auto_cleanup {
85 store.start_cleanup_task();
86 }
87
88 store
89 }
90
91 fn start_cleanup_task(&self) {
93 let storage = Arc::clone(&self.storage);
94 let interval = self.cleanup_interval;
95
96 tokio::spawn(async move {
97 loop {
98 tokio::time::sleep(interval).await;
99 let now = Instant::now();
100
101 storage.retain(|_, entry| entry.expires_at > now);
102 }
103 });
104 }
105
106 pub fn cleanup_expired(&self) {
108 let now = Instant::now();
109 self.storage.retain(|_, entry| entry.expires_at > now);
110 }
111
112 pub fn len(&self) -> usize {
114 self.storage.len()
115 }
116
117 pub fn is_empty(&self) -> bool {
119 self.storage.is_empty()
120 }
121
122 pub fn clear(&self) {
124 self.storage.clear();
125 }
126}
127
128impl Default for MemorySessionStore {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134#[async_trait]
135impl SessionStore for MemorySessionStore {
136 async fn get(&self, session_id: &str) -> Option<String> {
137 self.storage.get(session_id).filter(|entry| entry.expires_at > Instant::now()).map(|entry| entry.data.clone())
138 }
139
140 async fn set(&self, session_id: &str, data: &str, ttl: Duration) {
141 let entry = SessionEntry { data: data.to_string(), expires_at: Instant::now() + ttl };
142 self.storage.insert(session_id.to_string(), entry);
143 }
144
145 async fn remove(&self, session_id: &str) {
146 self.storage.remove(session_id);
147 }
148
149 async fn exists(&self, session_id: &str) -> bool {
150 self.storage.get(session_id).map(|entry| entry.expires_at > Instant::now()).unwrap_or(false)
151 }
152
153 async fn refresh(&self, session_id: &str, ttl: Duration) {
154 if let Some(mut entry) = self.storage.get_mut(session_id) {
155 entry.expires_at = Instant::now() + ttl;
156 }
157 }
158}