Skip to main content

zeph_core/
memory_tools.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::fmt::Write as _;
5use std::sync::Arc;
6
7use zeph_memory::embedding_store::SearchFilter;
8use zeph_memory::semantic::SemanticMemory;
9use zeph_memory::types::ConversationId;
10use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
11use zeph_tools::registry::{InvocationHint, ToolDef};
12
13use zeph_sanitizer::memory_validation::MemoryWriteValidator;
14
15#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
16struct MemorySearchParams {
17    /// Natural language query to search memory for relevant past messages and facts.
18    query: String,
19    /// Maximum number of results to return (default: 5, max: 20).
20    #[serde(default = "default_limit")]
21    limit: u32,
22}
23
24fn default_limit() -> u32 {
25    5
26}
27
28#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
29struct MemorySaveParams {
30    /// The content to save to long-term memory. Should be a concise, self-contained fact or note.
31    content: String,
32    /// Role label for the saved message (default: "assistant").
33    #[serde(default = "default_role")]
34    role: String,
35}
36
37fn default_role() -> String {
38    "assistant".into()
39}
40
41/// Executes `memory_search` and `memory_save` tool calls on behalf of the agent.
42pub struct MemoryToolExecutor {
43    memory: Arc<SemanticMemory>,
44    conversation_id: ConversationId,
45    validator: MemoryWriteValidator,
46    /// When `true` the backing store is in-memory (bare mode) and saves do not persist across sessions.
47    ephemeral: bool,
48}
49
50impl MemoryToolExecutor {
51    /// Create with default validator and persistent (non-ephemeral) semantics.
52    #[must_use]
53    pub fn new(memory: Arc<SemanticMemory>, conversation_id: ConversationId) -> Self {
54        Self {
55            memory,
56            conversation_id,
57            validator: MemoryWriteValidator::new(
58                zeph_sanitizer::memory_validation::MemoryWriteValidationConfig::default(),
59            ),
60            ephemeral: false,
61        }
62    }
63
64    /// Create with a custom validator (used when security config is loaded).
65    #[must_use]
66    pub fn with_validator(
67        memory: Arc<SemanticMemory>,
68        conversation_id: ConversationId,
69        validator: MemoryWriteValidator,
70    ) -> Self {
71        Self {
72            memory,
73            conversation_id,
74            validator,
75            ephemeral: false,
76        }
77    }
78
79    /// Mark this executor as ephemeral (bare mode).
80    ///
81    /// When set, `memory_save` reports that the content is session-only and will not be
82    /// available after the session ends.
83    #[must_use]
84    pub fn ephemeral(mut self) -> Self {
85        self.ephemeral = true;
86        self
87    }
88}
89
90impl ToolExecutor for MemoryToolExecutor {
91    fn tool_definitions(&self) -> Vec<ToolDef> {
92        vec![
93            ToolDef {
94                id: "memory_search".into(),
95                description: "Search long-term memory for relevant past messages, facts, and session summaries. Use to recall facts, preferences, or information the user provided during this or previous conversations.\n\nParameters: query (string, required) - natural language search query; limit (integer, optional) - max results 1-20 (default: 5)\nReturns: ranked list of memory entries with similarity scores and timestamps\nErrors: Execution on database failure\nExample: {\"query\": \"user preference for output format\", \"limit\": 5}".into(),
96                schema: schemars::schema_for!(MemorySearchParams),
97                invocation: InvocationHint::ToolCall,
98                output_schema: None,
99            },
100            ToolDef {
101                id: "memory_save".into(),
102                description: "Save a fact or note to long-term memory for cross-session recall. Use sparingly for key decisions, user preferences, or critical context worth remembering across sessions.\n\nParameters: content (string, required) - concise, self-contained fact or note; role (string, optional) - message role label (default: \"assistant\")\nReturns: confirmation with saved entry ID\nErrors: Execution on database failure; InvalidParams if content is empty\nExample: {\"content\": \"User prefers JSON output over YAML\", \"role\": \"assistant\"}".into(),
103                schema: schemars::schema_for!(MemorySaveParams),
104                invocation: InvocationHint::ToolCall,
105                output_schema: None,
106            },
107        ]
108    }
109
110    async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
111        Ok(None)
112    }
113
114    #[allow(clippy::too_many_lines)] // two tools with validation, search, and multi-source aggregation
115    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
116        match call.tool_id.as_str() {
117            "memory_search" => {
118                let params: MemorySearchParams = deserialize_params(&call.params)?;
119                let limit = params.limit.clamp(1, 20) as usize;
120
121                let filter = Some(SearchFilter {
122                    conversation_id: Some(self.conversation_id),
123                    role: None,
124                    category: None,
125                });
126
127                let recalled = self
128                    .memory
129                    .recall(&params.query, limit, filter)
130                    .await
131                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
132
133                let key_facts = self
134                    .memory
135                    .search_key_facts(&params.query, limit)
136                    .await
137                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
138
139                let summaries = self
140                    .memory
141                    .search_session_summaries(&params.query, limit, Some(self.conversation_id))
142                    .await
143                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
144
145                let mut output = String::new();
146
147                let _ = writeln!(output, "## Recalled Messages ({} results)", recalled.len());
148                for r in &recalled {
149                    let role = match r.message.role {
150                        zeph_llm::provider::Role::Assistant => "assistant",
151                        zeph_llm::provider::Role::System => "system",
152                        zeph_llm::provider::Role::User | _ => "user",
153                    };
154                    let content = r.message.content.trim();
155                    let _ = writeln!(output, "[score: {:.2}] {role}: {content}", r.score);
156                }
157
158                let _ = writeln!(output);
159                let _ = writeln!(output, "## Key Facts ({} results)", key_facts.len());
160                for fact in &key_facts {
161                    let _ = writeln!(output, "- {fact}");
162                }
163
164                let _ = writeln!(output);
165                let _ = writeln!(output, "## Session Summaries ({} results)", summaries.len());
166                for s in &summaries {
167                    let _ = writeln!(
168                        output,
169                        "[conv #{}, score: {:.2}] {}",
170                        s.conversation_id, s.score, s.summary_text
171                    );
172                }
173
174                Ok(Some(ToolOutput {
175                    tool_name: zeph_common::ToolName::new("memory_search"),
176                    summary: output,
177                    blocks_executed: 1,
178                    filter_stats: None,
179                    diff: None,
180                    streamed: false,
181                    terminal_id: None,
182                    locations: None,
183                    raw_response: None,
184                    claim_source: Some(zeph_tools::ClaimSource::Memory),
185                }))
186            }
187            "memory_save" => {
188                let params: MemorySaveParams = deserialize_params(&call.params)?;
189
190                if params.content.is_empty() {
191                    return Err(ToolError::InvalidParams {
192                        message: "content must not be empty".to_owned(),
193                    });
194                }
195                if params.content.len() > 4096 {
196                    return Err(ToolError::InvalidParams {
197                        message: "content exceeds maximum length of 4096 characters".to_owned(),
198                    });
199                }
200
201                // Schema validation: check content before writing to memory.
202                if let Err(e) = self.validator.validate_memory_save(&params.content) {
203                    return Err(ToolError::InvalidParams {
204                        message: format!("memory write rejected: {e}"),
205                    });
206                }
207
208                let role = params.role.as_str();
209
210                // Explicit user-directed saves bypass goal-conditioned scoring (goal_text = None).
211                let message_id_opt = self
212                    .memory
213                    .remember(self.conversation_id, role, &params.content, None)
214                    .await
215                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
216
217                let summary = match message_id_opt {
218                    Some(message_id) => {
219                        if self.ephemeral {
220                            format!(
221                                "Saved to session memory (message_id: {message_id}, conversation: {}). Ephemeral — not available after session ends.",
222                                self.conversation_id
223                            )
224                        } else {
225                            format!(
226                                "Saved to memory (message_id: {message_id}, conversation: {}). Content will be available for future recall.",
227                                self.conversation_id
228                            )
229                        }
230                    }
231                    None => "Memory admission rejected: message did not meet quality threshold."
232                        .to_owned(),
233                };
234
235                Ok(Some(ToolOutput {
236                    tool_name: zeph_common::ToolName::new("memory_save"),
237                    summary,
238                    blocks_executed: 1,
239                    filter_stats: None,
240                    diff: None,
241                    streamed: false,
242                    terminal_id: None,
243                    locations: None,
244                    raw_response: None,
245                    claim_source: Some(zeph_tools::ClaimSource::Memory),
246                }))
247            }
248            _ => Ok(None),
249        }
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use zeph_llm::any::AnyProvider;
257    use zeph_llm::mock::MockProvider;
258    use zeph_memory::semantic::SemanticMemory;
259
260    async fn make_memory() -> SemanticMemory {
261        SemanticMemory::with_sqlite_backend(
262            ":memory:",
263            AnyProvider::Mock(MockProvider::default()),
264            "test-model",
265            0.7,
266            0.3,
267        )
268        .await
269        .unwrap()
270    }
271
272    fn make_executor(memory: SemanticMemory) -> MemoryToolExecutor {
273        MemoryToolExecutor::new(Arc::new(memory), ConversationId(1))
274    }
275
276    #[tokio::test]
277    async fn tool_definitions_returns_two_tools() {
278        let memory = make_memory().await;
279        let executor = make_executor(memory);
280        let defs = executor.tool_definitions();
281        assert_eq!(defs.len(), 2);
282        assert_eq!(defs[0].id.as_ref(), "memory_search");
283        assert_eq!(defs[1].id.as_ref(), "memory_save");
284    }
285
286    #[tokio::test]
287    async fn execute_always_returns_none() {
288        let memory = make_memory().await;
289        let executor = make_executor(memory);
290        let result = executor.execute("any response").await.unwrap();
291        assert!(result.is_none());
292    }
293
294    #[tokio::test]
295    async fn execute_tool_call_unknown_returns_none() {
296        let memory = make_memory().await;
297        let executor = make_executor(memory);
298        let call = ToolCall {
299            tool_id: zeph_common::ToolName::new("unknown_tool"),
300            params: serde_json::Map::new(),
301            caller_id: None,
302            context: None,
303
304            tool_call_id: String::new(),
305            skill_name: None,
306        };
307        let result = executor.execute_tool_call(&call).await.unwrap();
308        assert!(result.is_none());
309    }
310
311    #[tokio::test]
312    async fn memory_search_returns_output() {
313        let memory = make_memory().await;
314        let executor = make_executor(memory);
315        let mut params = serde_json::Map::new();
316        params.insert(
317            "query".into(),
318            serde_json::Value::String("test query".into()),
319        );
320        let call = ToolCall {
321            tool_id: zeph_common::ToolName::new("memory_search"),
322            params,
323            caller_id: None,
324            context: None,
325
326            tool_call_id: String::new(),
327            skill_name: None,
328        };
329        let result = executor.execute_tool_call(&call).await.unwrap();
330        assert!(result.is_some());
331        let output = result.unwrap();
332        assert_eq!(output.tool_name, "memory_search");
333        assert!(output.summary.contains("Recalled Messages"));
334        assert!(output.summary.contains("Key Facts"));
335        assert!(output.summary.contains("Session Summaries"));
336    }
337
338    #[tokio::test]
339    async fn memory_save_stores_and_returns_confirmation() {
340        let memory = make_memory().await;
341        let sqlite = memory.sqlite().clone();
342        // Create conversation first
343        let cid = sqlite.create_conversation().await.unwrap();
344        let executor = MemoryToolExecutor::new(Arc::new(memory), cid);
345
346        let mut params = serde_json::Map::new();
347        params.insert(
348            "content".into(),
349            serde_json::Value::String("User prefers dark mode".into()),
350        );
351        let call = ToolCall {
352            tool_id: zeph_common::ToolName::new("memory_save"),
353            params,
354            caller_id: None,
355            context: None,
356
357            tool_call_id: String::new(),
358            skill_name: None,
359        };
360        let result = executor.execute_tool_call(&call).await.unwrap();
361        assert!(result.is_some());
362        let output = result.unwrap();
363        assert!(output.summary.contains("Saved to memory"));
364        assert!(output.summary.contains("message_id:"));
365    }
366
367    #[tokio::test]
368    async fn memory_save_empty_content_returns_error() {
369        let memory = make_memory().await;
370        let executor = make_executor(memory);
371        let mut params = serde_json::Map::new();
372        params.insert("content".into(), serde_json::Value::String(String::new()));
373        let call = ToolCall {
374            tool_id: zeph_common::ToolName::new("memory_save"),
375            params,
376            caller_id: None,
377            context: None,
378
379            tool_call_id: String::new(),
380            skill_name: None,
381        };
382        let result = executor.execute_tool_call(&call).await;
383        assert!(result.is_err());
384    }
385
386    #[tokio::test]
387    async fn memory_save_oversized_content_returns_error() {
388        let memory = make_memory().await;
389        let executor = make_executor(memory);
390        let mut params = serde_json::Map::new();
391        params.insert(
392            "content".into(),
393            serde_json::Value::String("x".repeat(4097)),
394        );
395        let call = ToolCall {
396            tool_id: zeph_common::ToolName::new("memory_save"),
397            params,
398            caller_id: None,
399            context: None,
400
401            tool_call_id: String::new(),
402            skill_name: None,
403        };
404        let result = executor.execute_tool_call(&call).await;
405        assert!(result.is_err());
406    }
407
408    #[tokio::test]
409    async fn memory_save_ephemeral_returns_session_only_message() {
410        let memory = make_memory().await;
411        let sqlite = memory.sqlite().clone();
412        let cid = sqlite.create_conversation().await.unwrap();
413        let executor = MemoryToolExecutor::new(Arc::new(memory), cid).ephemeral();
414
415        let mut params = serde_json::Map::new();
416        params.insert(
417            "content".into(),
418            serde_json::Value::String("temp fact".into()),
419        );
420        let call = ToolCall {
421            tool_id: zeph_common::ToolName::new("memory_save"),
422            params,
423            caller_id: None,
424            context: None,
425            tool_call_id: String::new(),
426            skill_name: None,
427        };
428        let output = executor.execute_tool_call(&call).await.unwrap().unwrap();
429        assert!(
430            output.summary.contains("Ephemeral"),
431            "bare-mode save must mention ephemeral semantics; got: {}",
432            output.summary
433        );
434        assert!(
435            !output.summary.contains("available for future recall"),
436            "bare-mode save must not claim cross-session persistence; got: {}",
437            output.summary
438        );
439    }
440
441    /// `memory_search` description must mention user-provided facts so the model
442    /// prefers it over `search_code` for recalling information from conversation (#2475).
443    #[tokio::test]
444    async fn memory_search_description_mentions_user_provided_facts() {
445        let memory = make_memory().await;
446        let executor = make_executor(memory);
447        let defs = executor.tool_definitions();
448        let memory_search = defs
449            .iter()
450            .find(|d| d.id.as_ref() == "memory_search")
451            .unwrap();
452        assert!(
453            memory_search
454                .description
455                .contains("user provided during this or previous conversations"),
456            "memory_search description must contain disambiguation phrase; got: {}",
457            memory_search.description
458        );
459    }
460}