1use std::collections::{HashMap, VecDeque};
7use std::sync::Arc;
8use std::time::Duration as StdDuration;
9
10use chrono::{DateTime, Duration, Utc};
11use dashmap::DashMap;
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use tokio::time::{Interval, interval};
15
16use crate::context::{
17 ClientIdExtractor, ClientSession, CompletionContext, ElicitationContext, RequestInfo,
18};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct SessionConfig {
23 pub max_sessions: usize,
25 pub session_timeout: Duration,
27 pub max_request_history: usize,
29 pub max_requests_per_session: Option<usize>,
31 pub cleanup_interval: StdDuration,
33 pub enable_analytics: bool,
35}
36
37impl Default for SessionConfig {
38 fn default() -> Self {
39 Self {
40 max_sessions: 1000,
41 session_timeout: Duration::hours(24),
42 max_request_history: 1000,
43 max_requests_per_session: None,
44 cleanup_interval: StdDuration::from_secs(300), enable_analytics: true,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct SessionAnalytics {
53 pub total_sessions: usize,
55 pub active_sessions: usize,
57 pub total_requests: usize,
59 pub successful_requests: usize,
61 pub failed_requests: usize,
63 pub avg_session_duration: Duration,
65 pub top_clients: Vec<(String, usize)>,
67 pub top_methods: Vec<(String, usize)>,
69 pub requests_per_minute: f64,
71}
72
73#[derive(Debug)]
75pub struct SessionManager {
76 config: SessionConfig,
78 sessions: Arc<DashMap<String, ClientSession>>,
80 client_extractor: Arc<ClientIdExtractor>,
82 request_history: Arc<RwLock<VecDeque<RequestInfo>>>,
84 session_history: Arc<RwLock<VecDeque<SessionEvent>>>,
86 cleanup_timer: Arc<RwLock<Option<Interval>>>,
88 stats: Arc<RwLock<SessionStats>>,
90 pending_elicitations: Arc<DashMap<String, Vec<ElicitationContext>>>,
92 active_completions: Arc<DashMap<String, Vec<CompletionContext>>>,
94}
95
96#[derive(Debug, Default)]
98struct SessionStats {
99 total_sessions: usize,
100 total_requests: usize,
101 successful_requests: usize,
102 failed_requests: usize,
103 total_session_duration: Duration,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SessionEvent {
109 pub timestamp: DateTime<Utc>,
111 pub client_id: String,
113 pub event_type: SessionEventType,
115 pub metadata: HashMap<String, serde_json::Value>,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub enum SessionEventType {
122 Created,
124 Authenticated,
126 Updated,
128 Expired,
130 Terminated,
132}
133
134impl SessionManager {
135 #[must_use]
137 pub fn new(config: SessionConfig) -> Self {
138 Self {
139 config,
140 sessions: Arc::new(DashMap::new()),
141 client_extractor: Arc::new(ClientIdExtractor::new()),
142 request_history: Arc::new(RwLock::new(VecDeque::new())),
143 session_history: Arc::new(RwLock::new(VecDeque::new())),
144 cleanup_timer: Arc::new(RwLock::new(None)),
145 stats: Arc::new(RwLock::new(SessionStats::default())),
146 pending_elicitations: Arc::new(DashMap::new()),
147 active_completions: Arc::new(DashMap::new()),
148 }
149 }
150
151 pub fn start(&self) {
153 let mut timer_guard = self.cleanup_timer.write();
154 if timer_guard.is_none() {
155 *timer_guard = Some(interval(self.config.cleanup_interval));
156 }
157 drop(timer_guard);
158
159 let sessions = self.sessions.clone();
161 let config = self.config.clone();
162 let session_history = self.session_history.clone();
163 let stats = self.stats.clone();
164 let pending_elicitations = self.pending_elicitations.clone();
165 let active_completions = self.active_completions.clone();
166
167 tokio::spawn(async move {
168 let mut timer = interval(config.cleanup_interval);
169 loop {
170 timer.tick().await;
171 Self::cleanup_expired_sessions(
172 &sessions,
173 &config,
174 &session_history,
175 &stats,
176 &pending_elicitations,
177 &active_completions,
178 );
179 }
180 });
181 }
182
183 #[must_use]
185 pub fn get_or_create_session(
186 &self,
187 client_id: String,
188 transport_type: String,
189 ) -> ClientSession {
190 self.sessions.get(&client_id).map_or_else(
191 || {
192 self.enforce_capacity();
194
195 let session = ClientSession::new(client_id.clone(), transport_type);
196 self.sessions.insert(client_id.clone(), session.clone());
197
198 let mut stats = self.stats.write();
200 stats.total_sessions += 1;
201 drop(stats);
202
203 self.record_session_event(client_id, SessionEventType::Created, HashMap::new());
204
205 session
206 },
207 |session| session.clone(),
208 )
209 }
210
211 pub fn update_client_activity(&self, client_id: &str) {
213 if let Some(mut session) = self.sessions.get_mut(client_id) {
214 session.update_activity();
215
216 if let Some(cap) = self.config.max_requests_per_session
218 && session.request_count > cap
219 {
220 drop(session);
223 let _ = self.terminate_session(client_id);
224 }
225 }
226 }
227
228 #[must_use]
230 pub fn authenticate_client(
231 &self,
232 client_id: &str,
233 client_name: Option<String>,
234 token: Option<String>,
235 ) -> bool {
236 if let Some(mut session) = self.sessions.get_mut(client_id) {
237 session.authenticate(client_name.clone());
238
239 if let Some(token) = token {
240 self.client_extractor
241 .register_token(token, client_id.to_string());
242 }
243
244 let mut metadata = HashMap::new();
245 if let Some(name) = client_name {
246 metadata.insert("client_name".to_string(), serde_json::json!(name));
247 }
248
249 self.record_session_event(
250 client_id.to_string(),
251 SessionEventType::Authenticated,
252 metadata,
253 );
254
255 return true;
256 }
257 false
258 }
259
260 pub fn record_request(&self, mut request_info: RequestInfo) {
262 if !self.config.enable_analytics {
263 return;
264 }
265
266 self.update_client_activity(&request_info.client_id);
268
269 let mut stats = self.stats.write();
271 stats.total_requests += 1;
272 if request_info.success {
273 stats.successful_requests += 1;
274 } else {
275 stats.failed_requests += 1;
276 }
277 drop(stats);
278
279 let mut history = self.request_history.write();
281 if history.len() >= self.config.max_request_history {
282 history.pop_front();
283 }
284
285 request_info.parameters = self.sanitize_parameters(request_info.parameters);
287 history.push_back(request_info);
288 }
289
290 #[must_use]
292 pub fn get_analytics(&self) -> SessionAnalytics {
293 let sessions = self.sessions.clone();
294
295 let active_sessions = sessions.len();
297
298 let total_duration = sessions
300 .iter()
301 .map(|entry| entry.session_duration())
302 .reduce(|acc, dur| acc + dur)
303 .unwrap_or_else(Duration::zero);
304
305 let avg_session_duration = if active_sessions > 0 {
306 total_duration / active_sessions as i32
307 } else {
308 Duration::zero()
309 };
310
311 let mut client_requests: HashMap<String, usize> = HashMap::new();
313 let mut method_requests: HashMap<String, usize> = HashMap::new();
314
315 let (recent_requests, top_clients, top_methods) = {
316 let history = self.request_history.read();
317 for request in history.iter() {
318 *client_requests
319 .entry(request.client_id.clone())
320 .or_insert(0) += 1;
321 *method_requests
322 .entry(request.method_name.clone())
323 .or_insert(0) += 1;
324 }
325
326 let mut top_clients: Vec<(String, usize)> = client_requests.into_iter().collect();
327 top_clients.sort_by(|a, b| b.1.cmp(&a.1));
328 top_clients.truncate(10);
329
330 let mut top_methods: Vec<(String, usize)> = method_requests.into_iter().collect();
331 top_methods.sort_by(|a, b| b.1.cmp(&a.1));
332 top_methods.truncate(10);
333
334 let one_hour_ago = Utc::now() - Duration::hours(1);
336 let recent_requests = history
337 .iter()
338 .filter(|req| req.timestamp > one_hour_ago)
339 .count();
340 drop(history);
341
342 (recent_requests, top_clients, top_methods)
343 };
344 let requests_per_minute = recent_requests as f64 / 60.0;
345
346 let stats = self.stats.read();
347 SessionAnalytics {
348 total_sessions: stats.total_sessions,
349 active_sessions,
350 total_requests: stats.total_requests,
351 successful_requests: stats.successful_requests,
352 failed_requests: stats.failed_requests,
353 avg_session_duration,
354 top_clients,
355 top_methods,
356 requests_per_minute,
357 }
358 }
359
360 #[must_use]
362 pub fn get_active_sessions(&self) -> Vec<ClientSession> {
363 self.sessions
364 .iter()
365 .map(|entry| entry.value().clone())
366 .collect()
367 }
368
369 #[must_use]
371 pub fn get_session(&self, client_id: &str) -> Option<ClientSession> {
372 self.sessions.get(client_id).map(|session| session.clone())
373 }
374
375 #[must_use]
377 pub fn client_extractor(&self) -> Arc<ClientIdExtractor> {
378 self.client_extractor.clone()
379 }
380
381 #[must_use]
383 pub fn terminate_session(&self, client_id: &str) -> bool {
384 if let Some((_, session)) = self.sessions.remove(client_id) {
385 let mut stats = self.stats.write();
386 stats.total_session_duration += session.session_duration();
387 drop(stats);
388
389 self.pending_elicitations.remove(client_id);
391 self.active_completions.remove(client_id);
392
393 self.record_session_event(
394 client_id.to_string(),
395 SessionEventType::Terminated,
396 HashMap::new(),
397 );
398
399 true
400 } else {
401 false
402 }
403 }
404
405 #[must_use]
407 pub fn get_request_history(&self, limit: Option<usize>) -> Vec<RequestInfo> {
408 let history = self.request_history.read();
409 let limit = limit.unwrap_or(100);
410
411 history.iter().rev().take(limit).cloned().collect()
412 }
413
414 #[must_use]
416 pub fn get_session_events(&self, limit: Option<usize>) -> Vec<SessionEvent> {
417 let events = self.session_history.read();
418 let limit = limit.unwrap_or(100);
419
420 events.iter().rev().take(limit).cloned().collect()
421 }
422
423 pub fn add_pending_elicitation(&self, client_id: String, elicitation: ElicitationContext) {
427 self.pending_elicitations
428 .entry(client_id)
429 .or_default()
430 .push(elicitation);
431 }
432
433 #[must_use]
435 pub fn get_pending_elicitations(&self, client_id: &str) -> Vec<ElicitationContext> {
436 self.pending_elicitations
437 .get(client_id)
438 .map(|entry| entry.clone())
439 .unwrap_or_default()
440 }
441
442 pub fn update_elicitation_state(
444 &self,
445 client_id: &str,
446 elicitation_id: &str,
447 state: crate::context::ElicitationState,
448 ) -> bool {
449 if let Some(mut elicitations) = self.pending_elicitations.get_mut(client_id) {
450 for elicitation in elicitations.iter_mut() {
451 if elicitation.elicitation_id == elicitation_id {
452 elicitation.set_state(state);
453 return true;
454 }
455 }
456 }
457 false
458 }
459
460 pub fn remove_completed_elicitations(&self, client_id: &str) {
462 if let Some(mut elicitations) = self.pending_elicitations.get_mut(client_id) {
463 elicitations.retain(|e| !e.is_complete());
464 }
465 }
466
467 pub fn clear_elicitations(&self, client_id: &str) {
469 self.pending_elicitations.remove(client_id);
470 }
471
472 pub fn add_active_completion(&self, client_id: String, completion: CompletionContext) {
476 self.active_completions
477 .entry(client_id)
478 .or_default()
479 .push(completion);
480 }
481
482 #[must_use]
484 pub fn get_active_completions(&self, client_id: &str) -> Vec<CompletionContext> {
485 self.active_completions
486 .get(client_id)
487 .map(|entry| entry.clone())
488 .unwrap_or_default()
489 }
490
491 pub fn remove_completion(&self, client_id: &str, completion_id: &str) -> bool {
493 if let Some(mut completions) = self.active_completions.get_mut(client_id) {
494 let original_len = completions.len();
495 completions.retain(|c| c.completion_id != completion_id);
496 return completions.len() < original_len;
497 }
498 false
499 }
500
501 pub fn clear_completions(&self, client_id: &str) {
503 self.active_completions.remove(client_id);
504 }
505
506 #[must_use]
508 pub fn get_enhanced_analytics(&self) -> SessionAnalytics {
509 let analytics = self.get_analytics();
510
511 let mut _total_elicitations = 0;
513 let mut _pending_elicitations = 0;
514 let mut _total_completions = 0;
515
516 for entry in self.pending_elicitations.iter() {
517 let elicitations = entry.value();
518 _total_elicitations += elicitations.len();
519 _pending_elicitations += elicitations.iter().filter(|e| !e.is_complete()).count();
520 }
521
522 for entry in self.active_completions.iter() {
523 _total_completions += entry.value().len();
524 }
525
526 analytics
530 }
531
532 fn cleanup_expired_sessions(
535 sessions: &Arc<DashMap<String, ClientSession>>,
536 config: &SessionConfig,
537 session_history: &Arc<RwLock<VecDeque<SessionEvent>>>,
538 stats: &Arc<RwLock<SessionStats>>,
539 pending_elicitations: &Arc<DashMap<String, Vec<ElicitationContext>>>,
540 active_completions: &Arc<DashMap<String, Vec<CompletionContext>>>,
541 ) {
542 let cutoff_time = Utc::now() - config.session_timeout;
543 let mut expired_sessions = Vec::new();
544
545 for entry in sessions.iter() {
546 if entry.last_activity < cutoff_time {
547 expired_sessions.push(entry.client_id.clone());
548 }
549 }
550
551 for client_id in expired_sessions {
552 if let Some((_, session)) = sessions.remove(&client_id) {
553 let mut stats_guard = stats.write();
555 stats_guard.total_session_duration += session.session_duration();
556 drop(stats_guard);
557
558 pending_elicitations.remove(&client_id);
560 active_completions.remove(&client_id);
561
562 let event = SessionEvent {
564 timestamp: Utc::now(),
565 client_id,
566 event_type: SessionEventType::Expired,
567 metadata: HashMap::new(),
568 };
569
570 let mut history = session_history.write();
571 if history.len() >= 1000 {
572 history.pop_front();
573 }
574 history.push_back(event);
575 }
576 }
577 }
578
579 fn record_session_event(
580 &self,
581 client_id: String,
582 event_type: SessionEventType,
583 metadata: HashMap<String, serde_json::Value>,
584 ) {
585 let event = SessionEvent {
586 timestamp: Utc::now(),
587 client_id,
588 event_type,
589 metadata,
590 };
591
592 let mut history = self.session_history.write();
593 if history.len() >= 1000 {
594 history.pop_front();
595 }
596 history.push_back(event);
597 }
598
599 fn enforce_capacity(&self) {
602 let target = self.config.max_sessions;
603 if self.sessions.len() < target {
605 return;
606 }
607
608 let mut entries: Vec<_> = self
610 .sessions
611 .iter()
612 .map(|entry| (entry.key().clone(), entry.last_activity))
613 .collect();
614 entries.sort_by_key(|(_, ts)| *ts);
615
616 let mut to_evict = self.sessions.len().saturating_sub(target) + 1; for (client_id, _) in entries {
619 if to_evict == 0 {
620 break;
621 }
622 if let Some((_, session)) = self.sessions.remove(&client_id) {
623 let mut stats = self.stats.write();
624 stats.total_session_duration += session.session_duration();
625 drop(stats);
626
627 let event = SessionEvent {
629 timestamp: Utc::now(),
630 client_id: client_id.clone(),
631 event_type: SessionEventType::Terminated,
632 metadata: {
633 let mut m = HashMap::new();
634 m.insert("reason".to_string(), serde_json::json!("capacity_eviction"));
635 m
636 },
637 };
638 {
639 let mut history = self.session_history.write();
640 if history.len() >= 1000 {
641 history.pop_front();
642 }
643 history.push_back(event);
644 } to_evict = to_evict.saturating_sub(1);
646 }
647 }
648 }
649
650 fn sanitize_parameters(&self, mut params: serde_json::Value) -> serde_json::Value {
651 let _ = self; if let Some(obj) = params.as_object_mut() {
654 let sensitive_keys = &["password", "token", "api_key", "secret", "auth"];
655 for key in sensitive_keys {
656 if obj.contains_key(*key) {
657 obj.insert(
658 (*key).to_string(),
659 serde_json::Value::String("[REDACTED]".to_string()),
660 );
661 }
662 }
663 }
664 params
665 }
666}
667
668impl Default for SessionManager {
669 fn default() -> Self {
670 Self::new(SessionConfig::default())
671 }
672}
673
674#[cfg(test)]
675mod tests {
676 use super::*;
677
678 #[tokio::test]
679 async fn test_session_creation() {
680 let manager = SessionManager::new(SessionConfig::default());
681
682 let session = manager.get_or_create_session("client-1".to_string(), "http".to_string());
683
684 assert_eq!(session.client_id, "client-1");
685 assert_eq!(session.transport_type, "http");
686 assert!(!session.authenticated);
687
688 let analytics = manager.get_analytics();
689 assert_eq!(analytics.total_sessions, 1);
690 assert_eq!(analytics.active_sessions, 1);
691 }
692
693 #[tokio::test]
694 async fn test_session_authentication() {
695 let manager = SessionManager::new(SessionConfig::default());
696
697 let session = manager.get_or_create_session("client-1".to_string(), "http".to_string());
698 assert!(!session.authenticated);
699
700 let success = manager.authenticate_client(
701 "client-1",
702 Some("Test Client".to_string()),
703 Some("token123".to_string()),
704 );
705
706 assert!(success);
707
708 let updated_session = manager.get_session("client-1").unwrap();
709 assert!(updated_session.authenticated);
710 assert_eq!(updated_session.client_name, Some("Test Client".to_string()));
711 }
712
713 #[tokio::test]
714 async fn test_request_recording() {
715 let mut manager = SessionManager::new(SessionConfig::default());
716 manager.config.enable_analytics = true;
717
718 let request = RequestInfo::new(
719 "client-1".to_string(),
720 "test_method".to_string(),
721 serde_json::json!({"param": "value"}),
722 )
723 .complete_success(100);
724
725 manager.record_request(request);
726
727 let analytics = manager.get_analytics();
728 assert_eq!(analytics.total_requests, 1);
729 assert_eq!(analytics.successful_requests, 1);
730 assert_eq!(analytics.failed_requests, 0);
731
732 let history = manager.get_request_history(Some(10));
733 assert_eq!(history.len(), 1);
734 assert_eq!(history[0].method_name, "test_method");
735 }
736
737 #[tokio::test]
738 async fn test_session_termination() {
739 let manager = SessionManager::new(SessionConfig::default());
740
741 let _ = manager.get_or_create_session("client-1".to_string(), "http".to_string());
742 assert!(manager.get_session("client-1").is_some());
743
744 let terminated = manager.terminate_session("client-1");
745 assert!(terminated);
746 assert!(manager.get_session("client-1").is_none());
747
748 let analytics = manager.get_analytics();
749 assert_eq!(analytics.active_sessions, 0);
750 }
751
752 #[tokio::test]
753 async fn test_parameter_sanitization() {
754 let manager = SessionManager::new(SessionConfig::default());
755
756 let sensitive_params = serde_json::json!({
757 "username": "testuser",
758 "password": "secret123",
759 "api_key": "key456",
760 "data": "normal_data"
761 });
762
763 let sanitized = manager.sanitize_parameters(sensitive_params);
764 let obj = sanitized.as_object().unwrap();
765
766 assert_eq!(
767 obj["username"],
768 serde_json::Value::String("testuser".to_string())
769 );
770 assert_eq!(
771 obj["password"],
772 serde_json::Value::String("[REDACTED]".to_string())
773 );
774 assert_eq!(
775 obj["api_key"],
776 serde_json::Value::String("[REDACTED]".to_string())
777 );
778 assert_eq!(
779 obj["data"],
780 serde_json::Value::String("normal_data".to_string())
781 );
782 }
783}