Skip to main content

walrus_core/memory/
tools.rs

1//! Memory tool schemas and handlers for agent tool registration.
2
3use crate::{Handler, Memory, RecallOptions, model::Tool};
4use std::sync::Arc;
5
6/// Build the `remember` tool + handler for the given memory backend.
7pub fn remember<M: Memory + 'static>(mem: Arc<M>) -> (Tool, Handler) {
8    let schema = serde_json::json!({
9        "type": "object",
10        "properties": {
11            "key": { "type": "string", "description": "Memory key" },
12            "value": { "type": "string", "description": "Value to remember" }
13        },
14        "required": ["key", "value"]
15    });
16    let tool = Tool {
17        name: "remember".into(),
18        description: "Store a key-value pair in memory.".into(),
19        parameters: serde_json::from_value(schema).unwrap(),
20        strict: false,
21    };
22    let handler: Handler = Arc::new(move |args| {
23        let mem = Arc::clone(&mem);
24        Box::pin(async move {
25            let parsed: serde_json::Value = match serde_json::from_str(&args) {
26                Ok(v) => v,
27                Err(e) => return format!("invalid arguments: {e}"),
28            };
29            let key = parsed["key"].as_str().unwrap_or("");
30            let value = parsed["value"].as_str().unwrap_or("");
31            match mem.store(key.to_owned(), value.to_owned()).await {
32                Ok(()) => format!("remembered: {key}"),
33                Err(e) => format!("failed to store: {e}"),
34            }
35        })
36    });
37    (tool, handler)
38}
39
40/// Build the `recall` tool + handler for the given memory backend.
41pub fn recall<M: Memory + 'static>(mem: Arc<M>) -> (Tool, Handler) {
42    let schema = serde_json::json!({
43        "type": "object",
44        "properties": {
45            "query": { "type": "string", "description": "Search query for relevant memories" },
46            "limit": { "type": "integer", "description": "Maximum number of results (default: 10)" }
47        },
48        "required": ["query"]
49    });
50    let tool = Tool {
51        name: "recall".into(),
52        description: "Search memory for entries relevant to a query.".into(),
53        parameters: serde_json::from_value(schema).unwrap(),
54        strict: false,
55    };
56    let handler: Handler = Arc::new(move |args| {
57        let mem = Arc::clone(&mem);
58        Box::pin(async move {
59            let parsed: serde_json::Value = match serde_json::from_str(&args) {
60                Ok(v) => v,
61                Err(e) => return format!("invalid arguments: {e}"),
62            };
63            let query = parsed["query"].as_str().unwrap_or("");
64            let limit = parsed["limit"].as_u64().unwrap_or(10) as usize;
65            let options = RecallOptions {
66                limit,
67                ..Default::default()
68            };
69            match mem.recall(query, options).await {
70                Ok(entries) if entries.is_empty() => "no memories found".to_owned(),
71                Ok(entries) => {
72                    let mut out = String::new();
73                    for entry in &entries {
74                        out.push_str(&format!("{}: {}\n", entry.key, entry.value));
75                    }
76                    out
77                }
78                Err(e) => format!("recall failed: {e}"),
79            }
80        })
81    });
82    (tool, handler)
83}