sh_layer3/builtin_tools/
memory_tools.rs1use crate::builtin_tools::BuiltinTool;
6use crate::memory_system::{MemoryStore, WorkingMemory};
7use crate::types::{Layer3Result, MemoryEntry, MemoryQuery, MemoryTier, ToolCategory};
8use async_trait::async_trait;
9use chrono::Utc;
10use sh_layer2::generate_short_id;
11use std::sync::Arc;
12
13pub struct SaveMemoryTool {
15 store: Arc<WorkingMemory>,
16}
17
18impl SaveMemoryTool {
19 pub fn new() -> Self {
20 Self {
21 store: Arc::new(WorkingMemory::default()),
22 }
23 }
24
25 pub fn with_store(store: Arc<WorkingMemory>) -> Self {
27 Self { store }
28 }
29}
30
31impl Default for SaveMemoryTool {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37#[async_trait]
38impl BuiltinTool for SaveMemoryTool {
39 fn name(&self) -> &str {
40 "save_memory"
41 }
42
43 fn description(&self) -> &str {
44 "Save a memory entry to the memory system."
45 }
46
47 fn parameters_schema(&self) -> serde_json::Value {
48 serde_json::json!({
49 "type": "object",
50 "properties": {
51 "content": {
52 "type": "string",
53 "description": "The content to remember"
54 },
55 "tier": {
56 "type": "string",
57 "enum": ["working", "session", "project", "long_term"],
58 "description": "Memory tier to store in (default: working)"
59 },
60 "metadata": {
61 "type": "object",
62 "description": "Optional: additional metadata"
63 }
64 },
65 "required": ["content"]
66 })
67 }
68
69 fn category(&self) -> ToolCategory {
70 ToolCategory::Memory
71 }
72
73 async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
74 let content = args["content"]
75 .as_str()
76 .ok_or_else(|| anyhow::anyhow!("Missing content parameter"))?;
77
78 let tier_str = args["tier"].as_str().unwrap_or("working");
79 let tier = match tier_str {
80 "working" => MemoryTier::Working,
81 "session" => MemoryTier::Session,
82 "project" => MemoryTier::Project,
83 "long_term" => MemoryTier::LongTerm,
84 _ => MemoryTier::Working,
85 };
86
87 let metadata = if let Some(obj) = args["metadata"].as_object() {
89 obj.clone()
90 } else {
91 serde_json::Map::new()
92 };
93
94 let entry = MemoryEntry {
96 id: generate_short_id(),
97 content: content.to_string(),
98 tier,
99 created_at: Utc::now(),
100 last_accessed: Utc::now(),
101 importance: 0.5,
102 metadata,
103 access_count: 0,
104 };
105
106 let id = self.store.store(entry).await?;
108
109 Ok(format!("Memory saved to {} tier with ID: {}", tier_str, id))
110 }
111}
112
113pub struct QueryMemoryTool {
115 store: Arc<WorkingMemory>,
116}
117
118impl QueryMemoryTool {
119 pub fn new() -> Self {
120 Self {
121 store: Arc::new(WorkingMemory::default()),
122 }
123 }
124
125 pub fn with_store(store: Arc<WorkingMemory>) -> Self {
127 Self { store }
128 }
129}
130
131impl Default for QueryMemoryTool {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137#[async_trait]
138impl BuiltinTool for QueryMemoryTool {
139 fn name(&self) -> &str {
140 "query_memory"
141 }
142
143 fn description(&self) -> &str {
144 "Query the memory system for relevant memories."
145 }
146
147 fn parameters_schema(&self) -> serde_json::Value {
148 serde_json::json!({
149 "type": "object",
150 "properties": {
151 "query": {
152 "type": "string",
153 "description": "The query text"
154 },
155 "tier": {
156 "type": "string",
157 "enum": ["working", "session", "project", "long_term"],
158 "description": "Optional: limit to specific tier"
159 },
160 "limit": {
161 "type": "integer",
162 "description": "Optional: maximum number of results (default: 10)"
163 }
164 },
165 "required": ["query"]
166 })
167 }
168
169 fn category(&self) -> ToolCategory {
170 ToolCategory::Memory
171 }
172
173 async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
174 let query_text = args["query"]
175 .as_str()
176 .ok_or_else(|| anyhow::anyhow!("Missing query parameter"))?;
177
178 let limit = args["limit"].as_u64().map(|l| l as usize);
179 let tier = args["tier"].as_str().and_then(|t| match t {
180 "working" => Some(MemoryTier::Working),
181 "session" => Some(MemoryTier::Session),
182 "project" => Some(MemoryTier::Project),
183 "long_term" => Some(MemoryTier::LongTerm),
184 _ => None,
185 });
186
187 let query = MemoryQuery {
188 query: query_text.to_string(),
189 tier,
190 limit,
191 time_range: None,
192 };
193
194 let results = self.store.query(&query).await?;
196
197 if results.is_empty() {
198 Ok("(no memories found)".to_string())
199 } else {
200 let output: Vec<String> = results
201 .iter()
202 .take(limit.unwrap_or(10))
203 .map(|e| {
204 let preview = if e.content.len() > 200 {
205 format!("{}...", &e.content[..200])
206 } else {
207 e.content.clone()
208 };
209 format!("{}: {}", e.id, preview)
210 })
211 .collect();
212 Ok(output.join("\n"))
213 }
214 }
215}
216
217pub struct ClearMemoryTool {
219 store: Arc<WorkingMemory>,
220}
221
222impl ClearMemoryTool {
223 pub fn new() -> Self {
224 Self {
225 store: Arc::new(WorkingMemory::default()),
226 }
227 }
228
229 pub fn with_store(store: Arc<WorkingMemory>) -> Self {
231 Self { store }
232 }
233}
234
235impl Default for ClearMemoryTool {
236 fn default() -> Self {
237 Self::new()
238 }
239}
240
241#[async_trait]
242impl BuiltinTool for ClearMemoryTool {
243 fn name(&self) -> &str {
244 "clear_memory"
245 }
246
247 fn description(&self) -> &str {
248 "Clear all memories from a specific tier."
249 }
250
251 fn parameters_schema(&self) -> serde_json::Value {
252 serde_json::json!({
253 "type": "object",
254 "properties": {
255 "tier": {
256 "type": "string",
257 "enum": ["working", "session", "project", "long_term"],
258 "description": "Memory tier to clear (default: working)"
259 }
260 },
261 "required": []
262 })
263 }
264
265 fn category(&self) -> ToolCategory {
266 ToolCategory::Memory
267 }
268
269 async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
270 let tier_str = args["tier"].as_str().unwrap_or("working");
271
272 let count = self.store.clear().await?;
274
275 Ok(format!("Cleared {} memories from {} tier", count, tier_str))
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use serde_json::json;
283
284 #[test]
285 fn test_memory_tool_category() {
286 let tool = SaveMemoryTool::new();
287 assert_eq!(tool.category(), ToolCategory::Memory);
288 }
289
290 #[test]
291 fn test_query_memory_tool_category() {
292 let tool = QueryMemoryTool::new();
293 assert_eq!(tool.category(), ToolCategory::Memory);
294 }
295
296 #[tokio::test]
297 async fn test_save_memory() {
298 let tool = SaveMemoryTool::new();
299 let result = tool.execute(json!({"content": "test memory"})).await;
300 assert!(result.is_ok());
301 assert!(result.unwrap().contains("Memory saved"));
302 }
303
304 #[tokio::test]
305 async fn test_query_memory_empty() {
306 let tool = QueryMemoryTool::new();
307 let result = tool.execute(json!({"query": "nonexistent"})).await;
308 assert!(result.is_ok());
309 assert!(result.unwrap().contains("no memories"));
310 }
311
312 #[tokio::test]
313 async fn test_save_and_query_memory() {
314 let store = Arc::new(WorkingMemory::default());
315
316 let save_tool = SaveMemoryTool::with_store(store.clone());
317 save_tool
318 .execute(json!({"content": "important fact: the sky is blue"}))
319 .await
320 .unwrap();
321
322 let query_tool = QueryMemoryTool::with_store(store);
323 let result = query_tool.execute(json!({"query": "sky"})).await.unwrap();
324 assert!(result.contains("sky is blue"));
325 }
326}