turbomcp_core/
session.rs

1//! Session management for `TurboMCP` applications
2//!
3//! Provides comprehensive session tracking, client management, and request analytics
4//! for MCP servers that need to manage multiple clients and track usage patterns.
5
6use std::collections::{HashMap, VecDeque};
7use std::sync::Arc;
8use std::time::Duration as StdDuration;
9
10use chrono::{DateTime, Duration, Utc};
11use dashmap::DashMap;
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use tokio::time::{Interval, interval};
15
16use crate::context::{
17    ClientIdExtractor, ClientSession, CompletionContext, ElicitationContext, RequestInfo,
18};
19
20/// Configuration for session management
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct SessionConfig {
23    /// Maximum number of sessions to track
24    pub max_sessions: usize,
25    /// Session timeout (inactive sessions will be removed)
26    pub session_timeout: Duration,
27    /// Maximum request history to keep per session
28    pub max_request_history: usize,
29    /// Optional hard cap on requests per individual session
30    pub max_requests_per_session: Option<usize>,
31    /// Cleanup interval for expired sessions
32    pub cleanup_interval: StdDuration,
33    /// Whether to track request analytics
34    pub enable_analytics: bool,
35}
36
37impl Default for SessionConfig {
38    fn default() -> Self {
39        Self {
40            max_sessions: 1000,
41            session_timeout: Duration::hours(24),
42            max_request_history: 1000,
43            max_requests_per_session: None,
44            cleanup_interval: StdDuration::from_secs(300), // 5 minutes
45            enable_analytics: true,
46        }
47    }
48}
49
50/// Session analytics and usage statistics
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct SessionAnalytics {
53    /// Total number of sessions created
54    pub total_sessions: usize,
55    /// Currently active sessions
56    pub active_sessions: usize,
57    /// Total requests processed
58    pub total_requests: usize,
59    /// Total successful requests
60    pub successful_requests: usize,
61    /// Total failed requests
62    pub failed_requests: usize,
63    /// Average session duration
64    pub avg_session_duration: Duration,
65    /// Most active clients (top 10)
66    pub top_clients: Vec<(String, usize)>,
67    /// Most used tools/methods (top 10)
68    pub top_methods: Vec<(String, usize)>,
69    /// Request rate (requests per minute)
70    pub requests_per_minute: f64,
71}
72
73/// Comprehensive session manager for MCP applications
74#[derive(Debug)]
75pub struct SessionManager {
76    /// Configuration
77    config: SessionConfig,
78    /// Active client sessions
79    sessions: Arc<DashMap<String, ClientSession>>,
80    /// Client ID extractor for authentication
81    client_extractor: Arc<ClientIdExtractor>,
82    /// Request history for analytics
83    request_history: Arc<RwLock<VecDeque<RequestInfo>>>,
84    /// Session creation history for analytics
85    session_history: Arc<RwLock<VecDeque<SessionEvent>>>,
86    /// Cleanup timer
87    cleanup_timer: Arc<RwLock<Option<Interval>>>,
88    /// Global statistics
89    stats: Arc<RwLock<SessionStats>>,
90    /// Pending elicitations by client ID
91    pending_elicitations: Arc<DashMap<String, Vec<ElicitationContext>>>,
92    /// Active completions by client ID
93    active_completions: Arc<DashMap<String, Vec<CompletionContext>>>,
94}
95
96/// Internal statistics tracking
97#[derive(Debug, Default)]
98struct SessionStats {
99    total_sessions: usize,
100    total_requests: usize,
101    successful_requests: usize,
102    failed_requests: usize,
103    total_session_duration: Duration,
104}
105
106/// Session lifecycle events
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SessionEvent {
109    /// Event timestamp
110    pub timestamp: DateTime<Utc>,
111    /// Client ID
112    pub client_id: String,
113    /// Event type
114    pub event_type: SessionEventType,
115    /// Additional metadata
116    pub metadata: HashMap<String, serde_json::Value>,
117}
118
119/// Types of session events
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub enum SessionEventType {
122    /// Session created
123    Created,
124    /// Session authenticated
125    Authenticated,
126    /// Session updated (activity)
127    Updated,
128    /// Session expired
129    Expired,
130    /// Session terminated
131    Terminated,
132}
133
134impl SessionManager {
135    /// Create a new session manager
136    #[must_use]
137    pub fn new(config: SessionConfig) -> Self {
138        Self {
139            config,
140            sessions: Arc::new(DashMap::new()),
141            client_extractor: Arc::new(ClientIdExtractor::new()),
142            request_history: Arc::new(RwLock::new(VecDeque::new())),
143            session_history: Arc::new(RwLock::new(VecDeque::new())),
144            cleanup_timer: Arc::new(RwLock::new(None)),
145            stats: Arc::new(RwLock::new(SessionStats::default())),
146            pending_elicitations: Arc::new(DashMap::new()),
147            active_completions: Arc::new(DashMap::new()),
148        }
149    }
150
151    /// Start the session manager (begin cleanup task)
152    pub fn start(&self) {
153        let mut timer_guard = self.cleanup_timer.write();
154        if timer_guard.is_none() {
155            *timer_guard = Some(interval(self.config.cleanup_interval));
156        }
157        drop(timer_guard);
158
159        // Start cleanup task
160        let sessions = self.sessions.clone();
161        let config = self.config.clone();
162        let session_history = self.session_history.clone();
163        let stats = self.stats.clone();
164        let pending_elicitations = self.pending_elicitations.clone();
165        let active_completions = self.active_completions.clone();
166
167        tokio::spawn(async move {
168            let mut timer = interval(config.cleanup_interval);
169            loop {
170                timer.tick().await;
171                Self::cleanup_expired_sessions(
172                    &sessions,
173                    &config,
174                    &session_history,
175                    &stats,
176                    &pending_elicitations,
177                    &active_completions,
178                );
179            }
180        });
181    }
182
183    /// Create or get existing session for a client
184    #[must_use]
185    pub fn get_or_create_session(
186        &self,
187        client_id: String,
188        transport_type: String,
189    ) -> ClientSession {
190        self.sessions.get(&client_id).map_or_else(
191            || {
192                // Enforce capacity before inserting a new session
193                self.enforce_capacity();
194
195                let session = ClientSession::new(client_id.clone(), transport_type);
196                self.sessions.insert(client_id.clone(), session.clone());
197
198                // Record session creation
199                let mut stats = self.stats.write();
200                stats.total_sessions += 1;
201                drop(stats);
202
203                self.record_session_event(client_id, SessionEventType::Created, HashMap::new());
204
205                session
206            },
207            |session| session.clone(),
208        )
209    }
210
211    /// Update client activity
212    pub fn update_client_activity(&self, client_id: &str) {
213        if let Some(mut session) = self.sessions.get_mut(client_id) {
214            session.update_activity();
215
216            // Optional: enforce per-session request cap by early termination
217            if let Some(cap) = self.config.max_requests_per_session
218                && session.request_count > cap
219            {
220                // Terminate the session when the cap is exceeded
221                // This is a conservative protection to prevent abusive sessions
222                drop(session);
223                let _ = self.terminate_session(client_id);
224            }
225        }
226    }
227
228    /// Authenticate a client session
229    #[must_use]
230    pub fn authenticate_client(
231        &self,
232        client_id: &str,
233        client_name: Option<String>,
234        token: Option<String>,
235    ) -> bool {
236        if let Some(mut session) = self.sessions.get_mut(client_id) {
237            session.authenticate(client_name.clone());
238
239            if let Some(token) = token {
240                self.client_extractor
241                    .register_token(token, client_id.to_string());
242            }
243
244            let mut metadata = HashMap::new();
245            if let Some(name) = client_name {
246                metadata.insert("client_name".to_string(), serde_json::json!(name));
247            }
248
249            self.record_session_event(
250                client_id.to_string(),
251                SessionEventType::Authenticated,
252                metadata,
253            );
254
255            return true;
256        }
257        false
258    }
259
260    /// Record a request for analytics
261    pub fn record_request(&self, mut request_info: RequestInfo) {
262        if !self.config.enable_analytics {
263            return;
264        }
265
266        // Update session activity
267        self.update_client_activity(&request_info.client_id);
268
269        // Update statistics
270        let mut stats = self.stats.write();
271        stats.total_requests += 1;
272        if request_info.success {
273            stats.successful_requests += 1;
274        } else {
275            stats.failed_requests += 1;
276        }
277        drop(stats);
278
279        // Add to request history
280        let mut history = self.request_history.write();
281        if history.len() >= self.config.max_request_history {
282            history.pop_front();
283        }
284
285        // Sanitize sensitive data before storing
286        request_info.parameters = self.sanitize_parameters(request_info.parameters);
287        history.push_back(request_info);
288    }
289
290    /// Get session analytics
291    #[must_use]
292    pub fn get_analytics(&self) -> SessionAnalytics {
293        let sessions = self.sessions.clone();
294
295        // Calculate active sessions
296        let active_sessions = sessions.len();
297
298        // Calculate average session duration
299        let total_duration = sessions
300            .iter()
301            .map(|entry| entry.session_duration())
302            .reduce(|acc, dur| acc + dur)
303            .unwrap_or_else(Duration::zero);
304
305        let avg_session_duration = if active_sessions > 0 {
306            total_duration / active_sessions as i32
307        } else {
308            Duration::zero()
309        };
310
311        // Calculate top clients by request count
312        let mut client_requests: HashMap<String, usize> = HashMap::new();
313        let mut method_requests: HashMap<String, usize> = HashMap::new();
314
315        let (recent_requests, top_clients, top_methods) = {
316            let history = self.request_history.read();
317            for request in history.iter() {
318                *client_requests
319                    .entry(request.client_id.clone())
320                    .or_insert(0) += 1;
321                *method_requests
322                    .entry(request.method_name.clone())
323                    .or_insert(0) += 1;
324            }
325
326            let mut top_clients: Vec<(String, usize)> = client_requests.into_iter().collect();
327            top_clients.sort_by(|a, b| b.1.cmp(&a.1));
328            top_clients.truncate(10);
329
330            let mut top_methods: Vec<(String, usize)> = method_requests.into_iter().collect();
331            top_methods.sort_by(|a, b| b.1.cmp(&a.1));
332            top_methods.truncate(10);
333
334            // Calculate request rate (requests per minute over last hour)
335            let one_hour_ago = Utc::now() - Duration::hours(1);
336            let recent_requests = history
337                .iter()
338                .filter(|req| req.timestamp > one_hour_ago)
339                .count();
340            drop(history);
341
342            (recent_requests, top_clients, top_methods)
343        };
344        let requests_per_minute = recent_requests as f64 / 60.0;
345
346        let stats = self.stats.read();
347        SessionAnalytics {
348            total_sessions: stats.total_sessions,
349            active_sessions,
350            total_requests: stats.total_requests,
351            successful_requests: stats.successful_requests,
352            failed_requests: stats.failed_requests,
353            avg_session_duration,
354            top_clients,
355            top_methods,
356            requests_per_minute,
357        }
358    }
359
360    /// Get all active sessions
361    #[must_use]
362    pub fn get_active_sessions(&self) -> Vec<ClientSession> {
363        self.sessions
364            .iter()
365            .map(|entry| entry.value().clone())
366            .collect()
367    }
368
369    /// Get session by client ID
370    #[must_use]
371    pub fn get_session(&self, client_id: &str) -> Option<ClientSession> {
372        self.sessions.get(client_id).map(|session| session.clone())
373    }
374
375    /// Get client ID extractor
376    #[must_use]
377    pub fn client_extractor(&self) -> Arc<ClientIdExtractor> {
378        self.client_extractor.clone()
379    }
380
381    /// Terminate a session
382    #[must_use]
383    pub fn terminate_session(&self, client_id: &str) -> bool {
384        if let Some((_, session)) = self.sessions.remove(client_id) {
385            let mut stats = self.stats.write();
386            stats.total_session_duration += session.session_duration();
387            drop(stats);
388
389            // Clean up associated elicitations and completions
390            self.pending_elicitations.remove(client_id);
391            self.active_completions.remove(client_id);
392
393            self.record_session_event(
394                client_id.to_string(),
395                SessionEventType::Terminated,
396                HashMap::new(),
397            );
398
399            true
400        } else {
401            false
402        }
403    }
404
405    /// Get request history
406    #[must_use]
407    pub fn get_request_history(&self, limit: Option<usize>) -> Vec<RequestInfo> {
408        let history = self.request_history.read();
409        let limit = limit.unwrap_or(100);
410
411        history.iter().rev().take(limit).cloned().collect()
412    }
413
414    /// Get session events
415    #[must_use]
416    pub fn get_session_events(&self, limit: Option<usize>) -> Vec<SessionEvent> {
417        let events = self.session_history.read();
418        let limit = limit.unwrap_or(100);
419
420        events.iter().rev().take(limit).cloned().collect()
421    }
422
423    // Elicitation Management
424
425    /// Add a pending elicitation for a client
426    pub fn add_pending_elicitation(&self, client_id: String, elicitation: ElicitationContext) {
427        self.pending_elicitations
428            .entry(client_id)
429            .or_default()
430            .push(elicitation);
431    }
432
433    /// Get all pending elicitations for a client
434    #[must_use]
435    pub fn get_pending_elicitations(&self, client_id: &str) -> Vec<ElicitationContext> {
436        self.pending_elicitations
437            .get(client_id)
438            .map(|entry| entry.clone())
439            .unwrap_or_default()
440    }
441
442    /// Update an elicitation state
443    pub fn update_elicitation_state(
444        &self,
445        client_id: &str,
446        elicitation_id: &str,
447        state: crate::context::ElicitationState,
448    ) -> bool {
449        if let Some(mut elicitations) = self.pending_elicitations.get_mut(client_id) {
450            for elicitation in elicitations.iter_mut() {
451                if elicitation.elicitation_id == elicitation_id {
452                    elicitation.set_state(state);
453                    return true;
454                }
455            }
456        }
457        false
458    }
459
460    /// Remove completed elicitations for a client
461    pub fn remove_completed_elicitations(&self, client_id: &str) {
462        if let Some(mut elicitations) = self.pending_elicitations.get_mut(client_id) {
463            elicitations.retain(|e| !e.is_complete());
464        }
465    }
466
467    /// Clear all elicitations for a client
468    pub fn clear_elicitations(&self, client_id: &str) {
469        self.pending_elicitations.remove(client_id);
470    }
471
472    // Completion Management
473
474    /// Add an active completion for a client
475    pub fn add_active_completion(&self, client_id: String, completion: CompletionContext) {
476        self.active_completions
477            .entry(client_id)
478            .or_default()
479            .push(completion);
480    }
481
482    /// Get all active completions for a client
483    #[must_use]
484    pub fn get_active_completions(&self, client_id: &str) -> Vec<CompletionContext> {
485        self.active_completions
486            .get(client_id)
487            .map(|entry| entry.clone())
488            .unwrap_or_default()
489    }
490
491    /// Remove a specific completion
492    pub fn remove_completion(&self, client_id: &str, completion_id: &str) -> bool {
493        if let Some(mut completions) = self.active_completions.get_mut(client_id) {
494            let original_len = completions.len();
495            completions.retain(|c| c.completion_id != completion_id);
496            return completions.len() < original_len;
497        }
498        false
499    }
500
501    /// Clear all completions for a client
502    pub fn clear_completions(&self, client_id: &str) {
503        self.active_completions.remove(client_id);
504    }
505
506    /// Get session statistics including elicitations and completions
507    #[must_use]
508    pub fn get_enhanced_analytics(&self) -> SessionAnalytics {
509        let analytics = self.get_analytics();
510
511        // Add elicitation and completion counts
512        let mut _total_elicitations = 0;
513        let mut _pending_elicitations = 0;
514        let mut _total_completions = 0;
515
516        for entry in self.pending_elicitations.iter() {
517            let elicitations = entry.value();
518            _total_elicitations += elicitations.len();
519            _pending_elicitations += elicitations.iter().filter(|e| !e.is_complete()).count();
520        }
521
522        for entry in self.active_completions.iter() {
523            _total_completions += entry.value().len();
524        }
525
526        // Additional analytics can be added to SessionAnalytics struct when extended
527        // Currently tracking: _total_elicitations, _pending_elicitations, _total_completions
528
529        analytics
530    }
531
532    // Private helper methods
533
534    fn cleanup_expired_sessions(
535        sessions: &Arc<DashMap<String, ClientSession>>,
536        config: &SessionConfig,
537        session_history: &Arc<RwLock<VecDeque<SessionEvent>>>,
538        stats: &Arc<RwLock<SessionStats>>,
539        pending_elicitations: &Arc<DashMap<String, Vec<ElicitationContext>>>,
540        active_completions: &Arc<DashMap<String, Vec<CompletionContext>>>,
541    ) {
542        let cutoff_time = Utc::now() - config.session_timeout;
543        let mut expired_sessions = Vec::new();
544
545        for entry in sessions.iter() {
546            if entry.last_activity < cutoff_time {
547                expired_sessions.push(entry.client_id.clone());
548            }
549        }
550
551        for client_id in expired_sessions {
552            if let Some((_, session)) = sessions.remove(&client_id) {
553                // Update stats
554                let mut stats_guard = stats.write();
555                stats_guard.total_session_duration += session.session_duration();
556                drop(stats_guard);
557
558                // Clean up associated elicitations and completions
559                pending_elicitations.remove(&client_id);
560                active_completions.remove(&client_id);
561
562                // Record event
563                let event = SessionEvent {
564                    timestamp: Utc::now(),
565                    client_id,
566                    event_type: SessionEventType::Expired,
567                    metadata: HashMap::new(),
568                };
569
570                let mut history = session_history.write();
571                if history.len() >= 1000 {
572                    history.pop_front();
573                }
574                history.push_back(event);
575            }
576        }
577    }
578
579    fn record_session_event(
580        &self,
581        client_id: String,
582        event_type: SessionEventType,
583        metadata: HashMap<String, serde_json::Value>,
584    ) {
585        let event = SessionEvent {
586            timestamp: Utc::now(),
587            client_id,
588            event_type,
589            metadata,
590        };
591
592        let mut history = self.session_history.write();
593        if history.len() >= 1000 {
594            history.pop_front();
595        }
596        history.push_back(event);
597    }
598
599    /// Ensure the number of active sessions does not exceed `max_sessions`.
600    /// This uses an LRU-like policy (evict least recently active sessions first).
601    fn enforce_capacity(&self) {
602        let target = self.config.max_sessions;
603        // Fast path
604        if self.sessions.len() < target {
605            return;
606        }
607
608        // Collect sessions sorted by last_activity ascending (least recent first)
609        let mut entries: Vec<_> = self
610            .sessions
611            .iter()
612            .map(|entry| (entry.key().clone(), entry.last_activity))
613            .collect();
614        entries.sort_by_key(|(_, ts)| *ts);
615
616        // Evict until under capacity
617        let mut to_evict = self.sessions.len().saturating_sub(target) + 1; // make room for 1 new
618        for (client_id, _) in entries {
619            if to_evict == 0 {
620                break;
621            }
622            if let Some((_, session)) = self.sessions.remove(&client_id) {
623                let mut stats = self.stats.write();
624                stats.total_session_duration += session.session_duration();
625                drop(stats);
626
627                // Record eviction as termination event
628                let event = SessionEvent {
629                    timestamp: Utc::now(),
630                    client_id: client_id.clone(),
631                    event_type: SessionEventType::Terminated,
632                    metadata: {
633                        let mut m = HashMap::new();
634                        m.insert("reason".to_string(), serde_json::json!("capacity_eviction"));
635                        m
636                    },
637                };
638                {
639                    let mut history = self.session_history.write();
640                    if history.len() >= 1000 {
641                        history.pop_front();
642                    }
643                    history.push_back(event);
644                } // Drop history lock early
645                to_evict = to_evict.saturating_sub(1);
646            }
647        }
648    }
649
650    fn sanitize_parameters(&self, mut params: serde_json::Value) -> serde_json::Value {
651        let _ = self; // Currently unused, may use config in future
652        // Remove or mask sensitive fields
653        if let Some(obj) = params.as_object_mut() {
654            let sensitive_keys = &["password", "token", "api_key", "secret", "auth"];
655            for key in sensitive_keys {
656                if obj.contains_key(*key) {
657                    obj.insert(
658                        (*key).to_string(),
659                        serde_json::Value::String("[REDACTED]".to_string()),
660                    );
661                }
662            }
663        }
664        params
665    }
666}
667
668impl Default for SessionManager {
669    fn default() -> Self {
670        Self::new(SessionConfig::default())
671    }
672}
673
674#[cfg(test)]
675mod tests {
676    use super::*;
677
678    #[tokio::test]
679    async fn test_session_creation() {
680        let manager = SessionManager::new(SessionConfig::default());
681
682        let session = manager.get_or_create_session("client-1".to_string(), "http".to_string());
683
684        assert_eq!(session.client_id, "client-1");
685        assert_eq!(session.transport_type, "http");
686        assert!(!session.authenticated);
687
688        let analytics = manager.get_analytics();
689        assert_eq!(analytics.total_sessions, 1);
690        assert_eq!(analytics.active_sessions, 1);
691    }
692
693    #[tokio::test]
694    async fn test_session_authentication() {
695        let manager = SessionManager::new(SessionConfig::default());
696
697        let session = manager.get_or_create_session("client-1".to_string(), "http".to_string());
698        assert!(!session.authenticated);
699
700        let success = manager.authenticate_client(
701            "client-1",
702            Some("Test Client".to_string()),
703            Some("token123".to_string()),
704        );
705
706        assert!(success);
707
708        let updated_session = manager.get_session("client-1").unwrap();
709        assert!(updated_session.authenticated);
710        assert_eq!(updated_session.client_name, Some("Test Client".to_string()));
711    }
712
713    #[tokio::test]
714    async fn test_request_recording() {
715        let mut manager = SessionManager::new(SessionConfig::default());
716        manager.config.enable_analytics = true;
717
718        let request = RequestInfo::new(
719            "client-1".to_string(),
720            "test_method".to_string(),
721            serde_json::json!({"param": "value"}),
722        )
723        .complete_success(100);
724
725        manager.record_request(request);
726
727        let analytics = manager.get_analytics();
728        assert_eq!(analytics.total_requests, 1);
729        assert_eq!(analytics.successful_requests, 1);
730        assert_eq!(analytics.failed_requests, 0);
731
732        let history = manager.get_request_history(Some(10));
733        assert_eq!(history.len(), 1);
734        assert_eq!(history[0].method_name, "test_method");
735    }
736
737    #[tokio::test]
738    async fn test_session_termination() {
739        let manager = SessionManager::new(SessionConfig::default());
740
741        let _ = manager.get_or_create_session("client-1".to_string(), "http".to_string());
742        assert!(manager.get_session("client-1").is_some());
743
744        let terminated = manager.terminate_session("client-1");
745        assert!(terminated);
746        assert!(manager.get_session("client-1").is_none());
747
748        let analytics = manager.get_analytics();
749        assert_eq!(analytics.active_sessions, 0);
750    }
751
752    #[tokio::test]
753    async fn test_parameter_sanitization() {
754        let manager = SessionManager::new(SessionConfig::default());
755
756        let sensitive_params = serde_json::json!({
757            "username": "testuser",
758            "password": "secret123",
759            "api_key": "key456",
760            "data": "normal_data"
761        });
762
763        let sanitized = manager.sanitize_parameters(sensitive_params);
764        let obj = sanitized.as_object().unwrap();
765
766        assert_eq!(
767            obj["username"],
768            serde_json::Value::String("testuser".to_string())
769        );
770        assert_eq!(
771            obj["password"],
772            serde_json::Value::String("[REDACTED]".to_string())
773        );
774        assert_eq!(
775            obj["api_key"],
776            serde_json::Value::String("[REDACTED]".to_string())
777        );
778        assert_eq!(
779            obj["data"],
780            serde_json::Value::String("normal_data".to_string())
781        );
782    }
783}