1use std::fmt::Write as _;
5use std::sync::Arc;
6
7use zeph_memory::embedding_store::SearchFilter;
8use zeph_memory::semantic::SemanticMemory;
9use zeph_memory::types::ConversationId;
10use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
11use zeph_tools::registry::{InvocationHint, ToolDef};
12
13use zeph_sanitizer::memory_validation::MemoryWriteValidator;
14
15#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
16struct MemorySearchParams {
17 query: String,
19 #[serde(default = "default_limit")]
21 limit: u32,
22}
23
24fn default_limit() -> u32 {
25 5
26}
27
28#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
29struct MemorySaveParams {
30 content: String,
32 #[serde(default = "default_role")]
34 role: String,
35}
36
37fn default_role() -> String {
38 "assistant".into()
39}
40
41pub struct MemoryToolExecutor {
43 memory: Arc<SemanticMemory>,
44 conversation_id: ConversationId,
45 validator: MemoryWriteValidator,
46 ephemeral: bool,
48}
49
50impl MemoryToolExecutor {
51 #[must_use]
53 pub fn new(memory: Arc<SemanticMemory>, conversation_id: ConversationId) -> Self {
54 Self {
55 memory,
56 conversation_id,
57 validator: MemoryWriteValidator::new(
58 zeph_sanitizer::memory_validation::MemoryWriteValidationConfig::default(),
59 ),
60 ephemeral: false,
61 }
62 }
63
64 #[must_use]
66 pub fn with_validator(
67 memory: Arc<SemanticMemory>,
68 conversation_id: ConversationId,
69 validator: MemoryWriteValidator,
70 ) -> Self {
71 Self {
72 memory,
73 conversation_id,
74 validator,
75 ephemeral: false,
76 }
77 }
78
79 #[must_use]
84 pub fn ephemeral(mut self) -> Self {
85 self.ephemeral = true;
86 self
87 }
88}
89
90impl ToolExecutor for MemoryToolExecutor {
91 fn tool_definitions(&self) -> Vec<ToolDef> {
92 vec![
93 ToolDef {
94 id: "memory_search".into(),
95 description: "Search long-term memory for relevant past messages, facts, and session summaries. Use to recall facts, preferences, or information the user provided during this or previous conversations.\n\nParameters: query (string, required) - natural language search query; limit (integer, optional) - max results 1-20 (default: 5)\nReturns: ranked list of memory entries with similarity scores and timestamps\nErrors: Execution on database failure\nExample: {\"query\": \"user preference for output format\", \"limit\": 5}".into(),
96 schema: schemars::schema_for!(MemorySearchParams),
97 invocation: InvocationHint::ToolCall,
98 output_schema: None,
99 },
100 ToolDef {
101 id: "memory_save".into(),
102 description: "Save a fact or note to long-term memory for cross-session recall. Use sparingly for key decisions, user preferences, or critical context worth remembering across sessions.\n\nParameters: content (string, required) - concise, self-contained fact or note; role (string, optional) - message role label (default: \"assistant\")\nReturns: confirmation with saved entry ID\nErrors: Execution on database failure; InvalidParams if content is empty\nExample: {\"content\": \"User prefers JSON output over YAML\", \"role\": \"assistant\"}".into(),
103 schema: schemars::schema_for!(MemorySaveParams),
104 invocation: InvocationHint::ToolCall,
105 output_schema: None,
106 },
107 ]
108 }
109
110 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
111 Ok(None)
112 }
113
114 #[allow(clippy::too_many_lines)] async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
116 match call.tool_id.as_str() {
117 "memory_search" => {
118 let params: MemorySearchParams = deserialize_params(&call.params)?;
119 let limit = params.limit.clamp(1, 20) as usize;
120
121 let filter = Some(SearchFilter {
122 conversation_id: Some(self.conversation_id),
123 role: None,
124 category: None,
125 });
126
127 let recalled = self
128 .memory
129 .recall(¶ms.query, limit, filter)
130 .await
131 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
132
133 let key_facts = self
134 .memory
135 .search_key_facts(¶ms.query, limit)
136 .await
137 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
138
139 let summaries = self
140 .memory
141 .search_session_summaries(¶ms.query, limit, Some(self.conversation_id))
142 .await
143 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
144
145 let mut output = String::new();
146
147 let _ = writeln!(output, "## Recalled Messages ({} results)", recalled.len());
148 for r in &recalled {
149 let role = match r.message.role {
150 zeph_llm::provider::Role::User => "user",
151 zeph_llm::provider::Role::Assistant => "assistant",
152 zeph_llm::provider::Role::System => "system",
153 };
154 let content = r.message.content.trim();
155 let _ = writeln!(output, "[score: {:.2}] {role}: {content}", r.score);
156 }
157
158 let _ = writeln!(output);
159 let _ = writeln!(output, "## Key Facts ({} results)", key_facts.len());
160 for fact in &key_facts {
161 let _ = writeln!(output, "- {fact}");
162 }
163
164 let _ = writeln!(output);
165 let _ = writeln!(output, "## Session Summaries ({} results)", summaries.len());
166 for s in &summaries {
167 let _ = writeln!(
168 output,
169 "[conv #{}, score: {:.2}] {}",
170 s.conversation_id, s.score, s.summary_text
171 );
172 }
173
174 Ok(Some(ToolOutput {
175 tool_name: zeph_common::ToolName::new("memory_search"),
176 summary: output,
177 blocks_executed: 1,
178 filter_stats: None,
179 diff: None,
180 streamed: false,
181 terminal_id: None,
182 locations: None,
183 raw_response: None,
184 claim_source: Some(zeph_tools::ClaimSource::Memory),
185 }))
186 }
187 "memory_save" => {
188 let params: MemorySaveParams = deserialize_params(&call.params)?;
189
190 if params.content.is_empty() {
191 return Err(ToolError::InvalidParams {
192 message: "content must not be empty".to_owned(),
193 });
194 }
195 if params.content.len() > 4096 {
196 return Err(ToolError::InvalidParams {
197 message: "content exceeds maximum length of 4096 characters".to_owned(),
198 });
199 }
200
201 if let Err(e) = self.validator.validate_memory_save(¶ms.content) {
203 return Err(ToolError::InvalidParams {
204 message: format!("memory write rejected: {e}"),
205 });
206 }
207
208 let role = params.role.as_str();
209
210 let message_id_opt = self
212 .memory
213 .remember(self.conversation_id, role, ¶ms.content, None)
214 .await
215 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
216
217 let summary = match message_id_opt {
218 Some(message_id) => {
219 if self.ephemeral {
220 format!(
221 "Saved to session memory (message_id: {message_id}, conversation: {}). Ephemeral — not available after session ends.",
222 self.conversation_id
223 )
224 } else {
225 format!(
226 "Saved to memory (message_id: {message_id}, conversation: {}). Content will be available for future recall.",
227 self.conversation_id
228 )
229 }
230 }
231 None => "Memory admission rejected: message did not meet quality threshold."
232 .to_owned(),
233 };
234
235 Ok(Some(ToolOutput {
236 tool_name: zeph_common::ToolName::new("memory_save"),
237 summary,
238 blocks_executed: 1,
239 filter_stats: None,
240 diff: None,
241 streamed: false,
242 terminal_id: None,
243 locations: None,
244 raw_response: None,
245 claim_source: Some(zeph_tools::ClaimSource::Memory),
246 }))
247 }
248 _ => Ok(None),
249 }
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use zeph_llm::any::AnyProvider;
257 use zeph_llm::mock::MockProvider;
258 use zeph_memory::semantic::SemanticMemory;
259
260 async fn make_memory() -> SemanticMemory {
261 SemanticMemory::with_sqlite_backend(
262 ":memory:",
263 AnyProvider::Mock(MockProvider::default()),
264 "test-model",
265 0.7,
266 0.3,
267 )
268 .await
269 .unwrap()
270 }
271
272 fn make_executor(memory: SemanticMemory) -> MemoryToolExecutor {
273 MemoryToolExecutor::new(Arc::new(memory), ConversationId(1))
274 }
275
276 #[tokio::test]
277 async fn tool_definitions_returns_two_tools() {
278 let memory = make_memory().await;
279 let executor = make_executor(memory);
280 let defs = executor.tool_definitions();
281 assert_eq!(defs.len(), 2);
282 assert_eq!(defs[0].id.as_ref(), "memory_search");
283 assert_eq!(defs[1].id.as_ref(), "memory_save");
284 }
285
286 #[tokio::test]
287 async fn execute_always_returns_none() {
288 let memory = make_memory().await;
289 let executor = make_executor(memory);
290 let result = executor.execute("any response").await.unwrap();
291 assert!(result.is_none());
292 }
293
294 #[tokio::test]
295 async fn execute_tool_call_unknown_returns_none() {
296 let memory = make_memory().await;
297 let executor = make_executor(memory);
298 let call = ToolCall {
299 tool_id: zeph_common::ToolName::new("unknown_tool"),
300 params: serde_json::Map::new(),
301 caller_id: None,
302 context: None,
303
304 tool_call_id: String::new(),
305 };
306 let result = executor.execute_tool_call(&call).await.unwrap();
307 assert!(result.is_none());
308 }
309
310 #[tokio::test]
311 async fn memory_search_returns_output() {
312 let memory = make_memory().await;
313 let executor = make_executor(memory);
314 let mut params = serde_json::Map::new();
315 params.insert(
316 "query".into(),
317 serde_json::Value::String("test query".into()),
318 );
319 let call = ToolCall {
320 tool_id: zeph_common::ToolName::new("memory_search"),
321 params,
322 caller_id: None,
323 context: None,
324
325 tool_call_id: String::new(),
326 };
327 let result = executor.execute_tool_call(&call).await.unwrap();
328 assert!(result.is_some());
329 let output = result.unwrap();
330 assert_eq!(output.tool_name, "memory_search");
331 assert!(output.summary.contains("Recalled Messages"));
332 assert!(output.summary.contains("Key Facts"));
333 assert!(output.summary.contains("Session Summaries"));
334 }
335
336 #[tokio::test]
337 async fn memory_save_stores_and_returns_confirmation() {
338 let memory = make_memory().await;
339 let sqlite = memory.sqlite().clone();
340 let cid = sqlite.create_conversation().await.unwrap();
342 let executor = MemoryToolExecutor::new(Arc::new(memory), cid);
343
344 let mut params = serde_json::Map::new();
345 params.insert(
346 "content".into(),
347 serde_json::Value::String("User prefers dark mode".into()),
348 );
349 let call = ToolCall {
350 tool_id: zeph_common::ToolName::new("memory_save"),
351 params,
352 caller_id: None,
353 context: None,
354
355 tool_call_id: String::new(),
356 };
357 let result = executor.execute_tool_call(&call).await.unwrap();
358 assert!(result.is_some());
359 let output = result.unwrap();
360 assert!(output.summary.contains("Saved to memory"));
361 assert!(output.summary.contains("message_id:"));
362 }
363
364 #[tokio::test]
365 async fn memory_save_empty_content_returns_error() {
366 let memory = make_memory().await;
367 let executor = make_executor(memory);
368 let mut params = serde_json::Map::new();
369 params.insert("content".into(), serde_json::Value::String(String::new()));
370 let call = ToolCall {
371 tool_id: zeph_common::ToolName::new("memory_save"),
372 params,
373 caller_id: None,
374 context: None,
375
376 tool_call_id: String::new(),
377 };
378 let result = executor.execute_tool_call(&call).await;
379 assert!(result.is_err());
380 }
381
382 #[tokio::test]
383 async fn memory_save_oversized_content_returns_error() {
384 let memory = make_memory().await;
385 let executor = make_executor(memory);
386 let mut params = serde_json::Map::new();
387 params.insert(
388 "content".into(),
389 serde_json::Value::String("x".repeat(4097)),
390 );
391 let call = ToolCall {
392 tool_id: zeph_common::ToolName::new("memory_save"),
393 params,
394 caller_id: None,
395 context: None,
396
397 tool_call_id: String::new(),
398 };
399 let result = executor.execute_tool_call(&call).await;
400 assert!(result.is_err());
401 }
402
403 #[tokio::test]
404 async fn memory_save_ephemeral_returns_session_only_message() {
405 let memory = make_memory().await;
406 let sqlite = memory.sqlite().clone();
407 let cid = sqlite.create_conversation().await.unwrap();
408 let executor = MemoryToolExecutor::new(Arc::new(memory), cid).ephemeral();
409
410 let mut params = serde_json::Map::new();
411 params.insert(
412 "content".into(),
413 serde_json::Value::String("temp fact".into()),
414 );
415 let call = ToolCall {
416 tool_id: zeph_common::ToolName::new("memory_save"),
417 params,
418 caller_id: None,
419 context: None,
420 tool_call_id: String::new(),
421 };
422 let output = executor.execute_tool_call(&call).await.unwrap().unwrap();
423 assert!(
424 output.summary.contains("Ephemeral"),
425 "bare-mode save must mention ephemeral semantics; got: {}",
426 output.summary
427 );
428 assert!(
429 !output.summary.contains("available for future recall"),
430 "bare-mode save must not claim cross-session persistence; got: {}",
431 output.summary
432 );
433 }
434
435 #[tokio::test]
438 async fn memory_search_description_mentions_user_provided_facts() {
439 let memory = make_memory().await;
440 let executor = make_executor(memory);
441 let defs = executor.tool_definitions();
442 let memory_search = defs
443 .iter()
444 .find(|d| d.id.as_ref() == "memory_search")
445 .unwrap();
446 assert!(
447 memory_search
448 .description
449 .contains("user provided during this or previous conversations"),
450 "memory_search description must contain disambiguation phrase; got: {}",
451 memory_search.description
452 );
453 }
454}