universal_bot_core/
context.rs

1//! Context management for conversation state
2//!
3//! This module provides context tracking and management for maintaining
4//! conversation state across multiple interactions.
5
6use std::collections::{HashMap, VecDeque};
7use std::sync::Arc;
8use std::time::Duration;
9
10use anyhow::Result;
11use chrono::{DateTime, Utc};
12use dashmap::DashMap;
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15use tracing::{debug, instrument};
16use uuid::Uuid;
17
18use crate::{
19    config::{ContextConfig, StorageBackend},
20    error::Error,
21    message::{Message, Response},
22};
23
24/// Conversation context containing state and history
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct Context {
27    /// Unique context ID
28    pub id: String,
29
30    /// Conversation history
31    pub history: VecDeque<ContextMessage>,
32
33    /// User information
34    pub user: UserContext,
35
36    /// Session variables
37    pub variables: HashMap<String, serde_json::Value>,
38
39    /// Context metadata
40    pub metadata: ContextMetadata,
41
42    /// Token count for the context
43    pub token_count: usize,
44}
45
46impl Context {
47    /// Create a new context
48    #[must_use]
49    pub fn new(id: impl Into<String>) -> Self {
50        Self {
51            id: id.into(),
52            history: VecDeque::new(),
53            user: UserContext::default(),
54            variables: HashMap::new(),
55            metadata: ContextMetadata::new(),
56            token_count: 0,
57        }
58    }
59
60    /// Add a message to the history
61    pub fn add_message(&mut self, message: &Message) {
62        let context_msg = ContextMessage::from_message(message);
63        self.token_count += context_msg.estimated_tokens();
64        self.history.push_back(context_msg);
65        self.metadata.last_activity = Utc::now();
66        self.metadata.message_count += 1;
67    }
68
69    /// Add a response to the history
70    pub fn add_response(&mut self, response: &Response) {
71        let context_msg = ContextMessage::from_response(response);
72        self.token_count += context_msg.estimated_tokens();
73        self.history.push_back(context_msg);
74        self.metadata.last_activity = Utc::now();
75        self.metadata.message_count += 1;
76
77        if let Some(usage) = &response.usage {
78            self.metadata.total_tokens += usage.total_tokens;
79            self.metadata.total_cost += usage.estimated_cost;
80        }
81    }
82
83    /// Trim history to fit within token limit
84    pub fn trim_to_token_limit(&mut self, max_tokens: usize) {
85        while self.token_count > max_tokens && !self.history.is_empty() {
86            if let Some(removed) = self.history.pop_front() {
87                self.token_count = self.token_count.saturating_sub(removed.estimated_tokens());
88            }
89        }
90    }
91
92    /// Get a variable value
93    pub fn get_variable(&self, key: &str) -> Option<&serde_json::Value> {
94        self.variables.get(key)
95    }
96
97    /// Set a variable value
98    pub fn set_variable(&mut self, key: impl Into<String>, value: serde_json::Value) {
99        self.variables.insert(key.into(), value);
100    }
101
102    /// Clear all history
103    pub fn clear_history(&mut self) {
104        self.history.clear();
105        self.token_count = 0;
106        self.metadata.message_count = 0;
107    }
108
109    /// Get the age of the context
110    #[must_use]
111    pub fn age(&self) -> Duration {
112        let now = Utc::now();
113        (now - self.metadata.created_at)
114            .to_std()
115            .unwrap_or(Duration::ZERO)
116    }
117
118    /// Check if the context is expired
119    #[must_use]
120    pub fn is_expired(&self, ttl: Duration) -> bool {
121        self.age() > ttl
122    }
123
124    /// Get a summary of the context
125    #[must_use]
126    pub fn summary(&self) -> String {
127        format!(
128            "Context {} - Messages: {}, Tokens: {}, Age: {:?}",
129            self.id,
130            self.metadata.message_count,
131            self.token_count,
132            self.age()
133        )
134    }
135}
136
137/// A message in the context history
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct ContextMessage {
140    /// Message role
141    pub role: MessageRole,
142    /// Message content
143    pub content: String,
144    /// Timestamp
145    pub timestamp: DateTime<Utc>,
146    /// Optional message ID
147    pub message_id: Option<Uuid>,
148}
149
150impl ContextMessage {
151    /// Create from a user message
152    pub fn from_message(message: &Message) -> Self {
153        Self {
154            role: MessageRole::User,
155            content: message.content.clone(),
156            timestamp: message.timestamp,
157            message_id: Some(message.id),
158        }
159    }
160
161    /// Create from a bot response
162    pub fn from_response(response: &Response) -> Self {
163        Self {
164            role: MessageRole::Assistant,
165            content: response.content.clone(),
166            timestamp: response.timestamp,
167            message_id: Some(response.id),
168        }
169    }
170
171    /// Create a system message
172    pub fn system(content: impl Into<String>) -> Self {
173        Self {
174            role: MessageRole::System,
175            content: content.into(),
176            timestamp: Utc::now(),
177            message_id: None,
178        }
179    }
180
181    /// Estimate token count (rough approximation)
182    const fn estimated_tokens(&self) -> usize {
183        // Rough estimate: 1 token per 4 characters
184        self.content.len() / 4
185    }
186}
187
188/// Message role in conversation
189#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
190#[serde(rename_all = "lowercase")]
191pub enum MessageRole {
192    /// System message
193    System,
194    /// User message
195    User,
196    /// Assistant message
197    Assistant,
198}
199
200/// User context information
201#[derive(Debug, Clone, Default, Serialize, Deserialize)]
202pub struct UserContext {
203    /// User ID
204    pub id: Option<String>,
205    /// User name
206    pub name: Option<String>,
207    /// User preferences
208    pub preferences: HashMap<String, serde_json::Value>,
209    /// User attributes
210    pub attributes: HashMap<String, String>,
211}
212
213/// Context metadata
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct ContextMetadata {
216    /// When the context was created
217    pub created_at: DateTime<Utc>,
218    /// Last activity timestamp
219    pub last_activity: DateTime<Utc>,
220    /// Total message count
221    pub message_count: usize,
222    /// Total tokens used
223    pub total_tokens: usize,
224    /// Total cost incurred
225    pub total_cost: f64,
226    /// Custom tags
227    pub tags: Vec<String>,
228}
229
230impl ContextMetadata {
231    fn new() -> Self {
232        let now = Utc::now();
233        Self {
234            created_at: now,
235            last_activity: now,
236            message_count: 0,
237            total_tokens: 0,
238            total_cost: 0.0,
239            tags: Vec::new(),
240        }
241    }
242}
243
244/// Context manager for handling multiple conversation contexts
245pub struct ContextManager {
246    config: ContextConfig,
247    store: Arc<dyn ContextStore>,
248    cache: Arc<DashMap<String, Arc<RwLock<Context>>>>,
249}
250
251impl ContextManager {
252    /// Create a new context manager
253    ///
254    /// # Errors
255    ///
256    /// Returns an error if store initialization fails.
257    #[instrument(skip(config))]
258    pub async fn new(config: ContextConfig) -> Result<Self> {
259        debug!("Creating context manager with config: {:?}", config);
260
261        let store: Arc<dyn ContextStore> = match &config.storage_backend {
262            StorageBackend::Memory => Arc::new(MemoryContextStore::new()),
263            StorageBackend::Redis { url: _ } => {
264                // Would initialize Redis store here
265                return Err(Error::new("Redis store not yet implemented").into());
266            }
267            StorageBackend::Postgres { url: _ } => {
268                // Would initialize Postgres store here
269                return Err(Error::new("Postgres store not yet implemented").into());
270            }
271            StorageBackend::Sqlite { path: _ } => {
272                // Would initialize SQLite store here
273                return Err(Error::new("SQLite store not yet implemented").into());
274            }
275        };
276
277        Ok(Self {
278            config,
279            store,
280            cache: Arc::new(DashMap::new()),
281        })
282    }
283
284    /// Get or create a context
285    ///
286    /// # Errors
287    ///
288    /// Returns an error if context creation or retrieval fails
289    #[instrument(skip(self))]
290    pub async fn get_or_create(&self, id: &str) -> Result<Arc<RwLock<Context>>> {
291        // Check cache first
292        if let Some(context) = self.cache.get(id) {
293            let ctx = context.clone();
294
295            // Check if expired
296            if ctx.read().is_expired(self.config.context_ttl) {
297                debug!("Context {} is expired, removing", id);
298                self.cache.remove(id);
299            } else {
300                debug!("Found context {} in cache", id);
301                return Ok(ctx);
302            }
303        }
304
305        // Try to load from store
306        if let Some(context) = self.store.get(id).await? {
307            if !context.is_expired(self.config.context_ttl) {
308                debug!("Loaded context {} from store", id);
309                let ctx = Arc::new(RwLock::new(context));
310                self.cache.insert(id.to_string(), ctx.clone());
311                return Ok(ctx);
312            }
313        }
314
315        // Create new context
316        debug!("Creating new context {}", id);
317        let context = Context::new(id);
318        let ctx = Arc::new(RwLock::new(context));
319        self.cache.insert(id.to_string(), ctx.clone());
320
321        // Persist if configured
322        if self.config.persist_context {
323            let context = ctx.read().clone();
324            self.store.set(id, context, self.config.context_ttl).await?;
325        }
326
327        Ok(ctx)
328    }
329
330    /// Update a context
331    ///
332    /// # Errors
333    ///
334    /// Returns an error if the update operation fails
335    #[instrument(skip(self, context))]
336    pub async fn update(&self, id: &str, context: Arc<RwLock<Context>>) -> Result<()> {
337        // Trim to token limit
338        {
339            let mut ctx = context.write();
340            ctx.trim_to_token_limit(self.config.max_context_tokens);
341        }
342
343        // Update cache
344        self.cache.insert(id.to_string(), context.clone());
345
346        // Persist if configured
347        if self.config.persist_context {
348            let ctx = context.read().clone();
349            self.store.set(id, ctx, self.config.context_ttl).await?;
350        }
351
352        Ok(())
353    }
354
355    /// Delete a context
356    ///
357    /// # Errors
358    ///
359    /// Returns an error if the deletion fails
360    #[instrument(skip(self))]
361    pub async fn delete(&self, id: &str) -> Result<()> {
362        debug!("Deleting context {}", id);
363        self.cache.remove(id);
364        self.store.delete(id).await?;
365        Ok(())
366    }
367
368    /// Clear expired contexts
369    ///
370    /// # Errors
371    ///
372    /// Returns an error if clearing expired contexts fails
373    #[instrument(skip(self))]
374    pub async fn clear_expired(&self) -> Result<usize> {
375        let mut removed = 0;
376        let expired_keys: Vec<String> = self
377            .cache
378            .iter()
379            .filter(|entry| entry.value().read().is_expired(self.config.context_ttl))
380            .map(|entry| entry.key().clone())
381            .collect();
382
383        for key in expired_keys {
384            self.cache.remove(&key);
385            self.store.delete(&key).await?;
386            removed += 1;
387        }
388
389        debug!("Removed {} expired contexts", removed);
390        Ok(removed)
391    }
392
393    /// Get statistics about managed contexts
394    #[must_use]
395    pub fn stats(&self) -> ContextStats {
396        let total = self.cache.len();
397        let mut total_tokens = 0;
398        let mut total_messages = 0;
399
400        for entry in self.cache.iter() {
401            let ctx = entry.value().read();
402            total_tokens += ctx.token_count;
403            total_messages += ctx.metadata.message_count;
404        }
405
406        ContextStats {
407            total_contexts: total,
408            total_tokens,
409            total_messages,
410            cache_size: total,
411        }
412    }
413}
414
415/// Context store trait for persistence
416#[async_trait::async_trait]
417pub trait ContextStore: Send + Sync {
418    /// Get a context by ID
419    async fn get(&self, key: &str) -> Result<Option<Context>>;
420
421    /// Set a context with TTL
422    async fn set(&self, key: &str, context: Context, ttl: Duration) -> Result<()>;
423
424    /// Delete a context
425    async fn delete(&self, key: &str) -> Result<()>;
426
427    /// List all context keys
428    async fn list_keys(&self, pattern: &str) -> Result<Vec<String>>;
429}
430
431/// In-memory context store implementation
432struct MemoryContextStore {
433    data: Arc<DashMap<String, (Context, DateTime<Utc>)>>,
434}
435
436impl MemoryContextStore {
437    fn new() -> Self {
438        Self {
439            data: Arc::new(DashMap::new()),
440        }
441    }
442}
443
444#[async_trait::async_trait]
445impl ContextStore for MemoryContextStore {
446    async fn get(&self, key: &str) -> Result<Option<Context>> {
447        Ok(self.data.get(key).map(|entry| entry.0.clone()))
448    }
449
450    async fn set(&self, key: &str, context: Context, ttl: Duration) -> Result<()> {
451        let expiry = Utc::now() + chrono::Duration::from_std(ttl)?;
452        self.data.insert(key.to_string(), (context, expiry));
453        Ok(())
454    }
455
456    async fn delete(&self, key: &str) -> Result<()> {
457        self.data.remove(key);
458        Ok(())
459    }
460
461    async fn list_keys(&self, pattern: &str) -> Result<Vec<String>> {
462        let keys = self
463            .data
464            .iter()
465            .filter(|entry| entry.key().contains(pattern))
466            .map(|entry| entry.key().clone())
467            .collect();
468        Ok(keys)
469    }
470}
471
472/// Statistics about managed contexts
473#[derive(Debug, Clone)]
474pub struct ContextStats {
475    /// Total number of contexts
476    pub total_contexts: usize,
477    /// Total tokens across all contexts
478    pub total_tokens: usize,
479    /// Total messages across all contexts
480    pub total_messages: usize,
481    /// Number of contexts in cache
482    pub cache_size: usize,
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488
489    #[test]
490    fn test_context_creation() {
491        let context = Context::new("test-123");
492        assert_eq!(context.id, "test-123");
493        assert!(context.history.is_empty());
494        assert_eq!(context.token_count, 0);
495    }
496
497    #[test]
498    fn test_context_message_addition() {
499        let mut context = Context::new("test");
500        let message = Message::text("Hello");
501
502        context.add_message(&message);
503        assert_eq!(context.history.len(), 1);
504        assert!(context.token_count > 0);
505        assert_eq!(context.metadata.message_count, 1);
506    }
507
508    #[test]
509    fn test_context_trimming() {
510        let mut context = Context::new("test");
511
512        // Add multiple messages
513        for i in 0..10 {
514            let msg = Message::text(format!("Message {i}"));
515            context.add_message(&msg);
516        }
517
518        let original_count = context.history.len();
519        context.trim_to_token_limit(10); // Very low limit
520
521        assert!(context.history.len() < original_count);
522        assert!(context.token_count <= 10);
523    }
524
525    #[test]
526    fn test_context_variables() {
527        let mut context = Context::new("test");
528
529        context.set_variable("key", serde_json::json!("value"));
530        assert_eq!(
531            context.get_variable("key"),
532            Some(&serde_json::json!("value"))
533        );
534        assert_eq!(context.get_variable("missing"), None);
535    }
536
537    #[test]
538    fn test_context_expiry() {
539        let context = Context::new("test");
540        assert!(!context.is_expired(Duration::from_secs(3600)));
541
542        // Can't easily test actual expiry without mocking time
543    }
544
545    #[tokio::test]
546    async fn test_context_manager() {
547        let config = ContextConfig::default();
548        let manager = ContextManager::new(config).await.unwrap();
549
550        let ctx1 = manager.get_or_create("test-1").await.unwrap();
551        let ctx2 = manager.get_or_create("test-1").await.unwrap();
552
553        // Should get the same context
554        assert_eq!(ctx1.read().id, ctx2.read().id);
555    }
556
557    #[tokio::test]
558    async fn test_memory_store() {
559        let store = MemoryContextStore::new();
560        let context = Context::new("test");
561
562        store
563            .set("test", context.clone(), Duration::from_secs(60))
564            .await
565            .unwrap();
566
567        let loaded = store.get("test").await.unwrap();
568        assert!(loaded.is_some());
569        assert_eq!(loaded.unwrap().id, "test");
570
571        store.delete("test").await.unwrap();
572        let deleted = store.get("test").await.unwrap();
573        assert!(deleted.is_none());
574    }
575}