1use std::fmt::Write as _;
5use std::sync::Arc;
6
7use zeph_memory::embedding_store::SearchFilter;
8use zeph_memory::semantic::SemanticMemory;
9use zeph_memory::types::ConversationId;
10use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
11use zeph_tools::registry::{InvocationHint, ToolDef};
12
13use zeph_sanitizer::memory_validation::MemoryWriteValidator;
14
15#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
16struct MemorySearchParams {
17 query: String,
19 #[serde(default = "default_limit")]
21 limit: u32,
22}
23
24fn default_limit() -> u32 {
25 5
26}
27
28#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
29struct MemorySaveParams {
30 content: String,
32 #[serde(default = "default_role")]
34 role: String,
35}
36
37fn default_role() -> String {
38 "assistant".into()
39}
40
41pub struct MemoryToolExecutor {
42 memory: Arc<SemanticMemory>,
43 conversation_id: ConversationId,
44 validator: MemoryWriteValidator,
45}
46
47impl MemoryToolExecutor {
48 #[must_use]
49 pub fn new(memory: Arc<SemanticMemory>, conversation_id: ConversationId) -> Self {
50 Self {
51 memory,
52 conversation_id,
53 validator: MemoryWriteValidator::new(
54 zeph_sanitizer::memory_validation::MemoryWriteValidationConfig::default(),
55 ),
56 }
57 }
58
59 #[must_use]
61 pub fn with_validator(
62 memory: Arc<SemanticMemory>,
63 conversation_id: ConversationId,
64 validator: MemoryWriteValidator,
65 ) -> Self {
66 Self {
67 memory,
68 conversation_id,
69 validator,
70 }
71 }
72}
73
74impl ToolExecutor for MemoryToolExecutor {
75 fn tool_definitions(&self) -> Vec<ToolDef> {
76 vec![
77 ToolDef {
78 id: "memory_search".into(),
79 description: "Search long-term memory for relevant past messages, facts, and session summaries. Use to recall facts, preferences, or information the user provided during this or previous conversations.\n\nParameters: query (string, required) - natural language search query; limit (integer, optional) - max results 1-20 (default: 5)\nReturns: ranked list of memory entries with similarity scores and timestamps\nErrors: Execution on database failure\nExample: {\"query\": \"user preference for output format\", \"limit\": 5}".into(),
80 schema: schemars::schema_for!(MemorySearchParams),
81 invocation: InvocationHint::ToolCall,
82 output_schema: None,
83 },
84 ToolDef {
85 id: "memory_save".into(),
86 description: "Save a fact or note to long-term memory for cross-session recall. Use sparingly for key decisions, user preferences, or critical context worth remembering across sessions.\n\nParameters: content (string, required) - concise, self-contained fact or note; role (string, optional) - message role label (default: \"assistant\")\nReturns: confirmation with saved entry ID\nErrors: Execution on database failure; InvalidParams if content is empty\nExample: {\"content\": \"User prefers JSON output over YAML\", \"role\": \"assistant\"}".into(),
87 schema: schemars::schema_for!(MemorySaveParams),
88 invocation: InvocationHint::ToolCall,
89 output_schema: None,
90 },
91 ]
92 }
93
94 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
95 Ok(None)
96 }
97
98 #[allow(clippy::too_many_lines)] async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
100 match call.tool_id.as_str() {
101 "memory_search" => {
102 let params: MemorySearchParams = deserialize_params(&call.params)?;
103 let limit = params.limit.clamp(1, 20) as usize;
104
105 let filter = Some(SearchFilter {
106 conversation_id: Some(self.conversation_id),
107 role: None,
108 category: None,
109 });
110
111 let recalled = self
112 .memory
113 .recall(¶ms.query, limit, filter)
114 .await
115 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
116
117 let key_facts = self
118 .memory
119 .search_key_facts(¶ms.query, limit)
120 .await
121 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
122
123 let summaries = self
124 .memory
125 .search_session_summaries(¶ms.query, limit, Some(self.conversation_id))
126 .await
127 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
128
129 let mut output = String::new();
130
131 let _ = writeln!(output, "## Recalled Messages ({} results)", recalled.len());
132 for r in &recalled {
133 let role = match r.message.role {
134 zeph_llm::provider::Role::User => "user",
135 zeph_llm::provider::Role::Assistant => "assistant",
136 zeph_llm::provider::Role::System => "system",
137 };
138 let content = r.message.content.trim();
139 let _ = writeln!(output, "[score: {:.2}] {role}: {content}", r.score);
140 }
141
142 let _ = writeln!(output);
143 let _ = writeln!(output, "## Key Facts ({} results)", key_facts.len());
144 for fact in &key_facts {
145 let _ = writeln!(output, "- {fact}");
146 }
147
148 let _ = writeln!(output);
149 let _ = writeln!(output, "## Session Summaries ({} results)", summaries.len());
150 for s in &summaries {
151 let _ = writeln!(
152 output,
153 "[conv #{}, score: {:.2}] {}",
154 s.conversation_id, s.score, s.summary_text
155 );
156 }
157
158 Ok(Some(ToolOutput {
159 tool_name: zeph_common::ToolName::new("memory_search"),
160 summary: output,
161 blocks_executed: 1,
162 filter_stats: None,
163 diff: None,
164 streamed: false,
165 terminal_id: None,
166 locations: None,
167 raw_response: None,
168 claim_source: Some(zeph_tools::ClaimSource::Memory),
169 }))
170 }
171 "memory_save" => {
172 let params: MemorySaveParams = deserialize_params(&call.params)?;
173
174 if params.content.is_empty() {
175 return Err(ToolError::InvalidParams {
176 message: "content must not be empty".to_owned(),
177 });
178 }
179 if params.content.len() > 4096 {
180 return Err(ToolError::InvalidParams {
181 message: "content exceeds maximum length of 4096 characters".to_owned(),
182 });
183 }
184
185 if let Err(e) = self.validator.validate_memory_save(¶ms.content) {
187 return Err(ToolError::InvalidParams {
188 message: format!("memory write rejected: {e}"),
189 });
190 }
191
192 let role = params.role.as_str();
193
194 let message_id_opt = self
196 .memory
197 .remember(self.conversation_id, role, ¶ms.content, None)
198 .await
199 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
200
201 let summary = match message_id_opt {
202 Some(message_id) => format!(
203 "Saved to memory (message_id: {message_id}, conversation: {}). Content will be available for future recall.",
204 self.conversation_id
205 ),
206 None => "Memory admission rejected: message did not meet quality threshold."
207 .to_owned(),
208 };
209
210 Ok(Some(ToolOutput {
211 tool_name: zeph_common::ToolName::new("memory_save"),
212 summary,
213 blocks_executed: 1,
214 filter_stats: None,
215 diff: None,
216 streamed: false,
217 terminal_id: None,
218 locations: None,
219 raw_response: None,
220 claim_source: Some(zeph_tools::ClaimSource::Memory),
221 }))
222 }
223 _ => Ok(None),
224 }
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use zeph_llm::any::AnyProvider;
232 use zeph_llm::mock::MockProvider;
233 use zeph_memory::semantic::SemanticMemory;
234
235 async fn make_memory() -> SemanticMemory {
236 SemanticMemory::with_sqlite_backend(
237 ":memory:",
238 AnyProvider::Mock(MockProvider::default()),
239 "test-model",
240 0.7,
241 0.3,
242 )
243 .await
244 .unwrap()
245 }
246
247 fn make_executor(memory: SemanticMemory) -> MemoryToolExecutor {
248 MemoryToolExecutor::new(Arc::new(memory), ConversationId(1))
249 }
250
251 #[tokio::test]
252 async fn tool_definitions_returns_two_tools() {
253 let memory = make_memory().await;
254 let executor = make_executor(memory);
255 let defs = executor.tool_definitions();
256 assert_eq!(defs.len(), 2);
257 assert_eq!(defs[0].id.as_ref(), "memory_search");
258 assert_eq!(defs[1].id.as_ref(), "memory_save");
259 }
260
261 #[tokio::test]
262 async fn execute_always_returns_none() {
263 let memory = make_memory().await;
264 let executor = make_executor(memory);
265 let result = executor.execute("any response").await.unwrap();
266 assert!(result.is_none());
267 }
268
269 #[tokio::test]
270 async fn execute_tool_call_unknown_returns_none() {
271 let memory = make_memory().await;
272 let executor = make_executor(memory);
273 let call = ToolCall {
274 tool_id: zeph_common::ToolName::new("unknown_tool"),
275 params: serde_json::Map::new(),
276 caller_id: None,
277 context: None,
278 };
279 let result = executor.execute_tool_call(&call).await.unwrap();
280 assert!(result.is_none());
281 }
282
283 #[tokio::test]
284 async fn memory_search_returns_output() {
285 let memory = make_memory().await;
286 let executor = make_executor(memory);
287 let mut params = serde_json::Map::new();
288 params.insert(
289 "query".into(),
290 serde_json::Value::String("test query".into()),
291 );
292 let call = ToolCall {
293 tool_id: zeph_common::ToolName::new("memory_search"),
294 params,
295 caller_id: None,
296 context: None,
297 };
298 let result = executor.execute_tool_call(&call).await.unwrap();
299 assert!(result.is_some());
300 let output = result.unwrap();
301 assert_eq!(output.tool_name, "memory_search");
302 assert!(output.summary.contains("Recalled Messages"));
303 assert!(output.summary.contains("Key Facts"));
304 assert!(output.summary.contains("Session Summaries"));
305 }
306
307 #[tokio::test]
308 async fn memory_save_stores_and_returns_confirmation() {
309 let memory = make_memory().await;
310 let sqlite = memory.sqlite().clone();
311 let cid = sqlite.create_conversation().await.unwrap();
313 let executor = MemoryToolExecutor::new(Arc::new(memory), cid);
314
315 let mut params = serde_json::Map::new();
316 params.insert(
317 "content".into(),
318 serde_json::Value::String("User prefers dark mode".into()),
319 );
320 let call = ToolCall {
321 tool_id: zeph_common::ToolName::new("memory_save"),
322 params,
323 caller_id: None,
324 context: None,
325 };
326 let result = executor.execute_tool_call(&call).await.unwrap();
327 assert!(result.is_some());
328 let output = result.unwrap();
329 assert!(output.summary.contains("Saved to memory"));
330 assert!(output.summary.contains("message_id:"));
331 }
332
333 #[tokio::test]
334 async fn memory_save_empty_content_returns_error() {
335 let memory = make_memory().await;
336 let executor = make_executor(memory);
337 let mut params = serde_json::Map::new();
338 params.insert("content".into(), serde_json::Value::String(String::new()));
339 let call = ToolCall {
340 tool_id: zeph_common::ToolName::new("memory_save"),
341 params,
342 caller_id: None,
343 context: None,
344 };
345 let result = executor.execute_tool_call(&call).await;
346 assert!(result.is_err());
347 }
348
349 #[tokio::test]
350 async fn memory_save_oversized_content_returns_error() {
351 let memory = make_memory().await;
352 let executor = make_executor(memory);
353 let mut params = serde_json::Map::new();
354 params.insert(
355 "content".into(),
356 serde_json::Value::String("x".repeat(4097)),
357 );
358 let call = ToolCall {
359 tool_id: zeph_common::ToolName::new("memory_save"),
360 params,
361 caller_id: None,
362 context: None,
363 };
364 let result = executor.execute_tool_call(&call).await;
365 assert!(result.is_err());
366 }
367
368 #[tokio::test]
371 async fn memory_search_description_mentions_user_provided_facts() {
372 let memory = make_memory().await;
373 let executor = make_executor(memory);
374 let defs = executor.tool_definitions();
375 let memory_search = defs
376 .iter()
377 .find(|d| d.id.as_ref() == "memory_search")
378 .unwrap();
379 assert!(
380 memory_search
381 .description
382 .contains("user provided during this or previous conversations"),
383 "memory_search description must contain disambiguation phrase; got: {}",
384 memory_search.description
385 );
386 }
387}