1use std::collections::{HashMap, VecDeque};
7use std::sync::Arc;
8use std::time::Duration;
9
10use anyhow::Result;
11use chrono::{DateTime, Utc};
12use dashmap::DashMap;
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15use tracing::{debug, instrument};
16use uuid::Uuid;
17
18use crate::{
19 config::{ContextConfig, StorageBackend},
20 error::Error,
21 message::{Message, Response},
22};
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct Context {
27 pub id: String,
29
30 pub history: VecDeque<ContextMessage>,
32
33 pub user: UserContext,
35
36 pub variables: HashMap<String, serde_json::Value>,
38
39 pub metadata: ContextMetadata,
41
42 pub token_count: usize,
44}
45
46impl Context {
47 #[must_use]
49 pub fn new(id: impl Into<String>) -> Self {
50 Self {
51 id: id.into(),
52 history: VecDeque::new(),
53 user: UserContext::default(),
54 variables: HashMap::new(),
55 metadata: ContextMetadata::new(),
56 token_count: 0,
57 }
58 }
59
60 pub fn add_message(&mut self, message: &Message) {
62 let context_msg = ContextMessage::from_message(message);
63 self.token_count += context_msg.estimated_tokens();
64 self.history.push_back(context_msg);
65 self.metadata.last_activity = Utc::now();
66 self.metadata.message_count += 1;
67 }
68
69 pub fn add_response(&mut self, response: &Response) {
71 let context_msg = ContextMessage::from_response(response);
72 self.token_count += context_msg.estimated_tokens();
73 self.history.push_back(context_msg);
74 self.metadata.last_activity = Utc::now();
75 self.metadata.message_count += 1;
76
77 if let Some(usage) = &response.usage {
78 self.metadata.total_tokens += usage.total_tokens;
79 self.metadata.total_cost += usage.estimated_cost;
80 }
81 }
82
83 pub fn trim_to_token_limit(&mut self, max_tokens: usize) {
85 while self.token_count > max_tokens && !self.history.is_empty() {
86 if let Some(removed) = self.history.pop_front() {
87 self.token_count = self.token_count.saturating_sub(removed.estimated_tokens());
88 }
89 }
90 }
91
92 pub fn get_variable(&self, key: &str) -> Option<&serde_json::Value> {
94 self.variables.get(key)
95 }
96
97 pub fn set_variable(&mut self, key: impl Into<String>, value: serde_json::Value) {
99 self.variables.insert(key.into(), value);
100 }
101
102 pub fn clear_history(&mut self) {
104 self.history.clear();
105 self.token_count = 0;
106 self.metadata.message_count = 0;
107 }
108
109 #[must_use]
111 pub fn age(&self) -> Duration {
112 let now = Utc::now();
113 (now - self.metadata.created_at)
114 .to_std()
115 .unwrap_or(Duration::ZERO)
116 }
117
118 #[must_use]
120 pub fn is_expired(&self, ttl: Duration) -> bool {
121 self.age() > ttl
122 }
123
124 #[must_use]
126 pub fn summary(&self) -> String {
127 format!(
128 "Context {} - Messages: {}, Tokens: {}, Age: {:?}",
129 self.id,
130 self.metadata.message_count,
131 self.token_count,
132 self.age()
133 )
134 }
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct ContextMessage {
140 pub role: MessageRole,
142 pub content: String,
144 pub timestamp: DateTime<Utc>,
146 pub message_id: Option<Uuid>,
148}
149
150impl ContextMessage {
151 pub fn from_message(message: &Message) -> Self {
153 Self {
154 role: MessageRole::User,
155 content: message.content.clone(),
156 timestamp: message.timestamp,
157 message_id: Some(message.id),
158 }
159 }
160
161 pub fn from_response(response: &Response) -> Self {
163 Self {
164 role: MessageRole::Assistant,
165 content: response.content.clone(),
166 timestamp: response.timestamp,
167 message_id: Some(response.id),
168 }
169 }
170
171 pub fn system(content: impl Into<String>) -> Self {
173 Self {
174 role: MessageRole::System,
175 content: content.into(),
176 timestamp: Utc::now(),
177 message_id: None,
178 }
179 }
180
181 const fn estimated_tokens(&self) -> usize {
183 self.content.len() / 4
185 }
186}
187
188#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
190#[serde(rename_all = "lowercase")]
191pub enum MessageRole {
192 System,
194 User,
196 Assistant,
198}
199
200#[derive(Debug, Clone, Default, Serialize, Deserialize)]
202pub struct UserContext {
203 pub id: Option<String>,
205 pub name: Option<String>,
207 pub preferences: HashMap<String, serde_json::Value>,
209 pub attributes: HashMap<String, String>,
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct ContextMetadata {
216 pub created_at: DateTime<Utc>,
218 pub last_activity: DateTime<Utc>,
220 pub message_count: usize,
222 pub total_tokens: usize,
224 pub total_cost: f64,
226 pub tags: Vec<String>,
228}
229
230impl ContextMetadata {
231 fn new() -> Self {
232 let now = Utc::now();
233 Self {
234 created_at: now,
235 last_activity: now,
236 message_count: 0,
237 total_tokens: 0,
238 total_cost: 0.0,
239 tags: Vec::new(),
240 }
241 }
242}
243
244pub struct ContextManager {
246 config: ContextConfig,
247 store: Arc<dyn ContextStore>,
248 cache: Arc<DashMap<String, Arc<RwLock<Context>>>>,
249}
250
251impl ContextManager {
252 #[instrument(skip(config))]
258 pub async fn new(config: ContextConfig) -> Result<Self> {
259 debug!("Creating context manager with config: {:?}", config);
260
261 let store: Arc<dyn ContextStore> = match &config.storage_backend {
262 StorageBackend::Memory => Arc::new(MemoryContextStore::new()),
263 StorageBackend::Redis { url: _ } => {
264 return Err(Error::new("Redis store not yet implemented").into());
266 }
267 StorageBackend::Postgres { url: _ } => {
268 return Err(Error::new("Postgres store not yet implemented").into());
270 }
271 StorageBackend::Sqlite { path: _ } => {
272 return Err(Error::new("SQLite store not yet implemented").into());
274 }
275 };
276
277 Ok(Self {
278 config,
279 store,
280 cache: Arc::new(DashMap::new()),
281 })
282 }
283
284 #[instrument(skip(self))]
290 pub async fn get_or_create(&self, id: &str) -> Result<Arc<RwLock<Context>>> {
291 if let Some(context) = self.cache.get(id) {
293 let ctx = context.clone();
294
295 if ctx.read().is_expired(self.config.context_ttl) {
297 debug!("Context {} is expired, removing", id);
298 self.cache.remove(id);
299 } else {
300 debug!("Found context {} in cache", id);
301 return Ok(ctx);
302 }
303 }
304
305 if let Some(context) = self.store.get(id).await? {
307 if !context.is_expired(self.config.context_ttl) {
308 debug!("Loaded context {} from store", id);
309 let ctx = Arc::new(RwLock::new(context));
310 self.cache.insert(id.to_string(), ctx.clone());
311 return Ok(ctx);
312 }
313 }
314
315 debug!("Creating new context {}", id);
317 let context = Context::new(id);
318 let ctx = Arc::new(RwLock::new(context));
319 self.cache.insert(id.to_string(), ctx.clone());
320
321 if self.config.persist_context {
323 let context = ctx.read().clone();
324 self.store.set(id, context, self.config.context_ttl).await?;
325 }
326
327 Ok(ctx)
328 }
329
330 #[instrument(skip(self, context))]
336 pub async fn update(&self, id: &str, context: Arc<RwLock<Context>>) -> Result<()> {
337 {
339 let mut ctx = context.write();
340 ctx.trim_to_token_limit(self.config.max_context_tokens);
341 }
342
343 self.cache.insert(id.to_string(), context.clone());
345
346 if self.config.persist_context {
348 let ctx = context.read().clone();
349 self.store.set(id, ctx, self.config.context_ttl).await?;
350 }
351
352 Ok(())
353 }
354
355 #[instrument(skip(self))]
361 pub async fn delete(&self, id: &str) -> Result<()> {
362 debug!("Deleting context {}", id);
363 self.cache.remove(id);
364 self.store.delete(id).await?;
365 Ok(())
366 }
367
368 #[instrument(skip(self))]
374 pub async fn clear_expired(&self) -> Result<usize> {
375 let mut removed = 0;
376 let expired_keys: Vec<String> = self
377 .cache
378 .iter()
379 .filter(|entry| entry.value().read().is_expired(self.config.context_ttl))
380 .map(|entry| entry.key().clone())
381 .collect();
382
383 for key in expired_keys {
384 self.cache.remove(&key);
385 self.store.delete(&key).await?;
386 removed += 1;
387 }
388
389 debug!("Removed {} expired contexts", removed);
390 Ok(removed)
391 }
392
393 #[must_use]
395 pub fn stats(&self) -> ContextStats {
396 let total = self.cache.len();
397 let mut total_tokens = 0;
398 let mut total_messages = 0;
399
400 for entry in self.cache.iter() {
401 let ctx = entry.value().read();
402 total_tokens += ctx.token_count;
403 total_messages += ctx.metadata.message_count;
404 }
405
406 ContextStats {
407 total_contexts: total,
408 total_tokens,
409 total_messages,
410 cache_size: total,
411 }
412 }
413}
414
415#[async_trait::async_trait]
417pub trait ContextStore: Send + Sync {
418 async fn get(&self, key: &str) -> Result<Option<Context>>;
420
421 async fn set(&self, key: &str, context: Context, ttl: Duration) -> Result<()>;
423
424 async fn delete(&self, key: &str) -> Result<()>;
426
427 async fn list_keys(&self, pattern: &str) -> Result<Vec<String>>;
429}
430
431struct MemoryContextStore {
433 data: Arc<DashMap<String, (Context, DateTime<Utc>)>>,
434}
435
436impl MemoryContextStore {
437 fn new() -> Self {
438 Self {
439 data: Arc::new(DashMap::new()),
440 }
441 }
442}
443
444#[async_trait::async_trait]
445impl ContextStore for MemoryContextStore {
446 async fn get(&self, key: &str) -> Result<Option<Context>> {
447 Ok(self.data.get(key).map(|entry| entry.0.clone()))
448 }
449
450 async fn set(&self, key: &str, context: Context, ttl: Duration) -> Result<()> {
451 let expiry = Utc::now() + chrono::Duration::from_std(ttl)?;
452 self.data.insert(key.to_string(), (context, expiry));
453 Ok(())
454 }
455
456 async fn delete(&self, key: &str) -> Result<()> {
457 self.data.remove(key);
458 Ok(())
459 }
460
461 async fn list_keys(&self, pattern: &str) -> Result<Vec<String>> {
462 let keys = self
463 .data
464 .iter()
465 .filter(|entry| entry.key().contains(pattern))
466 .map(|entry| entry.key().clone())
467 .collect();
468 Ok(keys)
469 }
470}
471
472#[derive(Debug, Clone)]
474pub struct ContextStats {
475 pub total_contexts: usize,
477 pub total_tokens: usize,
479 pub total_messages: usize,
481 pub cache_size: usize,
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
490 fn test_context_creation() {
491 let context = Context::new("test-123");
492 assert_eq!(context.id, "test-123");
493 assert!(context.history.is_empty());
494 assert_eq!(context.token_count, 0);
495 }
496
497 #[test]
498 fn test_context_message_addition() {
499 let mut context = Context::new("test");
500 let message = Message::text("Hello");
501
502 context.add_message(&message);
503 assert_eq!(context.history.len(), 1);
504 assert!(context.token_count > 0);
505 assert_eq!(context.metadata.message_count, 1);
506 }
507
508 #[test]
509 fn test_context_trimming() {
510 let mut context = Context::new("test");
511
512 for i in 0..10 {
514 let msg = Message::text(format!("Message {i}"));
515 context.add_message(&msg);
516 }
517
518 let original_count = context.history.len();
519 context.trim_to_token_limit(10); assert!(context.history.len() < original_count);
522 assert!(context.token_count <= 10);
523 }
524
525 #[test]
526 fn test_context_variables() {
527 let mut context = Context::new("test");
528
529 context.set_variable("key", serde_json::json!("value"));
530 assert_eq!(
531 context.get_variable("key"),
532 Some(&serde_json::json!("value"))
533 );
534 assert_eq!(context.get_variable("missing"), None);
535 }
536
537 #[test]
538 fn test_context_expiry() {
539 let context = Context::new("test");
540 assert!(!context.is_expired(Duration::from_secs(3600)));
541
542 }
544
545 #[tokio::test]
546 async fn test_context_manager() {
547 let config = ContextConfig::default();
548 let manager = ContextManager::new(config).await.unwrap();
549
550 let ctx1 = manager.get_or_create("test-1").await.unwrap();
551 let ctx2 = manager.get_or_create("test-1").await.unwrap();
552
553 assert_eq!(ctx1.read().id, ctx2.read().id);
555 }
556
557 #[tokio::test]
558 async fn test_memory_store() {
559 let store = MemoryContextStore::new();
560 let context = Context::new("test");
561
562 store
563 .set("test", context.clone(), Duration::from_secs(60))
564 .await
565 .unwrap();
566
567 let loaded = store.get("test").await.unwrap();
568 assert!(loaded.is_some());
569 assert_eq!(loaded.unwrap().id, "test");
570
571 store.delete("test").await.unwrap();
572 let deleted = store.get("test").await.unwrap();
573 assert!(deleted.is_none());
574 }
575}