Skip to main content

walrus_memory/
tools.rs

1//! Memory tool schemas and handlers for agent tool registration.
2
3use crate::{Memory, RecallOptions};
4use std::sync::Arc;
5use wcore::{Handler, model::Tool};
6
7/// Tool schema + handler pair, ready to register on a hook.
8pub struct MemoryTool {
9    pub tool: Tool,
10    pub handler: Handler,
11}
12
13/// Build the `remember` tool + handler for the given memory backend.
14pub fn remember<M: Memory + 'static>(mem: Arc<M>) -> MemoryTool {
15    let schema = serde_json::json!({
16        "type": "object",
17        "properties": {
18            "key": { "type": "string", "description": "Memory key" },
19            "value": { "type": "string", "description": "Value to remember" }
20        },
21        "required": ["key", "value"]
22    });
23    let tool = Tool {
24        name: "remember".into(),
25        description: "Store a key-value pair in memory.".into(),
26        parameters: serde_json::from_value(schema).unwrap(),
27        strict: false,
28    };
29    let handler: Handler = Arc::new(move |args| {
30        let mem = Arc::clone(&mem);
31        Box::pin(async move {
32            let parsed: serde_json::Value = match serde_json::from_str(&args) {
33                Ok(v) => v,
34                Err(e) => return format!("invalid arguments: {e}"),
35            };
36            let key = parsed["key"].as_str().unwrap_or("");
37            let value = parsed["value"].as_str().unwrap_or("");
38            match mem.store(key.to_owned(), value.to_owned()).await {
39                Ok(()) => format!("remembered: {key}"),
40                Err(e) => format!("failed to store: {e}"),
41            }
42        })
43    });
44    MemoryTool { tool, handler }
45}
46
47/// Build the `recall` tool + handler for the given memory backend.
48pub fn recall<M: Memory + 'static>(mem: Arc<M>) -> MemoryTool {
49    let schema = serde_json::json!({
50        "type": "object",
51        "properties": {
52            "query": { "type": "string", "description": "Search query for relevant memories" },
53            "limit": { "type": "integer", "description": "Maximum number of results (default: 10)" }
54        },
55        "required": ["query"]
56    });
57    let tool = Tool {
58        name: "recall".into(),
59        description: "Search memory for entries relevant to a query.".into(),
60        parameters: serde_json::from_value(schema).unwrap(),
61        strict: false,
62    };
63    let handler: Handler = Arc::new(move |args| {
64        let mem = Arc::clone(&mem);
65        Box::pin(async move {
66            let parsed: serde_json::Value = match serde_json::from_str(&args) {
67                Ok(v) => v,
68                Err(e) => return format!("invalid arguments: {e}"),
69            };
70            let query = parsed["query"].as_str().unwrap_or("");
71            let limit = parsed["limit"].as_u64().unwrap_or(10) as usize;
72            let options = RecallOptions {
73                limit,
74                ..Default::default()
75            };
76            match mem.recall(query, options).await {
77                Ok(entries) if entries.is_empty() => "no memories found".to_owned(),
78                Ok(entries) => {
79                    let mut out = String::new();
80                    for entry in &entries {
81                        out.push_str(&format!("{}: {}\n", entry.key, entry.value));
82                    }
83                    out
84                }
85                Err(e) => format!("recall failed: {e}"),
86            }
87        })
88    });
89    MemoryTool { tool, handler }
90}