1use crate::core::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, VecDeque};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6pub trait MemoryStore: Send + Sync {
7 fn store(&mut self, key: &str, value: &str) -> Result<()>;
8 fn retrieve(&self, key: &str) -> Result<Option<String>>;
9 fn list_keys(&self) -> Result<Vec<String>>;
10}
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14struct MemoryEntry {
15 value: String,
16 created_at: u64,
17 expires_at: Option<u64>,
18}
19
20impl MemoryEntry {
21 fn new(value: String, ttl_seconds: Option<u64>) -> Self {
22 let now_duration = SystemTime::now()
23 .duration_since(UNIX_EPOCH)
24 .unwrap();
25 let now = now_duration.as_nanos() as u64;
26
27 Self {
28 value,
29 created_at: now,
30 expires_at: ttl_seconds.map(|ttl| now + (ttl * 1_000_000_000)), }
32 }
33
34 fn is_expired(&self) -> bool {
35 if let Some(expires_at) = self.expires_at {
36 let now = SystemTime::now()
37 .duration_since(UNIX_EPOCH)
38 .unwrap()
39 .as_nanos() as u64;
40 now > expires_at
41 } else {
42 false
43 }
44 }
45}
46
47pub struct InMemoryStore {
49 data: HashMap<String, MemoryEntry>,
50 default_ttl: Option<u64>,
51 max_entries: Option<usize>,
52}
53
54impl InMemoryStore {
55 pub fn new() -> Self {
57 Self {
58 data: HashMap::new(),
59 default_ttl: None,
60 max_entries: None,
61 }
62 }
63
64 pub fn with_ttl(ttl_seconds: u64) -> Self {
66 Self {
67 data: HashMap::new(),
68 default_ttl: Some(ttl_seconds),
69 max_entries: None,
70 }
71 }
72
73 pub fn with_capacity(max_entries: usize) -> Self {
75 Self {
76 data: HashMap::new(),
77 default_ttl: None,
78 max_entries: Some(max_entries),
79 }
80 }
81
82 pub fn with_ttl_and_capacity(ttl_seconds: u64, max_entries: usize) -> Self {
84 Self {
85 data: HashMap::new(),
86 default_ttl: Some(ttl_seconds),
87 max_entries: Some(max_entries),
88 }
89 }
90
91 pub fn cleanup(&mut self) -> Result<()> {
93 let expired_keys: Vec<String> = self
94 .data
95 .iter()
96 .filter(|(_, entry)| entry.is_expired())
97 .map(|(key, _)| key.clone())
98 .collect();
99
100 for key in expired_keys {
101 self.data.remove(&key);
102 }
103
104 Ok(())
105 }
106
107 pub fn clear(&mut self) -> Result<()> {
109 self.data.clear();
110 Ok(())
111 }
112
113 pub fn summarize(&self) -> Result<String> {
115 let total_entries = self.data.len();
116 let expired_entries = self
117 .data
118 .values()
119 .filter(|entry| entry.is_expired())
120 .count();
121 let active_entries = total_entries - expired_entries;
122
123 let total_size: usize = self.data.values().map(|entry| entry.value.len()).sum();
124
125 Ok(format!(
126 "Memory Store Summary: {} entries ({} active, {} expired), {} bytes total",
127 total_entries, active_entries, expired_entries, total_size
128 ))
129 }
130
131 pub fn contains_key(&self, key: &str) -> bool {
133 if let Some(entry) = self.data.get(key) {
134 !entry.is_expired()
135 } else {
136 false
137 }
138 }
139
140 pub fn stats(&self) -> MemoryStats {
142 let total_entries = self.data.len();
143 let expired_entries = self
144 .data
145 .values()
146 .filter(|entry| entry.is_expired())
147 .count();
148 let total_size: usize = self.data.values().map(|entry| entry.value.len()).sum();
149
150 MemoryStats {
151 total_entries,
152 active_entries: total_entries - expired_entries,
153 expired_entries,
154 total_size_bytes: total_size,
155 max_entries: self.max_entries,
156 default_ttl: self.default_ttl,
157 }
158 }
159
160 fn ensure_capacity(&mut self) -> Result<()> {
161 if let Some(max_entries) = self.max_entries {
162 self.cleanup()?;
164
165 while self.data.len() >= max_entries {
167 if let Some(oldest_key) = self
168 .data
169 .iter()
170 .min_by_key(|(_, entry)| entry.created_at)
171 .map(|(key, _)| key.clone())
172 {
173 self.data.remove(&oldest_key);
174 } else {
175 break;
176 }
177 }
178 }
179 Ok(())
180 }
181}
182
183impl MemoryStore for InMemoryStore {
184 fn store(&mut self, key: &str, value: &str) -> Result<()> {
185 self.ensure_capacity()?;
187
188 let entry = MemoryEntry::new(value.to_string(), self.default_ttl);
189 self.data.insert(key.to_string(), entry);
190 Ok(())
191 }
192
193 fn retrieve(&self, key: &str) -> Result<Option<String>> {
194 if let Some(entry) = self.data.get(key) {
195 if entry.is_expired() {
196 Ok(None)
197 } else {
198 Ok(Some(entry.value.clone()))
199 }
200 } else {
201 Ok(None)
202 }
203 }
204
205 fn list_keys(&self) -> Result<Vec<String>> {
206 Ok(self
207 .data
208 .iter()
209 .filter(|(_, entry)| !entry.is_expired())
210 .map(|(key, _)| key.clone())
211 .collect())
212 }
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct MemoryStats {
218 pub total_entries: usize,
219 pub active_entries: usize,
220 pub expired_entries: usize,
221 pub total_size_bytes: usize,
222 pub max_entries: Option<usize>,
223 pub default_ttl: Option<u64>,
224}
225
226#[derive(Debug, Clone)]
228pub struct ConversationMemory {
229 messages: VecDeque<ConversationMessage>,
230 max_messages: usize,
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct ConversationMessage {
235 pub role: String,
236 pub content: String,
237 pub timestamp: u64,
238}
239
240impl ConversationMessage {
241 fn new(role: &str, content: &str) -> Self {
242 Self {
243 role: role.to_string(),
244 content: content.to_string(),
245 timestamp: SystemTime::now()
246 .duration_since(UNIX_EPOCH)
247 .unwrap()
248 .as_secs(),
249 }
250 }
251}
252
253impl ConversationMemory {
254 pub fn new(max_messages: usize) -> Self {
256 Self {
257 messages: VecDeque::new(),
258 max_messages,
259 }
260 }
261
262 pub fn add_message(&mut self, role: &str, content: &str) -> Result<()> {
264 if self.max_messages == 0 {
266 return Ok(());
267 }
268
269 let message = ConversationMessage::new(role, content);
270
271 if self.messages.len() >= self.max_messages {
273 self.messages.pop_front();
274 }
275
276 self.messages.push_back(message);
277 Ok(())
278 }
279
280 pub fn get_conversation(&self) -> Result<Vec<String>> {
282 Ok(self
283 .messages
284 .iter()
285 .map(|msg| format!("{}: {}", msg.role, msg.content))
286 .collect())
287 }
288
289 pub fn get_recent(&self, count: usize) -> Result<Vec<String>> {
291 Ok(self
292 .messages
293 .iter()
294 .rev()
295 .take(count)
296 .rev()
297 .map(|msg| format!("{}: {}", msg.role, msg.content))
298 .collect())
299 }
300
301 pub fn search(&self, term: &str) -> Result<Vec<String>> {
303 let term_lower = term.to_lowercase();
304 Ok(self
305 .messages
306 .iter()
307 .filter(|msg| {
308 msg.content.to_lowercase().contains(&term_lower)
309 || msg.role.to_lowercase().contains(&term_lower)
310 })
311 .map(|msg| format!("{}: {}", msg.role, msg.content))
312 .collect())
313 }
314
315 pub fn clear(&mut self) -> Result<()> {
317 self.messages.clear();
318 Ok(())
319 }
320
321 pub fn summarize(&self) -> Result<String> {
323 let total_messages = self.messages.len();
324 let roles: std::collections::HashSet<String> =
325 self.messages.iter().map(|msg| msg.role.clone()).collect();
326
327 Ok(format!(
328 "Conversation summary: {} messages from {} participants",
329 total_messages,
330 roles.len()
331 ))
332 }
333
334 pub fn stats(&self) -> ConversationStats {
336 let mut role_counts: HashMap<String, usize> = HashMap::new();
337 let mut total_chars = 0;
338
339 for msg in &self.messages {
340 *role_counts.entry(msg.role.clone()).or_insert(0) += 1;
341 total_chars += msg.content.len();
342 }
343
344 ConversationStats {
345 total_messages: self.messages.len(),
346 role_counts,
347 total_characters: total_chars,
348 max_capacity: self.max_messages,
349 }
350 }
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct ConversationStats {
356 pub total_messages: usize,
357 pub role_counts: HashMap<String, usize>,
358 pub total_characters: usize,
359 pub max_capacity: usize,
360}
361
362#[cfg(feature = "contextlite")]
364pub struct ContextLiteStore {
365 _endpoint: String,
366 _agent_id: String,
367 _client: reqwest::Client,
368}
369
370#[cfg(feature = "contextlite")]
371impl ContextLiteStore {
372 pub fn new(endpoint: String, agent_id: String) -> Self {
373 Self {
374 _endpoint: endpoint,
375 _agent_id: agent_id,
376 _client: reqwest::Client::new(),
377 }
378 }
379}
380
381#[cfg(feature = "contextlite")]
382impl MemoryStore for ContextLiteStore {
383 fn store(&mut self, key: &str, value: &str) -> Result<()> {
384 use tracing::{debug, error};
386
387 debug!("Storing key '{}' in ContextLite (agent: {})", key, self._agent_id);
388
389 let endpoint = self._endpoint.clone();
390 let agent_id = self._agent_id.clone();
391 let client = self._client.clone();
392 let key_owned = key.to_string();
393 let value_owned = value.to_string();
394
395 let result = tokio::task::block_in_place(|| {
397 tokio::runtime::Handle::current().block_on(async {
398 let url = format!("{}/api/v1/agents/{}/memory", endpoint, agent_id);
399
400 let payload = serde_json::json!({
401 "key": key_owned,
402 "value": value_owned,
403 "metadata": {
404 "timestamp": std::time::SystemTime::now()
405 .duration_since(std::time::UNIX_EPOCH)
406 .unwrap_or_default()
407 .as_secs(),
408 "source": "rustchain"
409 }
410 });
411
412 let response = client
413 .post(&url)
414 .header("Content-Type", "application/json")
415 .json(&payload)
416 .timeout(std::time::Duration::from_millis(5000))
417 .send()
418 .await;
419
420 match response {
421 Ok(resp) => {
422 if resp.status().is_success() {
423 debug!("Successfully stored key '{}' in ContextLite", key_owned);
424 Ok(())
425 } else {
426 let status = resp.status();
427 let error_text = resp.text().await.unwrap_or_else(|_| "Unknown error".to_string());
428 error!("ContextLite store failed with status {}: {}", status, error_text);
429 Err(crate::core::error::RustChainError::Memory(
430 crate::core::error::MemoryError::InvalidOperation {
431 operation: format!("store key '{}'", key_owned),
432 store_type: format!("ContextLite (status: {}, error: {})", status, error_text),
433 }
434 ))
435 }
436 }
437 Err(e) => {
438 error!("HTTP request to ContextLite failed: {}", e);
439 Err(crate::core::error::RustChainError::Memory(
440 crate::core::error::MemoryError::InvalidOperation {
441 operation: "HTTP request to ContextLite".to_string(),
442 store_type: format!("ContextLite (error: {})", e),
443 }
444 ))
445 }
446 }
447 })
448 });
449
450 result
451 }
452
453 fn retrieve(&self, key: &str) -> Result<Option<String>> {
454 use tracing::{debug, error, warn};
456
457 debug!("Retrieving key '{}' from ContextLite (agent: {})", key, self._agent_id);
458
459 let endpoint = self._endpoint.clone();
460 let agent_id = self._agent_id.clone();
461 let client = self._client.clone();
462 let key_owned = key.to_string();
463
464 let result = tokio::task::block_in_place(|| {
466 tokio::runtime::Handle::current().block_on(async {
467 let url = format!("{}/api/v1/agents/{}/memory/{}", endpoint, agent_id,
468 urlencoding::encode(&key_owned));
469
470 let response = client
471 .get(&url)
472 .header("Accept", "application/json")
473 .timeout(std::time::Duration::from_millis(5000))
474 .send()
475 .await;
476
477 match response {
478 Ok(resp) => {
479 let status = resp.status();
480 if status.is_success() {
481 let response_text = resp.text().await.unwrap_or_default();
482
483 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response_text) {
485 if let Some(value) = json_value.get("value") {
486 if let Some(value_str) = value.as_str() {
487 debug!("Successfully retrieved key '{}' from ContextLite", key_owned);
488 return Ok(Some(value_str.to_string()));
489 }
490 }
491 Ok(Some(response_text))
493 } else {
494 Ok(Some(response_text))
496 }
497 } else if status == reqwest::StatusCode::NOT_FOUND {
498 debug!("Key '{}' not found in ContextLite", key_owned);
499 Ok(None)
500 } else {
501 let error_text = resp.text().await.unwrap_or_else(|_| "Unknown error".to_string());
502 warn!("ContextLite retrieve failed with status {}: {}", status, error_text);
503 Err(crate::core::error::RustChainError::Memory(
504 crate::core::error::MemoryError::InvalidOperation {
505 operation: format!("retrieve key '{}'", key_owned),
506 store_type: format!("ContextLite (status: {}, error: {})", status, error_text),
507 }
508 ))
509 }
510 }
511 Err(e) => {
512 error!("HTTP request to ContextLite failed: {}", e);
513 warn!("ContextLite connectivity issue, returning None: {}", e);
515 Ok(None)
516 }
517 }
518 })
519 });
520
521 result
522 }
523
524 fn list_keys(&self) -> Result<Vec<String>> {
525 use tracing::{debug, error, warn};
527
528 debug!("Listing keys from ContextLite (agent: {})", self._agent_id);
529
530 let endpoint = self._endpoint.clone();
531 let agent_id = self._agent_id.clone();
532 let client = self._client.clone();
533
534 let result = tokio::task::block_in_place(|| {
536 tokio::runtime::Handle::current().block_on(async {
537 let url = format!("{}/api/v1/agents/{}/memory", endpoint, agent_id);
538
539 let response = client
540 .get(&url)
541 .header("Accept", "application/json")
542 .timeout(std::time::Duration::from_millis(10000)) .send()
544 .await;
545
546 match response {
547 Ok(resp) => {
548 let status = resp.status();
549 if status.is_success() {
550 let response_text = resp.text().await.unwrap_or_default();
551
552 if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&response_text) {
554 let mut keys = Vec::new();
555
556 if let Some(keys_array) = json_value.get("keys") {
558 if let Some(array) = keys_array.as_array() {
559 for item in array {
560 if let Some(key_str) = item.as_str() {
561 keys.push(key_str.to_string());
562 }
563 }
564 }
565 } else if let Some(data_array) = json_value.get("data") {
566 if let Some(array) = data_array.as_array() {
567 for item in array {
568 if let Some(key) = item.get("key") {
569 if let Some(key_str) = key.as_str() {
570 keys.push(key_str.to_string());
571 }
572 }
573 }
574 }
575 } else if let Some(array) = json_value.as_array() {
576 for item in array {
578 if let Some(key_str) = item.as_str() {
579 keys.push(key_str.to_string());
580 } else if let Some(key) = item.get("key") {
581 if let Some(key_str) = key.as_str() {
582 keys.push(key_str.to_string());
583 }
584 }
585 }
586 }
587
588 debug!("Successfully listed {} keys from ContextLite", keys.len());
589 Ok(keys)
590 } else {
591 warn!("ContextLite list_keys returned non-JSON response");
592 Ok(Vec::new())
593 }
594 } else if status == reqwest::StatusCode::NOT_FOUND {
595 debug!("Agent '{}' not found in ContextLite, returning empty list", agent_id);
596 Ok(Vec::new())
597 } else {
598 let error_text = resp.text().await.unwrap_or_else(|_| "Unknown error".to_string());
599 warn!("ContextLite list_keys failed with status {}: {}", status, error_text);
600 Err(crate::core::error::RustChainError::Memory(
601 crate::core::error::MemoryError::InvalidOperation {
602 operation: "list_keys".to_string(),
603 store_type: format!("ContextLite (status: {}, error: {})", status, error_text),
604 }
605 ))
606 }
607 }
608 Err(e) => {
609 error!("HTTP request to ContextLite failed: {}", e);
610 warn!("ContextLite connectivity issue, returning empty list: {}", e);
612 Ok(Vec::new())
613 }
614 }
615 })
616 });
617
618 result
619 }
620}
621
622#[cfg(test)]
624mod tests;
625