rustchain/core/
memory.rs

1use crate::core::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, VecDeque};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6pub trait MemoryStore: Send + Sync {
7    fn store(&mut self, key: &str, value: &str) -> Result<()>;
8    fn retrieve(&self, key: &str) -> Result<Option<String>>;
9    fn list_keys(&self) -> Result<Vec<String>>;
10}
11
12/// Enhanced memory entry with TTL support
13#[derive(Debug, Clone, Serialize, Deserialize)]
14struct MemoryEntry {
15    value: String,
16    created_at: u64,
17    expires_at: Option<u64>,
18}
19
20impl MemoryEntry {
21    fn new(value: String, ttl_seconds: Option<u64>) -> Self {
22        let now_duration = SystemTime::now()
23            .duration_since(UNIX_EPOCH)
24            .unwrap();
25        let now = now_duration.as_nanos() as u64;
26
27        Self {
28            value,
29            created_at: now,
30            expires_at: ttl_seconds.map(|ttl| now + (ttl * 1_000_000_000)), // Convert seconds to nanoseconds
31        }
32    }
33
34    fn is_expired(&self) -> bool {
35        if let Some(expires_at) = self.expires_at {
36            let now = SystemTime::now()
37                .duration_since(UNIX_EPOCH)
38                .unwrap()
39                .as_nanos() as u64;
40            now > expires_at
41        } else {
42            false
43        }
44    }
45}
46
47/// Enhanced in-memory store with TTL, cleanup, and additional operations
48pub struct InMemoryStore {
49    data: HashMap<String, MemoryEntry>,
50    default_ttl: Option<u64>,
51    max_entries: Option<usize>,
52}
53
54impl InMemoryStore {
55    /// Create a new in-memory store with default settings
56    pub fn new() -> Self {
57        Self {
58            data: HashMap::new(),
59            default_ttl: None,
60            max_entries: None,
61        }
62    }
63
64    /// Create a new in-memory store with TTL (overloaded for tests)
65    pub fn with_ttl(ttl_seconds: u64) -> Self {
66        Self {
67            data: HashMap::new(),
68            default_ttl: Some(ttl_seconds),
69            max_entries: None,
70        }
71    }
72
73    /// Create a new in-memory store with capacity limit
74    pub fn with_capacity(max_entries: usize) -> Self {
75        Self {
76            data: HashMap::new(),
77            default_ttl: None,
78            max_entries: Some(max_entries),
79        }
80    }
81
82    /// Create a new in-memory store with both TTL and capacity limit
83    pub fn with_ttl_and_capacity(ttl_seconds: u64, max_entries: usize) -> Self {
84        Self {
85            data: HashMap::new(),
86            default_ttl: Some(ttl_seconds),
87            max_entries: Some(max_entries),
88        }
89    }
90
91    /// Clean up expired entries
92    pub fn cleanup(&mut self) -> Result<()> {
93        let expired_keys: Vec<String> = self
94            .data
95            .iter()
96            .filter(|(_, entry)| entry.is_expired())
97            .map(|(key, _)| key.clone())
98            .collect();
99
100        for key in expired_keys {
101            self.data.remove(&key);
102        }
103
104        Ok(())
105    }
106
107    /// Clear all entries
108    pub fn clear(&mut self) -> Result<()> {
109        self.data.clear();
110        Ok(())
111    }
112
113    /// Get summary of memory store
114    pub fn summarize(&self) -> Result<String> {
115        let total_entries = self.data.len();
116        let expired_entries = self
117            .data
118            .values()
119            .filter(|entry| entry.is_expired())
120            .count();
121        let active_entries = total_entries - expired_entries;
122
123        let total_size: usize = self.data.values().map(|entry| entry.value.len()).sum();
124
125        Ok(format!(
126            "Memory Store Summary: {} entries ({} active, {} expired), {} bytes total",
127            total_entries, active_entries, expired_entries, total_size
128        ))
129    }
130
131    /// Check if an entry exists and is not expired
132    pub fn contains_key(&self, key: &str) -> bool {
133        if let Some(entry) = self.data.get(key) {
134            !entry.is_expired()
135        } else {
136            false
137        }
138    }
139
140    /// Get memory usage statistics
141    pub fn stats(&self) -> MemoryStats {
142        let total_entries = self.data.len();
143        let expired_entries = self
144            .data
145            .values()
146            .filter(|entry| entry.is_expired())
147            .count();
148        let total_size: usize = self.data.values().map(|entry| entry.value.len()).sum();
149
150        MemoryStats {
151            total_entries,
152            active_entries: total_entries - expired_entries,
153            expired_entries,
154            total_size_bytes: total_size,
155            max_entries: self.max_entries,
156            default_ttl: self.default_ttl,
157        }
158    }
159
160    fn ensure_capacity(&mut self) -> Result<()> {
161        if let Some(max_entries) = self.max_entries {
162            // First try cleanup to free space
163            self.cleanup()?;
164
165            // If would exceed capacity after adding new entry, make room by removing oldest
166            while self.data.len() >= max_entries {
167                if let Some(oldest_key) = self
168                    .data
169                    .iter()
170                    .min_by_key(|(_, entry)| entry.created_at)
171                    .map(|(key, _)| key.clone())
172                {
173                    self.data.remove(&oldest_key);
174                } else {
175                    break;
176                }
177            }
178        }
179        Ok(())
180    }
181}
182
183impl MemoryStore for InMemoryStore {
184    fn store(&mut self, key: &str, value: &str) -> Result<()> {
185        // Ensure we don't exceed capacity
186        self.ensure_capacity()?;
187
188        let entry = MemoryEntry::new(value.to_string(), self.default_ttl);
189        self.data.insert(key.to_string(), entry);
190        Ok(())
191    }
192
193    fn retrieve(&self, key: &str) -> Result<Option<String>> {
194        if let Some(entry) = self.data.get(key) {
195            if entry.is_expired() {
196                Ok(None)
197            } else {
198                Ok(Some(entry.value.clone()))
199            }
200        } else {
201            Ok(None)
202        }
203    }
204
205    fn list_keys(&self) -> Result<Vec<String>> {
206        Ok(self
207            .data
208            .iter()
209            .filter(|(_, entry)| !entry.is_expired())
210            .map(|(key, _)| key.clone())
211            .collect())
212    }
213}
214
215/// Memory usage statistics
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct MemoryStats {
218    pub total_entries: usize,
219    pub active_entries: usize,
220    pub expired_entries: usize,
221    pub total_size_bytes: usize,
222    pub max_entries: Option<usize>,
223    pub default_ttl: Option<u64>,
224}
225
226/// Conversation-specific memory for storing and managing chat history
227#[derive(Debug, Clone)]
228pub struct ConversationMemory {
229    messages: VecDeque<ConversationMessage>,
230    max_messages: usize,
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct ConversationMessage {
235    pub role: String,
236    pub content: String,
237    pub timestamp: u64,
238}
239
240impl ConversationMessage {
241    fn new(role: &str, content: &str) -> Self {
242        Self {
243            role: role.to_string(),
244            content: content.to_string(),
245            timestamp: SystemTime::now()
246                .duration_since(UNIX_EPOCH)
247                .unwrap()
248                .as_secs(),
249        }
250    }
251}
252
253impl ConversationMemory {
254    /// Create a new conversation memory with specified capacity
255    pub fn new(max_messages: usize) -> Self {
256        Self {
257            messages: VecDeque::new(),
258            max_messages,
259        }
260    }
261
262    /// Add a message to the conversation
263    pub fn add_message(&mut self, role: &str, content: &str) -> Result<()> {
264        // If max_messages is 0, don't store anything
265        if self.max_messages == 0 {
266            return Ok(());
267        }
268
269        let message = ConversationMessage::new(role, content);
270
271        // Remove oldest message if at capacity
272        if self.messages.len() >= self.max_messages {
273            self.messages.pop_front();
274        }
275
276        self.messages.push_back(message);
277        Ok(())
278    }
279
280    /// Get the entire conversation as formatted strings
281    pub fn get_conversation(&self) -> Result<Vec<String>> {
282        Ok(self
283            .messages
284            .iter()
285            .map(|msg| format!("{}: {}", msg.role, msg.content))
286            .collect())
287    }
288
289    /// Get the most recent N messages
290    pub fn get_recent(&self, count: usize) -> Result<Vec<String>> {
291        Ok(self
292            .messages
293            .iter()
294            .rev()
295            .take(count)
296            .rev()
297            .map(|msg| format!("{}: {}", msg.role, msg.content))
298            .collect())
299    }
300
301    /// Search for messages containing a specific term
302    pub fn search(&self, term: &str) -> Result<Vec<String>> {
303        let term_lower = term.to_lowercase();
304        Ok(self
305            .messages
306            .iter()
307            .filter(|msg| {
308                msg.content.to_lowercase().contains(&term_lower)
309                    || msg.role.to_lowercase().contains(&term_lower)
310            })
311            .map(|msg| format!("{}: {}", msg.role, msg.content))
312            .collect())
313    }
314
315    /// Clear all messages
316    pub fn clear(&mut self) -> Result<()> {
317        self.messages.clear();
318        Ok(())
319    }
320
321    /// Get summary of the conversation
322    pub fn summarize(&self) -> Result<String> {
323        let total_messages = self.messages.len();
324        let roles: std::collections::HashSet<String> =
325            self.messages.iter().map(|msg| msg.role.clone()).collect();
326
327        Ok(format!(
328            "Conversation summary: {} messages from {} participants",
329            total_messages,
330            roles.len()
331        ))
332    }
333
334    /// Get conversation statistics
335    pub fn stats(&self) -> ConversationStats {
336        let mut role_counts: HashMap<String, usize> = HashMap::new();
337        let mut total_chars = 0;
338
339        for msg in &self.messages {
340            *role_counts.entry(msg.role.clone()).or_insert(0) += 1;
341            total_chars += msg.content.len();
342        }
343
344        ConversationStats {
345            total_messages: self.messages.len(),
346            role_counts,
347            total_characters: total_chars,
348            max_capacity: self.max_messages,
349        }
350    }
351}
352
353/// Conversation statistics
354#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct ConversationStats {
356    pub total_messages: usize,
357    pub role_counts: HashMap<String, usize>,
358    pub total_characters: usize,
359    pub max_capacity: usize,
360}
361
362/// ContextLite Memory Store - Persistent storage backend
363#[cfg(feature = "contextlite")]
364pub struct ContextLiteStore {
365    _endpoint: String,
366    _agent_id: String,
367    _client: reqwest::Client,
368}
369
370#[cfg(feature = "contextlite")]
371impl ContextLiteStore {
372    pub fn new(endpoint: String, agent_id: String) -> Self {
373        Self {
374            _endpoint: endpoint,
375            _agent_id: agent_id,
376            _client: reqwest::Client::new(),
377        }
378    }
379}
380
381#[cfg(feature = "contextlite")]
382impl MemoryStore for ContextLiteStore {
383    fn store(&mut self, key: &str, value: &str) -> Result<()> {
384        // Synchronous implementation using tokio::task::block_in_place for async HTTP calls
385        use tracing::{debug, error};
386        
387        debug!("Storing key '{}' in ContextLite (agent: {})", key, self._agent_id);
388        
389        let endpoint = self._endpoint.clone();
390        let agent_id = self._agent_id.clone();
391        let client = self._client.clone();
392        let key_owned = key.to_string();
393        let value_owned = value.to_string();
394        
395        // Use block_in_place to run async code in sync context
396        let result = tokio::task::block_in_place(|| {
397            tokio::runtime::Handle::current().block_on(async {
398                let url = format!("{}/api/v1/agents/{}/memory", endpoint, agent_id);
399                
400                let payload = serde_json::json!({
401                    "key": key_owned,
402                    "value": value_owned,
403                    "metadata": {
404                        "timestamp": std::time::SystemTime::now()
405                            .duration_since(std::time::UNIX_EPOCH)
406                            .unwrap_or_default()
407                            .as_secs(),
408                        "source": "rustchain"
409                    }
410                });
411                
412                let response = client
413                    .post(&url)
414                    .header("Content-Type", "application/json")
415                    .json(&payload)
416                    .timeout(std::time::Duration::from_millis(5000))
417                    .send()
418                    .await;
419                
420                match response {
421                    Ok(resp) => {
422                        if resp.status().is_success() {
423                            debug!("Successfully stored key '{}' in ContextLite", key_owned);
424                            Ok(())
425                        } else {
426                            let status = resp.status();
427                            let error_text = resp.text().await.unwrap_or_else(|_| "Unknown error".to_string());
428                            error!("ContextLite store failed with status {}: {}", status, error_text);
429                            Err(crate::core::error::RustChainError::Memory(
430                                crate::core::error::MemoryError::InvalidOperation {
431                                    operation: format!("store key '{}'", key_owned),
432                                    store_type: format!("ContextLite (status: {}, error: {})", status, error_text),
433                                }
434                            ))
435                        }
436                    }
437                    Err(e) => {
438                        error!("HTTP request to ContextLite failed: {}", e);
439                        Err(crate::core::error::RustChainError::Memory(
440                            crate::core::error::MemoryError::InvalidOperation {
441                                operation: "HTTP request to ContextLite".to_string(),
442                                store_type: format!("ContextLite (error: {})", e),
443                            }
444                        ))
445                    }
446                }
447            })
448        });
449        
450        result
451    }
452    
453    fn retrieve(&self, key: &str) -> Result<Option<String>> {
454        // Synchronous implementation using tokio::task::block_in_place for async HTTP calls
455        use tracing::{debug, error, warn};
456        
457        debug!("Retrieving key '{}' from ContextLite (agent: {})", key, self._agent_id);
458        
459        let endpoint = self._endpoint.clone();
460        let agent_id = self._agent_id.clone();
461        let client = self._client.clone();
462        let key_owned = key.to_string();
463        
464        // Use block_in_place to run async code in sync context
465        let result = tokio::task::block_in_place(|| {
466            tokio::runtime::Handle::current().block_on(async {
467                let url = format!("{}/api/v1/agents/{}/memory/{}", endpoint, agent_id, 
468                    urlencoding::encode(&key_owned));
469                
470                let response = client
471                    .get(&url)
472                    .header("Accept", "application/json")
473                    .timeout(std::time::Duration::from_millis(5000))
474                    .send()
475                    .await;
476                
477                match response {
478                    Ok(resp) => {
479                        let status = resp.status();
480                        if status.is_success() {
481                            let response_text = resp.text().await.unwrap_or_default();
482                            
483                            // Try to parse JSON response
484                            if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response_text) {
485                                if let Some(value) = json_value.get("value") {
486                                    if let Some(value_str) = value.as_str() {
487                                        debug!("Successfully retrieved key '{}' from ContextLite", key_owned);
488                                        return Ok(Some(value_str.to_string()));
489                                    }
490                                }
491                                // If no "value" field, return the whole response as string
492                                Ok(Some(response_text))
493                            } else {
494                                // If not JSON, return raw text
495                                Ok(Some(response_text))
496                            }
497                        } else if status == reqwest::StatusCode::NOT_FOUND {
498                            debug!("Key '{}' not found in ContextLite", key_owned);
499                            Ok(None)
500                        } else {
501                            let error_text = resp.text().await.unwrap_or_else(|_| "Unknown error".to_string());
502                            warn!("ContextLite retrieve failed with status {}: {}", status, error_text);
503                            Err(crate::core::error::RustChainError::Memory(
504                                crate::core::error::MemoryError::InvalidOperation {
505                                    operation: format!("retrieve key '{}'", key_owned),
506                                    store_type: format!("ContextLite (status: {}, error: {})", status, error_text),
507                                }
508                            ))
509                        }
510                    }
511                    Err(e) => {
512                        error!("HTTP request to ContextLite failed: {}", e);
513                        // Return None for connectivity issues to allow graceful degradation
514                        warn!("ContextLite connectivity issue, returning None: {}", e);
515                        Ok(None)
516                    }
517                }
518            })
519        });
520        
521        result
522    }
523    
524    fn list_keys(&self) -> Result<Vec<String>> {
525        // Synchronous implementation using tokio::task::block_in_place for async HTTP calls
526        use tracing::{debug, error, warn};
527        
528        debug!("Listing keys from ContextLite (agent: {})", self._agent_id);
529        
530        let endpoint = self._endpoint.clone();
531        let agent_id = self._agent_id.clone();
532        let client = self._client.clone();
533        
534        // Use block_in_place to run async code in sync context
535        let result = tokio::task::block_in_place(|| {
536            tokio::runtime::Handle::current().block_on(async {
537                let url = format!("{}/api/v1/agents/{}/memory", endpoint, agent_id);
538                
539                let response = client
540                    .get(&url)
541                    .header("Accept", "application/json")
542                    .timeout(std::time::Duration::from_millis(10000)) // Longer timeout for list operations
543                    .send()
544                    .await;
545                
546                match response {
547                    Ok(resp) => {
548                        let status = resp.status();
549                        if status.is_success() {
550                            let response_text = resp.text().await.unwrap_or_default();
551                            
552                            // Try to parse JSON response containing list of keys
553                            if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response_text) {
554                                let mut keys = Vec::new();
555                                
556                                // Handle different possible response formats
557                                if let Some(keys_array) = json_value.get("keys") {
558                                    if let Some(array) = keys_array.as_array() {
559                                        for item in array {
560                                            if let Some(key_str) = item.as_str() {
561                                                keys.push(key_str.to_string());
562                                            }
563                                        }
564                                    }
565                                } else if let Some(data_array) = json_value.get("data") {
566                                    if let Some(array) = data_array.as_array() {
567                                        for item in array {
568                                            if let Some(key) = item.get("key") {
569                                                if let Some(key_str) = key.as_str() {
570                                                    keys.push(key_str.to_string());
571                                                }
572                                            }
573                                        }
574                                    }
575                                } else if let Some(array) = json_value.as_array() {
576                                    // Direct array of keys or objects
577                                    for item in array {
578                                        if let Some(key_str) = item.as_str() {
579                                            keys.push(key_str.to_string());
580                                        } else if let Some(key) = item.get("key") {
581                                            if let Some(key_str) = key.as_str() {
582                                                keys.push(key_str.to_string());
583                                            }
584                                        }
585                                    }
586                                }
587                                
588                                debug!("Successfully listed {} keys from ContextLite", keys.len());
589                                Ok(keys)
590                            } else {
591                                warn!("ContextLite list_keys returned non-JSON response");
592                                Ok(Vec::new())
593                            }
594                        } else if status == reqwest::StatusCode::NOT_FOUND {
595                            debug!("Agent '{}' not found in ContextLite, returning empty list", agent_id);
596                            Ok(Vec::new())
597                        } else {
598                            let error_text = resp.text().await.unwrap_or_else(|_| "Unknown error".to_string());
599                            warn!("ContextLite list_keys failed with status {}: {}", status, error_text);
600                            Err(crate::core::error::RustChainError::Memory(
601                                crate::core::error::MemoryError::InvalidOperation {
602                                    operation: "list_keys".to_string(),
603                                    store_type: format!("ContextLite (status: {}, error: {})", status, error_text),
604                                }
605                            ))
606                        }
607                    }
608                    Err(e) => {
609                        error!("HTTP request to ContextLite failed: {}", e);
610                        // Return empty list for connectivity issues to allow graceful degradation
611                        warn!("ContextLite connectivity issue, returning empty list: {}", e);
612                        Ok(Vec::new())
613                    }
614                }
615            })
616        });
617        
618        result
619    }
620}
621
622// Include the tests module
623#[cfg(test)]
624mod tests;
625