Skip to main content

walrus_memory/cmd/
serve.rs

1//! Extension serve command — run walrus-memory as an extension service over UDS.
2
3use crate::{config::MemoryConfig, dispatch::MemoryService, tool};
4use std::path::Path;
5use wcore::protocol::{
6    PROTOCOL_VERSION,
7    codec::{read_message, write_message},
8    ext::{
9        AfterCompactCap, AfterRunCap, BeforeRunCap, BuildAgentCap, Capability, CompactCap,
10        EventObserverCap, ExtAfterCompactResult, ExtAfterRunResult, ExtBeforeRunResult,
11        ExtBuildAgentResult, ExtCompactResult, ExtConfigured, ExtError, ExtInferRequest, ExtReady,
12        ExtRequest, ExtResponse, ExtServiceQueryResult, ExtToolResult, ExtToolSchemas, InferCap,
13        QueryCap, SimpleMessage, ToolsList, capability, ext_request, ext_response,
14    },
15};
16
17const EXTRACT_PROMPT: &str = include_str!("../../prompts/extract.md");
18
19pub async fn run(socket: &Path) -> anyhow::Result<()> {
20    // Clean up stale socket from a previous run.
21    if socket.exists() {
22        let _ = std::fs::remove_file(socket);
23    }
24
25    let listener = tokio::net::UnixListener::bind(socket)?;
26    tracing::info!("memory service listening on {}", socket.display());
27
28    let (stream, _) = listener.accept().await?;
29    let (mut reader, mut writer) = stream.into_split();
30
31    // ── Hello → Ready ────────────────────────────────────────────────
32    let hello: ExtRequest = read_message(&mut reader).await?;
33    match hello.msg {
34        Some(ext_request::Msg::Hello(_)) => {}
35        other => anyhow::bail!("expected Hello, got {other:?}"),
36    }
37
38    let tool_names = vec!["recall".to_owned(), "extract".to_owned()];
39
40    let ready = ExtResponse {
41        msg: Some(ext_response::Msg::Ready(ExtReady {
42            version: PROTOCOL_VERSION.to_owned(),
43            service: "memory".to_owned(),
44            capabilities: vec![
45                Capability {
46                    cap: Some(capability::Cap::Tools(ToolsList { names: tool_names })),
47                },
48                Capability {
49                    cap: Some(capability::Cap::BuildAgent(BuildAgentCap {})),
50                },
51                Capability {
52                    cap: Some(capability::Cap::BeforeRun(BeforeRunCap {})),
53                },
54                Capability {
55                    cap: Some(capability::Cap::Compact(CompactCap {})),
56                },
57                Capability {
58                    cap: Some(capability::Cap::Query(QueryCap {})),
59                },
60                Capability {
61                    cap: Some(capability::Cap::EventObserver(EventObserverCap {})),
62                },
63                Capability {
64                    cap: Some(capability::Cap::AfterRun(AfterRunCap {})),
65                },
66                Capability {
67                    cap: Some(capability::Cap::AfterCompact(AfterCompactCap {})),
68                },
69                Capability {
70                    cap: Some(capability::Cap::Infer(InferCap {})),
71                },
72            ],
73        })),
74    };
75    write_message(&mut writer, &ready).await?;
76
77    // ── Configure → Configured ───────────────────────────────────────
78    let configure: ExtRequest = read_message(&mut reader).await?;
79    let config = match configure.msg {
80        Some(ext_request::Msg::Configure(c)) => {
81            if c.config.is_empty() {
82                MemoryConfig::default()
83            } else {
84                serde_json::from_str(&c.config).unwrap_or_else(|e| {
85                    tracing::warn!("invalid config, using defaults: {e}");
86                    MemoryConfig::default()
87                })
88            }
89        }
90        other => anyhow::bail!("expected Configure, got {other:?}"),
91    };
92    let configured = ExtResponse {
93        msg: Some(ext_response::Msg::Configured(ExtConfigured {})),
94    };
95    write_message(&mut writer, &configured).await?;
96
97    // ── RegisterTools → ToolSchemas ──────────────────────────────────
98    let register: ExtRequest = read_message(&mut reader).await?;
99    match register.msg {
100        Some(ext_request::Msg::RegisterTools(_)) => {}
101        other => anyhow::bail!("expected RegisterTools, got {other:?}"),
102    }
103
104    // Build the memory service before constructing dynamic tool schemas.
105    let memory_dir = wcore::paths::CONFIG_DIR.join("memory");
106    let svc = MemoryService::open(&memory_dir, &config).await?;
107
108    // All tools including internal `extract` (needed by infer_fulfill).
109    // Agent-visible filtering happens via BuildAgent response (tool_defs).
110    let tools = tool::all_tool_defs();
111    let schemas = ExtResponse {
112        msg: Some(ext_response::Msg::ToolSchemas(ExtToolSchemas { tools })),
113    };
114    write_message(&mut writer, &schemas).await?;
115    tracing::info!("handshake complete");
116
117    // ── Dispatch loop ────────────────────────────────────────────────
118    let mut clean_exit = false;
119    loop {
120        let req: ExtRequest = match read_message(&mut reader).await {
121            Ok(r) => r,
122            Err(wcore::protocol::codec::FrameError::ConnectionClosed) => {
123                tracing::warn!("daemon connection closed");
124                break;
125            }
126            Err(e) => {
127                tracing::error!("read error: {e}");
128                break;
129            }
130        };
131
132        let resp = match req.msg {
133            Some(ext_request::Msg::ToolCall(call)) => {
134                let result = dispatch_tool(&svc, &call.name, &call.args, &call.agent).await;
135                ExtResponse {
136                    msg: Some(ext_response::Msg::ToolResult(ExtToolResult { result })),
137                }
138            }
139            Some(ext_request::Msg::BuildAgent(ba)) => {
140                let result =
141                    handle_build_agent(&svc, &ba.name, &ba.description, &ba.system_prompt).await;
142                ExtResponse {
143                    msg: Some(ext_response::Msg::BuildAgentResult(result)),
144                }
145            }
146            Some(ext_request::Msg::BeforeRun(br)) => {
147                let result = handle_before_run(&svc, &br.history).await;
148                ExtResponse {
149                    msg: Some(ext_response::Msg::BeforeRunResult(result)),
150                }
151            }
152            Some(ext_request::Msg::AfterRun(ar)) => {
153                let conversation = build_conversation_summary(&ar.history);
154                // Store a journal entry — extraction moved to on_after_compact.
155                let _ = svc.dispatch_journal(&conversation, &ar.agent).await;
156                ExtResponse {
157                    msg: Some(ext_response::Msg::AfterRunResult(ExtAfterRunResult {})),
158                }
159            }
160            Some(ext_request::Msg::AfterCompact(ac)) => {
161                // Store journal from compact summary, then request extraction LLM loop.
162                let _ = svc.dispatch_journal(&ac.summary, &ac.agent).await;
163                let messages = extraction_messages_from(&ac.summary);
164                ExtResponse {
165                    msg: Some(ext_response::Msg::InferRequest(ExtInferRequest {
166                        messages,
167                    })),
168                }
169            }
170            Some(ext_request::Msg::InferResult(_)) => {
171                // Infer complete — extraction tool calls already dispatched.
172                ExtResponse {
173                    msg: Some(ext_response::Msg::AfterCompactResult(
174                        ExtAfterCompactResult {},
175                    )),
176                }
177            }
178            Some(ext_request::Msg::Compact(c)) => {
179                let addition = handle_compact(&svc, &c.agent).await;
180                ExtResponse {
181                    msg: Some(ext_response::Msg::CompactResult(ExtCompactResult {
182                        addition,
183                    })),
184                }
185            }
186            Some(ext_request::Msg::ServiceQuery(sq)) => {
187                let result = handle_service_query(&svc, &sq.query).await;
188                ExtResponse {
189                    msg: Some(ext_response::Msg::ServiceQueryResult(
190                        ExtServiceQueryResult { result },
191                    )),
192                }
193            }
194            Some(ext_request::Msg::Event(_)) => {
195                // Fire-and-forget — no response expected.
196                continue;
197            }
198            Some(ext_request::Msg::GetSchema(_)) => ExtResponse {
199                msg: Some(ext_response::Msg::Error(ExtError {
200                    message: "schema not yet implemented".into(),
201                })),
202            },
203            Some(ext_request::Msg::Shutdown(_)) => {
204                tracing::info!("shutdown requested");
205                clean_exit = true;
206                break;
207            }
208            other => ExtResponse {
209                msg: Some(ext_response::Msg::Error(ExtError {
210                    message: format!("unexpected request: {other:?}"),
211                })),
212            },
213        };
214
215        if let Err(e) = write_message(&mut writer, &resp).await {
216            tracing::error!("write error: {e}");
217            break;
218        }
219    }
220
221    // Clean up socket file.
222    let _ = std::fs::remove_file(socket);
223    if clean_exit {
224        Ok(())
225    } else {
226        anyhow::bail!("connection lost")
227    }
228}
229
230/// Dispatch a tool call to the appropriate MemoryService method.
231async fn dispatch_tool(svc: &MemoryService, name: &str, args: &str, _agent: &str) -> String {
232    match name {
233        "recall" => svc.dispatch_recall(args).await,
234        "extract" => svc.dispatch_extract(args).await,
235        _ => format!("unknown tool: {name}"),
236    }
237}
238
239/// Handle the BuildAgent lifecycle event.
240///
241/// Builds prompt additions: `<self>`, `<identity>`, `<profile>` blocks
242/// plus the memory prompt. Returns agent-visible tools only.
243async fn handle_build_agent(
244    svc: &MemoryService,
245    name: &str,
246    description: &str,
247    _system_prompt: &str,
248) -> ExtBuildAgentResult {
249    let lance = &svc.lance;
250
251    // Inject <self> block.
252    let mut buf = String::from("\n\n<self>\n");
253    buf.push_str(&format!("name: {name}\n"));
254    if !description.is_empty() {
255        buf.push_str(&format!("description: {description}\n"));
256    }
257    buf.push_str("</self>");
258
259    // Inject identity entities (shared across all agents).
260    if let Ok(identities) = lance.query_by_type("identity", 50).await
261        && !identities.is_empty()
262    {
263        buf.push_str("\n\n<identity>\n");
264        for e in &identities {
265            buf.push_str(&format!("- **{}**: {}\n", e.key, e.value));
266        }
267        buf.push_str("</identity>");
268    }
269
270    // Inject profile entities (shared across all agents).
271    if let Ok(profiles) = lance.query_by_type("profile", 50).await
272        && !profiles.is_empty()
273    {
274        buf.push_str("\n\n<profile>\n");
275        for e in &profiles {
276            buf.push_str(&format!("- **{}**: {}\n", e.key, e.value));
277        }
278        buf.push_str("</profile>");
279    }
280
281    // Append memory prompt.
282    buf.push_str(&format!("\n\n{}", MemoryService::memory_prompt()));
283
284    ExtBuildAgentResult {
285        prompt_addition: buf,
286        tools: tool::tool_defs(),
287    }
288}
289
290/// Handle the BeforeRun lifecycle event.
291///
292/// Auto-recalls relevant entities and graph connections based on
293/// the last user message via unified semantic search.
294async fn handle_before_run(svc: &MemoryService, history: &[SimpleMessage]) -> ExtBeforeRunResult {
295    if !svc.auto_recall {
296        return ExtBeforeRunResult {
297            messages: Vec::new(),
298        };
299    }
300
301    // Extract the last user message as the recall query.
302    let query = match history
303        .iter()
304        .rev()
305        .find(|m| m.role == "user")
306        .map(|m| &m.content)
307    {
308        Some(q) if q.len() >= 10 => q.clone(),
309        _ => {
310            return ExtBeforeRunResult {
311                messages: Vec::new(),
312            };
313        }
314    };
315
316    let result = match svc.unified_search(&query, 5).await {
317        Some(r) => r,
318        None => {
319            return ExtBeforeRunResult {
320                messages: Vec::new(),
321            };
322        }
323    };
324
325    let block = format!("<recall>\n{result}\n</recall>");
326    ExtBeforeRunResult {
327        messages: vec![SimpleMessage {
328            role: "user".to_owned(),
329            content: block,
330        }],
331    }
332}
333
334/// Handle a ServiceQuery — JSON-encoded query for list/search operations.
335///
336/// Supported query types (JSON):
337/// - `{"op": "entities", "entity_type": "...", "limit": N}`
338/// - `{"op": "relations", "entity_id": "...", "limit": N}`
339/// - `{"op": "journals", "agent": "...", "limit": N}`
340/// - `{"op": "search", "query": "...", "entity_type": "...", "limit": N}`
341async fn handle_service_query(svc: &MemoryService, query: &str) -> String {
342    let parsed: serde_json::Value = match serde_json::from_str(query) {
343        Ok(v) => v,
344        Err(e) => return format!("invalid query JSON: {e}"),
345    };
346
347    let op = parsed["op"].as_str().unwrap_or("");
348    let default_limit = 50usize;
349
350    match op {
351        "entities" => {
352            let entity_type = parsed["entity_type"].as_str();
353            let limit = parsed["limit"]
354                .as_u64()
355                .map(|l| l as usize)
356                .unwrap_or(default_limit);
357            match svc.lance.list_entities(entity_type, limit).await {
358                Ok(entities) => {
359                    let items: Vec<serde_json::Value> = entities
360                        .iter()
361                        .map(|e| {
362                            serde_json::json!({
363                                "entity_type": e.entity_type,
364                                "key": e.key,
365                                "value": e.value,
366                                "created_at": e.created_at,
367                            })
368                        })
369                        .collect();
370                    serde_json::to_string(&items)
371                        .unwrap_or_else(|e| format!("serialize error: {e}"))
372                }
373                Err(e) => format!("entities query failed: {e}"),
374            }
375        }
376        "relations" => {
377            let entity_id = parsed["entity_id"].as_str();
378            let limit = parsed["limit"]
379                .as_u64()
380                .map(|l| l as usize)
381                .unwrap_or(default_limit);
382            match svc.lance.list_relations(entity_id, limit).await {
383                Ok(relations) => {
384                    let items: Vec<serde_json::Value> = relations
385                        .iter()
386                        .map(|r| {
387                            serde_json::json!({
388                                "source": r.source,
389                                "relation": r.relation,
390                                "target": r.target,
391                                "created_at": r.created_at,
392                            })
393                        })
394                        .collect();
395                    serde_json::to_string(&items)
396                        .unwrap_or_else(|e| format!("serialize error: {e}"))
397                }
398                Err(e) => format!("relations query failed: {e}"),
399            }
400        }
401        "journals" => {
402            let agent = parsed["agent"].as_str();
403            let limit = parsed["limit"]
404                .as_u64()
405                .map(|l| l as usize)
406                .unwrap_or(default_limit);
407            match svc.lance.list_journals(agent, limit).await {
408                Ok(journals) => {
409                    let items: Vec<serde_json::Value> = journals
410                        .iter()
411                        .map(|j| {
412                            serde_json::json!({
413                                "summary": j.summary,
414                                "agent": j.agent,
415                                "created_at": j.created_at,
416                            })
417                        })
418                        .collect();
419                    serde_json::to_string(&items)
420                        .unwrap_or_else(|e| format!("serialize error: {e}"))
421                }
422                Err(e) => format!("journals query failed: {e}"),
423            }
424        }
425        "search" => {
426            let query_str = parsed["query"].as_str().unwrap_or("");
427            let entity_type = parsed["entity_type"].as_str();
428            let limit = parsed["limit"]
429                .as_u64()
430                .map(|l| l as usize)
431                .unwrap_or(default_limit);
432            match svc
433                .lance
434                .search_entities(query_str, entity_type, limit)
435                .await
436            {
437                Ok(entities) => {
438                    let items: Vec<serde_json::Value> = entities
439                        .iter()
440                        .map(|e| {
441                            serde_json::json!({
442                                "entity_type": e.entity_type,
443                                "key": e.key,
444                                "value": e.value,
445                                "created_at": e.created_at,
446                            })
447                        })
448                        .collect();
449                    serde_json::to_string(&items)
450                        .unwrap_or_else(|e| format!("serialize error: {e}"))
451                }
452                Err(e) => format!("search query failed: {e}"),
453            }
454        }
455        _ => format!("unknown op: '{op}'. supported: entities, relations, journals, search"),
456    }
457}
458
459/// Handle the Compact lifecycle event — inject recent journals into the prompt.
460async fn handle_compact(svc: &MemoryService, agent: &str) -> String {
461    let mut addition = String::new();
462    if let Ok(journals) = svc.lance.recent_journals(agent, 3).await
463        && !journals.is_empty()
464    {
465        addition.push_str("\n\nRecent conversation journals (preserve key context):\n");
466        for j in &journals {
467            let ts = chrono::DateTime::from_timestamp(j.created_at as i64, 0)
468                .map(|dt| dt.format("%Y-%m-%d %H:%M").to_string())
469                .unwrap_or_else(|| j.created_at.to_string());
470            addition.push_str(&format!("- [{ts}] {}\n", j.summary));
471        }
472    }
473    addition
474}
475
476/// Build a condensed conversation summary from history, skipping recall
477/// injections and tool messages.
478fn build_conversation_summary(history: &[SimpleMessage]) -> String {
479    let mut conversation = String::new();
480    for msg in history {
481        let role = msg.role.as_str();
482        if msg.content.starts_with("<recall>") || role == "tool" {
483            continue;
484        }
485        conversation.push_str(&format!("[{role}] {}\n\n", msg.content));
486    }
487    conversation
488}
489
490/// Wrap a conversation summary into extraction messages for the Infer LLM.
491fn extraction_messages_from(conversation: &str) -> Vec<SimpleMessage> {
492    vec![
493        SimpleMessage {
494            role: "system".to_owned(),
495            content: EXTRACT_PROMPT.to_owned(),
496        },
497        SimpleMessage {
498            role: "user".to_owned(),
499            content: format!("Extract memories from this conversation:\n\n{conversation}"),
500        },
501    ]
502}