1use std::fmt::Write as _;
5use std::sync::Arc;
6
7use zeph_memory::embedding_store::SearchFilter;
8use zeph_memory::semantic::SemanticMemory;
9use zeph_memory::types::ConversationId;
10use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
11use zeph_tools::registry::{InvocationHint, ToolDef};
12
13use zeph_sanitizer::memory_validation::MemoryWriteValidator;
14
15#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
16struct MemorySearchParams {
17 query: String,
19 #[serde(default = "default_limit")]
21 limit: u32,
22}
23
24fn default_limit() -> u32 {
25 5
26}
27
28#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
29struct MemorySaveParams {
30 content: String,
32 #[serde(default = "default_role")]
34 role: String,
35}
36
37fn default_role() -> String {
38 "assistant".into()
39}
40
41pub struct MemoryToolExecutor {
42 memory: Arc<SemanticMemory>,
43 conversation_id: ConversationId,
44 validator: MemoryWriteValidator,
45}
46
47impl MemoryToolExecutor {
48 #[must_use]
49 pub fn new(memory: Arc<SemanticMemory>, conversation_id: ConversationId) -> Self {
50 Self {
51 memory,
52 conversation_id,
53 validator: MemoryWriteValidator::new(
54 zeph_sanitizer::memory_validation::MemoryWriteValidationConfig::default(),
55 ),
56 }
57 }
58
59 #[must_use]
61 pub fn with_validator(
62 memory: Arc<SemanticMemory>,
63 conversation_id: ConversationId,
64 validator: MemoryWriteValidator,
65 ) -> Self {
66 Self {
67 memory,
68 conversation_id,
69 validator,
70 }
71 }
72}
73
74impl ToolExecutor for MemoryToolExecutor {
75 fn tool_definitions(&self) -> Vec<ToolDef> {
76 vec![
77 ToolDef {
78 id: "memory_search".into(),
79 description: "Search long-term memory for relevant past messages, facts, and session summaries. Use to recall facts, preferences, or information the user provided during this or previous conversations.\n\nParameters: query (string, required) - natural language search query; limit (integer, optional) - max results 1-20 (default: 5)\nReturns: ranked list of memory entries with similarity scores and timestamps\nErrors: Execution on database failure\nExample: {\"query\": \"user preference for output format\", \"limit\": 5}".into(),
80 schema: schemars::schema_for!(MemorySearchParams),
81 invocation: InvocationHint::ToolCall,
82 },
83 ToolDef {
84 id: "memory_save".into(),
85 description: "Save a fact or note to long-term memory for cross-session recall. Use sparingly for key decisions, user preferences, or critical context worth remembering across sessions.\n\nParameters: content (string, required) - concise, self-contained fact or note; role (string, optional) - message role label (default: \"assistant\")\nReturns: confirmation with saved entry ID\nErrors: Execution on database failure; InvalidParams if content is empty\nExample: {\"content\": \"User prefers JSON output over YAML\", \"role\": \"assistant\"}".into(),
86 schema: schemars::schema_for!(MemorySaveParams),
87 invocation: InvocationHint::ToolCall,
88 },
89 ]
90 }
91
92 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
93 Ok(None)
94 }
95
96 #[allow(clippy::too_many_lines)] async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
98 match call.tool_id.as_str() {
99 "memory_search" => {
100 let params: MemorySearchParams = deserialize_params(&call.params)?;
101 let limit = params.limit.clamp(1, 20) as usize;
102
103 let filter = Some(SearchFilter {
104 conversation_id: Some(self.conversation_id),
105 role: None,
106 category: None,
107 });
108
109 let recalled = self
110 .memory
111 .recall(¶ms.query, limit, filter)
112 .await
113 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
114
115 let key_facts = self
116 .memory
117 .search_key_facts(¶ms.query, limit)
118 .await
119 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
120
121 let summaries = self
122 .memory
123 .search_session_summaries(¶ms.query, limit, Some(self.conversation_id))
124 .await
125 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
126
127 let mut output = String::new();
128
129 let _ = writeln!(output, "## Recalled Messages ({} results)", recalled.len());
130 for r in &recalled {
131 let role = match r.message.role {
132 zeph_llm::provider::Role::User => "user",
133 zeph_llm::provider::Role::Assistant => "assistant",
134 zeph_llm::provider::Role::System => "system",
135 };
136 let content = r.message.content.trim();
137 let _ = writeln!(output, "[score: {:.2}] {role}: {content}", r.score);
138 }
139
140 let _ = writeln!(output);
141 let _ = writeln!(output, "## Key Facts ({} results)", key_facts.len());
142 for fact in &key_facts {
143 let _ = writeln!(output, "- {fact}");
144 }
145
146 let _ = writeln!(output);
147 let _ = writeln!(output, "## Session Summaries ({} results)", summaries.len());
148 for s in &summaries {
149 let _ = writeln!(
150 output,
151 "[conv #{}, score: {:.2}] {}",
152 s.conversation_id, s.score, s.summary_text
153 );
154 }
155
156 Ok(Some(ToolOutput {
157 tool_name: zeph_common::ToolName::new("memory_search"),
158 summary: output,
159 blocks_executed: 1,
160 filter_stats: None,
161 diff: None,
162 streamed: false,
163 terminal_id: None,
164 locations: None,
165 raw_response: None,
166 claim_source: Some(zeph_tools::ClaimSource::Memory),
167 }))
168 }
169 "memory_save" => {
170 let params: MemorySaveParams = deserialize_params(&call.params)?;
171
172 if params.content.is_empty() {
173 return Err(ToolError::InvalidParams {
174 message: "content must not be empty".to_owned(),
175 });
176 }
177 if params.content.len() > 4096 {
178 return Err(ToolError::InvalidParams {
179 message: "content exceeds maximum length of 4096 characters".to_owned(),
180 });
181 }
182
183 if let Err(e) = self.validator.validate_memory_save(¶ms.content) {
185 return Err(ToolError::InvalidParams {
186 message: format!("memory write rejected: {e}"),
187 });
188 }
189
190 let role = params.role.as_str();
191
192 let message_id_opt = self
194 .memory
195 .remember(self.conversation_id, role, ¶ms.content, None)
196 .await
197 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
198
199 let summary = match message_id_opt {
200 Some(message_id) => format!(
201 "Saved to memory (message_id: {message_id}, conversation: {}). Content will be available for future recall.",
202 self.conversation_id
203 ),
204 None => "Memory admission rejected: message did not meet quality threshold."
205 .to_owned(),
206 };
207
208 Ok(Some(ToolOutput {
209 tool_name: zeph_common::ToolName::new("memory_save"),
210 summary,
211 blocks_executed: 1,
212 filter_stats: None,
213 diff: None,
214 streamed: false,
215 terminal_id: None,
216 locations: None,
217 raw_response: None,
218 claim_source: Some(zeph_tools::ClaimSource::Memory),
219 }))
220 }
221 _ => Ok(None),
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use zeph_llm::any::AnyProvider;
230 use zeph_llm::mock::MockProvider;
231 use zeph_memory::semantic::SemanticMemory;
232
233 async fn make_memory() -> SemanticMemory {
234 SemanticMemory::with_sqlite_backend(
235 ":memory:",
236 AnyProvider::Mock(MockProvider::default()),
237 "test-model",
238 0.7,
239 0.3,
240 )
241 .await
242 .unwrap()
243 }
244
245 fn make_executor(memory: SemanticMemory) -> MemoryToolExecutor {
246 MemoryToolExecutor::new(Arc::new(memory), ConversationId(1))
247 }
248
249 #[tokio::test]
250 async fn tool_definitions_returns_two_tools() {
251 let memory = make_memory().await;
252 let executor = make_executor(memory);
253 let defs = executor.tool_definitions();
254 assert_eq!(defs.len(), 2);
255 assert_eq!(defs[0].id.as_ref(), "memory_search");
256 assert_eq!(defs[1].id.as_ref(), "memory_save");
257 }
258
259 #[tokio::test]
260 async fn execute_always_returns_none() {
261 let memory = make_memory().await;
262 let executor = make_executor(memory);
263 let result = executor.execute("any response").await.unwrap();
264 assert!(result.is_none());
265 }
266
267 #[tokio::test]
268 async fn execute_tool_call_unknown_returns_none() {
269 let memory = make_memory().await;
270 let executor = make_executor(memory);
271 let call = ToolCall {
272 tool_id: zeph_common::ToolName::new("unknown_tool"),
273 params: serde_json::Map::new(),
274 caller_id: None,
275 };
276 let result = executor.execute_tool_call(&call).await.unwrap();
277 assert!(result.is_none());
278 }
279
280 #[tokio::test]
281 async fn memory_search_returns_output() {
282 let memory = make_memory().await;
283 let executor = make_executor(memory);
284 let mut params = serde_json::Map::new();
285 params.insert(
286 "query".into(),
287 serde_json::Value::String("test query".into()),
288 );
289 let call = ToolCall {
290 tool_id: zeph_common::ToolName::new("memory_search"),
291 params,
292 caller_id: None,
293 };
294 let result = executor.execute_tool_call(&call).await.unwrap();
295 assert!(result.is_some());
296 let output = result.unwrap();
297 assert_eq!(output.tool_name, "memory_search");
298 assert!(output.summary.contains("Recalled Messages"));
299 assert!(output.summary.contains("Key Facts"));
300 assert!(output.summary.contains("Session Summaries"));
301 }
302
303 #[tokio::test]
304 async fn memory_save_stores_and_returns_confirmation() {
305 let memory = make_memory().await;
306 let sqlite = memory.sqlite().clone();
307 let cid = sqlite.create_conversation().await.unwrap();
309 let executor = MemoryToolExecutor::new(Arc::new(memory), cid);
310
311 let mut params = serde_json::Map::new();
312 params.insert(
313 "content".into(),
314 serde_json::Value::String("User prefers dark mode".into()),
315 );
316 let call = ToolCall {
317 tool_id: zeph_common::ToolName::new("memory_save"),
318 params,
319 caller_id: None,
320 };
321 let result = executor.execute_tool_call(&call).await.unwrap();
322 assert!(result.is_some());
323 let output = result.unwrap();
324 assert!(output.summary.contains("Saved to memory"));
325 assert!(output.summary.contains("message_id:"));
326 }
327
328 #[tokio::test]
329 async fn memory_save_empty_content_returns_error() {
330 let memory = make_memory().await;
331 let executor = make_executor(memory);
332 let mut params = serde_json::Map::new();
333 params.insert("content".into(), serde_json::Value::String(String::new()));
334 let call = ToolCall {
335 tool_id: zeph_common::ToolName::new("memory_save"),
336 params,
337 caller_id: None,
338 };
339 let result = executor.execute_tool_call(&call).await;
340 assert!(result.is_err());
341 }
342
343 #[tokio::test]
344 async fn memory_save_oversized_content_returns_error() {
345 let memory = make_memory().await;
346 let executor = make_executor(memory);
347 let mut params = serde_json::Map::new();
348 params.insert(
349 "content".into(),
350 serde_json::Value::String("x".repeat(4097)),
351 );
352 let call = ToolCall {
353 tool_id: zeph_common::ToolName::new("memory_save"),
354 params,
355 caller_id: None,
356 };
357 let result = executor.execute_tool_call(&call).await;
358 assert!(result.is_err());
359 }
360
361 #[tokio::test]
364 async fn memory_search_description_mentions_user_provided_facts() {
365 let memory = make_memory().await;
366 let executor = make_executor(memory);
367 let defs = executor.tool_definitions();
368 let memory_search = defs
369 .iter()
370 .find(|d| d.id.as_ref() == "memory_search")
371 .unwrap();
372 assert!(
373 memory_search
374 .description
375 .contains("user provided during this or previous conversations"),
376 "memory_search description must contain disambiguation phrase; got: {}",
377 memory_search.description
378 );
379 }
380}