1use std::fmt::Write as _;
5use std::sync::Arc;
6
7use zeph_memory::embedding_store::SearchFilter;
8use zeph_memory::semantic::SemanticMemory;
9use zeph_memory::types::ConversationId;
10use zeph_tools::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
11use zeph_tools::registry::{InvocationHint, ToolDef};
12
13#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
14struct MemorySearchParams {
15 query: String,
17 #[serde(default = "default_limit")]
19 limit: u32,
20}
21
22fn default_limit() -> u32 {
23 5
24}
25
26#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
27struct MemorySaveParams {
28 content: String,
30 #[serde(default = "default_role")]
32 role: String,
33}
34
35fn default_role() -> String {
36 "assistant".into()
37}
38
39pub struct MemoryToolExecutor {
40 memory: Arc<SemanticMemory>,
41 conversation_id: ConversationId,
42}
43
44impl MemoryToolExecutor {
45 #[must_use]
46 pub fn new(memory: Arc<SemanticMemory>, conversation_id: ConversationId) -> Self {
47 Self {
48 memory,
49 conversation_id,
50 }
51 }
52}
53
54impl ToolExecutor for MemoryToolExecutor {
55 fn tool_definitions(&self) -> Vec<ToolDef> {
56 vec![
57 ToolDef {
58 id: "memory_search".into(),
59 description: "Search long-term memory for relevant past messages, facts, and session summaries. Use when the user references past conversations or you need historical context.\n\nParameters: query (string, required) - natural language search query; limit (integer, optional) - max results 1-20 (default: 5)\nReturns: ranked list of memory entries with similarity scores and timestamps\nErrors: Execution on database failure\nExample: {\"query\": \"user preference for output format\", \"limit\": 5}".into(),
60 schema: schemars::schema_for!(MemorySearchParams),
61 invocation: InvocationHint::ToolCall,
62 },
63 ToolDef {
64 id: "memory_save".into(),
65 description: "Save a fact or note to long-term memory for cross-session recall. Use sparingly for key decisions, user preferences, or critical context worth remembering across sessions.\n\nParameters: content (string, required) - concise, self-contained fact or note; role (string, optional) - message role label (default: \"assistant\")\nReturns: confirmation with saved entry ID\nErrors: Execution on database failure; InvalidParams if content is empty\nExample: {\"content\": \"User prefers JSON output over YAML\", \"role\": \"assistant\"}".into(),
66 schema: schemars::schema_for!(MemorySaveParams),
67 invocation: InvocationHint::ToolCall,
68 },
69 ]
70 }
71
72 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
73 Ok(None)
74 }
75
76 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
77 match call.tool_id.as_str() {
78 "memory_search" => {
79 let params: MemorySearchParams = deserialize_params(&call.params)?;
80 let limit = params.limit.clamp(1, 20) as usize;
81
82 let filter = Some(SearchFilter {
83 conversation_id: Some(self.conversation_id),
84 role: None,
85 });
86
87 let recalled = self
88 .memory
89 .recall(¶ms.query, limit, filter)
90 .await
91 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
92
93 let key_facts = self
94 .memory
95 .search_key_facts(¶ms.query, limit)
96 .await
97 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
98
99 let summaries = self
100 .memory
101 .search_session_summaries(¶ms.query, limit, Some(self.conversation_id))
102 .await
103 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
104
105 let mut output = String::new();
106
107 let _ = writeln!(output, "## Recalled Messages ({} results)", recalled.len());
108 for r in &recalled {
109 let role = match r.message.role {
110 zeph_llm::provider::Role::User => "user",
111 zeph_llm::provider::Role::Assistant => "assistant",
112 zeph_llm::provider::Role::System => "system",
113 };
114 let content = r.message.content.trim();
115 let _ = writeln!(output, "[score: {:.2}] {role}: {content}", r.score);
116 }
117
118 let _ = writeln!(output);
119 let _ = writeln!(output, "## Key Facts ({} results)", key_facts.len());
120 for fact in &key_facts {
121 let _ = writeln!(output, "- {fact}");
122 }
123
124 let _ = writeln!(output);
125 let _ = writeln!(output, "## Session Summaries ({} results)", summaries.len());
126 for s in &summaries {
127 let _ = writeln!(
128 output,
129 "[conv #{}, score: {:.2}] {}",
130 s.conversation_id, s.score, s.summary_text
131 );
132 }
133
134 Ok(Some(ToolOutput {
135 tool_name: "memory_search".to_owned(),
136 summary: output,
137 blocks_executed: 1,
138 filter_stats: None,
139 diff: None,
140 streamed: false,
141 terminal_id: None,
142 locations: None,
143 raw_response: None,
144 }))
145 }
146 "memory_save" => {
147 let params: MemorySaveParams = deserialize_params(&call.params)?;
148
149 if params.content.is_empty() {
150 return Err(ToolError::InvalidParams {
151 message: "content must not be empty".to_owned(),
152 });
153 }
154 if params.content.len() > 4096 {
155 return Err(ToolError::InvalidParams {
156 message: "content exceeds maximum length of 4096 characters".to_owned(),
157 });
158 }
159
160 let role = params.role.as_str();
161
162 let message_id = self
163 .memory
164 .remember(self.conversation_id, role, ¶ms.content)
165 .await
166 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
167
168 Ok(Some(ToolOutput {
169 tool_name: "memory_save".to_owned(),
170 summary: format!(
171 "Saved to memory (message_id: {message_id}, conversation: {}). Content will be available for future recall.",
172 self.conversation_id
173 ),
174 blocks_executed: 1,
175 filter_stats: None,
176 diff: None,
177 streamed: false,
178 terminal_id: None,
179 locations: None,
180 raw_response: None,
181 }))
182 }
183 _ => Ok(None),
184 }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use zeph_llm::any::AnyProvider;
192 use zeph_llm::mock::MockProvider;
193 use zeph_memory::semantic::SemanticMemory;
194
195 async fn make_memory() -> SemanticMemory {
196 SemanticMemory::with_sqlite_backend(
197 ":memory:",
198 AnyProvider::Mock(MockProvider::default()),
199 "test-model",
200 0.7,
201 0.3,
202 )
203 .await
204 .unwrap()
205 }
206
207 fn make_executor(memory: SemanticMemory) -> MemoryToolExecutor {
208 MemoryToolExecutor::new(Arc::new(memory), ConversationId(1))
209 }
210
211 #[tokio::test]
212 async fn tool_definitions_returns_two_tools() {
213 let memory = make_memory().await;
214 let executor = make_executor(memory);
215 let defs = executor.tool_definitions();
216 assert_eq!(defs.len(), 2);
217 assert_eq!(defs[0].id.as_ref(), "memory_search");
218 assert_eq!(defs[1].id.as_ref(), "memory_save");
219 }
220
221 #[tokio::test]
222 async fn execute_always_returns_none() {
223 let memory = make_memory().await;
224 let executor = make_executor(memory);
225 let result = executor.execute("any response").await.unwrap();
226 assert!(result.is_none());
227 }
228
229 #[tokio::test]
230 async fn execute_tool_call_unknown_returns_none() {
231 let memory = make_memory().await;
232 let executor = make_executor(memory);
233 let call = ToolCall {
234 tool_id: "unknown_tool".to_owned(),
235 params: serde_json::Map::new(),
236 };
237 let result = executor.execute_tool_call(&call).await.unwrap();
238 assert!(result.is_none());
239 }
240
241 #[tokio::test]
242 async fn memory_search_returns_output() {
243 let memory = make_memory().await;
244 let executor = make_executor(memory);
245 let mut params = serde_json::Map::new();
246 params.insert(
247 "query".into(),
248 serde_json::Value::String("test query".into()),
249 );
250 let call = ToolCall {
251 tool_id: "memory_search".to_owned(),
252 params,
253 };
254 let result = executor.execute_tool_call(&call).await.unwrap();
255 assert!(result.is_some());
256 let output = result.unwrap();
257 assert_eq!(output.tool_name, "memory_search");
258 assert!(output.summary.contains("Recalled Messages"));
259 assert!(output.summary.contains("Key Facts"));
260 assert!(output.summary.contains("Session Summaries"));
261 }
262
263 #[tokio::test]
264 async fn memory_save_stores_and_returns_confirmation() {
265 let memory = make_memory().await;
266 let sqlite = memory.sqlite().clone();
267 let cid = sqlite.create_conversation().await.unwrap();
269 let executor = MemoryToolExecutor::new(Arc::new(memory), cid);
270
271 let mut params = serde_json::Map::new();
272 params.insert(
273 "content".into(),
274 serde_json::Value::String("User prefers dark mode".into()),
275 );
276 let call = ToolCall {
277 tool_id: "memory_save".to_owned(),
278 params,
279 };
280 let result = executor.execute_tool_call(&call).await.unwrap();
281 assert!(result.is_some());
282 let output = result.unwrap();
283 assert!(output.summary.contains("Saved to memory"));
284 assert!(output.summary.contains("message_id:"));
285 }
286
287 #[tokio::test]
288 async fn memory_save_empty_content_returns_error() {
289 let memory = make_memory().await;
290 let executor = make_executor(memory);
291 let mut params = serde_json::Map::new();
292 params.insert("content".into(), serde_json::Value::String(String::new()));
293 let call = ToolCall {
294 tool_id: "memory_save".to_owned(),
295 params,
296 };
297 let result = executor.execute_tool_call(&call).await;
298 assert!(result.is_err());
299 }
300
301 #[tokio::test]
302 async fn memory_save_oversized_content_returns_error() {
303 let memory = make_memory().await;
304 let executor = make_executor(memory);
305 let mut params = serde_json::Map::new();
306 params.insert(
307 "content".into(),
308 serde_json::Value::String("x".repeat(4097)),
309 );
310 let call = ToolCall {
311 tool_id: "memory_save".to_owned(),
312 params,
313 };
314 let result = executor.execute_tool_call(&call).await;
315 assert!(result.is_err());
316 }
317}