Skip to main content

walrus_memory/
dispatch.rs

1//! MemoryService — standalone graph-based memory service owning LanceDB
2//! entity, relation, and journal storage with candle embeddings.
3
4use crate::{
5    config::MemoryConfig,
6    embedder::Embedder,
7    lance::{Direction, EntityRow, LanceStore, RelationRow},
8};
9use std::{path::Path, sync::Mutex};
10const MEMORY_PROMPT: &str = include_str!("../prompts/memory.md");
11
12/// Graph-based memory service owning LanceDB entity, relation, and journal storage.
13pub struct MemoryService {
14    pub lance: LanceStore,
15    pub embedder: Mutex<Embedder>,
16    pub auto_recall: bool,
17}
18
19impl MemoryService {
20    /// Create a new MemoryService, opening or creating the LanceDB database.
21    pub async fn open(memory_dir: impl AsRef<Path>, config: &MemoryConfig) -> anyhow::Result<Self> {
22        let memory_dir = memory_dir.as_ref();
23        tokio::fs::create_dir_all(memory_dir).await?;
24
25        // Load embedder first — needed for entity vector backfill during open.
26        let cache_dir = wcore::paths::CONFIG_DIR.join(".cache").join("huggingface");
27        let embedder = tokio::task::spawn_blocking(move || Embedder::load(&cache_dir)).await??;
28
29        let lance_dir = memory_dir.join("lance");
30        let embed_mutex = Mutex::new(embedder);
31        let lance = LanceStore::open(&lance_dir, |text| {
32            let mut emb = embed_mutex
33                .lock()
34                .map_err(|e| anyhow::anyhow!("embedder lock poisoned: {e}"))?;
35            emb.embed(text)
36        })
37        .await?;
38
39        Ok(Self {
40            lance,
41            embedder: embed_mutex,
42            auto_recall: config.auto_recall,
43        })
44    }
45
46    /// Generate an embedding vector for text. Runs candle inference in a blocking task.
47    pub async fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
48        let text = text.to_owned();
49        tokio::task::block_in_place(|| {
50            let mut embedder = self
51                .embedder
52                .lock()
53                .map_err(|e| anyhow::anyhow!("embedder lock poisoned: {e}"))?;
54            embedder.embed(&text)
55        })
56    }
57
58    /// Return the memory prompt to append to agent system prompts.
59    pub fn memory_prompt() -> &'static str {
60        MEMORY_PROMPT
61    }
62
63    // ── Tool dispatch methods ────────────────────────────────────────
64
65    /// Unified search: embed query → semantic entity search → 1-hop graph on top-3 → format.
66    ///
67    /// Shared by `dispatch_recall` (per query) and `handle_before_run` (auto-recall).
68    /// Returns `None` when no results are found.
69    pub async fn unified_search(&self, query: &str, limit: usize) -> Option<String> {
70        let vector = match self.embed(query).await {
71            Ok(v) => v,
72            Err(e) => {
73                tracing::warn!("embed failed for search: {e}");
74                return None;
75            }
76        };
77
78        let mut lines = Vec::new();
79
80        // Semantic entity search.
81        let entities = self
82            .lance
83            .search_entities_semantic(&vector, None, limit)
84            .await
85            .unwrap_or_default();
86        for e in &entities {
87            lines.push(format!("[{}] {}: {}", e.entity_type, e.key, e.value));
88        }
89
90        // 1-hop connections for top-3 matched entities.
91        for e in entities.iter().take(3) {
92            if let Ok(rels) = self
93                .lance
94                .find_connections(&e.id, None, Direction::Both, 5)
95                .await
96            {
97                for r in &rels {
98                    let line = format!("{} -[{}]-> {}", r.source, r.relation, r.target);
99                    if !lines.contains(&line) {
100                        lines.push(line);
101                    }
102                }
103            }
104        }
105
106        if lines.is_empty() {
107            None
108        } else {
109            Some(lines.join("\n"))
110        }
111    }
112
113    /// Dispatch the `recall` tool call — batch queries via `unified_search`.
114    pub async fn dispatch_recall(&self, args: &str) -> String {
115        let input: crate::tool::Recall = match serde_json::from_str(args) {
116            Ok(v) => v,
117            Err(e) => return format!("invalid arguments: {e}"),
118        };
119        if input.queries.is_empty() {
120            return "missing required field: queries".to_owned();
121        }
122        let limit = input.limit.unwrap_or(5) as usize;
123
124        let mut sections = Vec::new();
125        for query in &input.queries {
126            if query.is_empty() {
127                continue;
128            }
129            if let Some(result) = self.unified_search(query, limit).await {
130                sections.push(format!("## {query}\n{result}"));
131            }
132        }
133
134        if sections.is_empty() {
135            "no results found".to_owned()
136        } else {
137            sections.join("\n\n")
138        }
139    }
140
141    /// Dispatch the `extract` tool call — batch upsert entities + relations.
142    pub async fn dispatch_extract(&self, args: &str) -> String {
143        let input: crate::tool::Extract = match serde_json::from_str(args) {
144            Ok(v) => v,
145            Err(e) => return format!("invalid arguments: {e}"),
146        };
147
148        let mut results = Vec::new();
149
150        // Upsert entities.
151        for entity in &input.entities {
152            if entity.key.is_empty() {
153                results.push("skipped entity: empty key".to_owned());
154                continue;
155            }
156            let entity_type = entity.entity_type.as_deref().unwrap_or("fact");
157            let id = entity_id(entity_type, &entity.key);
158            let text = format!("{} {}", entity.key, entity.value);
159            let vector = match self.embed(&text).await {
160                Ok(v) => v,
161                Err(e) => {
162                    results.push(format!("failed to embed '{}': {e}", entity.key));
163                    continue;
164                }
165            };
166            let row = EntityRow {
167                id: &id,
168                entity_type,
169                key: &entity.key,
170                value: &entity.value,
171                vector,
172            };
173            match self.lance.upsert_entity(&row).await {
174                Ok(()) => results.push(format!("stored [{}] {}", entity_type, entity.key)),
175                Err(e) => results.push(format!("failed '{}': {e}", entity.key)),
176            }
177        }
178
179        // Upsert relations.
180        for rel in &input.relations {
181            if rel.source.is_empty() || rel.target.is_empty() || rel.relation.is_empty() {
182                results.push("skipped relation: empty field".to_owned());
183                continue;
184            }
185
186            // Look up source entity.
187            let source = match self.lance.find_entity_by_key(&rel.source).await {
188                Ok(Some(e)) => e,
189                Ok(None) => {
190                    results.push(format!("source not found: '{}'", rel.source));
191                    continue;
192                }
193                Err(e) => {
194                    results.push(format!("source lookup failed: {e}"));
195                    continue;
196                }
197            };
198
199            // Look up target entity.
200            let target = match self.lance.find_entity_by_key(&rel.target).await {
201                Ok(Some(e)) => e,
202                Ok(None) => {
203                    results.push(format!("target not found: '{}'", rel.target));
204                    continue;
205                }
206                Err(e) => {
207                    results.push(format!("target lookup failed: {e}"));
208                    continue;
209                }
210            };
211
212            let row = RelationRow {
213                source: &source.id,
214                relation: &rel.relation,
215                target: &target.id,
216            };
217            match self.lance.upsert_relation(&row).await {
218                Ok(()) => results.push(format!(
219                    "related: {} -[{}]-> {}",
220                    rel.source, rel.relation, rel.target
221                )),
222                Err(e) => results.push(format!("relation failed: {e}")),
223            }
224        }
225
226        if results.is_empty() {
227            "nothing to extract".to_owned()
228        } else {
229            results.join("\n")
230        }
231    }
232
233    /// Internal dispatch for storing a journal entry.
234    ///
235    /// Called by the agent loop after compaction — `args` is the raw summary text.
236    pub async fn dispatch_journal(&self, args: &str, agent: &str) -> String {
237        if args.is_empty() {
238            return "empty journal entry".to_owned();
239        }
240
241        let vector = match self.embed(args).await {
242            Ok(v) => v,
243            Err(e) => return format!("failed to embed journal: {e}"),
244        };
245
246        match self.lance.insert_journal(agent, args, vector).await {
247            Ok(()) => "journal entry stored".to_owned(),
248            Err(e) => format!("failed to store journal: {e}"),
249        }
250    }
251}
252
253/// Build entity ID: `{entity_type}:{key}`.
254fn entity_id(entity_type: &str, key: &str) -> String {
255    format!("{entity_type}:{key}")
256}