turbomcp_core/
session.rs

1//! Session management for `TurboMCP` applications
2//!
3//! Provides comprehensive session tracking, client management, and request analytics
4//! for MCP servers that need to manage multiple clients and track usage patterns.
5
6use std::collections::{HashMap, VecDeque};
7use std::sync::Arc;
8use std::time::Duration as StdDuration;
9
10use chrono::{DateTime, Duration, Utc};
11use dashmap::DashMap;
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use tokio::time::{Interval, interval};
15
16use crate::context::{ClientIdExtractor, ClientSession, RequestInfo};
17
18/// Configuration for session management
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct SessionConfig {
21    /// Maximum number of sessions to track
22    pub max_sessions: usize,
23    /// Session timeout (inactive sessions will be removed)
24    pub session_timeout: Duration,
25    /// Maximum request history to keep per session
26    pub max_request_history: usize,
27    /// Optional hard cap on requests per individual session
28    pub max_requests_per_session: Option<usize>,
29    /// Cleanup interval for expired sessions
30    pub cleanup_interval: StdDuration,
31    /// Whether to track request analytics
32    pub enable_analytics: bool,
33}
34
35impl Default for SessionConfig {
36    fn default() -> Self {
37        Self {
38            max_sessions: 1000,
39            session_timeout: Duration::hours(24),
40            max_request_history: 1000,
41            max_requests_per_session: None,
42            cleanup_interval: StdDuration::from_secs(300), // 5 minutes
43            enable_analytics: true,
44        }
45    }
46}
47
48/// Session analytics and usage statistics
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SessionAnalytics {
51    /// Total number of sessions created
52    pub total_sessions: usize,
53    /// Currently active sessions
54    pub active_sessions: usize,
55    /// Total requests processed
56    pub total_requests: usize,
57    /// Total successful requests
58    pub successful_requests: usize,
59    /// Total failed requests
60    pub failed_requests: usize,
61    /// Average session duration
62    pub avg_session_duration: Duration,
63    /// Most active clients (top 10)
64    pub top_clients: Vec<(String, usize)>,
65    /// Most used tools/methods (top 10)
66    pub top_methods: Vec<(String, usize)>,
67    /// Request rate (requests per minute)
68    pub requests_per_minute: f64,
69}
70
71/// Comprehensive session manager for MCP applications
72#[derive(Debug)]
73pub struct SessionManager {
74    /// Configuration
75    config: SessionConfig,
76    /// Active client sessions
77    sessions: Arc<DashMap<String, ClientSession>>,
78    /// Client ID extractor for authentication
79    client_extractor: Arc<ClientIdExtractor>,
80    /// Request history for analytics
81    request_history: Arc<RwLock<VecDeque<RequestInfo>>>,
82    /// Session creation history for analytics
83    session_history: Arc<RwLock<VecDeque<SessionEvent>>>,
84    /// Cleanup timer
85    cleanup_timer: Arc<RwLock<Option<Interval>>>,
86    /// Global statistics
87    stats: Arc<RwLock<SessionStats>>,
88}
89
90/// Internal statistics tracking
91#[derive(Debug, Default)]
92struct SessionStats {
93    total_sessions: usize,
94    total_requests: usize,
95    successful_requests: usize,
96    failed_requests: usize,
97    total_session_duration: Duration,
98}
99
100/// Session lifecycle events
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct SessionEvent {
103    /// Event timestamp
104    pub timestamp: DateTime<Utc>,
105    /// Client ID
106    pub client_id: String,
107    /// Event type
108    pub event_type: SessionEventType,
109    /// Additional metadata
110    pub metadata: HashMap<String, serde_json::Value>,
111}
112
113/// Types of session events
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub enum SessionEventType {
116    /// Session created
117    Created,
118    /// Session authenticated
119    Authenticated,
120    /// Session updated (activity)
121    Updated,
122    /// Session expired
123    Expired,
124    /// Session terminated
125    Terminated,
126}
127
128impl SessionManager {
129    /// Create a new session manager
130    #[must_use]
131    pub fn new(config: SessionConfig) -> Self {
132        Self {
133            config,
134            sessions: Arc::new(DashMap::new()),
135            client_extractor: Arc::new(ClientIdExtractor::new()),
136            request_history: Arc::new(RwLock::new(VecDeque::new())),
137            session_history: Arc::new(RwLock::new(VecDeque::new())),
138            cleanup_timer: Arc::new(RwLock::new(None)),
139            stats: Arc::new(RwLock::new(SessionStats::default())),
140        }
141    }
142
143    /// Start the session manager (begin cleanup task)
144    pub fn start(&self) {
145        let mut timer_guard = self.cleanup_timer.write();
146        if timer_guard.is_none() {
147            *timer_guard = Some(interval(self.config.cleanup_interval));
148        }
149        drop(timer_guard);
150
151        // Start cleanup task
152        let sessions = self.sessions.clone();
153        let config = self.config.clone();
154        let session_history = self.session_history.clone();
155        let stats = self.stats.clone();
156
157        tokio::spawn(async move {
158            let mut timer = interval(config.cleanup_interval);
159            loop {
160                timer.tick().await;
161                Self::cleanup_expired_sessions(&sessions, &config, &session_history, &stats);
162            }
163        });
164    }
165
166    /// Create or get existing session for a client
167    #[must_use]
168    pub fn get_or_create_session(
169        &self,
170        client_id: String,
171        transport_type: String,
172    ) -> ClientSession {
173        self.sessions.get(&client_id).map_or_else(
174            || {
175                // Enforce capacity before inserting a new session
176                self.enforce_capacity();
177
178                let session = ClientSession::new(client_id.clone(), transport_type);
179                self.sessions.insert(client_id.clone(), session.clone());
180
181                // Record session creation
182                let mut stats = self.stats.write();
183                stats.total_sessions += 1;
184                drop(stats);
185
186                self.record_session_event(client_id, SessionEventType::Created, HashMap::new());
187
188                session
189            },
190            |session| session.clone(),
191        )
192    }
193
194    /// Update client activity
195    pub fn update_client_activity(&self, client_id: &str) {
196        if let Some(mut session) = self.sessions.get_mut(client_id) {
197            session.update_activity();
198
199            // Optional: enforce per-session request cap by early termination
200            if let Some(cap) = self.config.max_requests_per_session
201                && session.request_count > cap
202            {
203                // Terminate the session when the cap is exceeded
204                // This is a conservative protection to prevent abusive sessions
205                drop(session);
206                let _ = self.terminate_session(client_id);
207            }
208        }
209    }
210
211    /// Authenticate a client session
212    #[must_use]
213    pub fn authenticate_client(
214        &self,
215        client_id: &str,
216        client_name: Option<String>,
217        token: Option<String>,
218    ) -> bool {
219        if let Some(mut session) = self.sessions.get_mut(client_id) {
220            session.authenticate(client_name.clone());
221
222            if let Some(token) = token {
223                self.client_extractor
224                    .register_token(token, client_id.to_string());
225            }
226
227            let mut metadata = HashMap::new();
228            if let Some(name) = client_name {
229                metadata.insert("client_name".to_string(), serde_json::json!(name));
230            }
231
232            self.record_session_event(
233                client_id.to_string(),
234                SessionEventType::Authenticated,
235                metadata,
236            );
237
238            return true;
239        }
240        false
241    }
242
243    /// Record a request for analytics
244    pub fn record_request(&self, mut request_info: RequestInfo) {
245        if !self.config.enable_analytics {
246            return;
247        }
248
249        // Update session activity
250        self.update_client_activity(&request_info.client_id);
251
252        // Update statistics
253        let mut stats = self.stats.write();
254        stats.total_requests += 1;
255        if request_info.success {
256            stats.successful_requests += 1;
257        } else {
258            stats.failed_requests += 1;
259        }
260        drop(stats);
261
262        // Add to request history
263        let mut history = self.request_history.write();
264        if history.len() >= self.config.max_request_history {
265            history.pop_front();
266        }
267
268        // Sanitize sensitive data before storing
269        request_info.parameters = self.sanitize_parameters(request_info.parameters);
270        history.push_back(request_info);
271    }
272
273    /// Get session analytics
274    #[must_use]
275    pub fn get_analytics(&self) -> SessionAnalytics {
276        let sessions = self.sessions.clone();
277
278        // Calculate active sessions
279        let active_sessions = sessions.len();
280
281        // Calculate average session duration
282        let total_duration = sessions
283            .iter()
284            .map(|entry| entry.session_duration())
285            .reduce(|acc, dur| acc + dur)
286            .unwrap_or_else(Duration::zero);
287
288        let avg_session_duration = if active_sessions > 0 {
289            total_duration / active_sessions as i32
290        } else {
291            Duration::zero()
292        };
293
294        // Calculate top clients by request count
295        let mut client_requests: HashMap<String, usize> = HashMap::new();
296        let mut method_requests: HashMap<String, usize> = HashMap::new();
297
298        let (recent_requests, top_clients, top_methods) = {
299            let history = self.request_history.read();
300            for request in history.iter() {
301                *client_requests
302                    .entry(request.client_id.clone())
303                    .or_insert(0) += 1;
304                *method_requests
305                    .entry(request.method_name.clone())
306                    .or_insert(0) += 1;
307            }
308
309            let mut top_clients: Vec<(String, usize)> = client_requests.into_iter().collect();
310            top_clients.sort_by(|a, b| b.1.cmp(&a.1));
311            top_clients.truncate(10);
312
313            let mut top_methods: Vec<(String, usize)> = method_requests.into_iter().collect();
314            top_methods.sort_by(|a, b| b.1.cmp(&a.1));
315            top_methods.truncate(10);
316
317            // Calculate request rate (requests per minute over last hour)
318            let one_hour_ago = Utc::now() - Duration::hours(1);
319            let recent_requests = history
320                .iter()
321                .filter(|req| req.timestamp > one_hour_ago)
322                .count();
323            drop(history);
324
325            (recent_requests, top_clients, top_methods)
326        };
327        let requests_per_minute = recent_requests as f64 / 60.0;
328
329        let stats = self.stats.read();
330        SessionAnalytics {
331            total_sessions: stats.total_sessions,
332            active_sessions,
333            total_requests: stats.total_requests,
334            successful_requests: stats.successful_requests,
335            failed_requests: stats.failed_requests,
336            avg_session_duration,
337            top_clients,
338            top_methods,
339            requests_per_minute,
340        }
341    }
342
343    /// Get all active sessions
344    #[must_use]
345    pub fn get_active_sessions(&self) -> Vec<ClientSession> {
346        self.sessions
347            .iter()
348            .map(|entry| entry.value().clone())
349            .collect()
350    }
351
352    /// Get session by client ID
353    #[must_use]
354    pub fn get_session(&self, client_id: &str) -> Option<ClientSession> {
355        self.sessions.get(client_id).map(|session| session.clone())
356    }
357
358    /// Get client ID extractor
359    #[must_use]
360    pub fn client_extractor(&self) -> Arc<ClientIdExtractor> {
361        self.client_extractor.clone()
362    }
363
364    /// Terminate a session
365    #[must_use]
366    pub fn terminate_session(&self, client_id: &str) -> bool {
367        if let Some((_, session)) = self.sessions.remove(client_id) {
368            let mut stats = self.stats.write();
369            stats.total_session_duration += session.session_duration();
370            drop(stats);
371
372            self.record_session_event(
373                client_id.to_string(),
374                SessionEventType::Terminated,
375                HashMap::new(),
376            );
377
378            true
379        } else {
380            false
381        }
382    }
383
384    /// Get request history
385    #[must_use]
386    pub fn get_request_history(&self, limit: Option<usize>) -> Vec<RequestInfo> {
387        let history = self.request_history.read();
388        let limit = limit.unwrap_or(100);
389
390        history.iter().rev().take(limit).cloned().collect()
391    }
392
393    /// Get session events
394    #[must_use]
395    pub fn get_session_events(&self, limit: Option<usize>) -> Vec<SessionEvent> {
396        let events = self.session_history.read();
397        let limit = limit.unwrap_or(100);
398
399        events.iter().rev().take(limit).cloned().collect()
400    }
401
402    // Private helper methods
403
404    fn cleanup_expired_sessions(
405        sessions: &Arc<DashMap<String, ClientSession>>,
406        config: &SessionConfig,
407        session_history: &Arc<RwLock<VecDeque<SessionEvent>>>,
408        stats: &Arc<RwLock<SessionStats>>,
409    ) {
410        let cutoff_time = Utc::now() - config.session_timeout;
411        let mut expired_sessions = Vec::new();
412
413        for entry in sessions.iter() {
414            if entry.last_activity < cutoff_time {
415                expired_sessions.push(entry.client_id.clone());
416            }
417        }
418
419        for client_id in expired_sessions {
420            if let Some((_, session)) = sessions.remove(&client_id) {
421                // Update stats
422                let mut stats_guard = stats.write();
423                stats_guard.total_session_duration += session.session_duration();
424                drop(stats_guard);
425
426                // Record event
427                let event = SessionEvent {
428                    timestamp: Utc::now(),
429                    client_id,
430                    event_type: SessionEventType::Expired,
431                    metadata: HashMap::new(),
432                };
433
434                let mut history = session_history.write();
435                if history.len() >= 1000 {
436                    history.pop_front();
437                }
438                history.push_back(event);
439            }
440        }
441    }
442
443    fn record_session_event(
444        &self,
445        client_id: String,
446        event_type: SessionEventType,
447        metadata: HashMap<String, serde_json::Value>,
448    ) {
449        let event = SessionEvent {
450            timestamp: Utc::now(),
451            client_id,
452            event_type,
453            metadata,
454        };
455
456        let mut history = self.session_history.write();
457        if history.len() >= 1000 {
458            history.pop_front();
459        }
460        history.push_back(event);
461    }
462
463    /// Ensure the number of active sessions does not exceed `max_sessions`.
464    /// This uses an LRU-like policy (evict least recently active sessions first).
465    fn enforce_capacity(&self) {
466        let target = self.config.max_sessions;
467        // Fast path
468        if self.sessions.len() < target {
469            return;
470        }
471
472        // Collect sessions sorted by last_activity ascending (least recent first)
473        let mut entries: Vec<_> = self
474            .sessions
475            .iter()
476            .map(|entry| (entry.key().clone(), entry.last_activity))
477            .collect();
478        entries.sort_by_key(|(_, ts)| *ts);
479
480        // Evict until under capacity
481        let mut to_evict = self.sessions.len().saturating_sub(target) + 1; // make room for 1 new
482        for (client_id, _) in entries {
483            if to_evict == 0 {
484                break;
485            }
486            if let Some((_, session)) = self.sessions.remove(&client_id) {
487                let mut stats = self.stats.write();
488                stats.total_session_duration += session.session_duration();
489                drop(stats);
490
491                // Record eviction as termination event
492                let event = SessionEvent {
493                    timestamp: Utc::now(),
494                    client_id: client_id.clone(),
495                    event_type: SessionEventType::Terminated,
496                    metadata: {
497                        let mut m = HashMap::new();
498                        m.insert("reason".to_string(), serde_json::json!("capacity_eviction"));
499                        m
500                    },
501                };
502                {
503                    let mut history = self.session_history.write();
504                    if history.len() >= 1000 {
505                        history.pop_front();
506                    }
507                    history.push_back(event);
508                } // Drop history lock early
509                to_evict = to_evict.saturating_sub(1);
510            }
511        }
512    }
513
514    fn sanitize_parameters(&self, mut params: serde_json::Value) -> serde_json::Value {
515        let _ = self; // Currently unused, may use config in future
516        // Remove or mask sensitive fields
517        if let Some(obj) = params.as_object_mut() {
518            let sensitive_keys = &["password", "token", "api_key", "secret", "auth"];
519            for key in sensitive_keys {
520                if obj.contains_key(*key) {
521                    obj.insert(
522                        (*key).to_string(),
523                        serde_json::Value::String("[REDACTED]".to_string()),
524                    );
525                }
526            }
527        }
528        params
529    }
530}
531
532impl Default for SessionManager {
533    fn default() -> Self {
534        Self::new(SessionConfig::default())
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541
542    #[tokio::test]
543    async fn test_session_creation() {
544        let manager = SessionManager::new(SessionConfig::default());
545
546        let session = manager.get_or_create_session("client-1".to_string(), "http".to_string());
547
548        assert_eq!(session.client_id, "client-1");
549        assert_eq!(session.transport_type, "http");
550        assert!(!session.authenticated);
551
552        let analytics = manager.get_analytics();
553        assert_eq!(analytics.total_sessions, 1);
554        assert_eq!(analytics.active_sessions, 1);
555    }
556
557    #[tokio::test]
558    async fn test_session_authentication() {
559        let manager = SessionManager::new(SessionConfig::default());
560
561        let session = manager.get_or_create_session("client-1".to_string(), "http".to_string());
562        assert!(!session.authenticated);
563
564        let success = manager.authenticate_client(
565            "client-1",
566            Some("Test Client".to_string()),
567            Some("token123".to_string()),
568        );
569
570        assert!(success);
571
572        let updated_session = manager.get_session("client-1").unwrap();
573        assert!(updated_session.authenticated);
574        assert_eq!(updated_session.client_name, Some("Test Client".to_string()));
575    }
576
577    #[tokio::test]
578    async fn test_request_recording() {
579        let mut manager = SessionManager::new(SessionConfig::default());
580        manager.config.enable_analytics = true;
581
582        let request = RequestInfo::new(
583            "client-1".to_string(),
584            "test_method".to_string(),
585            serde_json::json!({"param": "value"}),
586        )
587        .complete_success(100);
588
589        manager.record_request(request);
590
591        let analytics = manager.get_analytics();
592        assert_eq!(analytics.total_requests, 1);
593        assert_eq!(analytics.successful_requests, 1);
594        assert_eq!(analytics.failed_requests, 0);
595
596        let history = manager.get_request_history(Some(10));
597        assert_eq!(history.len(), 1);
598        assert_eq!(history[0].method_name, "test_method");
599    }
600
601    #[tokio::test]
602    async fn test_session_termination() {
603        let manager = SessionManager::new(SessionConfig::default());
604
605        let _ = manager.get_or_create_session("client-1".to_string(), "http".to_string());
606        assert!(manager.get_session("client-1").is_some());
607
608        let terminated = manager.terminate_session("client-1");
609        assert!(terminated);
610        assert!(manager.get_session("client-1").is_none());
611
612        let analytics = manager.get_analytics();
613        assert_eq!(analytics.active_sessions, 0);
614    }
615
616    #[tokio::test]
617    async fn test_parameter_sanitization() {
618        let manager = SessionManager::new(SessionConfig::default());
619
620        let sensitive_params = serde_json::json!({
621            "username": "testuser",
622            "password": "secret123",
623            "api_key": "key456",
624            "data": "normal_data"
625        });
626
627        let sanitized = manager.sanitize_parameters(sensitive_params);
628        let obj = sanitized.as_object().unwrap();
629
630        assert_eq!(
631            obj["username"],
632            serde_json::Value::String("testuser".to_string())
633        );
634        assert_eq!(
635            obj["password"],
636            serde_json::Value::String("[REDACTED]".to_string())
637        );
638        assert_eq!(
639            obj["api_key"],
640            serde_json::Value::String("[REDACTED]".to_string())
641        );
642        assert_eq!(
643            obj["data"],
644            serde_json::Value::String("normal_data".to_string())
645        );
646    }
647}