Skip to main content

zeph_core/
memory_tools.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::fmt::Write as _;
5use std::sync::Arc;
6
7use zeph_memory::embedding_store::SearchFilter;
8use zeph_memory::semantic::SemanticMemory;
9use zeph_memory::types::ConversationId;
10use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
11use zeph_tools::registry::{InvocationHint, ToolDef};
12
13use zeph_sanitizer::memory_validation::MemoryWriteValidator;
14
15#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
16struct MemorySearchParams {
17    /// Natural language query to search memory for relevant past messages and facts.
18    query: String,
19    /// Maximum number of results to return (default: 5, max: 20).
20    #[serde(default = "default_limit")]
21    limit: u32,
22}
23
24fn default_limit() -> u32 {
25    5
26}
27
28#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
29struct MemorySaveParams {
30    /// The content to save to long-term memory. Should be a concise, self-contained fact or note.
31    content: String,
32    /// Role label for the saved message (default: "assistant").
33    #[serde(default = "default_role")]
34    role: String,
35}
36
37fn default_role() -> String {
38    "assistant".into()
39}
40
41pub struct MemoryToolExecutor {
42    memory: Arc<SemanticMemory>,
43    conversation_id: ConversationId,
44    validator: MemoryWriteValidator,
45}
46
47impl MemoryToolExecutor {
48    #[must_use]
49    pub fn new(memory: Arc<SemanticMemory>, conversation_id: ConversationId) -> Self {
50        Self {
51            memory,
52            conversation_id,
53            validator: MemoryWriteValidator::new(
54                zeph_sanitizer::memory_validation::MemoryWriteValidationConfig::default(),
55            ),
56        }
57    }
58
59    /// Create with a custom validator (used when security config is loaded).
60    #[must_use]
61    pub fn with_validator(
62        memory: Arc<SemanticMemory>,
63        conversation_id: ConversationId,
64        validator: MemoryWriteValidator,
65    ) -> Self {
66        Self {
67            memory,
68            conversation_id,
69            validator,
70        }
71    }
72}
73
74impl ToolExecutor for MemoryToolExecutor {
75    fn tool_definitions(&self) -> Vec<ToolDef> {
76        vec![
77            ToolDef {
78                id: "memory_search".into(),
79                description: "Search long-term memory for relevant past messages, facts, and session summaries. Use when the user references past conversations or you need historical context.\n\nParameters: query (string, required) - natural language search query; limit (integer, optional) - max results 1-20 (default: 5)\nReturns: ranked list of memory entries with similarity scores and timestamps\nErrors: Execution on database failure\nExample: {\"query\": \"user preference for output format\", \"limit\": 5}".into(),
80                schema: schemars::schema_for!(MemorySearchParams),
81                invocation: InvocationHint::ToolCall,
82            },
83            ToolDef {
84                id: "memory_save".into(),
85                description: "Save a fact or note to long-term memory for cross-session recall. Use sparingly for key decisions, user preferences, or critical context worth remembering across sessions.\n\nParameters: content (string, required) - concise, self-contained fact or note; role (string, optional) - message role label (default: \"assistant\")\nReturns: confirmation with saved entry ID\nErrors: Execution on database failure; InvalidParams if content is empty\nExample: {\"content\": \"User prefers JSON output over YAML\", \"role\": \"assistant\"}".into(),
86                schema: schemars::schema_for!(MemorySaveParams),
87                invocation: InvocationHint::ToolCall,
88            },
89        ]
90    }
91
92    async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
93        Ok(None)
94    }
95
96    #[allow(clippy::too_many_lines)] // two tools with validation, search, and multi-source aggregation
97    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
98        match call.tool_id.as_str() {
99            "memory_search" => {
100                let params: MemorySearchParams = deserialize_params(&call.params)?;
101                let limit = params.limit.clamp(1, 20) as usize;
102
103                let filter = Some(SearchFilter {
104                    conversation_id: Some(self.conversation_id),
105                    role: None,
106                });
107
108                let recalled = self
109                    .memory
110                    .recall(&params.query, limit, filter)
111                    .await
112                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
113
114                let key_facts = self
115                    .memory
116                    .search_key_facts(&params.query, limit)
117                    .await
118                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
119
120                let summaries = self
121                    .memory
122                    .search_session_summaries(&params.query, limit, Some(self.conversation_id))
123                    .await
124                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
125
126                let mut output = String::new();
127
128                let _ = writeln!(output, "## Recalled Messages ({} results)", recalled.len());
129                for r in &recalled {
130                    let role = match r.message.role {
131                        zeph_llm::provider::Role::User => "user",
132                        zeph_llm::provider::Role::Assistant => "assistant",
133                        zeph_llm::provider::Role::System => "system",
134                    };
135                    let content = r.message.content.trim();
136                    let _ = writeln!(output, "[score: {:.2}] {role}: {content}", r.score);
137                }
138
139                let _ = writeln!(output);
140                let _ = writeln!(output, "## Key Facts ({} results)", key_facts.len());
141                for fact in &key_facts {
142                    let _ = writeln!(output, "- {fact}");
143                }
144
145                let _ = writeln!(output);
146                let _ = writeln!(output, "## Session Summaries ({} results)", summaries.len());
147                for s in &summaries {
148                    let _ = writeln!(
149                        output,
150                        "[conv #{}, score: {:.2}] {}",
151                        s.conversation_id, s.score, s.summary_text
152                    );
153                }
154
155                Ok(Some(ToolOutput {
156                    tool_name: "memory_search".to_owned(),
157                    summary: output,
158                    blocks_executed: 1,
159                    filter_stats: None,
160                    diff: None,
161                    streamed: false,
162                    terminal_id: None,
163                    locations: None,
164                    raw_response: None,
165                }))
166            }
167            "memory_save" => {
168                let params: MemorySaveParams = deserialize_params(&call.params)?;
169
170                if params.content.is_empty() {
171                    return Err(ToolError::InvalidParams {
172                        message: "content must not be empty".to_owned(),
173                    });
174                }
175                if params.content.len() > 4096 {
176                    return Err(ToolError::InvalidParams {
177                        message: "content exceeds maximum length of 4096 characters".to_owned(),
178                    });
179                }
180
181                // Schema validation: check content before writing to memory.
182                if let Err(e) = self.validator.validate_memory_save(&params.content) {
183                    return Err(ToolError::InvalidParams {
184                        message: format!("memory write rejected: {e}"),
185                    });
186                }
187
188                let role = params.role.as_str();
189
190                let message_id = self
191                    .memory
192                    .remember(self.conversation_id, role, &params.content)
193                    .await
194                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
195
196                Ok(Some(ToolOutput {
197                    tool_name: "memory_save".to_owned(),
198                    summary: format!(
199                        "Saved to memory (message_id: {message_id}, conversation: {}). Content will be available for future recall.",
200                        self.conversation_id
201                    ),
202                    blocks_executed: 1,
203                    filter_stats: None,
204                    diff: None,
205                    streamed: false,
206                    terminal_id: None,
207                    locations: None,
208                    raw_response: None,
209                }))
210            }
211            _ => Ok(None),
212        }
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use zeph_llm::any::AnyProvider;
220    use zeph_llm::mock::MockProvider;
221    use zeph_memory::semantic::SemanticMemory;
222
223    async fn make_memory() -> SemanticMemory {
224        SemanticMemory::with_sqlite_backend(
225            ":memory:",
226            AnyProvider::Mock(MockProvider::default()),
227            "test-model",
228            0.7,
229            0.3,
230        )
231        .await
232        .unwrap()
233    }
234
235    fn make_executor(memory: SemanticMemory) -> MemoryToolExecutor {
236        MemoryToolExecutor::new(Arc::new(memory), ConversationId(1))
237    }
238
239    #[tokio::test]
240    async fn tool_definitions_returns_two_tools() {
241        let memory = make_memory().await;
242        let executor = make_executor(memory);
243        let defs = executor.tool_definitions();
244        assert_eq!(defs.len(), 2);
245        assert_eq!(defs[0].id.as_ref(), "memory_search");
246        assert_eq!(defs[1].id.as_ref(), "memory_save");
247    }
248
249    #[tokio::test]
250    async fn execute_always_returns_none() {
251        let memory = make_memory().await;
252        let executor = make_executor(memory);
253        let result = executor.execute("any response").await.unwrap();
254        assert!(result.is_none());
255    }
256
257    #[tokio::test]
258    async fn execute_tool_call_unknown_returns_none() {
259        let memory = make_memory().await;
260        let executor = make_executor(memory);
261        let call = ToolCall {
262            tool_id: "unknown_tool".to_owned(),
263            params: serde_json::Map::new(),
264        };
265        let result = executor.execute_tool_call(&call).await.unwrap();
266        assert!(result.is_none());
267    }
268
269    #[tokio::test]
270    async fn memory_search_returns_output() {
271        let memory = make_memory().await;
272        let executor = make_executor(memory);
273        let mut params = serde_json::Map::new();
274        params.insert(
275            "query".into(),
276            serde_json::Value::String("test query".into()),
277        );
278        let call = ToolCall {
279            tool_id: "memory_search".to_owned(),
280            params,
281        };
282        let result = executor.execute_tool_call(&call).await.unwrap();
283        assert!(result.is_some());
284        let output = result.unwrap();
285        assert_eq!(output.tool_name, "memory_search");
286        assert!(output.summary.contains("Recalled Messages"));
287        assert!(output.summary.contains("Key Facts"));
288        assert!(output.summary.contains("Session Summaries"));
289    }
290
291    #[tokio::test]
292    async fn memory_save_stores_and_returns_confirmation() {
293        let memory = make_memory().await;
294        let sqlite = memory.sqlite().clone();
295        // Create conversation first
296        let cid = sqlite.create_conversation().await.unwrap();
297        let executor = MemoryToolExecutor::new(Arc::new(memory), cid);
298
299        let mut params = serde_json::Map::new();
300        params.insert(
301            "content".into(),
302            serde_json::Value::String("User prefers dark mode".into()),
303        );
304        let call = ToolCall {
305            tool_id: "memory_save".to_owned(),
306            params,
307        };
308        let result = executor.execute_tool_call(&call).await.unwrap();
309        assert!(result.is_some());
310        let output = result.unwrap();
311        assert!(output.summary.contains("Saved to memory"));
312        assert!(output.summary.contains("message_id:"));
313    }
314
315    #[tokio::test]
316    async fn memory_save_empty_content_returns_error() {
317        let memory = make_memory().await;
318        let executor = make_executor(memory);
319        let mut params = serde_json::Map::new();
320        params.insert("content".into(), serde_json::Value::String(String::new()));
321        let call = ToolCall {
322            tool_id: "memory_save".to_owned(),
323            params,
324        };
325        let result = executor.execute_tool_call(&call).await;
326        assert!(result.is_err());
327    }
328
329    #[tokio::test]
330    async fn memory_save_oversized_content_returns_error() {
331        let memory = make_memory().await;
332        let executor = make_executor(memory);
333        let mut params = serde_json::Map::new();
334        params.insert(
335            "content".into(),
336            serde_json::Value::String("x".repeat(4097)),
337        );
338        let call = ToolCall {
339            tool_id: "memory_save".to_owned(),
340            params,
341        };
342        let result = executor.execute_tool_call(&call).await;
343        assert!(result.is_err());
344    }
345}