Skip to main content

zeph_core/
memory_tools.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::fmt::Write as _;
5use std::sync::Arc;
6
7use zeph_memory::embedding_store::SearchFilter;
8use zeph_memory::semantic::SemanticMemory;
9use zeph_memory::types::ConversationId;
10use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
11use zeph_tools::registry::{InvocationHint, ToolDef};
12
13#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
14struct MemorySearchParams {
15    /// Natural language query to search memory for relevant past messages and facts.
16    query: String,
17    /// Maximum number of results to return (default: 5, max: 20).
18    #[serde(default = "default_limit")]
19    limit: u32,
20}
21
22fn default_limit() -> u32 {
23    5
24}
25
26#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
27struct MemorySaveParams {
28    /// The content to save to long-term memory. Should be a concise, self-contained fact or note.
29    content: String,
30    /// Role label for the saved message (default: "assistant").
31    #[serde(default = "default_role")]
32    role: String,
33}
34
35fn default_role() -> String {
36    "assistant".into()
37}
38
39pub struct MemoryToolExecutor {
40    memory: Arc<SemanticMemory>,
41    conversation_id: ConversationId,
42}
43
44impl MemoryToolExecutor {
45    #[must_use]
46    pub fn new(memory: Arc<SemanticMemory>, conversation_id: ConversationId) -> Self {
47        Self {
48            memory,
49            conversation_id,
50        }
51    }
52}
53
54impl ToolExecutor for MemoryToolExecutor {
55    fn tool_definitions(&self) -> Vec<ToolDef> {
56        vec![
57            ToolDef {
58                id: "memory_search".into(),
59                description: "Search long-term memory for relevant past messages, facts, and session summaries. Use when the user references past conversations or you need historical context.\n\nParameters: query (string, required) - natural language search query; limit (integer, optional) - max results 1-20 (default: 5)\nReturns: ranked list of memory entries with similarity scores and timestamps\nErrors: Execution on database failure\nExample: {\"query\": \"user preference for output format\", \"limit\": 5}".into(),
60                schema: schemars::schema_for!(MemorySearchParams),
61                invocation: InvocationHint::ToolCall,
62            },
63            ToolDef {
64                id: "memory_save".into(),
65                description: "Save a fact or note to long-term memory for cross-session recall. Use sparingly for key decisions, user preferences, or critical context worth remembering across sessions.\n\nParameters: content (string, required) - concise, self-contained fact or note; role (string, optional) - message role label (default: \"assistant\")\nReturns: confirmation with saved entry ID\nErrors: Execution on database failure; InvalidParams if content is empty\nExample: {\"content\": \"User prefers JSON output over YAML\", \"role\": \"assistant\"}".into(),
66                schema: schemars::schema_for!(MemorySaveParams),
67                invocation: InvocationHint::ToolCall,
68            },
69        ]
70    }
71
72    async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
73        Ok(None)
74    }
75
76    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
77        match call.tool_id.as_str() {
78            "memory_search" => {
79                let params: MemorySearchParams = deserialize_params(&call.params)?;
80                let limit = params.limit.clamp(1, 20) as usize;
81
82                let filter = Some(SearchFilter {
83                    conversation_id: Some(self.conversation_id),
84                    role: None,
85                });
86
87                let recalled = self
88                    .memory
89                    .recall(&params.query, limit, filter)
90                    .await
91                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
92
93                let key_facts = self
94                    .memory
95                    .search_key_facts(&params.query, limit)
96                    .await
97                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
98
99                let summaries = self
100                    .memory
101                    .search_session_summaries(&params.query, limit, Some(self.conversation_id))
102                    .await
103                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
104
105                let mut output = String::new();
106
107                let _ = writeln!(output, "## Recalled Messages ({} results)", recalled.len());
108                for r in &recalled {
109                    let role = match r.message.role {
110                        zeph_llm::provider::Role::User => "user",
111                        zeph_llm::provider::Role::Assistant => "assistant",
112                        zeph_llm::provider::Role::System => "system",
113                    };
114                    let content = r.message.content.trim();
115                    let _ = writeln!(output, "[score: {:.2}] {role}: {content}", r.score);
116                }
117
118                let _ = writeln!(output);
119                let _ = writeln!(output, "## Key Facts ({} results)", key_facts.len());
120                for fact in &key_facts {
121                    let _ = writeln!(output, "- {fact}");
122                }
123
124                let _ = writeln!(output);
125                let _ = writeln!(output, "## Session Summaries ({} results)", summaries.len());
126                for s in &summaries {
127                    let _ = writeln!(
128                        output,
129                        "[conv #{}, score: {:.2}] {}",
130                        s.conversation_id, s.score, s.summary_text
131                    );
132                }
133
134                Ok(Some(ToolOutput {
135                    tool_name: "memory_search".to_owned(),
136                    summary: output,
137                    blocks_executed: 1,
138                    filter_stats: None,
139                    diff: None,
140                    streamed: false,
141                    terminal_id: None,
142                    locations: None,
143                    raw_response: None,
144                }))
145            }
146            "memory_save" => {
147                let params: MemorySaveParams = deserialize_params(&call.params)?;
148
149                if params.content.is_empty() {
150                    return Err(ToolError::InvalidParams {
151                        message: "content must not be empty".to_owned(),
152                    });
153                }
154                if params.content.len() > 4096 {
155                    return Err(ToolError::InvalidParams {
156                        message: "content exceeds maximum length of 4096 characters".to_owned(),
157                    });
158                }
159
160                let role = params.role.as_str();
161
162                let message_id = self
163                    .memory
164                    .remember(self.conversation_id, role, &params.content)
165                    .await
166                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
167
168                Ok(Some(ToolOutput {
169                    tool_name: "memory_save".to_owned(),
170                    summary: format!(
171                        "Saved to memory (message_id: {message_id}, conversation: {}). Content will be available for future recall.",
172                        self.conversation_id
173                    ),
174                    blocks_executed: 1,
175                    filter_stats: None,
176                    diff: None,
177                    streamed: false,
178                    terminal_id: None,
179                    locations: None,
180                    raw_response: None,
181                }))
182            }
183            _ => Ok(None),
184        }
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use zeph_llm::any::AnyProvider;
192    use zeph_llm::mock::MockProvider;
193    use zeph_memory::semantic::SemanticMemory;
194
195    async fn make_memory() -> SemanticMemory {
196        SemanticMemory::with_sqlite_backend(
197            ":memory:",
198            AnyProvider::Mock(MockProvider::default()),
199            "test-model",
200            0.7,
201            0.3,
202        )
203        .await
204        .unwrap()
205    }
206
207    fn make_executor(memory: SemanticMemory) -> MemoryToolExecutor {
208        MemoryToolExecutor::new(Arc::new(memory), ConversationId(1))
209    }
210
211    #[tokio::test]
212    async fn tool_definitions_returns_two_tools() {
213        let memory = make_memory().await;
214        let executor = make_executor(memory);
215        let defs = executor.tool_definitions();
216        assert_eq!(defs.len(), 2);
217        assert_eq!(defs[0].id.as_ref(), "memory_search");
218        assert_eq!(defs[1].id.as_ref(), "memory_save");
219    }
220
221    #[tokio::test]
222    async fn execute_always_returns_none() {
223        let memory = make_memory().await;
224        let executor = make_executor(memory);
225        let result = executor.execute("any response").await.unwrap();
226        assert!(result.is_none());
227    }
228
229    #[tokio::test]
230    async fn execute_tool_call_unknown_returns_none() {
231        let memory = make_memory().await;
232        let executor = make_executor(memory);
233        let call = ToolCall {
234            tool_id: "unknown_tool".to_owned(),
235            params: serde_json::Map::new(),
236        };
237        let result = executor.execute_tool_call(&call).await.unwrap();
238        assert!(result.is_none());
239    }
240
241    #[tokio::test]
242    async fn memory_search_returns_output() {
243        let memory = make_memory().await;
244        let executor = make_executor(memory);
245        let mut params = serde_json::Map::new();
246        params.insert(
247            "query".into(),
248            serde_json::Value::String("test query".into()),
249        );
250        let call = ToolCall {
251            tool_id: "memory_search".to_owned(),
252            params,
253        };
254        let result = executor.execute_tool_call(&call).await.unwrap();
255        assert!(result.is_some());
256        let output = result.unwrap();
257        assert_eq!(output.tool_name, "memory_search");
258        assert!(output.summary.contains("Recalled Messages"));
259        assert!(output.summary.contains("Key Facts"));
260        assert!(output.summary.contains("Session Summaries"));
261    }
262
263    #[tokio::test]
264    async fn memory_save_stores_and_returns_confirmation() {
265        let memory = make_memory().await;
266        let sqlite = memory.sqlite().clone();
267        // Create conversation first
268        let cid = sqlite.create_conversation().await.unwrap();
269        let executor = MemoryToolExecutor::new(Arc::new(memory), cid);
270
271        let mut params = serde_json::Map::new();
272        params.insert(
273            "content".into(),
274            serde_json::Value::String("User prefers dark mode".into()),
275        );
276        let call = ToolCall {
277            tool_id: "memory_save".to_owned(),
278            params,
279        };
280        let result = executor.execute_tool_call(&call).await.unwrap();
281        assert!(result.is_some());
282        let output = result.unwrap();
283        assert!(output.summary.contains("Saved to memory"));
284        assert!(output.summary.contains("message_id:"));
285    }
286
287    #[tokio::test]
288    async fn memory_save_empty_content_returns_error() {
289        let memory = make_memory().await;
290        let executor = make_executor(memory);
291        let mut params = serde_json::Map::new();
292        params.insert("content".into(), serde_json::Value::String(String::new()));
293        let call = ToolCall {
294            tool_id: "memory_save".to_owned(),
295            params,
296        };
297        let result = executor.execute_tool_call(&call).await;
298        assert!(result.is_err());
299    }
300
301    #[tokio::test]
302    async fn memory_save_oversized_content_returns_error() {
303        let memory = make_memory().await;
304        let executor = make_executor(memory);
305        let mut params = serde_json::Map::new();
306        params.insert(
307            "content".into(),
308            serde_json::Value::String("x".repeat(4097)),
309        );
310        let call = ToolCall {
311            tool_id: "memory_save".to_owned(),
312            params,
313        };
314        let result = executor.execute_tool_call(&call).await;
315        assert!(result.is_err());
316    }
317}