Skip to main content

zeph_core/
memory_tools.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::fmt::Write as _;
5use std::sync::Arc;
6
7use zeph_memory::embedding_store::SearchFilter;
8use zeph_memory::semantic::SemanticMemory;
9use zeph_memory::types::ConversationId;
10use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
11use zeph_tools::registry::{InvocationHint, ToolDef};
12
13use zeph_sanitizer::memory_validation::MemoryWriteValidator;
14
15#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
16struct MemorySearchParams {
17    /// Natural language query to search memory for relevant past messages and facts.
18    query: String,
19    /// Maximum number of results to return (default: 5, max: 20).
20    #[serde(default = "default_limit")]
21    limit: u32,
22}
23
24fn default_limit() -> u32 {
25    5
26}
27
28#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
29struct MemorySaveParams {
30    /// The content to save to long-term memory. Should be a concise, self-contained fact or note.
31    content: String,
32    /// Role label for the saved message (default: "assistant").
33    #[serde(default = "default_role")]
34    role: String,
35}
36
37fn default_role() -> String {
38    "assistant".into()
39}
40
41pub struct MemoryToolExecutor {
42    memory: Arc<SemanticMemory>,
43    conversation_id: ConversationId,
44    validator: MemoryWriteValidator,
45}
46
47impl MemoryToolExecutor {
48    #[must_use]
49    pub fn new(memory: Arc<SemanticMemory>, conversation_id: ConversationId) -> Self {
50        Self {
51            memory,
52            conversation_id,
53            validator: MemoryWriteValidator::new(
54                zeph_sanitizer::memory_validation::MemoryWriteValidationConfig::default(),
55            ),
56        }
57    }
58
59    /// Create with a custom validator (used when security config is loaded).
60    #[must_use]
61    pub fn with_validator(
62        memory: Arc<SemanticMemory>,
63        conversation_id: ConversationId,
64        validator: MemoryWriteValidator,
65    ) -> Self {
66        Self {
67            memory,
68            conversation_id,
69            validator,
70        }
71    }
72}
73
74impl ToolExecutor for MemoryToolExecutor {
75    fn tool_definitions(&self) -> Vec<ToolDef> {
76        vec![
77            ToolDef {
78                id: "memory_search".into(),
79                description: "Search long-term memory for relevant past messages, facts, and session summaries. Use to recall facts, preferences, or information the user provided during this or previous conversations.\n\nParameters: query (string, required) - natural language search query; limit (integer, optional) - max results 1-20 (default: 5)\nReturns: ranked list of memory entries with similarity scores and timestamps\nErrors: Execution on database failure\nExample: {\"query\": \"user preference for output format\", \"limit\": 5}".into(),
80                schema: schemars::schema_for!(MemorySearchParams),
81                invocation: InvocationHint::ToolCall,
82                output_schema: None,
83            },
84            ToolDef {
85                id: "memory_save".into(),
86                description: "Save a fact or note to long-term memory for cross-session recall. Use sparingly for key decisions, user preferences, or critical context worth remembering across sessions.\n\nParameters: content (string, required) - concise, self-contained fact or note; role (string, optional) - message role label (default: \"assistant\")\nReturns: confirmation with saved entry ID\nErrors: Execution on database failure; InvalidParams if content is empty\nExample: {\"content\": \"User prefers JSON output over YAML\", \"role\": \"assistant\"}".into(),
87                schema: schemars::schema_for!(MemorySaveParams),
88                invocation: InvocationHint::ToolCall,
89                output_schema: None,
90            },
91        ]
92    }
93
94    async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
95        Ok(None)
96    }
97
98    #[allow(clippy::too_many_lines)] // two tools with validation, search, and multi-source aggregation
99    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
100        match call.tool_id.as_str() {
101            "memory_search" => {
102                let params: MemorySearchParams = deserialize_params(&call.params)?;
103                let limit = params.limit.clamp(1, 20) as usize;
104
105                let filter = Some(SearchFilter {
106                    conversation_id: Some(self.conversation_id),
107                    role: None,
108                    category: None,
109                });
110
111                let recalled = self
112                    .memory
113                    .recall(&params.query, limit, filter)
114                    .await
115                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
116
117                let key_facts = self
118                    .memory
119                    .search_key_facts(&params.query, limit)
120                    .await
121                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
122
123                let summaries = self
124                    .memory
125                    .search_session_summaries(&params.query, limit, Some(self.conversation_id))
126                    .await
127                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
128
129                let mut output = String::new();
130
131                let _ = writeln!(output, "## Recalled Messages ({} results)", recalled.len());
132                for r in &recalled {
133                    let role = match r.message.role {
134                        zeph_llm::provider::Role::User => "user",
135                        zeph_llm::provider::Role::Assistant => "assistant",
136                        zeph_llm::provider::Role::System => "system",
137                    };
138                    let content = r.message.content.trim();
139                    let _ = writeln!(output, "[score: {:.2}] {role}: {content}", r.score);
140                }
141
142                let _ = writeln!(output);
143                let _ = writeln!(output, "## Key Facts ({} results)", key_facts.len());
144                for fact in &key_facts {
145                    let _ = writeln!(output, "- {fact}");
146                }
147
148                let _ = writeln!(output);
149                let _ = writeln!(output, "## Session Summaries ({} results)", summaries.len());
150                for s in &summaries {
151                    let _ = writeln!(
152                        output,
153                        "[conv #{}, score: {:.2}] {}",
154                        s.conversation_id, s.score, s.summary_text
155                    );
156                }
157
158                Ok(Some(ToolOutput {
159                    tool_name: zeph_common::ToolName::new("memory_search"),
160                    summary: output,
161                    blocks_executed: 1,
162                    filter_stats: None,
163                    diff: None,
164                    streamed: false,
165                    terminal_id: None,
166                    locations: None,
167                    raw_response: None,
168                    claim_source: Some(zeph_tools::ClaimSource::Memory),
169                }))
170            }
171            "memory_save" => {
172                let params: MemorySaveParams = deserialize_params(&call.params)?;
173
174                if params.content.is_empty() {
175                    return Err(ToolError::InvalidParams {
176                        message: "content must not be empty".to_owned(),
177                    });
178                }
179                if params.content.len() > 4096 {
180                    return Err(ToolError::InvalidParams {
181                        message: "content exceeds maximum length of 4096 characters".to_owned(),
182                    });
183                }
184
185                // Schema validation: check content before writing to memory.
186                if let Err(e) = self.validator.validate_memory_save(&params.content) {
187                    return Err(ToolError::InvalidParams {
188                        message: format!("memory write rejected: {e}"),
189                    });
190                }
191
192                let role = params.role.as_str();
193
194                // Explicit user-directed saves bypass goal-conditioned scoring (goal_text = None).
195                let message_id_opt = self
196                    .memory
197                    .remember(self.conversation_id, role, &params.content, None)
198                    .await
199                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
200
201                let summary = match message_id_opt {
202                    Some(message_id) => format!(
203                        "Saved to memory (message_id: {message_id}, conversation: {}). Content will be available for future recall.",
204                        self.conversation_id
205                    ),
206                    None => "Memory admission rejected: message did not meet quality threshold."
207                        .to_owned(),
208                };
209
210                Ok(Some(ToolOutput {
211                    tool_name: zeph_common::ToolName::new("memory_save"),
212                    summary,
213                    blocks_executed: 1,
214                    filter_stats: None,
215                    diff: None,
216                    streamed: false,
217                    terminal_id: None,
218                    locations: None,
219                    raw_response: None,
220                    claim_source: Some(zeph_tools::ClaimSource::Memory),
221                }))
222            }
223            _ => Ok(None),
224        }
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use zeph_llm::any::AnyProvider;
232    use zeph_llm::mock::MockProvider;
233    use zeph_memory::semantic::SemanticMemory;
234
235    async fn make_memory() -> SemanticMemory {
236        SemanticMemory::with_sqlite_backend(
237            ":memory:",
238            AnyProvider::Mock(MockProvider::default()),
239            "test-model",
240            0.7,
241            0.3,
242        )
243        .await
244        .unwrap()
245    }
246
247    fn make_executor(memory: SemanticMemory) -> MemoryToolExecutor {
248        MemoryToolExecutor::new(Arc::new(memory), ConversationId(1))
249    }
250
251    #[tokio::test]
252    async fn tool_definitions_returns_two_tools() {
253        let memory = make_memory().await;
254        let executor = make_executor(memory);
255        let defs = executor.tool_definitions();
256        assert_eq!(defs.len(), 2);
257        assert_eq!(defs[0].id.as_ref(), "memory_search");
258        assert_eq!(defs[1].id.as_ref(), "memory_save");
259    }
260
261    #[tokio::test]
262    async fn execute_always_returns_none() {
263        let memory = make_memory().await;
264        let executor = make_executor(memory);
265        let result = executor.execute("any response").await.unwrap();
266        assert!(result.is_none());
267    }
268
269    #[tokio::test]
270    async fn execute_tool_call_unknown_returns_none() {
271        let memory = make_memory().await;
272        let executor = make_executor(memory);
273        let call = ToolCall {
274            tool_id: zeph_common::ToolName::new("unknown_tool"),
275            params: serde_json::Map::new(),
276            caller_id: None,
277            context: None,
278
279            tool_call_id: String::new(),
280        };
281        let result = executor.execute_tool_call(&call).await.unwrap();
282        assert!(result.is_none());
283    }
284
285    #[tokio::test]
286    async fn memory_search_returns_output() {
287        let memory = make_memory().await;
288        let executor = make_executor(memory);
289        let mut params = serde_json::Map::new();
290        params.insert(
291            "query".into(),
292            serde_json::Value::String("test query".into()),
293        );
294        let call = ToolCall {
295            tool_id: zeph_common::ToolName::new("memory_search"),
296            params,
297            caller_id: None,
298            context: None,
299
300            tool_call_id: String::new(),
301        };
302        let result = executor.execute_tool_call(&call).await.unwrap();
303        assert!(result.is_some());
304        let output = result.unwrap();
305        assert_eq!(output.tool_name, "memory_search");
306        assert!(output.summary.contains("Recalled Messages"));
307        assert!(output.summary.contains("Key Facts"));
308        assert!(output.summary.contains("Session Summaries"));
309    }
310
311    #[tokio::test]
312    async fn memory_save_stores_and_returns_confirmation() {
313        let memory = make_memory().await;
314        let sqlite = memory.sqlite().clone();
315        // Create conversation first
316        let cid = sqlite.create_conversation().await.unwrap();
317        let executor = MemoryToolExecutor::new(Arc::new(memory), cid);
318
319        let mut params = serde_json::Map::new();
320        params.insert(
321            "content".into(),
322            serde_json::Value::String("User prefers dark mode".into()),
323        );
324        let call = ToolCall {
325            tool_id: zeph_common::ToolName::new("memory_save"),
326            params,
327            caller_id: None,
328            context: None,
329
330            tool_call_id: String::new(),
331        };
332        let result = executor.execute_tool_call(&call).await.unwrap();
333        assert!(result.is_some());
334        let output = result.unwrap();
335        assert!(output.summary.contains("Saved to memory"));
336        assert!(output.summary.contains("message_id:"));
337    }
338
339    #[tokio::test]
340    async fn memory_save_empty_content_returns_error() {
341        let memory = make_memory().await;
342        let executor = make_executor(memory);
343        let mut params = serde_json::Map::new();
344        params.insert("content".into(), serde_json::Value::String(String::new()));
345        let call = ToolCall {
346            tool_id: zeph_common::ToolName::new("memory_save"),
347            params,
348            caller_id: None,
349            context: None,
350
351            tool_call_id: String::new(),
352        };
353        let result = executor.execute_tool_call(&call).await;
354        assert!(result.is_err());
355    }
356
357    #[tokio::test]
358    async fn memory_save_oversized_content_returns_error() {
359        let memory = make_memory().await;
360        let executor = make_executor(memory);
361        let mut params = serde_json::Map::new();
362        params.insert(
363            "content".into(),
364            serde_json::Value::String("x".repeat(4097)),
365        );
366        let call = ToolCall {
367            tool_id: zeph_common::ToolName::new("memory_save"),
368            params,
369            caller_id: None,
370            context: None,
371
372            tool_call_id: String::new(),
373        };
374        let result = executor.execute_tool_call(&call).await;
375        assert!(result.is_err());
376    }
377
378    /// `memory_search` description must mention user-provided facts so the model
379    /// prefers it over `search_code` for recalling information from conversation (#2475).
380    #[tokio::test]
381    async fn memory_search_description_mentions_user_provided_facts() {
382        let memory = make_memory().await;
383        let executor = make_executor(memory);
384        let defs = executor.tool_definitions();
385        let memory_search = defs
386            .iter()
387            .find(|d| d.id.as_ref() == "memory_search")
388            .unwrap();
389        assert!(
390            memory_search
391                .description
392                .contains("user provided during this or previous conversations"),
393            "memory_search description must contain disambiguation phrase; got: {}",
394            memory_search.description
395        );
396    }
397}