turbomcp_client/llm/
session.rs

1//! Session and conversation management
2//!
3//! Handles conversation sessions, context strategies, and conversation history.
4
5use crate::llm::core::{LLMError, LLMMessage, LLMResult};
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use uuid::Uuid;
10
11/// Configuration for session management
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct SessionConfig {
14    /// Maximum conversation history length
15    pub max_history_length: usize,
16    /// Default context strategy
17    pub default_context_strategy: ContextStrategy,
18    /// Session timeout in seconds
19    pub session_timeout_seconds: u64,
20}
21
22impl Default for SessionConfig {
23    fn default() -> Self {
24        Self {
25            max_history_length: 100,
26            default_context_strategy: ContextStrategy::SlidingWindow { window_size: 20 },
27            session_timeout_seconds: 3600, // 1 hour
28        }
29    }
30}
31
32/// Strategies for managing conversation context
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub enum ContextStrategy {
35    /// Keep full conversation history
36    FullHistory,
37    /// Keep a sliding window of recent messages
38    SlidingWindow { window_size: usize },
39    /// Summarize old messages and keep recent ones
40    Summarized {
41        summary_threshold: usize,
42        keep_recent: usize,
43    },
44    /// Smart context management based on relevance
45    Smart {
46        max_tokens: usize,
47        relevance_threshold: f64,
48    },
49}
50
51/// Metadata associated with a conversation session
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct SessionMetadata {
54    /// User identifier
55    pub user_id: String,
56    /// Session tags
57    pub tags: Vec<String>,
58    /// Custom metadata
59    pub custom: HashMap<String, serde_json::Value>,
60    /// Session priority
61    pub priority: i32,
62    /// Language preference
63    pub language: Option<String>,
64}
65
66impl SessionMetadata {
67    /// Create new session metadata
68    pub fn new(user_id: impl Into<String>) -> Self {
69        Self {
70            user_id: user_id.into(),
71            tags: Vec::new(),
72            custom: HashMap::new(),
73            priority: 0,
74            language: None,
75        }
76    }
77
78    /// Add a tag
79    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
80        self.tags.push(tag.into());
81        self
82    }
83
84    /// Set custom metadata
85    pub fn with_custom(mut self, key: String, value: serde_json::Value) -> Self {
86        self.custom.insert(key, value);
87        self
88    }
89}
90
91/// A conversation session with history and metadata
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ConversationSession {
94    /// Unique session identifier
95    pub id: String,
96    /// Session metadata
97    pub metadata: SessionMetadata,
98    /// Conversation messages
99    pub messages: Vec<LLMMessage>,
100    /// Context management strategy
101    pub context_strategy: ContextStrategy,
102    /// Session creation time
103    pub created_at: DateTime<Utc>,
104    /// Last activity time
105    pub last_activity: DateTime<Utc>,
106    /// Session status
107    pub status: SessionStatus,
108}
109
110/// Status of a conversation session
111#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
112pub enum SessionStatus {
113    /// Session is active
114    Active,
115    /// Session is paused
116    Paused,
117    /// Session has expired
118    Expired,
119    /// Session was manually closed
120    Closed,
121}
122
123impl ConversationSession {
124    /// Create a new conversation session
125    pub fn new(metadata: SessionMetadata, context_strategy: ContextStrategy) -> Self {
126        let now = Utc::now();
127        Self {
128            id: Uuid::new_v4().to_string(),
129            metadata,
130            messages: Vec::new(),
131            context_strategy,
132            created_at: now,
133            last_activity: now,
134            status: SessionStatus::Active,
135        }
136    }
137
138    /// Add a message to the session
139    pub fn add_message(&mut self, message: LLMMessage) {
140        self.messages.push(message);
141        self.last_activity = Utc::now();
142    }
143
144    /// Get active messages based on context strategy
145    pub fn get_active_messages(&self) -> Vec<LLMMessage> {
146        match &self.context_strategy {
147            ContextStrategy::FullHistory => self.messages.clone(),
148            ContextStrategy::SlidingWindow { window_size } => {
149                let start_idx = self.messages.len().saturating_sub(*window_size);
150                self.messages[start_idx..].to_vec()
151            }
152            ContextStrategy::Summarized { keep_recent, .. } => {
153                let start_idx = self.messages.len().saturating_sub(*keep_recent);
154                self.messages[start_idx..].to_vec()
155            }
156            ContextStrategy::Smart { .. } => {
157                // TODO: Implement smart context selection
158                // For now, fall back to sliding window
159                let window_size = 20;
160                let start_idx = self.messages.len().saturating_sub(window_size);
161                self.messages[start_idx..].to_vec()
162            }
163        }
164    }
165
166    /// Check if session has expired
167    pub fn is_expired(&self, timeout_seconds: u64) -> bool {
168        let now = Utc::now();
169        let timeout = chrono::Duration::seconds(timeout_seconds as i64);
170        now.signed_duration_since(self.last_activity) > timeout
171    }
172
173    /// Get session duration
174    pub fn duration(&self) -> chrono::Duration {
175        self.last_activity.signed_duration_since(self.created_at)
176    }
177
178    /// Get message count
179    pub fn message_count(&self) -> usize {
180        self.messages.len()
181    }
182
183    /// Pause the session
184    pub fn pause(&mut self) {
185        self.status = SessionStatus::Paused;
186    }
187
188    /// Resume the session
189    pub fn resume(&mut self) {
190        self.status = SessionStatus::Active;
191        self.last_activity = Utc::now();
192    }
193
194    /// Close the session
195    pub fn close(&mut self) {
196        self.status = SessionStatus::Closed;
197    }
198}
199
200/// Session manager for handling multiple conversation sessions
201#[derive(Debug)]
202pub struct SessionManager {
203    sessions: HashMap<String, ConversationSession>,
204    config: SessionConfig,
205}
206
207impl SessionManager {
208    /// Create a new session manager
209    pub fn new(config: SessionConfig) -> Self {
210        Self {
211            sessions: HashMap::new(),
212            config,
213        }
214    }
215
216    /// Create a new session
217    pub fn create_session(
218        &mut self,
219        metadata: SessionMetadata,
220        context_strategy: Option<ContextStrategy>,
221    ) -> String {
222        let strategy = context_strategy.unwrap_or(self.config.default_context_strategy.clone());
223        let session = ConversationSession::new(metadata, strategy);
224        let session_id = session.id.clone();
225
226        self.sessions.insert(session_id.clone(), session);
227        session_id
228    }
229
230    /// Get a session by ID
231    pub fn get_session(&self, session_id: &str) -> Option<&ConversationSession> {
232        self.sessions.get(session_id)
233    }
234
235    /// Get a mutable session by ID
236    pub fn get_session_mut(&mut self, session_id: &str) -> Option<&mut ConversationSession> {
237        self.sessions.get_mut(session_id)
238    }
239
240    /// Add a message to a session
241    pub fn add_message(&mut self, session_id: &str, message: LLMMessage) -> LLMResult<()> {
242        let session = self
243            .sessions
244            .get_mut(session_id)
245            .ok_or_else(|| LLMError::session(format!("Session not found: {}", session_id)))?;
246
247        if session.status != SessionStatus::Active {
248            return Err(LLMError::session("Session is not active"));
249        }
250
251        session.add_message(message);
252
253        // Trim history if needed
254        if session.messages.len() > self.config.max_history_length {
255            let excess = session.messages.len() - self.config.max_history_length;
256            session.messages.drain(0..excess);
257        }
258
259        Ok(())
260    }
261
262    /// Get active messages for a session
263    pub fn get_active_messages(&self, session_id: &str) -> LLMResult<Vec<LLMMessage>> {
264        let session = self
265            .sessions
266            .get(session_id)
267            .ok_or_else(|| LLMError::session(format!("Session not found: {}", session_id)))?;
268
269        Ok(session.get_active_messages())
270    }
271
272    /// List all session IDs
273    pub fn list_sessions(&self) -> Vec<String> {
274        self.sessions.keys().cloned().collect()
275    }
276
277    /// Get sessions by user ID
278    pub fn get_user_sessions(&self, user_id: &str) -> Vec<String> {
279        self.sessions
280            .iter()
281            .filter_map(|(id, session)| {
282                if session.metadata.user_id == user_id {
283                    Some(id.clone())
284                } else {
285                    None
286                }
287            })
288            .collect()
289    }
290
291    /// Clean up expired sessions
292    pub fn cleanup_expired(&mut self) -> usize {
293        let timeout = self.config.session_timeout_seconds;
294        let expired_ids: Vec<_> = self
295            .sessions
296            .iter()
297            .filter_map(|(id, session)| {
298                if session.is_expired(timeout) {
299                    Some(id.clone())
300                } else {
301                    None
302                }
303            })
304            .collect();
305
306        let count = expired_ids.len();
307        for id in expired_ids {
308            if let Some(mut session) = self.sessions.remove(&id) {
309                session.status = SessionStatus::Expired;
310            }
311        }
312
313        count
314    }
315
316    /// Pause a session
317    pub fn pause_session(&mut self, session_id: &str) -> LLMResult<()> {
318        let session = self
319            .sessions
320            .get_mut(session_id)
321            .ok_or_else(|| LLMError::session(format!("Session not found: {}", session_id)))?;
322
323        session.pause();
324        Ok(())
325    }
326
327    /// Resume a session
328    pub fn resume_session(&mut self, session_id: &str) -> LLMResult<()> {
329        let session = self
330            .sessions
331            .get_mut(session_id)
332            .ok_or_else(|| LLMError::session(format!("Session not found: {}", session_id)))?;
333
334        session.resume();
335        Ok(())
336    }
337
338    /// Close a session
339    pub fn close_session(&mut self, session_id: &str) -> LLMResult<()> {
340        let session = self
341            .sessions
342            .get_mut(session_id)
343            .ok_or_else(|| LLMError::session(format!("Session not found: {}", session_id)))?;
344
345        session.close();
346        Ok(())
347    }
348
349    /// Remove a session completely
350    pub fn remove_session(&mut self, session_id: &str) -> Option<ConversationSession> {
351        self.sessions.remove(session_id)
352    }
353
354    /// Get session statistics
355    pub fn get_stats(&self) -> SessionStats {
356        let total_sessions = self.sessions.len();
357        let active_sessions = self
358            .sessions
359            .values()
360            .filter(|s| s.status == SessionStatus::Active)
361            .count();
362        let paused_sessions = self
363            .sessions
364            .values()
365            .filter(|s| s.status == SessionStatus::Paused)
366            .count();
367        let total_messages: usize = self.sessions.values().map(|s| s.message_count()).sum();
368
369        SessionStats {
370            total_sessions,
371            active_sessions,
372            paused_sessions,
373            total_messages,
374        }
375    }
376}
377
378/// Statistics about session manager
379#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct SessionStats {
381    /// Total number of sessions
382    pub total_sessions: usize,
383    /// Number of active sessions
384    pub active_sessions: usize,
385    /// Number of paused sessions
386    pub paused_sessions: usize,
387    /// Total messages across all sessions
388    pub total_messages: usize,
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::llm::core::LLMMessage;
395
396    #[test]
397    fn test_session_creation() {
398        let metadata = SessionMetadata::new("user123")
399            .with_tag("test")
400            .with_custom("priority".to_string(), serde_json::json!(1));
401
402        let session =
403            ConversationSession::new(metadata, ContextStrategy::SlidingWindow { window_size: 10 });
404
405        assert!(!session.id.is_empty());
406        assert_eq!(session.metadata.user_id, "user123");
407        assert!(session.metadata.tags.contains(&"test".to_string()));
408        assert_eq!(session.status, SessionStatus::Active);
409        assert_eq!(session.message_count(), 0);
410    }
411
412    #[test]
413    fn test_session_messages() {
414        let metadata = SessionMetadata::new("user123");
415        let mut session = ConversationSession::new(metadata, ContextStrategy::FullHistory);
416
417        session.add_message(LLMMessage::user("Hello"));
418        session.add_message(LLMMessage::assistant("Hi there!"));
419
420        assert_eq!(session.message_count(), 2);
421
422        let active_messages = session.get_active_messages();
423        assert_eq!(active_messages.len(), 2);
424    }
425
426    #[test]
427    fn test_sliding_window_context() {
428        let metadata = SessionMetadata::new("user123");
429        let mut session =
430            ConversationSession::new(metadata, ContextStrategy::SlidingWindow { window_size: 2 });
431
432        session.add_message(LLMMessage::user("Message 1"));
433        session.add_message(LLMMessage::assistant("Response 1"));
434        session.add_message(LLMMessage::user("Message 2"));
435        session.add_message(LLMMessage::assistant("Response 2"));
436
437        let active_messages = session.get_active_messages();
438        assert_eq!(active_messages.len(), 2); // Only last 2 messages
439        assert_eq!(active_messages[0].content.as_text(), Some("Message 2"));
440        assert_eq!(active_messages[1].content.as_text(), Some("Response 2"));
441    }
442
443    #[test]
444    fn test_session_manager() {
445        let config = SessionConfig::default();
446        let mut manager = SessionManager::new(config);
447
448        let metadata = SessionMetadata::new("user123");
449        let session_id = manager.create_session(metadata, None);
450
451        assert!(manager.get_session(&session_id).is_some());
452        assert_eq!(manager.list_sessions().len(), 1);
453
454        manager
455            .add_message(&session_id, LLMMessage::user("Hello"))
456            .unwrap();
457
458        let active_messages = manager.get_active_messages(&session_id).unwrap();
459        assert_eq!(active_messages.len(), 1);
460
461        let stats = manager.get_stats();
462        assert_eq!(stats.total_sessions, 1);
463        assert_eq!(stats.active_sessions, 1);
464        assert_eq!(stats.total_messages, 1);
465    }
466
467    #[test]
468    fn test_session_status_management() {
469        let config = SessionConfig::default();
470        let mut manager = SessionManager::new(config);
471
472        let metadata = SessionMetadata::new("user123");
473        let session_id = manager.create_session(metadata, None);
474
475        manager.pause_session(&session_id).unwrap();
476        let session = manager.get_session(&session_id).unwrap();
477        assert_eq!(session.status, SessionStatus::Paused);
478
479        manager.resume_session(&session_id).unwrap();
480        let session = manager.get_session(&session_id).unwrap();
481        assert_eq!(session.status, SessionStatus::Active);
482
483        manager.close_session(&session_id).unwrap();
484        let session = manager.get_session(&session_id).unwrap();
485        assert_eq!(session.status, SessionStatus::Closed);
486    }
487}