Skip to main content

sh_layer3/builtin_tools/
memory_tools.rs

1//! # Memory Tools
2//!
3//! 记忆操作工具集,使用分层记忆系统。
4
5use crate::builtin_tools::BuiltinTool;
6use crate::memory_system::{MemoryStore, WorkingMemory};
7use crate::types::{Layer3Result, MemoryEntry, MemoryQuery, MemoryTier, ToolCategory};
8use async_trait::async_trait;
9use chrono::Utc;
10use sh_layer2::generate_short_id;
11use std::sync::Arc;
12
13/// Save Memory Tool
14pub struct SaveMemoryTool {
15    store: Arc<WorkingMemory>,
16}
17
18impl SaveMemoryTool {
19    pub fn new() -> Self {
20        Self {
21            store: Arc::new(WorkingMemory::default()),
22        }
23    }
24
25    /// 使用指定的 store 创建
26    pub fn with_store(store: Arc<WorkingMemory>) -> Self {
27        Self { store }
28    }
29}
30
31impl Default for SaveMemoryTool {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37#[async_trait]
38impl BuiltinTool for SaveMemoryTool {
39    fn name(&self) -> &str {
40        "save_memory"
41    }
42
43    fn description(&self) -> &str {
44        "Save a memory entry to the memory system."
45    }
46
47    fn parameters_schema(&self) -> serde_json::Value {
48        serde_json::json!({
49            "type": "object",
50            "properties": {
51                "content": {
52                    "type": "string",
53                    "description": "The content to remember"
54                },
55                "tier": {
56                    "type": "string",
57                    "enum": ["working", "session", "project", "long_term"],
58                    "description": "Memory tier to store in (default: working)"
59                },
60                "metadata": {
61                    "type": "object",
62                    "description": "Optional: additional metadata"
63                }
64            },
65            "required": ["content"]
66        })
67    }
68
69    fn category(&self) -> ToolCategory {
70        ToolCategory::Memory
71    }
72
73    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
74        let content = args["content"]
75            .as_str()
76            .ok_or_else(|| anyhow::anyhow!("Missing content parameter"))?;
77
78        let tier_str = args["tier"].as_str().unwrap_or("working");
79        let tier = match tier_str {
80            "working" => MemoryTier::Working,
81            "session" => MemoryTier::Session,
82            "project" => MemoryTier::Project,
83            "long_term" => MemoryTier::LongTerm,
84            _ => MemoryTier::Working,
85        };
86
87        // Extract metadata as Map
88        let metadata = if let Some(obj) = args["metadata"].as_object() {
89            obj.clone()
90        } else {
91            serde_json::Map::new()
92        };
93
94        // Create memory entry
95        let entry = MemoryEntry {
96            id: generate_short_id(),
97            content: content.to_string(),
98            tier,
99            created_at: Utc::now(),
100            last_accessed: Utc::now(),
101            importance: 0.5,
102            metadata,
103            access_count: 0,
104        };
105
106        // Store in working memory
107        let id = self.store.store(entry).await?;
108
109        Ok(format!("Memory saved to {} tier with ID: {}", tier_str, id))
110    }
111}
112
113/// Query Memory Tool
114pub struct QueryMemoryTool {
115    store: Arc<WorkingMemory>,
116}
117
118impl QueryMemoryTool {
119    pub fn new() -> Self {
120        Self {
121            store: Arc::new(WorkingMemory::default()),
122        }
123    }
124
125    /// 使用指定的 store 创建
126    pub fn with_store(store: Arc<WorkingMemory>) -> Self {
127        Self { store }
128    }
129}
130
131impl Default for QueryMemoryTool {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137#[async_trait]
138impl BuiltinTool for QueryMemoryTool {
139    fn name(&self) -> &str {
140        "query_memory"
141    }
142
143    fn description(&self) -> &str {
144        "Query the memory system for relevant memories."
145    }
146
147    fn parameters_schema(&self) -> serde_json::Value {
148        serde_json::json!({
149            "type": "object",
150            "properties": {
151                "query": {
152                    "type": "string",
153                    "description": "The query text"
154                },
155                "tier": {
156                    "type": "string",
157                    "enum": ["working", "session", "project", "long_term"],
158                    "description": "Optional: limit to specific tier"
159                },
160                "limit": {
161                    "type": "integer",
162                    "description": "Optional: maximum number of results (default: 10)"
163                }
164            },
165            "required": ["query"]
166        })
167    }
168
169    fn category(&self) -> ToolCategory {
170        ToolCategory::Memory
171    }
172
173    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
174        let query_text = args["query"]
175            .as_str()
176            .ok_or_else(|| anyhow::anyhow!("Missing query parameter"))?;
177
178        let limit = args["limit"].as_u64().map(|l| l as usize);
179        let tier = args["tier"].as_str().and_then(|t| match t {
180            "working" => Some(MemoryTier::Working),
181            "session" => Some(MemoryTier::Session),
182            "project" => Some(MemoryTier::Project),
183            "long_term" => Some(MemoryTier::LongTerm),
184            _ => None,
185        });
186
187        let query = MemoryQuery {
188            query: query_text.to_string(),
189            tier,
190            limit,
191            time_range: None,
192        };
193
194        // Query working memory
195        let results = self.store.query(&query).await?;
196
197        if results.is_empty() {
198            Ok("(no memories found)".to_string())
199        } else {
200            let output: Vec<String> = results
201                .iter()
202                .take(limit.unwrap_or(10))
203                .map(|e| {
204                    let preview = if e.content.len() > 200 {
205                        format!("{}...", &e.content[..200])
206                    } else {
207                        e.content.clone()
208                    };
209                    format!("{}: {}", e.id, preview)
210                })
211                .collect();
212            Ok(output.join("\n"))
213        }
214    }
215}
216
217/// Clear Memory Tool
218pub struct ClearMemoryTool {
219    store: Arc<WorkingMemory>,
220}
221
222impl ClearMemoryTool {
223    pub fn new() -> Self {
224        Self {
225            store: Arc::new(WorkingMemory::default()),
226        }
227    }
228
229    /// 使用指定的 store 创建
230    pub fn with_store(store: Arc<WorkingMemory>) -> Self {
231        Self { store }
232    }
233}
234
235impl Default for ClearMemoryTool {
236    fn default() -> Self {
237        Self::new()
238    }
239}
240
241#[async_trait]
242impl BuiltinTool for ClearMemoryTool {
243    fn name(&self) -> &str {
244        "clear_memory"
245    }
246
247    fn description(&self) -> &str {
248        "Clear all memories from a specific tier."
249    }
250
251    fn parameters_schema(&self) -> serde_json::Value {
252        serde_json::json!({
253            "type": "object",
254            "properties": {
255                "tier": {
256                    "type": "string",
257                    "enum": ["working", "session", "project", "long_term"],
258                    "description": "Memory tier to clear (default: working)"
259                }
260            },
261            "required": []
262        })
263    }
264
265    fn category(&self) -> ToolCategory {
266        ToolCategory::Memory
267    }
268
269    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
270        let tier_str = args["tier"].as_str().unwrap_or("working");
271
272        // Clear working memory
273        let count = self.store.clear().await?;
274
275        Ok(format!("Cleared {} memories from {} tier", count, tier_str))
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use serde_json::json;
283
284    #[test]
285    fn test_memory_tool_category() {
286        let tool = SaveMemoryTool::new();
287        assert_eq!(tool.category(), ToolCategory::Memory);
288    }
289
290    #[test]
291    fn test_query_memory_tool_category() {
292        let tool = QueryMemoryTool::new();
293        assert_eq!(tool.category(), ToolCategory::Memory);
294    }
295
296    #[tokio::test]
297    async fn test_save_memory() {
298        let tool = SaveMemoryTool::new();
299        let result = tool.execute(json!({"content": "test memory"})).await;
300        assert!(result.is_ok());
301        assert!(result.unwrap().contains("Memory saved"));
302    }
303
304    #[tokio::test]
305    async fn test_query_memory_empty() {
306        let tool = QueryMemoryTool::new();
307        let result = tool.execute(json!({"query": "nonexistent"})).await;
308        assert!(result.is_ok());
309        assert!(result.unwrap().contains("no memories"));
310    }
311
312    #[tokio::test]
313    async fn test_save_and_query_memory() {
314        let store = Arc::new(WorkingMemory::default());
315
316        let save_tool = SaveMemoryTool::with_store(store.clone());
317        save_tool
318            .execute(json!({"content": "important fact: the sky is blue"}))
319            .await
320            .unwrap();
321
322        let query_tool = QueryMemoryTool::with_store(store);
323        let result = query_tool.execute(json!({"query": "sky"})).await.unwrap();
324        assert!(result.contains("sky is blue"));
325    }
326}