1use std::fmt::Write as _;
5use std::sync::Arc;
6
7use zeph_memory::embedding_store::SearchFilter;
8use zeph_memory::semantic::SemanticMemory;
9use zeph_memory::types::ConversationId;
10use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
11use zeph_tools::registry::{InvocationHint, ToolDef};
12
13use zeph_sanitizer::memory_validation::MemoryWriteValidator;
14
15#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
16struct MemorySearchParams {
17 query: String,
19 #[serde(default = "default_limit")]
21 limit: u32,
22}
23
24fn default_limit() -> u32 {
25 5
26}
27
28#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
29struct MemorySaveParams {
30 content: String,
32 #[serde(default = "default_role")]
34 role: String,
35}
36
37fn default_role() -> String {
38 "assistant".into()
39}
40
41pub struct MemoryToolExecutor {
42 memory: Arc<SemanticMemory>,
43 conversation_id: ConversationId,
44 validator: MemoryWriteValidator,
45}
46
47impl MemoryToolExecutor {
48 #[must_use]
49 pub fn new(memory: Arc<SemanticMemory>, conversation_id: ConversationId) -> Self {
50 Self {
51 memory,
52 conversation_id,
53 validator: MemoryWriteValidator::new(
54 zeph_sanitizer::memory_validation::MemoryWriteValidationConfig::default(),
55 ),
56 }
57 }
58
59 #[must_use]
61 pub fn with_validator(
62 memory: Arc<SemanticMemory>,
63 conversation_id: ConversationId,
64 validator: MemoryWriteValidator,
65 ) -> Self {
66 Self {
67 memory,
68 conversation_id,
69 validator,
70 }
71 }
72}
73
74impl ToolExecutor for MemoryToolExecutor {
75 fn tool_definitions(&self) -> Vec<ToolDef> {
76 vec![
77 ToolDef {
78 id: "memory_search".into(),
79 description: "Search long-term memory for relevant past messages, facts, and session summaries. Use when the user references past conversations or you need historical context.\n\nParameters: query (string, required) - natural language search query; limit (integer, optional) - max results 1-20 (default: 5)\nReturns: ranked list of memory entries with similarity scores and timestamps\nErrors: Execution on database failure\nExample: {\"query\": \"user preference for output format\", \"limit\": 5}".into(),
80 schema: schemars::schema_for!(MemorySearchParams),
81 invocation: InvocationHint::ToolCall,
82 },
83 ToolDef {
84 id: "memory_save".into(),
85 description: "Save a fact or note to long-term memory for cross-session recall. Use sparingly for key decisions, user preferences, or critical context worth remembering across sessions.\n\nParameters: content (string, required) - concise, self-contained fact or note; role (string, optional) - message role label (default: \"assistant\")\nReturns: confirmation with saved entry ID\nErrors: Execution on database failure; InvalidParams if content is empty\nExample: {\"content\": \"User prefers JSON output over YAML\", \"role\": \"assistant\"}".into(),
86 schema: schemars::schema_for!(MemorySaveParams),
87 invocation: InvocationHint::ToolCall,
88 },
89 ]
90 }
91
92 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
93 Ok(None)
94 }
95
96 #[allow(clippy::too_many_lines)] async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
98 match call.tool_id.as_str() {
99 "memory_search" => {
100 let params: MemorySearchParams = deserialize_params(&call.params)?;
101 let limit = params.limit.clamp(1, 20) as usize;
102
103 let filter = Some(SearchFilter {
104 conversation_id: Some(self.conversation_id),
105 role: None,
106 });
107
108 let recalled = self
109 .memory
110 .recall(¶ms.query, limit, filter)
111 .await
112 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
113
114 let key_facts = self
115 .memory
116 .search_key_facts(¶ms.query, limit)
117 .await
118 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
119
120 let summaries = self
121 .memory
122 .search_session_summaries(¶ms.query, limit, Some(self.conversation_id))
123 .await
124 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
125
126 let mut output = String::new();
127
128 let _ = writeln!(output, "## Recalled Messages ({} results)", recalled.len());
129 for r in &recalled {
130 let role = match r.message.role {
131 zeph_llm::provider::Role::User => "user",
132 zeph_llm::provider::Role::Assistant => "assistant",
133 zeph_llm::provider::Role::System => "system",
134 };
135 let content = r.message.content.trim();
136 let _ = writeln!(output, "[score: {:.2}] {role}: {content}", r.score);
137 }
138
139 let _ = writeln!(output);
140 let _ = writeln!(output, "## Key Facts ({} results)", key_facts.len());
141 for fact in &key_facts {
142 let _ = writeln!(output, "- {fact}");
143 }
144
145 let _ = writeln!(output);
146 let _ = writeln!(output, "## Session Summaries ({} results)", summaries.len());
147 for s in &summaries {
148 let _ = writeln!(
149 output,
150 "[conv #{}, score: {:.2}] {}",
151 s.conversation_id, s.score, s.summary_text
152 );
153 }
154
155 Ok(Some(ToolOutput {
156 tool_name: "memory_search".to_owned(),
157 summary: output,
158 blocks_executed: 1,
159 filter_stats: None,
160 diff: None,
161 streamed: false,
162 terminal_id: None,
163 locations: None,
164 raw_response: None,
165 }))
166 }
167 "memory_save" => {
168 let params: MemorySaveParams = deserialize_params(&call.params)?;
169
170 if params.content.is_empty() {
171 return Err(ToolError::InvalidParams {
172 message: "content must not be empty".to_owned(),
173 });
174 }
175 if params.content.len() > 4096 {
176 return Err(ToolError::InvalidParams {
177 message: "content exceeds maximum length of 4096 characters".to_owned(),
178 });
179 }
180
181 if let Err(e) = self.validator.validate_memory_save(¶ms.content) {
183 return Err(ToolError::InvalidParams {
184 message: format!("memory write rejected: {e}"),
185 });
186 }
187
188 let role = params.role.as_str();
189
190 let message_id = self
191 .memory
192 .remember(self.conversation_id, role, ¶ms.content)
193 .await
194 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
195
196 Ok(Some(ToolOutput {
197 tool_name: "memory_save".to_owned(),
198 summary: format!(
199 "Saved to memory (message_id: {message_id}, conversation: {}). Content will be available for future recall.",
200 self.conversation_id
201 ),
202 blocks_executed: 1,
203 filter_stats: None,
204 diff: None,
205 streamed: false,
206 terminal_id: None,
207 locations: None,
208 raw_response: None,
209 }))
210 }
211 _ => Ok(None),
212 }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use zeph_llm::any::AnyProvider;
220 use zeph_llm::mock::MockProvider;
221 use zeph_memory::semantic::SemanticMemory;
222
223 async fn make_memory() -> SemanticMemory {
224 SemanticMemory::with_sqlite_backend(
225 ":memory:",
226 AnyProvider::Mock(MockProvider::default()),
227 "test-model",
228 0.7,
229 0.3,
230 )
231 .await
232 .unwrap()
233 }
234
235 fn make_executor(memory: SemanticMemory) -> MemoryToolExecutor {
236 MemoryToolExecutor::new(Arc::new(memory), ConversationId(1))
237 }
238
239 #[tokio::test]
240 async fn tool_definitions_returns_two_tools() {
241 let memory = make_memory().await;
242 let executor = make_executor(memory);
243 let defs = executor.tool_definitions();
244 assert_eq!(defs.len(), 2);
245 assert_eq!(defs[0].id.as_ref(), "memory_search");
246 assert_eq!(defs[1].id.as_ref(), "memory_save");
247 }
248
249 #[tokio::test]
250 async fn execute_always_returns_none() {
251 let memory = make_memory().await;
252 let executor = make_executor(memory);
253 let result = executor.execute("any response").await.unwrap();
254 assert!(result.is_none());
255 }
256
257 #[tokio::test]
258 async fn execute_tool_call_unknown_returns_none() {
259 let memory = make_memory().await;
260 let executor = make_executor(memory);
261 let call = ToolCall {
262 tool_id: "unknown_tool".to_owned(),
263 params: serde_json::Map::new(),
264 };
265 let result = executor.execute_tool_call(&call).await.unwrap();
266 assert!(result.is_none());
267 }
268
269 #[tokio::test]
270 async fn memory_search_returns_output() {
271 let memory = make_memory().await;
272 let executor = make_executor(memory);
273 let mut params = serde_json::Map::new();
274 params.insert(
275 "query".into(),
276 serde_json::Value::String("test query".into()),
277 );
278 let call = ToolCall {
279 tool_id: "memory_search".to_owned(),
280 params,
281 };
282 let result = executor.execute_tool_call(&call).await.unwrap();
283 assert!(result.is_some());
284 let output = result.unwrap();
285 assert_eq!(output.tool_name, "memory_search");
286 assert!(output.summary.contains("Recalled Messages"));
287 assert!(output.summary.contains("Key Facts"));
288 assert!(output.summary.contains("Session Summaries"));
289 }
290
291 #[tokio::test]
292 async fn memory_save_stores_and_returns_confirmation() {
293 let memory = make_memory().await;
294 let sqlite = memory.sqlite().clone();
295 let cid = sqlite.create_conversation().await.unwrap();
297 let executor = MemoryToolExecutor::new(Arc::new(memory), cid);
298
299 let mut params = serde_json::Map::new();
300 params.insert(
301 "content".into(),
302 serde_json::Value::String("User prefers dark mode".into()),
303 );
304 let call = ToolCall {
305 tool_id: "memory_save".to_owned(),
306 params,
307 };
308 let result = executor.execute_tool_call(&call).await.unwrap();
309 assert!(result.is_some());
310 let output = result.unwrap();
311 assert!(output.summary.contains("Saved to memory"));
312 assert!(output.summary.contains("message_id:"));
313 }
314
315 #[tokio::test]
316 async fn memory_save_empty_content_returns_error() {
317 let memory = make_memory().await;
318 let executor = make_executor(memory);
319 let mut params = serde_json::Map::new();
320 params.insert("content".into(), serde_json::Value::String(String::new()));
321 let call = ToolCall {
322 tool_id: "memory_save".to_owned(),
323 params,
324 };
325 let result = executor.execute_tool_call(&call).await;
326 assert!(result.is_err());
327 }
328
329 #[tokio::test]
330 async fn memory_save_oversized_content_returns_error() {
331 let memory = make_memory().await;
332 let executor = make_executor(memory);
333 let mut params = serde_json::Map::new();
334 params.insert(
335 "content".into(),
336 serde_json::Value::String("x".repeat(4097)),
337 );
338 let call = ToolCall {
339 tool_id: "memory_save".to_owned(),
340 params,
341 };
342 let result = executor.execute_tool_call(&call).await;
343 assert!(result.is_err());
344 }
345}