1use crate::{config::MemoryConfig, dispatch::MemoryService, tool};
4use std::path::Path;
5use wcore::protocol::{
6 PROTOCOL_VERSION,
7 codec::{read_message, write_message},
8 ext::{
9 AfterCompactCap, AfterRunCap, BeforeRunCap, BuildAgentCap, Capability, CompactCap,
10 EventObserverCap, ExtAfterCompactResult, ExtAfterRunResult, ExtBeforeRunResult,
11 ExtBuildAgentResult, ExtCompactResult, ExtConfigured, ExtError, ExtInferRequest, ExtReady,
12 ExtRequest, ExtResponse, ExtServiceQueryResult, ExtToolResult, ExtToolSchemas, InferCap,
13 QueryCap, SimpleMessage, ToolsList, capability, ext_request, ext_response,
14 },
15};
16
17const EXTRACT_PROMPT: &str = include_str!("../../prompts/extract.md");
18
19pub async fn run(socket: &Path) -> anyhow::Result<()> {
20 if socket.exists() {
22 let _ = std::fs::remove_file(socket);
23 }
24
25 let listener = tokio::net::UnixListener::bind(socket)?;
26 tracing::info!("memory service listening on {}", socket.display());
27
28 let (stream, _) = listener.accept().await?;
29 let (mut reader, mut writer) = stream.into_split();
30
31 let hello: ExtRequest = read_message(&mut reader).await?;
33 match hello.msg {
34 Some(ext_request::Msg::Hello(_)) => {}
35 other => anyhow::bail!("expected Hello, got {other:?}"),
36 }
37
38 let tool_names = vec!["recall".to_owned(), "extract".to_owned()];
39
40 let ready = ExtResponse {
41 msg: Some(ext_response::Msg::Ready(ExtReady {
42 version: PROTOCOL_VERSION.to_owned(),
43 service: "memory".to_owned(),
44 capabilities: vec![
45 Capability {
46 cap: Some(capability::Cap::Tools(ToolsList { names: tool_names })),
47 },
48 Capability {
49 cap: Some(capability::Cap::BuildAgent(BuildAgentCap {})),
50 },
51 Capability {
52 cap: Some(capability::Cap::BeforeRun(BeforeRunCap {})),
53 },
54 Capability {
55 cap: Some(capability::Cap::Compact(CompactCap {})),
56 },
57 Capability {
58 cap: Some(capability::Cap::Query(QueryCap {})),
59 },
60 Capability {
61 cap: Some(capability::Cap::EventObserver(EventObserverCap {})),
62 },
63 Capability {
64 cap: Some(capability::Cap::AfterRun(AfterRunCap {})),
65 },
66 Capability {
67 cap: Some(capability::Cap::AfterCompact(AfterCompactCap {})),
68 },
69 Capability {
70 cap: Some(capability::Cap::Infer(InferCap {})),
71 },
72 ],
73 })),
74 };
75 write_message(&mut writer, &ready).await?;
76
77 let configure: ExtRequest = read_message(&mut reader).await?;
79 let config = match configure.msg {
80 Some(ext_request::Msg::Configure(c)) => {
81 if c.config.is_empty() {
82 MemoryConfig::default()
83 } else {
84 serde_json::from_str(&c.config).unwrap_or_else(|e| {
85 tracing::warn!("invalid config, using defaults: {e}");
86 MemoryConfig::default()
87 })
88 }
89 }
90 other => anyhow::bail!("expected Configure, got {other:?}"),
91 };
92 let configured = ExtResponse {
93 msg: Some(ext_response::Msg::Configured(ExtConfigured {})),
94 };
95 write_message(&mut writer, &configured).await?;
96
97 let register: ExtRequest = read_message(&mut reader).await?;
99 match register.msg {
100 Some(ext_request::Msg::RegisterTools(_)) => {}
101 other => anyhow::bail!("expected RegisterTools, got {other:?}"),
102 }
103
104 let memory_dir = wcore::paths::CONFIG_DIR.join("memory");
106 let svc = MemoryService::open(&memory_dir, &config).await?;
107
108 let tools = tool::all_tool_defs();
111 let schemas = ExtResponse {
112 msg: Some(ext_response::Msg::ToolSchemas(ExtToolSchemas { tools })),
113 };
114 write_message(&mut writer, &schemas).await?;
115 tracing::info!("handshake complete");
116
117 let mut clean_exit = false;
119 loop {
120 let req: ExtRequest = match read_message(&mut reader).await {
121 Ok(r) => r,
122 Err(wcore::protocol::codec::FrameError::ConnectionClosed) => {
123 tracing::warn!("daemon connection closed");
124 break;
125 }
126 Err(e) => {
127 tracing::error!("read error: {e}");
128 break;
129 }
130 };
131
132 let resp = match req.msg {
133 Some(ext_request::Msg::ToolCall(call)) => {
134 let result = dispatch_tool(&svc, &call.name, &call.args, &call.agent).await;
135 ExtResponse {
136 msg: Some(ext_response::Msg::ToolResult(ExtToolResult { result })),
137 }
138 }
139 Some(ext_request::Msg::BuildAgent(ba)) => {
140 let result =
141 handle_build_agent(&svc, &ba.name, &ba.description, &ba.system_prompt).await;
142 ExtResponse {
143 msg: Some(ext_response::Msg::BuildAgentResult(result)),
144 }
145 }
146 Some(ext_request::Msg::BeforeRun(br)) => {
147 let result = handle_before_run(&svc, &br.history).await;
148 ExtResponse {
149 msg: Some(ext_response::Msg::BeforeRunResult(result)),
150 }
151 }
152 Some(ext_request::Msg::AfterRun(ar)) => {
153 let conversation = build_conversation_summary(&ar.history);
154 let _ = svc.dispatch_journal(&conversation, &ar.agent).await;
156 ExtResponse {
157 msg: Some(ext_response::Msg::AfterRunResult(ExtAfterRunResult {})),
158 }
159 }
160 Some(ext_request::Msg::AfterCompact(ac)) => {
161 let _ = svc.dispatch_journal(&ac.summary, &ac.agent).await;
163 let messages = extraction_messages_from(&ac.summary);
164 ExtResponse {
165 msg: Some(ext_response::Msg::InferRequest(ExtInferRequest {
166 messages,
167 })),
168 }
169 }
170 Some(ext_request::Msg::InferResult(_)) => {
171 ExtResponse {
173 msg: Some(ext_response::Msg::AfterCompactResult(
174 ExtAfterCompactResult {},
175 )),
176 }
177 }
178 Some(ext_request::Msg::Compact(c)) => {
179 let addition = handle_compact(&svc, &c.agent).await;
180 ExtResponse {
181 msg: Some(ext_response::Msg::CompactResult(ExtCompactResult {
182 addition,
183 })),
184 }
185 }
186 Some(ext_request::Msg::ServiceQuery(sq)) => {
187 let result = handle_service_query(&svc, &sq.query).await;
188 ExtResponse {
189 msg: Some(ext_response::Msg::ServiceQueryResult(
190 ExtServiceQueryResult { result },
191 )),
192 }
193 }
194 Some(ext_request::Msg::Event(_)) => {
195 continue;
197 }
198 Some(ext_request::Msg::GetSchema(_)) => ExtResponse {
199 msg: Some(ext_response::Msg::Error(ExtError {
200 message: "schema not yet implemented".into(),
201 })),
202 },
203 Some(ext_request::Msg::Shutdown(_)) => {
204 tracing::info!("shutdown requested");
205 clean_exit = true;
206 break;
207 }
208 other => ExtResponse {
209 msg: Some(ext_response::Msg::Error(ExtError {
210 message: format!("unexpected request: {other:?}"),
211 })),
212 },
213 };
214
215 if let Err(e) = write_message(&mut writer, &resp).await {
216 tracing::error!("write error: {e}");
217 break;
218 }
219 }
220
221 let _ = std::fs::remove_file(socket);
223 if clean_exit {
224 Ok(())
225 } else {
226 anyhow::bail!("connection lost")
227 }
228}
229
230async fn dispatch_tool(svc: &MemoryService, name: &str, args: &str, _agent: &str) -> String {
232 match name {
233 "recall" => svc.dispatch_recall(args).await,
234 "extract" => svc.dispatch_extract(args).await,
235 _ => format!("unknown tool: {name}"),
236 }
237}
238
239async fn handle_build_agent(
244 svc: &MemoryService,
245 name: &str,
246 description: &str,
247 _system_prompt: &str,
248) -> ExtBuildAgentResult {
249 let lance = &svc.lance;
250
251 let mut buf = String::from("\n\n<self>\n");
253 buf.push_str(&format!("name: {name}\n"));
254 if !description.is_empty() {
255 buf.push_str(&format!("description: {description}\n"));
256 }
257 buf.push_str("</self>");
258
259 if let Ok(identities) = lance.query_by_type("identity", 50).await
261 && !identities.is_empty()
262 {
263 buf.push_str("\n\n<identity>\n");
264 for e in &identities {
265 buf.push_str(&format!("- **{}**: {}\n", e.key, e.value));
266 }
267 buf.push_str("</identity>");
268 }
269
270 if let Ok(profiles) = lance.query_by_type("profile", 50).await
272 && !profiles.is_empty()
273 {
274 buf.push_str("\n\n<profile>\n");
275 for e in &profiles {
276 buf.push_str(&format!("- **{}**: {}\n", e.key, e.value));
277 }
278 buf.push_str("</profile>");
279 }
280
281 buf.push_str(&format!("\n\n{}", MemoryService::memory_prompt()));
283
284 ExtBuildAgentResult {
285 prompt_addition: buf,
286 tools: tool::tool_defs(),
287 }
288}
289
290async fn handle_before_run(svc: &MemoryService, history: &[SimpleMessage]) -> ExtBeforeRunResult {
295 if !svc.auto_recall {
296 return ExtBeforeRunResult {
297 messages: Vec::new(),
298 };
299 }
300
301 let query = match history
303 .iter()
304 .rev()
305 .find(|m| m.role == "user")
306 .map(|m| &m.content)
307 {
308 Some(q) if q.len() >= 10 => q.clone(),
309 _ => {
310 return ExtBeforeRunResult {
311 messages: Vec::new(),
312 };
313 }
314 };
315
316 let result = match svc.unified_search(&query, 5).await {
317 Some(r) => r,
318 None => {
319 return ExtBeforeRunResult {
320 messages: Vec::new(),
321 };
322 }
323 };
324
325 let block = format!("<recall>\n{result}\n</recall>");
326 ExtBeforeRunResult {
327 messages: vec![SimpleMessage {
328 role: "user".to_owned(),
329 content: block,
330 }],
331 }
332}
333
334async fn handle_service_query(svc: &MemoryService, query: &str) -> String {
342 let parsed: serde_json::Value = match serde_json::from_str(query) {
343 Ok(v) => v,
344 Err(e) => return format!("invalid query JSON: {e}"),
345 };
346
347 let op = parsed["op"].as_str().unwrap_or("");
348 let default_limit = 50usize;
349
350 match op {
351 "entities" => {
352 let entity_type = parsed["entity_type"].as_str();
353 let limit = parsed["limit"]
354 .as_u64()
355 .map(|l| l as usize)
356 .unwrap_or(default_limit);
357 match svc.lance.list_entities(entity_type, limit).await {
358 Ok(entities) => {
359 let items: Vec<serde_json::Value> = entities
360 .iter()
361 .map(|e| {
362 serde_json::json!({
363 "entity_type": e.entity_type,
364 "key": e.key,
365 "value": e.value,
366 "created_at": e.created_at,
367 })
368 })
369 .collect();
370 serde_json::to_string(&items)
371 .unwrap_or_else(|e| format!("serialize error: {e}"))
372 }
373 Err(e) => format!("entities query failed: {e}"),
374 }
375 }
376 "relations" => {
377 let entity_id = parsed["entity_id"].as_str();
378 let limit = parsed["limit"]
379 .as_u64()
380 .map(|l| l as usize)
381 .unwrap_or(default_limit);
382 match svc.lance.list_relations(entity_id, limit).await {
383 Ok(relations) => {
384 let items: Vec<serde_json::Value> = relations
385 .iter()
386 .map(|r| {
387 serde_json::json!({
388 "source": r.source,
389 "relation": r.relation,
390 "target": r.target,
391 "created_at": r.created_at,
392 })
393 })
394 .collect();
395 serde_json::to_string(&items)
396 .unwrap_or_else(|e| format!("serialize error: {e}"))
397 }
398 Err(e) => format!("relations query failed: {e}"),
399 }
400 }
401 "journals" => {
402 let agent = parsed["agent"].as_str();
403 let limit = parsed["limit"]
404 .as_u64()
405 .map(|l| l as usize)
406 .unwrap_or(default_limit);
407 match svc.lance.list_journals(agent, limit).await {
408 Ok(journals) => {
409 let items: Vec<serde_json::Value> = journals
410 .iter()
411 .map(|j| {
412 serde_json::json!({
413 "summary": j.summary,
414 "agent": j.agent,
415 "created_at": j.created_at,
416 })
417 })
418 .collect();
419 serde_json::to_string(&items)
420 .unwrap_or_else(|e| format!("serialize error: {e}"))
421 }
422 Err(e) => format!("journals query failed: {e}"),
423 }
424 }
425 "search" => {
426 let query_str = parsed["query"].as_str().unwrap_or("");
427 let entity_type = parsed["entity_type"].as_str();
428 let limit = parsed["limit"]
429 .as_u64()
430 .map(|l| l as usize)
431 .unwrap_or(default_limit);
432 match svc
433 .lance
434 .search_entities(query_str, entity_type, limit)
435 .await
436 {
437 Ok(entities) => {
438 let items: Vec<serde_json::Value> = entities
439 .iter()
440 .map(|e| {
441 serde_json::json!({
442 "entity_type": e.entity_type,
443 "key": e.key,
444 "value": e.value,
445 "created_at": e.created_at,
446 })
447 })
448 .collect();
449 serde_json::to_string(&items)
450 .unwrap_or_else(|e| format!("serialize error: {e}"))
451 }
452 Err(e) => format!("search query failed: {e}"),
453 }
454 }
455 _ => format!("unknown op: '{op}'. supported: entities, relations, journals, search"),
456 }
457}
458
459async fn handle_compact(svc: &MemoryService, agent: &str) -> String {
461 let mut addition = String::new();
462 if let Ok(journals) = svc.lance.recent_journals(agent, 3).await
463 && !journals.is_empty()
464 {
465 addition.push_str("\n\nRecent conversation journals (preserve key context):\n");
466 for j in &journals {
467 let ts = chrono::DateTime::from_timestamp(j.created_at as i64, 0)
468 .map(|dt| dt.format("%Y-%m-%d %H:%M").to_string())
469 .unwrap_or_else(|| j.created_at.to_string());
470 addition.push_str(&format!("- [{ts}] {}\n", j.summary));
471 }
472 }
473 addition
474}
475
476fn build_conversation_summary(history: &[SimpleMessage]) -> String {
479 let mut conversation = String::new();
480 for msg in history {
481 let role = msg.role.as_str();
482 if msg.content.starts_with("<recall>") || role == "tool" {
483 continue;
484 }
485 conversation.push_str(&format!("[{role}] {}\n\n", msg.content));
486 }
487 conversation
488}
489
490fn extraction_messages_from(conversation: &str) -> Vec<SimpleMessage> {
492 vec![
493 SimpleMessage {
494 role: "system".to_owned(),
495 content: EXTRACT_PROMPT.to_owned(),
496 },
497 SimpleMessage {
498 role: "user".to_owned(),
499 content: format!("Extract memories from this conversation:\n\n{conversation}"),
500 },
501 ]
502}