1use std::fmt::Write as _;
5use std::sync::Arc;
6
7use zeph_memory::embedding_store::SearchFilter;
8use zeph_memory::semantic::SemanticMemory;
9use zeph_memory::types::ConversationId;
10use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
11use zeph_tools::registry::{InvocationHint, ToolDef};
12
13use zeph_sanitizer::memory_validation::MemoryWriteValidator;
14
15#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
16struct MemorySearchParams {
17 query: String,
19 #[serde(default = "default_limit")]
21 limit: u32,
22}
23
24fn default_limit() -> u32 {
25 5
26}
27
28#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
29struct MemorySaveParams {
30 content: String,
32 #[serde(default = "default_role")]
34 role: String,
35}
36
37fn default_role() -> String {
38 "assistant".into()
39}
40
41pub struct MemoryToolExecutor {
42 memory: Arc<SemanticMemory>,
43 conversation_id: ConversationId,
44 validator: MemoryWriteValidator,
45}
46
47impl MemoryToolExecutor {
48 #[must_use]
49 pub fn new(memory: Arc<SemanticMemory>, conversation_id: ConversationId) -> Self {
50 Self {
51 memory,
52 conversation_id,
53 validator: MemoryWriteValidator::new(
54 zeph_sanitizer::memory_validation::MemoryWriteValidationConfig::default(),
55 ),
56 }
57 }
58
59 #[must_use]
61 pub fn with_validator(
62 memory: Arc<SemanticMemory>,
63 conversation_id: ConversationId,
64 validator: MemoryWriteValidator,
65 ) -> Self {
66 Self {
67 memory,
68 conversation_id,
69 validator,
70 }
71 }
72}
73
74impl ToolExecutor for MemoryToolExecutor {
75 fn tool_definitions(&self) -> Vec<ToolDef> {
76 vec![
77 ToolDef {
78 id: "memory_search".into(),
79 description: "Search long-term memory for relevant past messages, facts, and session summaries. Use to recall facts, preferences, or information the user provided during this or previous conversations.\n\nParameters: query (string, required) - natural language search query; limit (integer, optional) - max results 1-20 (default: 5)\nReturns: ranked list of memory entries with similarity scores and timestamps\nErrors: Execution on database failure\nExample: {\"query\": \"user preference for output format\", \"limit\": 5}".into(),
80 schema: schemars::schema_for!(MemorySearchParams),
81 invocation: InvocationHint::ToolCall,
82 },
83 ToolDef {
84 id: "memory_save".into(),
85 description: "Save a fact or note to long-term memory for cross-session recall. Use sparingly for key decisions, user preferences, or critical context worth remembering across sessions.\n\nParameters: content (string, required) - concise, self-contained fact or note; role (string, optional) - message role label (default: \"assistant\")\nReturns: confirmation with saved entry ID\nErrors: Execution on database failure; InvalidParams if content is empty\nExample: {\"content\": \"User prefers JSON output over YAML\", \"role\": \"assistant\"}".into(),
86 schema: schemars::schema_for!(MemorySaveParams),
87 invocation: InvocationHint::ToolCall,
88 },
89 ]
90 }
91
92 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
93 Ok(None)
94 }
95
96 #[allow(clippy::too_many_lines)] async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
98 match call.tool_id.as_str() {
99 "memory_search" => {
100 let params: MemorySearchParams = deserialize_params(&call.params)?;
101 let limit = params.limit.clamp(1, 20) as usize;
102
103 let filter = Some(SearchFilter {
104 conversation_id: Some(self.conversation_id),
105 role: None,
106 });
107
108 let recalled = self
109 .memory
110 .recall(¶ms.query, limit, filter)
111 .await
112 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
113
114 let key_facts = self
115 .memory
116 .search_key_facts(¶ms.query, limit)
117 .await
118 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
119
120 let summaries = self
121 .memory
122 .search_session_summaries(¶ms.query, limit, Some(self.conversation_id))
123 .await
124 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
125
126 let mut output = String::new();
127
128 let _ = writeln!(output, "## Recalled Messages ({} results)", recalled.len());
129 for r in &recalled {
130 let role = match r.message.role {
131 zeph_llm::provider::Role::User => "user",
132 zeph_llm::provider::Role::Assistant => "assistant",
133 zeph_llm::provider::Role::System => "system",
134 };
135 let content = r.message.content.trim();
136 let _ = writeln!(output, "[score: {:.2}] {role}: {content}", r.score);
137 }
138
139 let _ = writeln!(output);
140 let _ = writeln!(output, "## Key Facts ({} results)", key_facts.len());
141 for fact in &key_facts {
142 let _ = writeln!(output, "- {fact}");
143 }
144
145 let _ = writeln!(output);
146 let _ = writeln!(output, "## Session Summaries ({} results)", summaries.len());
147 for s in &summaries {
148 let _ = writeln!(
149 output,
150 "[conv #{}, score: {:.2}] {}",
151 s.conversation_id, s.score, s.summary_text
152 );
153 }
154
155 Ok(Some(ToolOutput {
156 tool_name: "memory_search".to_owned(),
157 summary: output,
158 blocks_executed: 1,
159 filter_stats: None,
160 diff: None,
161 streamed: false,
162 terminal_id: None,
163 locations: None,
164 raw_response: None,
165 claim_source: Some(zeph_tools::ClaimSource::Memory),
166 }))
167 }
168 "memory_save" => {
169 let params: MemorySaveParams = deserialize_params(&call.params)?;
170
171 if params.content.is_empty() {
172 return Err(ToolError::InvalidParams {
173 message: "content must not be empty".to_owned(),
174 });
175 }
176 if params.content.len() > 4096 {
177 return Err(ToolError::InvalidParams {
178 message: "content exceeds maximum length of 4096 characters".to_owned(),
179 });
180 }
181
182 if let Err(e) = self.validator.validate_memory_save(¶ms.content) {
184 return Err(ToolError::InvalidParams {
185 message: format!("memory write rejected: {e}"),
186 });
187 }
188
189 let role = params.role.as_str();
190
191 let message_id_opt = self
193 .memory
194 .remember(self.conversation_id, role, ¶ms.content, None)
195 .await
196 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
197
198 let summary = match message_id_opt {
199 Some(message_id) => format!(
200 "Saved to memory (message_id: {message_id}, conversation: {}). Content will be available for future recall.",
201 self.conversation_id
202 ),
203 None => "Memory admission rejected: message did not meet quality threshold."
204 .to_owned(),
205 };
206
207 Ok(Some(ToolOutput {
208 tool_name: "memory_save".to_owned(),
209 summary,
210 blocks_executed: 1,
211 filter_stats: None,
212 diff: None,
213 streamed: false,
214 terminal_id: None,
215 locations: None,
216 raw_response: None,
217 claim_source: Some(zeph_tools::ClaimSource::Memory),
218 }))
219 }
220 _ => Ok(None),
221 }
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use zeph_llm::any::AnyProvider;
229 use zeph_llm::mock::MockProvider;
230 use zeph_memory::semantic::SemanticMemory;
231
232 async fn make_memory() -> SemanticMemory {
233 SemanticMemory::with_sqlite_backend(
234 ":memory:",
235 AnyProvider::Mock(MockProvider::default()),
236 "test-model",
237 0.7,
238 0.3,
239 )
240 .await
241 .unwrap()
242 }
243
244 fn make_executor(memory: SemanticMemory) -> MemoryToolExecutor {
245 MemoryToolExecutor::new(Arc::new(memory), ConversationId(1))
246 }
247
248 #[tokio::test]
249 async fn tool_definitions_returns_two_tools() {
250 let memory = make_memory().await;
251 let executor = make_executor(memory);
252 let defs = executor.tool_definitions();
253 assert_eq!(defs.len(), 2);
254 assert_eq!(defs[0].id.as_ref(), "memory_search");
255 assert_eq!(defs[1].id.as_ref(), "memory_save");
256 }
257
258 #[tokio::test]
259 async fn execute_always_returns_none() {
260 let memory = make_memory().await;
261 let executor = make_executor(memory);
262 let result = executor.execute("any response").await.unwrap();
263 assert!(result.is_none());
264 }
265
266 #[tokio::test]
267 async fn execute_tool_call_unknown_returns_none() {
268 let memory = make_memory().await;
269 let executor = make_executor(memory);
270 let call = ToolCall {
271 tool_id: "unknown_tool".to_owned(),
272 params: serde_json::Map::new(),
273 };
274 let result = executor.execute_tool_call(&call).await.unwrap();
275 assert!(result.is_none());
276 }
277
278 #[tokio::test]
279 async fn memory_search_returns_output() {
280 let memory = make_memory().await;
281 let executor = make_executor(memory);
282 let mut params = serde_json::Map::new();
283 params.insert(
284 "query".into(),
285 serde_json::Value::String("test query".into()),
286 );
287 let call = ToolCall {
288 tool_id: "memory_search".to_owned(),
289 params,
290 };
291 let result = executor.execute_tool_call(&call).await.unwrap();
292 assert!(result.is_some());
293 let output = result.unwrap();
294 assert_eq!(output.tool_name, "memory_search");
295 assert!(output.summary.contains("Recalled Messages"));
296 assert!(output.summary.contains("Key Facts"));
297 assert!(output.summary.contains("Session Summaries"));
298 }
299
300 #[tokio::test]
301 async fn memory_save_stores_and_returns_confirmation() {
302 let memory = make_memory().await;
303 let sqlite = memory.sqlite().clone();
304 let cid = sqlite.create_conversation().await.unwrap();
306 let executor = MemoryToolExecutor::new(Arc::new(memory), cid);
307
308 let mut params = serde_json::Map::new();
309 params.insert(
310 "content".into(),
311 serde_json::Value::String("User prefers dark mode".into()),
312 );
313 let call = ToolCall {
314 tool_id: "memory_save".to_owned(),
315 params,
316 };
317 let result = executor.execute_tool_call(&call).await.unwrap();
318 assert!(result.is_some());
319 let output = result.unwrap();
320 assert!(output.summary.contains("Saved to memory"));
321 assert!(output.summary.contains("message_id:"));
322 }
323
324 #[tokio::test]
325 async fn memory_save_empty_content_returns_error() {
326 let memory = make_memory().await;
327 let executor = make_executor(memory);
328 let mut params = serde_json::Map::new();
329 params.insert("content".into(), serde_json::Value::String(String::new()));
330 let call = ToolCall {
331 tool_id: "memory_save".to_owned(),
332 params,
333 };
334 let result = executor.execute_tool_call(&call).await;
335 assert!(result.is_err());
336 }
337
338 #[tokio::test]
339 async fn memory_save_oversized_content_returns_error() {
340 let memory = make_memory().await;
341 let executor = make_executor(memory);
342 let mut params = serde_json::Map::new();
343 params.insert(
344 "content".into(),
345 serde_json::Value::String("x".repeat(4097)),
346 );
347 let call = ToolCall {
348 tool_id: "memory_save".to_owned(),
349 params,
350 };
351 let result = executor.execute_tool_call(&call).await;
352 assert!(result.is_err());
353 }
354
355 #[tokio::test]
358 async fn memory_search_description_mentions_user_provided_facts() {
359 let memory = make_memory().await;
360 let executor = make_executor(memory);
361 let defs = executor.tool_definitions();
362 let memory_search = defs
363 .iter()
364 .find(|d| d.id.as_ref() == "memory_search")
365 .unwrap();
366 assert!(
367 memory_search
368 .description
369 .contains("user provided during this or previous conversations"),
370 "memory_search description must contain disambiguation phrase; got: {}",
371 memory_search.description
372 );
373 }
374}