turul_mcp_session_storage/
in_memory.rs

1//! In-Memory Session Storage Implementation
2//!
3//! This implementation stores all session data in memory using Arc<RwLock<>>
4//! for thread safety. Suitable for:
5//! - Development and testing
6//! - Single-instance deployments with session persistence not required
7//! - High-performance scenarios where sessions are short-lived
8
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::SystemTime;
13
14use async_trait::async_trait;
15use tokio::sync::RwLock;
16use tracing::{debug, info};
17
18use crate::{SessionInfo, SessionStorage, SessionStorageError, SseEvent};
19use turul_mcp_protocol::ServerCapabilities;
20
21/// In-memory storage for sessions and events (SSE compliant)
22#[derive(Debug, Clone)]
23pub struct InMemorySessionStorage {
24    /// All sessions by session ID
25    sessions: Arc<RwLock<HashMap<String, SessionInfo>>>,
26    /// All events by session_id -> Vec\<SseEvent\>
27    events: Arc<RwLock<HashMap<String, Vec<SseEvent>>>>,
28    /// Global event ID counter for ordering
29    event_counter: Arc<AtomicU64>,
30    /// Configuration
31    config: InMemoryConfig,
32}
33
34/// Configuration for in-memory session storage
35#[derive(Debug, Clone)]
36pub struct InMemoryConfig {
37    /// Maximum events to keep per session (for memory management)
38    pub max_events_per_session: usize,
39    /// Maximum sessions to keep (for memory management)
40    pub max_sessions: usize,
41}
42
43impl Default for InMemoryConfig {
44    fn default() -> Self {
45        Self {
46            max_events_per_session: 10_000, // 10k events per session
47            max_sessions: 100_000,          // 100k concurrent sessions
48        }
49    }
50}
51
52/// Error type for in-memory storage operations
53#[derive(Debug, thiserror::Error)]
54pub enum InMemoryError {
55    #[error("Session not found: {0}")]
56    SessionNotFound(String),
57    #[error("Maximum sessions limit reached: {0}")]
58    MaxSessionsReached(usize),
59    #[error("Maximum events per session limit reached: {0}")]
60    MaxEventsReached(usize),
61    #[error("Serialization error: {0}")]
62    SerializationError(#[from] serde_json::Error),
63}
64
65impl Default for InMemorySessionStorage {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl InMemorySessionStorage {
72    /// Create new in-memory session storage with default configuration
73    pub fn new() -> Self {
74        Self::with_config(InMemoryConfig::default())
75    }
76
77    /// Create new in-memory session storage with custom configuration
78    pub fn with_config(config: InMemoryConfig) -> Self {
79        Self {
80            sessions: Arc::new(RwLock::new(HashMap::new())),
81            events: Arc::new(RwLock::new(HashMap::new())),
82            event_counter: Arc::new(AtomicU64::new(1)), // Start at 1 for SSE compatibility
83            config,
84        }
85    }
86
87    /// Get current statistics
88    pub async fn stats(&self) -> InMemoryStats {
89        let sessions = self.sessions.read().await;
90        let events = self.events.read().await;
91
92        let total_events = events.values().map(|v| v.len()).sum();
93
94        InMemoryStats {
95            session_count: sessions.len(),
96            total_event_count: total_events,
97            max_events_per_session: self.config.max_events_per_session,
98            max_sessions: self.config.max_sessions,
99        }
100    }
101
102    /// Cleanup old events to prevent memory bloat
103    async fn cleanup_events(&self) -> Result<u64, InMemoryError> {
104        let mut events = self.events.write().await;
105        let mut total_removed = 0u64;
106
107        for (session_id, event_list) in events.iter_mut() {
108            if event_list.len() > self.config.max_events_per_session {
109                let excess = event_list.len() - self.config.max_events_per_session;
110                event_list.drain(0..excess); // Remove oldest events
111                total_removed += excess as u64;
112                debug!(
113                    "Cleaned up {} old events for session {}",
114                    excess, session_id
115                );
116            }
117        }
118
119        if total_removed > 0 {
120            info!(
121                "Cleaned up {} old events across all sessions",
122                total_removed
123            );
124        }
125
126        Ok(total_removed)
127    }
128}
129
130/// Statistics for in-memory storage
131#[derive(Debug, Clone)]
132pub struct InMemoryStats {
133    pub session_count: usize,
134    pub total_event_count: usize,
135    pub max_events_per_session: usize,
136    pub max_sessions: usize,
137}
138
139#[async_trait]
140impl SessionStorage for InMemorySessionStorage {
141    type Error = SessionStorageError;
142
143    fn backend_name(&self) -> &'static str {
144        "InMemory"
145    }
146
147    // ============================================================================
148    // Session Management
149    // ============================================================================
150
151    async fn create_session(
152        &self,
153        capabilities: ServerCapabilities,
154    ) -> Result<SessionInfo, Self::Error> {
155        let mut sessions = self.sessions.write().await;
156
157        if sessions.len() >= self.config.max_sessions {
158            return Err(SessionStorageError::MaxSessionsReached(
159                self.config.max_sessions,
160            ));
161        }
162
163        let mut session = SessionInfo::new();
164        session.server_capabilities = Some(capabilities);
165
166        let session_id = session.session_id.clone();
167        sessions.insert(session_id.clone(), session.clone());
168
169        debug!("Created session: {}", session_id);
170        Ok(session)
171    }
172
173    async fn create_session_with_id(
174        &self,
175        session_id: String,
176        capabilities: ServerCapabilities,
177    ) -> Result<SessionInfo, Self::Error> {
178        let mut sessions = self.sessions.write().await;
179
180        if sessions.len() >= self.config.max_sessions {
181            return Err(SessionStorageError::MaxSessionsReached(
182                self.config.max_sessions,
183            ));
184        }
185
186        let mut session = SessionInfo::with_id(session_id.clone());
187        session.server_capabilities = Some(capabilities);
188
189        sessions.insert(session_id.clone(), session.clone());
190
191        debug!("Created session with ID: {}", session_id);
192        Ok(session)
193    }
194
195    async fn get_session(&self, session_id: &str) -> Result<Option<SessionInfo>, Self::Error> {
196        let sessions = self.sessions.read().await;
197        Ok(sessions.get(session_id).cloned())
198    }
199
200    async fn update_session(&self, session_info: SessionInfo) -> Result<(), Self::Error> {
201        let mut sessions = self.sessions.write().await;
202        sessions.insert(session_info.session_id.clone(), session_info);
203        Ok(())
204    }
205
206    async fn set_session_state(
207        &self,
208        session_id: &str,
209        key: &str,
210        value: serde_json::Value,
211    ) -> Result<(), Self::Error> {
212        let mut sessions = self.sessions.write().await;
213
214        if let Some(session) = sessions.get_mut(session_id) {
215            session.state.insert(key.to_string(), value);
216            session.touch(); // Update last activity
217            Ok(())
218        } else {
219            Err(SessionStorageError::SessionNotFound(session_id.to_string()))
220        }
221    }
222
223    async fn get_session_state(
224        &self,
225        session_id: &str,
226        key: &str,
227    ) -> Result<Option<serde_json::Value>, Self::Error> {
228        let sessions = self.sessions.read().await;
229
230        if let Some(session) = sessions.get(session_id) {
231            Ok(session.state.get(key).cloned())
232        } else {
233            Err(SessionStorageError::SessionNotFound(session_id.to_string()))
234        }
235    }
236
237    async fn remove_session_state(
238        &self,
239        session_id: &str,
240        key: &str,
241    ) -> Result<Option<serde_json::Value>, Self::Error> {
242        let mut sessions = self.sessions.write().await;
243
244        if let Some(session) = sessions.get_mut(session_id) {
245            let removed = session.state.remove(key);
246            session.touch(); // Update last activity
247            Ok(removed)
248        } else {
249            Err(SessionStorageError::SessionNotFound(session_id.to_string()))
250        }
251    }
252
253    async fn delete_session(&self, session_id: &str) -> Result<bool, Self::Error> {
254        let mut sessions = self.sessions.write().await;
255        let mut events = self.events.write().await;
256
257        // Remove the session
258        let removed = sessions.remove(session_id).is_some();
259
260        if removed {
261            // Remove all events for this session
262            events.remove(session_id);
263
264            debug!("Deleted session and all associated data: {}", session_id);
265        }
266
267        Ok(removed)
268    }
269
270    async fn list_sessions(&self) -> Result<Vec<String>, Self::Error> {
271        let sessions = self.sessions.read().await;
272        Ok(sessions.keys().cloned().collect())
273    }
274
275    // ============================================================================
276    // Event Management
277    // ============================================================================
278
279    async fn store_event(
280        &self,
281        session_id: &str,
282        mut event: SseEvent,
283    ) -> Result<SseEvent, Self::Error> {
284        let mut events = self.events.write().await;
285
286        // Assign unique event ID
287        event.id = self.event_counter.fetch_add(1, Ordering::SeqCst);
288
289        let event_list = events
290            .entry(session_id.to_string())
291            .or_insert_with(Vec::new);
292
293        // Check event limit
294        if event_list.len() >= self.config.max_events_per_session {
295            return Err(SessionStorageError::MaxEventsReached(
296                self.config.max_events_per_session,
297            ));
298        }
299
300        event_list.push(event.clone());
301
302        debug!(
303            "Stored event: session={}, event_id={}",
304            session_id, event.id
305        );
306        Ok(event)
307    }
308
309    async fn get_events_after(
310        &self,
311        session_id: &str,
312        after_event_id: u64,
313    ) -> Result<Vec<SseEvent>, Self::Error> {
314        let events = self.events.read().await;
315
316        if let Some(event_list) = events.get(session_id) {
317            let filtered: Vec<SseEvent> = event_list
318                .iter()
319                .filter(|event| event.id > after_event_id)
320                .cloned()
321                .collect();
322            Ok(filtered)
323        } else {
324            Ok(Vec::new())
325        }
326    }
327
328    async fn get_recent_events(
329        &self,
330        session_id: &str,
331        limit: usize,
332    ) -> Result<Vec<SseEvent>, Self::Error> {
333        let events = self.events.read().await;
334
335        if let Some(event_list) = events.get(session_id) {
336            let recent: Vec<SseEvent> =
337                event_list.iter().rev().take(limit).rev().cloned().collect();
338            Ok(recent)
339        } else {
340            Ok(Vec::new())
341        }
342    }
343
344    async fn delete_events_before(
345        &self,
346        session_id: &str,
347        before_event_id: u64,
348    ) -> Result<u64, Self::Error> {
349        let mut events = self.events.write().await;
350
351        if let Some(event_list) = events.get_mut(session_id) {
352            let original_len = event_list.len();
353            event_list.retain(|event| event.id >= before_event_id);
354            let removed = original_len - event_list.len();
355            Ok(removed as u64)
356        } else {
357            Ok(0)
358        }
359    }
360
361    // ============================================================================
362    // Cleanup and Maintenance
363    // ============================================================================
364
365    async fn expire_sessions(&self, older_than: SystemTime) -> Result<Vec<String>, Self::Error> {
366        let mut sessions = self.sessions.write().await;
367        let mut events = self.events.write().await;
368
369        let cutoff_millis = older_than
370            .duration_since(SystemTime::UNIX_EPOCH)
371            .unwrap_or_default()
372            .as_millis() as u64;
373
374        let mut expired_sessions = Vec::new();
375
376        // Find expired sessions
377        sessions.retain(|session_id, session_info| {
378            if session_info.last_activity < cutoff_millis {
379                expired_sessions.push(session_id.clone());
380                false
381            } else {
382                true
383            }
384        });
385
386        // Remove events for expired sessions
387        for session_id in &expired_sessions {
388            events.remove(session_id);
389        }
390
391        if !expired_sessions.is_empty() {
392            info!("Expired {} sessions", expired_sessions.len());
393        }
394
395        Ok(expired_sessions)
396    }
397
398    async fn session_count(&self) -> Result<usize, Self::Error> {
399        let sessions = self.sessions.read().await;
400        Ok(sessions.len())
401    }
402
403    async fn event_count(&self) -> Result<usize, Self::Error> {
404        let events = self.events.read().await;
405        let total = events.values().map(|v| v.len()).sum();
406        Ok(total)
407    }
408
409    async fn maintenance(&self) -> Result<(), Self::Error> {
410        self.cleanup_events().await?;
411        Ok(())
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use turul_mcp_protocol::ServerCapabilities;
419
420    #[tokio::test]
421    async fn test_session_lifecycle() {
422        let storage = InMemorySessionStorage::new();
423
424        // Create session
425        let session = storage
426            .create_session(ServerCapabilities::default())
427            .await
428            .unwrap();
429        let session_id = session.session_id.clone();
430
431        // Get session
432        let retrieved = storage.get_session(&session_id).await.unwrap();
433        assert!(retrieved.is_some());
434        assert_eq!(retrieved.unwrap().session_id, session_id);
435
436        // Delete session
437        let deleted = storage.delete_session(&session_id).await.unwrap();
438        assert!(deleted);
439
440        // Verify deletion
441        let not_found = storage.get_session(&session_id).await.unwrap();
442        assert!(not_found.is_none());
443    }
444
445    #[tokio::test]
446    async fn test_session_state() {
447        let storage = InMemorySessionStorage::new();
448        let session = storage
449            .create_session(ServerCapabilities::default())
450            .await
451            .unwrap();
452        let session_id = session.session_id.clone();
453
454        // Set state
455        let value = serde_json::json!({"test": "value"});
456        storage
457            .set_session_state(&session_id, "test_key", value.clone())
458            .await
459            .unwrap();
460
461        // Get state
462        let retrieved = storage
463            .get_session_state(&session_id, "test_key")
464            .await
465            .unwrap();
466        assert_eq!(retrieved, Some(value));
467
468        // Remove state
469        let removed = storage
470            .remove_session_state(&session_id, "test_key")
471            .await
472            .unwrap();
473        assert_eq!(removed, Some(serde_json::json!({"test": "value"})));
474
475        // Verify removal
476        let not_found = storage
477            .get_session_state(&session_id, "test_key")
478            .await
479            .unwrap();
480        assert_eq!(not_found, None);
481    }
482
483    #[tokio::test]
484    async fn test_event_storage_and_retrieval() {
485        let storage = InMemorySessionStorage::new();
486        let session = storage
487            .create_session(ServerCapabilities::default())
488            .await
489            .unwrap();
490        let session_id = session.session_id.clone();
491
492        // Store events (no stream_id needed)
493        let event1 = SseEvent::new("data".to_string(), serde_json::json!({"message": "test1"}));
494        let event2 = SseEvent::new("data".to_string(), serde_json::json!({"message": "test2"}));
495
496        let stored1 = storage.store_event(&session_id, event1).await.unwrap();
497        let stored2 = storage.store_event(&session_id, event2).await.unwrap();
498
499        assert!(stored1.id < stored2.id); // Event IDs should be ordered
500
501        // Get events after first event
502        let events_after = storage
503            .get_events_after(&session_id, stored1.id)
504            .await
505            .unwrap();
506        assert_eq!(events_after.len(), 1);
507        assert_eq!(events_after[0].id, stored2.id);
508
509        // Get recent events
510        let recent = storage.get_recent_events(&session_id, 10).await.unwrap();
511        assert_eq!(recent.len(), 2);
512    }
513}