Skip to main content

spider_agent/
memory.rs

1//! Session memory for spider_agent.
2//!
3//! Uses DashMap for lock-free concurrent access.
4//!
5//! # Features
6//! - **Key-Value Store**: Lock-free concurrent storage for arbitrary JSON values
7//! - **URL History**: Track visited URLs for navigation context
8//! - **Action History**: Record actions taken for debugging and context
9//! - **Extraction History**: Accumulate extracted data across pages
10//!
11//! Compatible with spider's AutomationMemory patterns while using
12//! DashMap for optimal concurrent performance.
13
14use dashmap::DashMap;
15use parking_lot::RwLock;
16use serde::{Deserialize, Serialize};
17use std::sync::Arc;
18
19/// Maximum number of actions to keep in history.
20const MAX_ACTION_HISTORY: usize = 50;
21/// Maximum number of URLs to keep in history.
22const MAX_URL_HISTORY: usize = 100;
23/// Maximum number of extractions to keep.
24const MAX_EXTRACTIONS: usize = 50;
25
26/// Session memory for storing state across operations.
27///
28/// Uses DashMap internally for lock-free concurrent reads and writes.
29/// This is optimal for high-concurrency scenarios.
30///
31/// # Example
32/// ```
33/// use spider_agent::AgentMemory;
34///
35/// let memory = AgentMemory::new();
36///
37/// // Key-value storage
38/// memory.set("user_id", serde_json::json!("12345"));
39///
40/// // URL tracking
41/// memory.add_visited_url("https://example.com");
42///
43/// // Action history
44/// memory.add_action("Searched for 'rust frameworks'");
45///
46/// // Extraction history
47/// memory.add_extraction(serde_json::json!({"title": "Example"}));
48///
49/// // Generate context for LLM
50/// let context = memory.to_context_string();
51/// ```
52#[derive(Debug, Clone, Default)]
53pub struct AgentMemory {
54    /// Lock-free concurrent key-value store.
55    data: Arc<DashMap<String, serde_json::Value>>,
56    /// History of visited URLs (most recent last).
57    visited_urls: Arc<RwLock<Vec<String>>>,
58    /// Brief summary of recent actions (most recent last).
59    action_history: Arc<RwLock<Vec<String>>>,
60    /// History of extracted data from pages (most recent last).
61    extractions: Arc<RwLock<Vec<serde_json::Value>>>,
62}
63
64impl AgentMemory {
65    /// Create a new empty memory.
66    pub fn new() -> Self {
67        Self {
68            data: Arc::new(DashMap::new()),
69            visited_urls: Arc::new(RwLock::new(Vec::new())),
70            action_history: Arc::new(RwLock::new(Vec::new())),
71            extractions: Arc::new(RwLock::new(Vec::new())),
72        }
73    }
74
75    /// Create memory with pre-allocated capacity.
76    pub fn with_capacity(capacity: usize) -> Self {
77        Self {
78            data: Arc::new(DashMap::with_capacity(capacity)),
79            visited_urls: Arc::new(RwLock::new(Vec::with_capacity(MAX_URL_HISTORY))),
80            action_history: Arc::new(RwLock::new(Vec::with_capacity(MAX_ACTION_HISTORY))),
81            extractions: Arc::new(RwLock::new(Vec::with_capacity(MAX_EXTRACTIONS))),
82        }
83    }
84
85    // ========== Key-Value Store ==========
86
87    /// Get a value from memory.
88    ///
89    /// Returns a clone of the value to avoid holding refs across await points.
90    pub fn get(&self, key: &str) -> Option<serde_json::Value> {
91        self.data.get(key).map(|v| v.value().clone())
92    }
93
94    /// Set a value in memory.
95    pub fn set(&self, key: impl Into<String>, value: serde_json::Value) {
96        self.data.insert(key.into(), value);
97    }
98
99    /// Remove a value from memory.
100    pub fn remove(&self, key: &str) -> Option<serde_json::Value> {
101        self.data.remove(key).map(|(_, v)| v)
102    }
103
104    /// Clear all key-value data.
105    pub fn clear(&self) {
106        self.data.clear();
107    }
108
109    /// Check if memory contains a key.
110    pub fn contains(&self, key: &str) -> bool {
111        self.data.contains_key(key)
112    }
113
114    /// Get number of key-value entries.
115    pub fn len(&self) -> usize {
116        self.data.len()
117    }
118
119    /// Check if key-value store is empty.
120    pub fn is_empty(&self) -> bool {
121        self.data.is_empty()
122    }
123
124    /// Get a typed value from memory.
125    pub fn get_as<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
126        self.data
127            .get(key)
128            .and_then(|v| serde_json::from_value(v.value().clone()).ok())
129    }
130
131    /// Set a typed value in memory.
132    pub fn set_value<T: Serialize>(&self, key: impl Into<String>, value: &T) {
133        if let Ok(json) = serde_json::to_value(value) {
134            self.data.insert(key.into(), json);
135        }
136    }
137
138    /// Update a value atomically using a closure.
139    ///
140    /// The closure receives the current value (if any) and returns the new value.
141    pub fn update<F>(&self, key: impl Into<String>, f: F)
142    where
143        F: FnOnce(Option<&serde_json::Value>) -> serde_json::Value,
144    {
145        let key = key.into();
146        let new_value = f(self.data.get(&key).as_deref());
147        self.data.insert(key, new_value);
148    }
149
150    /// Get or insert a value.
151    pub fn get_or_insert(
152        &self,
153        key: impl Into<String>,
154        default: serde_json::Value,
155    ) -> serde_json::Value {
156        self.data
157            .entry(key.into())
158            .or_insert(default)
159            .value()
160            .clone()
161    }
162
163    // ========== URL History ==========
164
165    /// Record a visited URL.
166    ///
167    /// Keeps the most recent URLs up to the limit.
168    pub fn add_visited_url(&self, url: impl Into<String>) {
169        let mut urls = self.visited_urls.write();
170        urls.push(url.into());
171        if urls.len() > MAX_URL_HISTORY {
172            urls.remove(0);
173        }
174    }
175
176    /// Get the list of visited URLs.
177    pub fn visited_urls(&self) -> Vec<String> {
178        self.visited_urls.read().clone()
179    }
180
181    /// Get the last N visited URLs.
182    pub fn recent_urls(&self, n: usize) -> Vec<String> {
183        let urls = self.visited_urls.read();
184        urls.iter().rev().take(n).cloned().collect()
185    }
186
187    /// Check if a URL has been visited.
188    pub fn has_visited(&self, url: &str) -> bool {
189        self.visited_urls.read().iter().any(|u| u == url)
190    }
191
192    /// Clear URL history.
193    pub fn clear_urls(&self) {
194        self.visited_urls.write().clear();
195    }
196
197    // ========== Action History ==========
198
199    /// Record an action summary.
200    ///
201    /// Keeps the most recent actions up to the limit.
202    pub fn add_action(&self, action: impl Into<String>) {
203        let mut actions = self.action_history.write();
204        actions.push(action.into());
205        if actions.len() > MAX_ACTION_HISTORY {
206            actions.remove(0);
207        }
208    }
209
210    /// Get the list of actions.
211    pub fn action_history(&self) -> Vec<String> {
212        self.action_history.read().clone()
213    }
214
215    /// Get the last N actions.
216    pub fn recent_actions(&self, n: usize) -> Vec<String> {
217        let actions = self.action_history.read();
218        actions.iter().rev().take(n).cloned().collect()
219    }
220
221    /// Clear action history.
222    pub fn clear_actions(&self) {
223        self.action_history.write().clear();
224    }
225
226    // ========== Extraction History ==========
227
228    /// Add an extracted value to history.
229    ///
230    /// Keeps the most recent extractions up to the limit.
231    pub fn add_extraction(&self, data: serde_json::Value) {
232        let mut extractions = self.extractions.write();
233        extractions.push(data);
234        if extractions.len() > MAX_EXTRACTIONS {
235            extractions.remove(0);
236        }
237    }
238
239    /// Get all extractions.
240    pub fn extractions(&self) -> Vec<serde_json::Value> {
241        self.extractions.read().clone()
242    }
243
244    /// Get the last N extractions.
245    pub fn recent_extractions(&self, n: usize) -> Vec<serde_json::Value> {
246        let extractions = self.extractions.read();
247        extractions.iter().rev().take(n).cloned().collect()
248    }
249
250    /// Clear extraction history.
251    pub fn clear_extractions(&self) {
252        self.extractions.write().clear();
253    }
254
255    // ========== Bulk Operations ==========
256
257    /// Clear all history (URLs, actions, extractions) but keep key-value store.
258    pub fn clear_history(&self) {
259        self.visited_urls.write().clear();
260        self.action_history.write().clear();
261        self.extractions.write().clear();
262    }
263
264    /// Clear everything including key-value store and all history.
265    pub fn clear_all(&self) {
266        self.data.clear();
267        self.visited_urls.write().clear();
268        self.action_history.write().clear();
269        self.extractions.write().clear();
270    }
271
272    /// Check if all memory is empty (store + all history).
273    pub fn is_all_empty(&self) -> bool {
274        self.data.is_empty()
275            && self.visited_urls.read().is_empty()
276            && self.action_history.read().is_empty()
277            && self.extractions.read().is_empty()
278    }
279
280    // ========== Context Generation ==========
281
282    /// Generate a context string for inclusion in LLM prompts.
283    ///
284    /// This provides the LLM with session context including:
285    /// - Key-value store contents
286    /// - Recent URLs visited
287    /// - Recent actions taken
288    /// - Recent extractions
289    pub fn to_context_string(&self) -> String {
290        if self.is_all_empty() {
291            return String::new();
292        }
293
294        let mut parts = Vec::new();
295
296        // Key-value store
297        if !self.data.is_empty() {
298            let store: std::collections::HashMap<_, _> = self
299                .data
300                .iter()
301                .map(|r| (r.key().clone(), r.value().clone()))
302                .collect();
303            if let Ok(json) = serde_json::to_string_pretty(&store) {
304                parts.push(format!("## Memory Store\n```json\n{}\n```", json));
305            }
306        }
307
308        // Recent URLs
309        let urls = self.visited_urls.read();
310        if !urls.is_empty() {
311            let recent: Vec<_> = urls.iter().rev().take(10).collect();
312            let url_list: String = recent
313                .iter()
314                .rev()
315                .enumerate()
316                .map(|(i, u)| format!("{}. {}", i + 1, u))
317                .collect::<Vec<_>>()
318                .join("\n");
319            parts.push(format!(
320                "## Recent URLs (last {})\n{}",
321                recent.len(),
322                url_list
323            ));
324        }
325        drop(urls);
326
327        // Recent extractions
328        let extractions = self.extractions.read();
329        if !extractions.is_empty() {
330            let recent: Vec<_> = extractions.iter().rev().take(5).collect();
331            let json_strs: Vec<_> = recent
332                .iter()
333                .rev()
334                .filter_map(|v| serde_json::to_string(v).ok())
335                .collect();
336            parts.push(format!(
337                "## Recent Extractions (last {})\n{}",
338                json_strs.len(),
339                json_strs.join("\n")
340            ));
341        }
342        drop(extractions);
343
344        // Recent actions
345        let actions = self.action_history.read();
346        if !actions.is_empty() {
347            let recent: Vec<_> = actions.iter().rev().take(10).collect();
348            let action_list: String = recent
349                .iter()
350                .rev()
351                .enumerate()
352                .map(|(i, a)| format!("{}. {}", i + 1, a))
353                .collect::<Vec<_>>()
354                .join("\n");
355            parts.push(format!(
356                "## Recent Actions (last {})\n{}",
357                recent.len(),
358                action_list
359            ));
360        }
361        drop(actions);
362
363        parts.join("\n\n")
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn test_memory_basic() {
373        let memory = AgentMemory::new();
374
375        memory.set("key1", serde_json::json!("value1"));
376        memory.set("key2", serde_json::json!(42));
377
378        assert_eq!(memory.get("key1"), Some(serde_json::json!("value1")));
379        assert_eq!(memory.get("key2"), Some(serde_json::json!(42)));
380        assert_eq!(memory.get("key3"), None);
381        assert_eq!(memory.len(), 2);
382    }
383
384    #[test]
385    fn test_memory_typed() {
386        let memory = AgentMemory::new();
387
388        memory.set_value("name", &"Alice".to_string());
389        memory.set_value("age", &30u32);
390
391        assert_eq!(memory.get_as::<String>("name"), Some("Alice".to_string()));
392        assert_eq!(memory.get_as::<u32>("age"), Some(30));
393    }
394
395    #[test]
396    fn test_memory_clear() {
397        let memory = AgentMemory::new();
398
399        memory.set("key1", serde_json::json!("value1"));
400        memory.set("key2", serde_json::json!("value2"));
401
402        assert_eq!(memory.len(), 2);
403
404        memory.clear();
405
406        assert!(memory.is_empty());
407    }
408
409    #[test]
410    fn test_memory_update() {
411        let memory = AgentMemory::new();
412
413        memory.set("counter", serde_json::json!(0));
414
415        memory.update("counter", |v| {
416            let current = v.and_then(|v| v.as_i64()).unwrap_or(0);
417            serde_json::json!(current + 1)
418        });
419
420        assert_eq!(memory.get("counter"), Some(serde_json::json!(1)));
421    }
422
423    #[test]
424    fn test_memory_get_or_insert() {
425        let memory = AgentMemory::new();
426
427        let value = memory.get_or_insert("key", serde_json::json!("default"));
428        assert_eq!(value, serde_json::json!("default"));
429
430        // Should return existing value
431        memory.set("key", serde_json::json!("updated"));
432        let value = memory.get_or_insert("key", serde_json::json!("other"));
433        assert_eq!(value, serde_json::json!("updated"));
434    }
435
436    #[test]
437    fn test_memory_concurrent_clone() {
438        let memory = AgentMemory::new();
439        let memory2 = memory.clone();
440
441        memory.set("key", serde_json::json!("value"));
442
443        // Clone shares the same underlying data
444        assert_eq!(memory2.get("key"), Some(serde_json::json!("value")));
445    }
446
447    #[test]
448    fn test_memory_url_history() {
449        let memory = AgentMemory::new();
450
451        memory.add_visited_url("https://example.com");
452        memory.add_visited_url("https://example.com/page1");
453        memory.add_visited_url("https://example.com/page2");
454
455        assert!(memory.has_visited("https://example.com"));
456        assert!(!memory.has_visited("https://other.com"));
457
458        let recent = memory.recent_urls(2);
459        assert_eq!(recent.len(), 2);
460        assert_eq!(recent[0], "https://example.com/page2");
461        assert_eq!(recent[1], "https://example.com/page1");
462
463        let all = memory.visited_urls();
464        assert_eq!(all.len(), 3);
465    }
466
467    #[test]
468    fn test_memory_action_history() {
469        let memory = AgentMemory::new();
470
471        memory.add_action("Searched for 'rust'");
472        memory.add_action("Clicked search button");
473        memory.add_action("Extracted results");
474
475        let recent = memory.recent_actions(2);
476        assert_eq!(recent.len(), 2);
477        assert_eq!(recent[0], "Extracted results");
478        assert_eq!(recent[1], "Clicked search button");
479    }
480
481    #[test]
482    fn test_memory_extractions() {
483        let memory = AgentMemory::new();
484
485        memory.add_extraction(serde_json::json!({"title": "Page 1"}));
486        memory.add_extraction(serde_json::json!({"title": "Page 2"}));
487
488        let extractions = memory.extractions();
489        assert_eq!(extractions.len(), 2);
490
491        let recent = memory.recent_extractions(1);
492        assert_eq!(recent[0]["title"], "Page 2");
493    }
494
495    #[test]
496    fn test_memory_clear_all() {
497        let memory = AgentMemory::new();
498
499        memory.set("key", serde_json::json!("value"));
500        memory.add_visited_url("https://example.com");
501        memory.add_action("Test action");
502        memory.add_extraction(serde_json::json!({"data": "test"}));
503
504        assert!(!memory.is_all_empty());
505
506        memory.clear_all();
507
508        assert!(memory.is_all_empty());
509    }
510
511    #[test]
512    fn test_memory_context_string() {
513        let memory = AgentMemory::new();
514
515        memory.set("user_id", serde_json::json!("123"));
516        memory.add_visited_url("https://example.com");
517        memory.add_action("Logged in");
518
519        let context = memory.to_context_string();
520
521        assert!(context.contains("Memory Store"));
522        assert!(context.contains("user_id"));
523        assert!(context.contains("Recent URLs"));
524        assert!(context.contains("example.com"));
525        assert!(context.contains("Recent Actions"));
526        assert!(context.contains("Logged in"));
527    }
528}