1use dashmap::DashMap;
15use parking_lot::RwLock;
16use serde::{Deserialize, Serialize};
17use std::sync::Arc;
18
19const MAX_ACTION_HISTORY: usize = 50;
21const MAX_URL_HISTORY: usize = 100;
23const MAX_EXTRACTIONS: usize = 50;
25
26#[derive(Debug, Clone, Default)]
53pub struct AgentMemory {
54 data: Arc<DashMap<String, serde_json::Value>>,
56 visited_urls: Arc<RwLock<Vec<String>>>,
58 action_history: Arc<RwLock<Vec<String>>>,
60 extractions: Arc<RwLock<Vec<serde_json::Value>>>,
62}
63
64impl AgentMemory {
65 pub fn new() -> Self {
67 Self {
68 data: Arc::new(DashMap::new()),
69 visited_urls: Arc::new(RwLock::new(Vec::new())),
70 action_history: Arc::new(RwLock::new(Vec::new())),
71 extractions: Arc::new(RwLock::new(Vec::new())),
72 }
73 }
74
75 pub fn with_capacity(capacity: usize) -> Self {
77 Self {
78 data: Arc::new(DashMap::with_capacity(capacity)),
79 visited_urls: Arc::new(RwLock::new(Vec::with_capacity(MAX_URL_HISTORY))),
80 action_history: Arc::new(RwLock::new(Vec::with_capacity(MAX_ACTION_HISTORY))),
81 extractions: Arc::new(RwLock::new(Vec::with_capacity(MAX_EXTRACTIONS))),
82 }
83 }
84
85 pub fn get(&self, key: &str) -> Option<serde_json::Value> {
91 self.data.get(key).map(|v| v.value().clone())
92 }
93
94 pub fn set(&self, key: impl Into<String>, value: serde_json::Value) {
96 self.data.insert(key.into(), value);
97 }
98
99 pub fn remove(&self, key: &str) -> Option<serde_json::Value> {
101 self.data.remove(key).map(|(_, v)| v)
102 }
103
104 pub fn clear(&self) {
106 self.data.clear();
107 }
108
109 pub fn contains(&self, key: &str) -> bool {
111 self.data.contains_key(key)
112 }
113
114 pub fn len(&self) -> usize {
116 self.data.len()
117 }
118
119 pub fn is_empty(&self) -> bool {
121 self.data.is_empty()
122 }
123
124 pub fn get_as<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
126 self.data
127 .get(key)
128 .and_then(|v| serde_json::from_value(v.value().clone()).ok())
129 }
130
131 pub fn set_value<T: Serialize>(&self, key: impl Into<String>, value: &T) {
133 if let Ok(json) = serde_json::to_value(value) {
134 self.data.insert(key.into(), json);
135 }
136 }
137
138 pub fn update<F>(&self, key: impl Into<String>, f: F)
142 where
143 F: FnOnce(Option<&serde_json::Value>) -> serde_json::Value,
144 {
145 let key = key.into();
146 let new_value = f(self.data.get(&key).as_deref());
147 self.data.insert(key, new_value);
148 }
149
150 pub fn get_or_insert(
152 &self,
153 key: impl Into<String>,
154 default: serde_json::Value,
155 ) -> serde_json::Value {
156 self.data
157 .entry(key.into())
158 .or_insert(default)
159 .value()
160 .clone()
161 }
162
163 pub fn add_visited_url(&self, url: impl Into<String>) {
169 let mut urls = self.visited_urls.write();
170 urls.push(url.into());
171 if urls.len() > MAX_URL_HISTORY {
172 urls.remove(0);
173 }
174 }
175
176 pub fn visited_urls(&self) -> Vec<String> {
178 self.visited_urls.read().clone()
179 }
180
181 pub fn recent_urls(&self, n: usize) -> Vec<String> {
183 let urls = self.visited_urls.read();
184 urls.iter().rev().take(n).cloned().collect()
185 }
186
187 pub fn has_visited(&self, url: &str) -> bool {
189 self.visited_urls.read().iter().any(|u| u == url)
190 }
191
192 pub fn clear_urls(&self) {
194 self.visited_urls.write().clear();
195 }
196
197 pub fn add_action(&self, action: impl Into<String>) {
203 let mut actions = self.action_history.write();
204 actions.push(action.into());
205 if actions.len() > MAX_ACTION_HISTORY {
206 actions.remove(0);
207 }
208 }
209
210 pub fn action_history(&self) -> Vec<String> {
212 self.action_history.read().clone()
213 }
214
215 pub fn recent_actions(&self, n: usize) -> Vec<String> {
217 let actions = self.action_history.read();
218 actions.iter().rev().take(n).cloned().collect()
219 }
220
221 pub fn clear_actions(&self) {
223 self.action_history.write().clear();
224 }
225
226 pub fn add_extraction(&self, data: serde_json::Value) {
232 let mut extractions = self.extractions.write();
233 extractions.push(data);
234 if extractions.len() > MAX_EXTRACTIONS {
235 extractions.remove(0);
236 }
237 }
238
239 pub fn extractions(&self) -> Vec<serde_json::Value> {
241 self.extractions.read().clone()
242 }
243
244 pub fn recent_extractions(&self, n: usize) -> Vec<serde_json::Value> {
246 let extractions = self.extractions.read();
247 extractions.iter().rev().take(n).cloned().collect()
248 }
249
250 pub fn clear_extractions(&self) {
252 self.extractions.write().clear();
253 }
254
255 pub fn clear_history(&self) {
259 self.visited_urls.write().clear();
260 self.action_history.write().clear();
261 self.extractions.write().clear();
262 }
263
264 pub fn clear_all(&self) {
266 self.data.clear();
267 self.visited_urls.write().clear();
268 self.action_history.write().clear();
269 self.extractions.write().clear();
270 }
271
272 pub fn is_all_empty(&self) -> bool {
274 self.data.is_empty()
275 && self.visited_urls.read().is_empty()
276 && self.action_history.read().is_empty()
277 && self.extractions.read().is_empty()
278 }
279
280 pub fn to_context_string(&self) -> String {
290 if self.is_all_empty() {
291 return String::new();
292 }
293
294 let mut parts = Vec::new();
295
296 if !self.data.is_empty() {
298 let store: std::collections::HashMap<_, _> = self
299 .data
300 .iter()
301 .map(|r| (r.key().clone(), r.value().clone()))
302 .collect();
303 if let Ok(json) = serde_json::to_string_pretty(&store) {
304 parts.push(format!("## Memory Store\n```json\n{}\n```", json));
305 }
306 }
307
308 let urls = self.visited_urls.read();
310 if !urls.is_empty() {
311 let recent: Vec<_> = urls.iter().rev().take(10).collect();
312 let url_list: String = recent
313 .iter()
314 .rev()
315 .enumerate()
316 .map(|(i, u)| format!("{}. {}", i + 1, u))
317 .collect::<Vec<_>>()
318 .join("\n");
319 parts.push(format!(
320 "## Recent URLs (last {})\n{}",
321 recent.len(),
322 url_list
323 ));
324 }
325 drop(urls);
326
327 let extractions = self.extractions.read();
329 if !extractions.is_empty() {
330 let recent: Vec<_> = extractions.iter().rev().take(5).collect();
331 let json_strs: Vec<_> = recent
332 .iter()
333 .rev()
334 .filter_map(|v| serde_json::to_string(v).ok())
335 .collect();
336 parts.push(format!(
337 "## Recent Extractions (last {})\n{}",
338 json_strs.len(),
339 json_strs.join("\n")
340 ));
341 }
342 drop(extractions);
343
344 let actions = self.action_history.read();
346 if !actions.is_empty() {
347 let recent: Vec<_> = actions.iter().rev().take(10).collect();
348 let action_list: String = recent
349 .iter()
350 .rev()
351 .enumerate()
352 .map(|(i, a)| format!("{}. {}", i + 1, a))
353 .collect::<Vec<_>>()
354 .join("\n");
355 parts.push(format!(
356 "## Recent Actions (last {})\n{}",
357 recent.len(),
358 action_list
359 ));
360 }
361 drop(actions);
362
363 parts.join("\n\n")
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn test_memory_basic() {
373 let memory = AgentMemory::new();
374
375 memory.set("key1", serde_json::json!("value1"));
376 memory.set("key2", serde_json::json!(42));
377
378 assert_eq!(memory.get("key1"), Some(serde_json::json!("value1")));
379 assert_eq!(memory.get("key2"), Some(serde_json::json!(42)));
380 assert_eq!(memory.get("key3"), None);
381 assert_eq!(memory.len(), 2);
382 }
383
384 #[test]
385 fn test_memory_typed() {
386 let memory = AgentMemory::new();
387
388 memory.set_value("name", &"Alice".to_string());
389 memory.set_value("age", &30u32);
390
391 assert_eq!(memory.get_as::<String>("name"), Some("Alice".to_string()));
392 assert_eq!(memory.get_as::<u32>("age"), Some(30));
393 }
394
395 #[test]
396 fn test_memory_clear() {
397 let memory = AgentMemory::new();
398
399 memory.set("key1", serde_json::json!("value1"));
400 memory.set("key2", serde_json::json!("value2"));
401
402 assert_eq!(memory.len(), 2);
403
404 memory.clear();
405
406 assert!(memory.is_empty());
407 }
408
409 #[test]
410 fn test_memory_update() {
411 let memory = AgentMemory::new();
412
413 memory.set("counter", serde_json::json!(0));
414
415 memory.update("counter", |v| {
416 let current = v.and_then(|v| v.as_i64()).unwrap_or(0);
417 serde_json::json!(current + 1)
418 });
419
420 assert_eq!(memory.get("counter"), Some(serde_json::json!(1)));
421 }
422
423 #[test]
424 fn test_memory_get_or_insert() {
425 let memory = AgentMemory::new();
426
427 let value = memory.get_or_insert("key", serde_json::json!("default"));
428 assert_eq!(value, serde_json::json!("default"));
429
430 memory.set("key", serde_json::json!("updated"));
432 let value = memory.get_or_insert("key", serde_json::json!("other"));
433 assert_eq!(value, serde_json::json!("updated"));
434 }
435
436 #[test]
437 fn test_memory_concurrent_clone() {
438 let memory = AgentMemory::new();
439 let memory2 = memory.clone();
440
441 memory.set("key", serde_json::json!("value"));
442
443 assert_eq!(memory2.get("key"), Some(serde_json::json!("value")));
445 }
446
447 #[test]
448 fn test_memory_url_history() {
449 let memory = AgentMemory::new();
450
451 memory.add_visited_url("https://example.com");
452 memory.add_visited_url("https://example.com/page1");
453 memory.add_visited_url("https://example.com/page2");
454
455 assert!(memory.has_visited("https://example.com"));
456 assert!(!memory.has_visited("https://other.com"));
457
458 let recent = memory.recent_urls(2);
459 assert_eq!(recent.len(), 2);
460 assert_eq!(recent[0], "https://example.com/page2");
461 assert_eq!(recent[1], "https://example.com/page1");
462
463 let all = memory.visited_urls();
464 assert_eq!(all.len(), 3);
465 }
466
467 #[test]
468 fn test_memory_action_history() {
469 let memory = AgentMemory::new();
470
471 memory.add_action("Searched for 'rust'");
472 memory.add_action("Clicked search button");
473 memory.add_action("Extracted results");
474
475 let recent = memory.recent_actions(2);
476 assert_eq!(recent.len(), 2);
477 assert_eq!(recent[0], "Extracted results");
478 assert_eq!(recent[1], "Clicked search button");
479 }
480
481 #[test]
482 fn test_memory_extractions() {
483 let memory = AgentMemory::new();
484
485 memory.add_extraction(serde_json::json!({"title": "Page 1"}));
486 memory.add_extraction(serde_json::json!({"title": "Page 2"}));
487
488 let extractions = memory.extractions();
489 assert_eq!(extractions.len(), 2);
490
491 let recent = memory.recent_extractions(1);
492 assert_eq!(recent[0]["title"], "Page 2");
493 }
494
495 #[test]
496 fn test_memory_clear_all() {
497 let memory = AgentMemory::new();
498
499 memory.set("key", serde_json::json!("value"));
500 memory.add_visited_url("https://example.com");
501 memory.add_action("Test action");
502 memory.add_extraction(serde_json::json!({"data": "test"}));
503
504 assert!(!memory.is_all_empty());
505
506 memory.clear_all();
507
508 assert!(memory.is_all_empty());
509 }
510
511 #[test]
512 fn test_memory_context_string() {
513 let memory = AgentMemory::new();
514
515 memory.set("user_id", serde_json::json!("123"));
516 memory.add_visited_url("https://example.com");
517 memory.add_action("Logged in");
518
519 let context = memory.to_context_string();
520
521 assert!(context.contains("Memory Store"));
522 assert!(context.contains("user_id"));
523 assert!(context.contains("Recent URLs"));
524 assert!(context.contains("example.com"));
525 assert!(context.contains("Recent Actions"));
526 assert!(context.contains("Logged in"));
527 }
528}