1use crate::llm::core::{LLMError, LLMMessage, LLMResult};
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use uuid::Uuid;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct SessionConfig {
14 pub max_history_length: usize,
16 pub default_context_strategy: ContextStrategy,
18 pub session_timeout_seconds: u64,
20}
21
22impl Default for SessionConfig {
23 fn default() -> Self {
24 Self {
25 max_history_length: 100,
26 default_context_strategy: ContextStrategy::SlidingWindow { window_size: 20 },
27 session_timeout_seconds: 3600, }
29 }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub enum ContextStrategy {
35 FullHistory,
37 SlidingWindow { window_size: usize },
39 Summarized {
41 summary_threshold: usize,
42 keep_recent: usize,
43 },
44 Smart {
46 max_tokens: usize,
47 relevance_threshold: f64,
48 },
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct SessionMetadata {
54 pub user_id: String,
56 pub tags: Vec<String>,
58 pub custom: HashMap<String, serde_json::Value>,
60 pub priority: i32,
62 pub language: Option<String>,
64}
65
66impl SessionMetadata {
67 pub fn new(user_id: impl Into<String>) -> Self {
69 Self {
70 user_id: user_id.into(),
71 tags: Vec::new(),
72 custom: HashMap::new(),
73 priority: 0,
74 language: None,
75 }
76 }
77
78 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
80 self.tags.push(tag.into());
81 self
82 }
83
84 pub fn with_custom(mut self, key: String, value: serde_json::Value) -> Self {
86 self.custom.insert(key, value);
87 self
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ConversationSession {
94 pub id: String,
96 pub metadata: SessionMetadata,
98 pub messages: Vec<LLMMessage>,
100 pub context_strategy: ContextStrategy,
102 pub created_at: DateTime<Utc>,
104 pub last_activity: DateTime<Utc>,
106 pub status: SessionStatus,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
112pub enum SessionStatus {
113 Active,
115 Paused,
117 Expired,
119 Closed,
121}
122
123impl ConversationSession {
124 pub fn new(metadata: SessionMetadata, context_strategy: ContextStrategy) -> Self {
126 let now = Utc::now();
127 Self {
128 id: Uuid::new_v4().to_string(),
129 metadata,
130 messages: Vec::new(),
131 context_strategy,
132 created_at: now,
133 last_activity: now,
134 status: SessionStatus::Active,
135 }
136 }
137
138 pub fn add_message(&mut self, message: LLMMessage) {
140 self.messages.push(message);
141 self.last_activity = Utc::now();
142 }
143
144 pub fn get_active_messages(&self) -> Vec<LLMMessage> {
146 match &self.context_strategy {
147 ContextStrategy::FullHistory => self.messages.clone(),
148 ContextStrategy::SlidingWindow { window_size } => {
149 let start_idx = self.messages.len().saturating_sub(*window_size);
150 self.messages[start_idx..].to_vec()
151 }
152 ContextStrategy::Summarized { keep_recent, .. } => {
153 let start_idx = self.messages.len().saturating_sub(*keep_recent);
154 self.messages[start_idx..].to_vec()
155 }
156 ContextStrategy::Smart { .. } => {
157 let window_size = 20;
160 let start_idx = self.messages.len().saturating_sub(window_size);
161 self.messages[start_idx..].to_vec()
162 }
163 }
164 }
165
166 pub fn is_expired(&self, timeout_seconds: u64) -> bool {
168 let now = Utc::now();
169 let timeout = chrono::Duration::seconds(timeout_seconds as i64);
170 now.signed_duration_since(self.last_activity) > timeout
171 }
172
173 pub fn duration(&self) -> chrono::Duration {
175 self.last_activity.signed_duration_since(self.created_at)
176 }
177
178 pub fn message_count(&self) -> usize {
180 self.messages.len()
181 }
182
183 pub fn pause(&mut self) {
185 self.status = SessionStatus::Paused;
186 }
187
188 pub fn resume(&mut self) {
190 self.status = SessionStatus::Active;
191 self.last_activity = Utc::now();
192 }
193
194 pub fn close(&mut self) {
196 self.status = SessionStatus::Closed;
197 }
198}
199
200#[derive(Debug)]
202pub struct SessionManager {
203 sessions: HashMap<String, ConversationSession>,
204 config: SessionConfig,
205}
206
207impl SessionManager {
208 pub fn new(config: SessionConfig) -> Self {
210 Self {
211 sessions: HashMap::new(),
212 config,
213 }
214 }
215
216 pub fn create_session(
218 &mut self,
219 metadata: SessionMetadata,
220 context_strategy: Option<ContextStrategy>,
221 ) -> String {
222 let strategy = context_strategy.unwrap_or(self.config.default_context_strategy.clone());
223 let session = ConversationSession::new(metadata, strategy);
224 let session_id = session.id.clone();
225
226 self.sessions.insert(session_id.clone(), session);
227 session_id
228 }
229
230 pub fn get_session(&self, session_id: &str) -> Option<&ConversationSession> {
232 self.sessions.get(session_id)
233 }
234
235 pub fn get_session_mut(&mut self, session_id: &str) -> Option<&mut ConversationSession> {
237 self.sessions.get_mut(session_id)
238 }
239
240 pub fn add_message(&mut self, session_id: &str, message: LLMMessage) -> LLMResult<()> {
242 let session = self
243 .sessions
244 .get_mut(session_id)
245 .ok_or_else(|| LLMError::session(format!("Session not found: {}", session_id)))?;
246
247 if session.status != SessionStatus::Active {
248 return Err(LLMError::session("Session is not active"));
249 }
250
251 session.add_message(message);
252
253 if session.messages.len() > self.config.max_history_length {
255 let excess = session.messages.len() - self.config.max_history_length;
256 session.messages.drain(0..excess);
257 }
258
259 Ok(())
260 }
261
262 pub fn get_active_messages(&self, session_id: &str) -> LLMResult<Vec<LLMMessage>> {
264 let session = self
265 .sessions
266 .get(session_id)
267 .ok_or_else(|| LLMError::session(format!("Session not found: {}", session_id)))?;
268
269 Ok(session.get_active_messages())
270 }
271
272 pub fn list_sessions(&self) -> Vec<String> {
274 self.sessions.keys().cloned().collect()
275 }
276
277 pub fn get_user_sessions(&self, user_id: &str) -> Vec<String> {
279 self.sessions
280 .iter()
281 .filter_map(|(id, session)| {
282 if session.metadata.user_id == user_id {
283 Some(id.clone())
284 } else {
285 None
286 }
287 })
288 .collect()
289 }
290
291 pub fn cleanup_expired(&mut self) -> usize {
293 let timeout = self.config.session_timeout_seconds;
294 let expired_ids: Vec<_> = self
295 .sessions
296 .iter()
297 .filter_map(|(id, session)| {
298 if session.is_expired(timeout) {
299 Some(id.clone())
300 } else {
301 None
302 }
303 })
304 .collect();
305
306 let count = expired_ids.len();
307 for id in expired_ids {
308 if let Some(mut session) = self.sessions.remove(&id) {
309 session.status = SessionStatus::Expired;
310 }
311 }
312
313 count
314 }
315
316 pub fn pause_session(&mut self, session_id: &str) -> LLMResult<()> {
318 let session = self
319 .sessions
320 .get_mut(session_id)
321 .ok_or_else(|| LLMError::session(format!("Session not found: {}", session_id)))?;
322
323 session.pause();
324 Ok(())
325 }
326
327 pub fn resume_session(&mut self, session_id: &str) -> LLMResult<()> {
329 let session = self
330 .sessions
331 .get_mut(session_id)
332 .ok_or_else(|| LLMError::session(format!("Session not found: {}", session_id)))?;
333
334 session.resume();
335 Ok(())
336 }
337
338 pub fn close_session(&mut self, session_id: &str) -> LLMResult<()> {
340 let session = self
341 .sessions
342 .get_mut(session_id)
343 .ok_or_else(|| LLMError::session(format!("Session not found: {}", session_id)))?;
344
345 session.close();
346 Ok(())
347 }
348
349 pub fn remove_session(&mut self, session_id: &str) -> Option<ConversationSession> {
351 self.sessions.remove(session_id)
352 }
353
354 pub fn get_stats(&self) -> SessionStats {
356 let total_sessions = self.sessions.len();
357 let active_sessions = self
358 .sessions
359 .values()
360 .filter(|s| s.status == SessionStatus::Active)
361 .count();
362 let paused_sessions = self
363 .sessions
364 .values()
365 .filter(|s| s.status == SessionStatus::Paused)
366 .count();
367 let total_messages: usize = self.sessions.values().map(|s| s.message_count()).sum();
368
369 SessionStats {
370 total_sessions,
371 active_sessions,
372 paused_sessions,
373 total_messages,
374 }
375 }
376}
377
378#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct SessionStats {
381 pub total_sessions: usize,
383 pub active_sessions: usize,
385 pub paused_sessions: usize,
387 pub total_messages: usize,
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::llm::core::LLMMessage;
395
396 #[test]
397 fn test_session_creation() {
398 let metadata = SessionMetadata::new("user123")
399 .with_tag("test")
400 .with_custom("priority".to_string(), serde_json::json!(1));
401
402 let session =
403 ConversationSession::new(metadata, ContextStrategy::SlidingWindow { window_size: 10 });
404
405 assert!(!session.id.is_empty());
406 assert_eq!(session.metadata.user_id, "user123");
407 assert!(session.metadata.tags.contains(&"test".to_string()));
408 assert_eq!(session.status, SessionStatus::Active);
409 assert_eq!(session.message_count(), 0);
410 }
411
412 #[test]
413 fn test_session_messages() {
414 let metadata = SessionMetadata::new("user123");
415 let mut session = ConversationSession::new(metadata, ContextStrategy::FullHistory);
416
417 session.add_message(LLMMessage::user("Hello"));
418 session.add_message(LLMMessage::assistant("Hi there!"));
419
420 assert_eq!(session.message_count(), 2);
421
422 let active_messages = session.get_active_messages();
423 assert_eq!(active_messages.len(), 2);
424 }
425
426 #[test]
427 fn test_sliding_window_context() {
428 let metadata = SessionMetadata::new("user123");
429 let mut session =
430 ConversationSession::new(metadata, ContextStrategy::SlidingWindow { window_size: 2 });
431
432 session.add_message(LLMMessage::user("Message 1"));
433 session.add_message(LLMMessage::assistant("Response 1"));
434 session.add_message(LLMMessage::user("Message 2"));
435 session.add_message(LLMMessage::assistant("Response 2"));
436
437 let active_messages = session.get_active_messages();
438 assert_eq!(active_messages.len(), 2); assert_eq!(active_messages[0].content.as_text(), Some("Message 2"));
440 assert_eq!(active_messages[1].content.as_text(), Some("Response 2"));
441 }
442
443 #[test]
444 fn test_session_manager() {
445 let config = SessionConfig::default();
446 let mut manager = SessionManager::new(config);
447
448 let metadata = SessionMetadata::new("user123");
449 let session_id = manager.create_session(metadata, None);
450
451 assert!(manager.get_session(&session_id).is_some());
452 assert_eq!(manager.list_sessions().len(), 1);
453
454 manager
455 .add_message(&session_id, LLMMessage::user("Hello"))
456 .unwrap();
457
458 let active_messages = manager.get_active_messages(&session_id).unwrap();
459 assert_eq!(active_messages.len(), 1);
460
461 let stats = manager.get_stats();
462 assert_eq!(stats.total_sessions, 1);
463 assert_eq!(stats.active_sessions, 1);
464 assert_eq!(stats.total_messages, 1);
465 }
466
467 #[test]
468 fn test_session_status_management() {
469 let config = SessionConfig::default();
470 let mut manager = SessionManager::new(config);
471
472 let metadata = SessionMetadata::new("user123");
473 let session_id = manager.create_session(metadata, None);
474
475 manager.pause_session(&session_id).unwrap();
476 let session = manager.get_session(&session_id).unwrap();
477 assert_eq!(session.status, SessionStatus::Paused);
478
479 manager.resume_session(&session_id).unwrap();
480 let session = manager.get_session(&session_id).unwrap();
481 assert_eq!(session.status, SessionStatus::Active);
482
483 manager.close_session(&session_id).unwrap();
484 let session = manager.get_session(&session_id).unwrap();
485 assert_eq!(session.status, SessionStatus::Closed);
486 }
487}