turul_mcp_session_storage/
in_memory.rs1use std::collections::HashMap;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::SystemTime;
13
14use async_trait::async_trait;
15use tokio::sync::RwLock;
16use tracing::{debug, info};
17
18use crate::{SessionInfo, SessionStorage, SessionStorageError, SseEvent};
19use turul_mcp_protocol::ServerCapabilities;
20
21#[derive(Debug, Clone)]
23pub struct InMemorySessionStorage {
24 sessions: Arc<RwLock<HashMap<String, SessionInfo>>>,
26 events: Arc<RwLock<HashMap<String, Vec<SseEvent>>>>,
28 event_counter: Arc<AtomicU64>,
30 config: InMemoryConfig,
32}
33
34#[derive(Debug, Clone)]
36pub struct InMemoryConfig {
37 pub max_events_per_session: usize,
39 pub max_sessions: usize,
41}
42
43impl Default for InMemoryConfig {
44 fn default() -> Self {
45 Self {
46 max_events_per_session: 10_000, max_sessions: 100_000, }
49 }
50}
51
52#[derive(Debug, thiserror::Error)]
54pub enum InMemoryError {
55 #[error("Session not found: {0}")]
56 SessionNotFound(String),
57 #[error("Maximum sessions limit reached: {0}")]
58 MaxSessionsReached(usize),
59 #[error("Maximum events per session limit reached: {0}")]
60 MaxEventsReached(usize),
61 #[error("Serialization error: {0}")]
62 SerializationError(#[from] serde_json::Error),
63}
64
65impl Default for InMemorySessionStorage {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl InMemorySessionStorage {
72 pub fn new() -> Self {
74 Self::with_config(InMemoryConfig::default())
75 }
76
77 pub fn with_config(config: InMemoryConfig) -> Self {
79 Self {
80 sessions: Arc::new(RwLock::new(HashMap::new())),
81 events: Arc::new(RwLock::new(HashMap::new())),
82 event_counter: Arc::new(AtomicU64::new(1)), config,
84 }
85 }
86
87 pub async fn stats(&self) -> InMemoryStats {
89 let sessions = self.sessions.read().await;
90 let events = self.events.read().await;
91
92 let total_events = events.values().map(|v| v.len()).sum();
93
94 InMemoryStats {
95 session_count: sessions.len(),
96 total_event_count: total_events,
97 max_events_per_session: self.config.max_events_per_session,
98 max_sessions: self.config.max_sessions,
99 }
100 }
101
102 async fn cleanup_events(&self) -> Result<u64, InMemoryError> {
104 let mut events = self.events.write().await;
105 let mut total_removed = 0u64;
106
107 for (session_id, event_list) in events.iter_mut() {
108 if event_list.len() > self.config.max_events_per_session {
109 let excess = event_list.len() - self.config.max_events_per_session;
110 event_list.drain(0..excess); total_removed += excess as u64;
112 debug!(
113 "Cleaned up {} old events for session {}",
114 excess, session_id
115 );
116 }
117 }
118
119 if total_removed > 0 {
120 info!(
121 "Cleaned up {} old events across all sessions",
122 total_removed
123 );
124 }
125
126 Ok(total_removed)
127 }
128}
129
130#[derive(Debug, Clone)]
132pub struct InMemoryStats {
133 pub session_count: usize,
134 pub total_event_count: usize,
135 pub max_events_per_session: usize,
136 pub max_sessions: usize,
137}
138
139#[async_trait]
140impl SessionStorage for InMemorySessionStorage {
141 type Error = SessionStorageError;
142
143 fn backend_name(&self) -> &'static str {
144 "InMemory"
145 }
146
147 async fn create_session(
152 &self,
153 capabilities: ServerCapabilities,
154 ) -> Result<SessionInfo, Self::Error> {
155 let mut sessions = self.sessions.write().await;
156
157 if sessions.len() >= self.config.max_sessions {
158 return Err(SessionStorageError::MaxSessionsReached(
159 self.config.max_sessions,
160 ));
161 }
162
163 let mut session = SessionInfo::new();
164 session.server_capabilities = Some(capabilities);
165
166 let session_id = session.session_id.clone();
167 sessions.insert(session_id.clone(), session.clone());
168
169 debug!("Created session: {}", session_id);
170 Ok(session)
171 }
172
173 async fn create_session_with_id(
174 &self,
175 session_id: String,
176 capabilities: ServerCapabilities,
177 ) -> Result<SessionInfo, Self::Error> {
178 let mut sessions = self.sessions.write().await;
179
180 if sessions.len() >= self.config.max_sessions {
181 return Err(SessionStorageError::MaxSessionsReached(
182 self.config.max_sessions,
183 ));
184 }
185
186 let mut session = SessionInfo::with_id(session_id.clone());
187 session.server_capabilities = Some(capabilities);
188
189 sessions.insert(session_id.clone(), session.clone());
190
191 debug!("Created session with ID: {}", session_id);
192 Ok(session)
193 }
194
195 async fn get_session(&self, session_id: &str) -> Result<Option<SessionInfo>, Self::Error> {
196 let sessions = self.sessions.read().await;
197 Ok(sessions.get(session_id).cloned())
198 }
199
200 async fn update_session(&self, session_info: SessionInfo) -> Result<(), Self::Error> {
201 let mut sessions = self.sessions.write().await;
202 sessions.insert(session_info.session_id.clone(), session_info);
203 Ok(())
204 }
205
206 async fn set_session_state(
207 &self,
208 session_id: &str,
209 key: &str,
210 value: serde_json::Value,
211 ) -> Result<(), Self::Error> {
212 let mut sessions = self.sessions.write().await;
213
214 if let Some(session) = sessions.get_mut(session_id) {
215 session.state.insert(key.to_string(), value);
216 session.touch(); Ok(())
218 } else {
219 Err(SessionStorageError::SessionNotFound(session_id.to_string()))
220 }
221 }
222
223 async fn get_session_state(
224 &self,
225 session_id: &str,
226 key: &str,
227 ) -> Result<Option<serde_json::Value>, Self::Error> {
228 let sessions = self.sessions.read().await;
229
230 if let Some(session) = sessions.get(session_id) {
231 Ok(session.state.get(key).cloned())
232 } else {
233 Err(SessionStorageError::SessionNotFound(session_id.to_string()))
234 }
235 }
236
237 async fn remove_session_state(
238 &self,
239 session_id: &str,
240 key: &str,
241 ) -> Result<Option<serde_json::Value>, Self::Error> {
242 let mut sessions = self.sessions.write().await;
243
244 if let Some(session) = sessions.get_mut(session_id) {
245 let removed = session.state.remove(key);
246 session.touch(); Ok(removed)
248 } else {
249 Err(SessionStorageError::SessionNotFound(session_id.to_string()))
250 }
251 }
252
253 async fn delete_session(&self, session_id: &str) -> Result<bool, Self::Error> {
254 let mut sessions = self.sessions.write().await;
255 let mut events = self.events.write().await;
256
257 let removed = sessions.remove(session_id).is_some();
259
260 if removed {
261 events.remove(session_id);
263
264 debug!("Deleted session and all associated data: {}", session_id);
265 }
266
267 Ok(removed)
268 }
269
270 async fn list_sessions(&self) -> Result<Vec<String>, Self::Error> {
271 let sessions = self.sessions.read().await;
272 Ok(sessions.keys().cloned().collect())
273 }
274
275 async fn store_event(
280 &self,
281 session_id: &str,
282 mut event: SseEvent,
283 ) -> Result<SseEvent, Self::Error> {
284 let mut events = self.events.write().await;
285
286 event.id = self.event_counter.fetch_add(1, Ordering::SeqCst);
288
289 let event_list = events
290 .entry(session_id.to_string())
291 .or_insert_with(Vec::new);
292
293 if event_list.len() >= self.config.max_events_per_session {
295 return Err(SessionStorageError::MaxEventsReached(
296 self.config.max_events_per_session,
297 ));
298 }
299
300 event_list.push(event.clone());
301
302 debug!(
303 "Stored event: session={}, event_id={}",
304 session_id, event.id
305 );
306 Ok(event)
307 }
308
309 async fn get_events_after(
310 &self,
311 session_id: &str,
312 after_event_id: u64,
313 ) -> Result<Vec<SseEvent>, Self::Error> {
314 let events = self.events.read().await;
315
316 if let Some(event_list) = events.get(session_id) {
317 let filtered: Vec<SseEvent> = event_list
318 .iter()
319 .filter(|event| event.id > after_event_id)
320 .cloned()
321 .collect();
322 Ok(filtered)
323 } else {
324 Ok(Vec::new())
325 }
326 }
327
328 async fn get_recent_events(
329 &self,
330 session_id: &str,
331 limit: usize,
332 ) -> Result<Vec<SseEvent>, Self::Error> {
333 let events = self.events.read().await;
334
335 if let Some(event_list) = events.get(session_id) {
336 let recent: Vec<SseEvent> =
337 event_list.iter().rev().take(limit).rev().cloned().collect();
338 Ok(recent)
339 } else {
340 Ok(Vec::new())
341 }
342 }
343
344 async fn delete_events_before(
345 &self,
346 session_id: &str,
347 before_event_id: u64,
348 ) -> Result<u64, Self::Error> {
349 let mut events = self.events.write().await;
350
351 if let Some(event_list) = events.get_mut(session_id) {
352 let original_len = event_list.len();
353 event_list.retain(|event| event.id >= before_event_id);
354 let removed = original_len - event_list.len();
355 Ok(removed as u64)
356 } else {
357 Ok(0)
358 }
359 }
360
361 async fn expire_sessions(&self, older_than: SystemTime) -> Result<Vec<String>, Self::Error> {
366 let mut sessions = self.sessions.write().await;
367 let mut events = self.events.write().await;
368
369 let cutoff_millis = older_than
370 .duration_since(SystemTime::UNIX_EPOCH)
371 .unwrap_or_default()
372 .as_millis() as u64;
373
374 let mut expired_sessions = Vec::new();
375
376 sessions.retain(|session_id, session_info| {
378 if session_info.last_activity < cutoff_millis {
379 expired_sessions.push(session_id.clone());
380 false
381 } else {
382 true
383 }
384 });
385
386 for session_id in &expired_sessions {
388 events.remove(session_id);
389 }
390
391 if !expired_sessions.is_empty() {
392 info!("Expired {} sessions", expired_sessions.len());
393 }
394
395 Ok(expired_sessions)
396 }
397
398 async fn session_count(&self) -> Result<usize, Self::Error> {
399 let sessions = self.sessions.read().await;
400 Ok(sessions.len())
401 }
402
403 async fn event_count(&self) -> Result<usize, Self::Error> {
404 let events = self.events.read().await;
405 let total = events.values().map(|v| v.len()).sum();
406 Ok(total)
407 }
408
409 async fn maintenance(&self) -> Result<(), Self::Error> {
410 self.cleanup_events().await?;
411 Ok(())
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use turul_mcp_protocol::ServerCapabilities;
419
420 #[tokio::test]
421 async fn test_session_lifecycle() {
422 let storage = InMemorySessionStorage::new();
423
424 let session = storage
426 .create_session(ServerCapabilities::default())
427 .await
428 .unwrap();
429 let session_id = session.session_id.clone();
430
431 let retrieved = storage.get_session(&session_id).await.unwrap();
433 assert!(retrieved.is_some());
434 assert_eq!(retrieved.unwrap().session_id, session_id);
435
436 let deleted = storage.delete_session(&session_id).await.unwrap();
438 assert!(deleted);
439
440 let not_found = storage.get_session(&session_id).await.unwrap();
442 assert!(not_found.is_none());
443 }
444
445 #[tokio::test]
446 async fn test_session_state() {
447 let storage = InMemorySessionStorage::new();
448 let session = storage
449 .create_session(ServerCapabilities::default())
450 .await
451 .unwrap();
452 let session_id = session.session_id.clone();
453
454 let value = serde_json::json!({"test": "value"});
456 storage
457 .set_session_state(&session_id, "test_key", value.clone())
458 .await
459 .unwrap();
460
461 let retrieved = storage
463 .get_session_state(&session_id, "test_key")
464 .await
465 .unwrap();
466 assert_eq!(retrieved, Some(value));
467
468 let removed = storage
470 .remove_session_state(&session_id, "test_key")
471 .await
472 .unwrap();
473 assert_eq!(removed, Some(serde_json::json!({"test": "value"})));
474
475 let not_found = storage
477 .get_session_state(&session_id, "test_key")
478 .await
479 .unwrap();
480 assert_eq!(not_found, None);
481 }
482
483 #[tokio::test]
484 async fn test_event_storage_and_retrieval() {
485 let storage = InMemorySessionStorage::new();
486 let session = storage
487 .create_session(ServerCapabilities::default())
488 .await
489 .unwrap();
490 let session_id = session.session_id.clone();
491
492 let event1 = SseEvent::new("data".to_string(), serde_json::json!({"message": "test1"}));
494 let event2 = SseEvent::new("data".to_string(), serde_json::json!({"message": "test2"}));
495
496 let stored1 = storage.store_event(&session_id, event1).await.unwrap();
497 let stored2 = storage.store_event(&session_id, event2).await.unwrap();
498
499 assert!(stored1.id < stored2.id); let events_after = storage
503 .get_events_after(&session_id, stored1.id)
504 .await
505 .unwrap();
506 assert_eq!(events_after.len(), 1);
507 assert_eq!(events_after[0].id, stored2.id);
508
509 let recent = storage.get_recent_events(&session_id, 10).await.unwrap();
511 assert_eq!(recent.len(), 2);
512 }
513}