Skip to main content

walrus_daemon/hook/memory/
mod.rs

1//! Graph-based memory hook — owns LanceDB with entities, relations, and
2//! journals tables. Registers `remember`, `recall`, `relate`, `connections`,
3//! `compact`, and `distill` tool schemas. Journals store compaction summaries
4//! with vector embeddings for semantic search via candle (all-MiniLM-L6-v2).
5
6pub use config::MemoryConfig;
7use embedder::Embedder;
8use lance::LanceStore;
9use std::path::Path;
10use std::sync::Mutex;
11use wcore::{
12    AgentConfig, Hook, ToolRegistry,
13    agent::AsTool,
14    model::{Message, Role, Tool},
15    paths::CONFIG_DIR,
16};
17
18pub mod config;
19pub(crate) mod dispatch;
20pub(crate) mod embedder;
21pub(crate) mod lance;
22pub(crate) mod tool;
23
24const MEMORY_PROMPT: &str = include_str!("../../../prompts/memory.md");
25
26/// Default entity types provided by the framework.
27const DEFAULT_ENTITIES: &[&str] = &[
28    "fact",
29    "preference",
30    "person",
31    "event",
32    "concept",
33    "identity",
34    "profile",
35];
36
37/// Default relation types provided by the framework.
38const DEFAULT_RELATIONS: &[&str] = &[
39    "knows",
40    "prefers",
41    "related_to",
42    "caused_by",
43    "part_of",
44    "depends_on",
45    "tagged_with",
46];
47
48/// Graph-based memory hook owning LanceDB entity, relation, and journal storage.
49pub struct MemoryHook {
50    pub(crate) lance: LanceStore,
51    pub(crate) embedder: Mutex<Embedder>,
52    pub(crate) allowed_entities: Vec<String>,
53    pub(crate) allowed_relations: Vec<String>,
54    pub(crate) connection_limit: usize,
55    pub(crate) auto_recall: bool,
56}
57
58impl MemoryHook {
59    /// Create a new MemoryHook, opening or creating the LanceDB database.
60    pub async fn open(memory_dir: impl AsRef<Path>, config: &MemoryConfig) -> anyhow::Result<Self> {
61        let memory_dir = memory_dir.as_ref();
62        tokio::fs::create_dir_all(memory_dir).await?;
63
64        // Load embedder first — needed for entity vector backfill during open.
65        let cache_dir = CONFIG_DIR.join(".cache").join("huggingface");
66        let embedder = tokio::task::spawn_blocking(move || Embedder::load(&cache_dir)).await??;
67
68        let lance_dir = memory_dir.join("lance");
69        let embed_mutex = Mutex::new(embedder);
70        let lance = LanceStore::open(&lance_dir, |text| {
71            let mut emb = embed_mutex
72                .lock()
73                .map_err(|e| anyhow::anyhow!("embedder lock poisoned: {e}"))?;
74            emb.embed(text)
75        })
76        .await?;
77
78        let allowed_entities = merge_defaults(DEFAULT_ENTITIES, &config.entities);
79        let allowed_relations = merge_defaults(DEFAULT_RELATIONS, &config.relations);
80        let connection_limit = config.connections.clamp(1, 100);
81
82        Ok(Self {
83            lance,
84            embedder: embed_mutex,
85            allowed_entities,
86            allowed_relations,
87            connection_limit,
88            auto_recall: config.auto_recall,
89        })
90    }
91
92    /// Check if an entity type is allowed.
93    pub(crate) fn is_valid_entity(&self, entity_type: &str) -> bool {
94        self.allowed_entities.iter().any(|t| t == entity_type)
95    }
96
97    /// Check if a relation type is allowed.
98    pub(crate) fn is_valid_relation(&self, relation: &str) -> bool {
99        self.allowed_relations.iter().any(|r| r == relation)
100    }
101
102    /// Generate an embedding vector for text. Runs candle inference in a blocking task.
103    pub(crate) async fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
104        let text = text.to_owned();
105        tokio::task::block_in_place(|| {
106            let mut embedder = self
107                .embedder
108                .lock()
109                .map_err(|e| anyhow::anyhow!("embedder lock poisoned: {e}"))?;
110            embedder.embed(&text)
111        })
112    }
113}
114
115/// Truncate a string at a UTF-8 safe boundary, appending "..." if truncated.
116fn truncate_utf8(s: &str, max_bytes: usize) -> String {
117    if s.len() <= max_bytes {
118        return s.to_owned();
119    }
120    // Walk backward from max_bytes to find a char boundary.
121    let mut end = max_bytes;
122    while end > 0 && !s.is_char_boundary(end) {
123        end -= 1;
124    }
125    format!("{}...", &s[..end])
126}
127
128fn merge_defaults(defaults: &[&str], extras: &[String]) -> Vec<String> {
129    let mut merged: Vec<String> = defaults.iter().map(|s| (*s).to_owned()).collect();
130    for t in extras {
131        if !merged.contains(t) {
132            merged.push(t.clone());
133        }
134    }
135    merged
136}
137
138impl Hook for MemoryHook {
139    fn on_build_agent(&self, mut config: AgentConfig) -> AgentConfig {
140        // Entity injection from LanceDB happens synchronously via a blocking
141        // read. We use tokio::task::block_in_place to avoid deadlocks since
142        // Hook::on_build_agent is not async.
143        let agent_name = config.name.to_string();
144        let lance = &self.lance;
145
146        // Inject <self> block — agent's static birth identity from config.
147        let mut self_block = String::from("\n\n<self>\n");
148        self_block.push_str(&format!("name: {}\n", config.name));
149        if !config.description.is_empty() {
150            self_block.push_str(&format!("description: {}\n", config.description));
151        }
152        self_block.push_str("</self>");
153
154        let extra = tokio::task::block_in_place(|| {
155            tokio::runtime::Handle::current().block_on(async {
156                let mut buf = self_block;
157
158                // Inject identity entities (shared across all agents).
159                if let Ok(identities) = lance.query_by_type("identity", 50).await
160                    && !identities.is_empty()
161                {
162                    buf.push_str("\n\n<identity>\n");
163                    for e in &identities {
164                        buf.push_str(&format!("- **{}**: {}\n", e.key, e.value));
165                    }
166                    buf.push_str("</identity>");
167                }
168
169                // Inject profile entities (shared across all agents).
170                if let Ok(profiles) = lance.query_by_type("profile", 50).await
171                    && !profiles.is_empty()
172                {
173                    buf.push_str("\n\n<profile>\n");
174                    for e in &profiles {
175                        buf.push_str(&format!("- **{}**: {}\n", e.key, e.value));
176                    }
177                    buf.push_str("</profile>");
178                }
179
180                // Inject recent journal entries (agent-scoped).
181                if let Ok(journals) = lance.recent_journals(&agent_name, 3).await
182                    && !journals.is_empty()
183                {
184                    buf.push_str("\n\n<journal>\n");
185                    for j in &journals {
186                        let ts = chrono::DateTime::from_timestamp(j.created_at as i64, 0)
187                            .map(|dt| dt.format("%Y-%m-%d %H:%M").to_string())
188                            .unwrap_or_else(|| j.created_at.to_string());
189                        // Truncate summary to avoid bloating the system prompt.
190                        let summary = truncate_utf8(&j.summary, 500);
191                        buf.push_str(&format!("- **{ts}**: {summary}\n"));
192                    }
193                    buf.push_str("</journal>");
194                }
195
196                buf
197            })
198        });
199
200        if !extra.is_empty() {
201            config.system_prompt = format!("{}{extra}", config.system_prompt);
202        }
203        config.system_prompt = format!("{}\n\n{MEMORY_PROMPT}", config.system_prompt);
204        config
205    }
206
207    fn on_before_run(&self, agent: &str, history: &[Message]) -> Vec<Message> {
208        if !self.auto_recall {
209            return Vec::new();
210        }
211
212        // Extract the last user message as the recall query.
213        let query = match history.iter().rev().find(|m| m.role == Role::User) {
214            Some(m) if m.content.len() >= 10 => &m.content,
215            _ => return Vec::new(),
216        };
217
218        let lance = &self.lance;
219        let agent = agent.to_owned();
220        let query = query.clone();
221
222        tokio::task::block_in_place(|| {
223            tokio::runtime::Handle::current().block_on(async {
224                let mut lines = Vec::new();
225
226                // Embed the user message once; reuse for entities + journals.
227                let vector = match self.embed(&query).await {
228                    Ok(v) => v,
229                    Err(e) => {
230                        tracing::warn!("auto-recall embed failed: {e}");
231                        return Vec::new();
232                    }
233                };
234
235                // Semantic entity search.
236                let entities = lance
237                    .search_entities_semantic(&vector, None, 5)
238                    .await
239                    .unwrap_or_default();
240                for e in &entities {
241                    lines.push(format!("[{}] {}: {}", e.entity_type, e.key, e.value));
242                }
243
244                // 1-hop connections for top-3 matched entities.
245                for e in entities.iter().take(3) {
246                    if let Ok(rels) = lance
247                        .find_connections(&e.id, None, lance::Direction::Both, 5)
248                        .await
249                    {
250                        for r in &rels {
251                            let line = format!("{} -[{}]-> {}", r.source, r.relation, r.target);
252                            if !lines.contains(&line) {
253                                lines.push(line);
254                            }
255                        }
256                    }
257                }
258
259                // Semantic journal search (reuse same embedding vector).
260                if let Ok(journals) = lance.search_journals(&vector, &agent, 2).await {
261                    for j in &journals {
262                        let ts = chrono::DateTime::from_timestamp(j.created_at as i64, 0)
263                            .map(|dt| dt.format("%Y-%m-%d %H:%M").to_string())
264                            .unwrap_or_else(|| j.created_at.to_string());
265                        let summary = truncate_utf8(&j.summary, 300);
266                        lines.push(format!("[journal {ts}] {summary}"));
267                    }
268                }
269
270                if lines.is_empty() {
271                    return Vec::new();
272                }
273
274                let block = format!("<recall>\n{}\n</recall>", lines.join("\n"));
275                vec![Message::user(block)]
276            })
277        })
278    }
279
280    fn on_compact(&self, _prompt: &mut String) {
281        // This hook is unused. Identity context is passed directly in
282        // Agent::compact() which inserts the agent's system_prompt (containing
283        // <self>, <identity>, <profile>, <journal> blocks) as a user message
284        // before conversation history.
285    }
286
287    async fn on_register_tools(&self, tools: &mut ToolRegistry) {
288        // remember and relate have dynamic descriptions (inject allowed types).
289        tools.insert(Tool {
290            description: format!(
291                "Store a memory entity. Types: {}.",
292                self.allowed_entities.join(", ")
293            )
294            .into(),
295            ..tool::Remember::as_tool()
296        });
297        tools.insert(tool::Recall::as_tool());
298        tools.insert(Tool {
299            description: format!(
300                "Create a directed relation between two entities by key. Relations: {}.",
301                self.allowed_relations.join(", ")
302            )
303            .into(),
304            ..tool::Relate::as_tool()
305        });
306        tools.insert(tool::Connections::as_tool());
307        tools.insert(tool::Compact::as_tool());
308        tools.insert(tool::Distill::as_tool());
309    }
310}