Skip to main content

walrus_daemon/hook/memory/
mod.rs

1//! Graph-based memory hook — owns LanceDB with entities, relations, and
2//! journals tables. Registers `remember`, `recall`, `relate`, `connections`,
3//! `compact`, and `distill` tool schemas. Journals store compaction summaries
4//! with vector embeddings for semantic search via candle (all-MiniLM-L6-v2).
5
6pub use config::MemoryConfig;
7use embedder::Embedder;
8use lance::LanceStore;
9use std::path::Path;
10use std::sync::Mutex;
11use wcore::{AgentConfig, Hook, ToolRegistry, agent::AsTool, model::Tool, paths::CONFIG_DIR};
12
13pub mod config;
14pub(crate) mod dispatch;
15pub(crate) mod embedder;
16pub(crate) mod lance;
17pub(crate) mod tool;
18
19const MEMORY_PROMPT: &str = include_str!("../../../prompts/memory.md");
20
21/// Default entity types provided by the framework.
22const DEFAULT_ENTITIES: &[&str] = &[
23    "fact",
24    "preference",
25    "person",
26    "event",
27    "concept",
28    "identity",
29    "profile",
30];
31
32/// Default relation types provided by the framework.
33const DEFAULT_RELATIONS: &[&str] = &[
34    "knows",
35    "prefers",
36    "related_to",
37    "caused_by",
38    "part_of",
39    "depends_on",
40    "tagged_with",
41];
42
43/// Graph-based memory hook owning LanceDB entity, relation, and journal storage.
44pub struct MemoryHook {
45    pub(crate) lance: LanceStore,
46    pub(crate) embedder: Mutex<Embedder>,
47    pub(crate) allowed_entities: Vec<String>,
48    pub(crate) allowed_relations: Vec<String>,
49    pub(crate) connection_limit: usize,
50}
51
52impl MemoryHook {
53    /// Create a new MemoryHook, opening or creating the LanceDB database.
54    pub async fn open(memory_dir: impl AsRef<Path>, config: &MemoryConfig) -> anyhow::Result<Self> {
55        let memory_dir = memory_dir.as_ref();
56        tokio::fs::create_dir_all(memory_dir).await?;
57        let lance_dir = memory_dir.join("lance");
58        let lance = LanceStore::open(&lance_dir).await?;
59
60        let cache_dir = CONFIG_DIR.join(".cache").join("huggingface");
61        let embedder = tokio::task::spawn_blocking(move || Embedder::load(&cache_dir)).await??;
62
63        let allowed_entities = merge_defaults(DEFAULT_ENTITIES, &config.entities);
64        let allowed_relations = merge_defaults(DEFAULT_RELATIONS, &config.relations);
65        let connection_limit = config.connections.clamp(1, 100);
66
67        Ok(Self {
68            lance,
69            embedder: Mutex::new(embedder),
70            allowed_entities,
71            allowed_relations,
72            connection_limit,
73        })
74    }
75
76    /// Check if an entity type is allowed.
77    pub(crate) fn is_valid_entity(&self, entity_type: &str) -> bool {
78        self.allowed_entities.iter().any(|t| t == entity_type)
79    }
80
81    /// Check if a relation type is allowed.
82    pub(crate) fn is_valid_relation(&self, relation: &str) -> bool {
83        self.allowed_relations.iter().any(|r| r == relation)
84    }
85
86    /// Generate an embedding vector for text. Runs candle inference in a blocking task.
87    pub(crate) async fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
88        let text = text.to_owned();
89        tokio::task::block_in_place(|| {
90            let mut embedder = self
91                .embedder
92                .lock()
93                .map_err(|e| anyhow::anyhow!("embedder lock poisoned: {e}"))?;
94            embedder.embed(&text)
95        })
96    }
97}
98
99fn merge_defaults(defaults: &[&str], extras: &[String]) -> Vec<String> {
100    let mut merged: Vec<String> = defaults.iter().map(|s| (*s).to_owned()).collect();
101    for t in extras {
102        if !merged.contains(t) {
103            merged.push(t.clone());
104        }
105    }
106    merged
107}
108
109impl Hook for MemoryHook {
110    fn on_build_agent(&self, mut config: AgentConfig) -> AgentConfig {
111        // Entity injection from LanceDB happens synchronously via a blocking
112        // read. We use tokio::task::block_in_place to avoid deadlocks since
113        // Hook::on_build_agent is not async.
114        let agent_name = config.name.to_string();
115        let lance = &self.lance;
116
117        // Inject <self> block — agent's static birth identity from config.
118        let mut self_block = String::from("\n\n<self>\n");
119        self_block.push_str(&format!("name: {}\n", config.name));
120        if !config.description.is_empty() {
121            self_block.push_str(&format!("description: {}\n", config.description));
122        }
123        self_block.push_str("</self>");
124
125        let extra = tokio::task::block_in_place(|| {
126            tokio::runtime::Handle::current().block_on(async {
127                let mut buf = self_block;
128
129                // Inject identity entities.
130                if let Ok(identities) = lance.query_by_type(&agent_name, "identity", 50).await
131                    && !identities.is_empty()
132                {
133                    buf.push_str("\n\n<identity>\n");
134                    for e in &identities {
135                        buf.push_str(&format!("- **{}**: {}\n", e.key, e.value));
136                    }
137                    buf.push_str("</identity>");
138                }
139
140                // Inject profile entities.
141                if let Ok(profiles) = lance.query_by_type(&agent_name, "profile", 50).await
142                    && !profiles.is_empty()
143                {
144                    buf.push_str("\n\n<profile>\n");
145                    for e in &profiles {
146                        buf.push_str(&format!("- **{}**: {}\n", e.key, e.value));
147                    }
148                    buf.push_str("</profile>");
149                }
150
151                // Inject recent journal entries.
152                if let Ok(journals) = lance.recent_journals(&agent_name, 3).await
153                    && !journals.is_empty()
154                {
155                    buf.push_str("\n\n<journal>\n");
156                    for j in &journals {
157                        let ts = chrono::DateTime::from_timestamp(j.created_at as i64, 0)
158                            .map(|dt| dt.format("%Y-%m-%d %H:%M").to_string())
159                            .unwrap_or_else(|| j.created_at.to_string());
160                        // Truncate summary to avoid bloating the system prompt.
161                        let summary = if j.summary.len() > 500 {
162                            format!("{}...", &j.summary[..500])
163                        } else {
164                            j.summary.clone()
165                        };
166                        buf.push_str(&format!("- **{ts}**: {summary}\n"));
167                    }
168                    buf.push_str("</journal>");
169                }
170
171                buf
172            })
173        });
174
175        if !extra.is_empty() {
176            config.system_prompt = format!("{}{extra}", config.system_prompt);
177        }
178        config.system_prompt = format!("{}\n\n{MEMORY_PROMPT}", config.system_prompt);
179        config
180    }
181
182    fn on_compact(&self, _prompt: &mut String) {
183        // This hook is unused. Identity context is passed directly in
184        // Agent::compact() which inserts the agent's system_prompt (containing
185        // <self>, <identity>, <profile>, <journal> blocks) as a user message
186        // before conversation history.
187    }
188
189    async fn on_register_tools(&self, tools: &mut ToolRegistry) {
190        // remember and relate have dynamic descriptions (inject allowed types).
191        tools.insert(Tool {
192            description: format!(
193                "Store a memory entity. Types: {}.",
194                self.allowed_entities.join(", ")
195            )
196            .into(),
197            ..tool::Remember::as_tool()
198        });
199        tools.insert(tool::Recall::as_tool());
200        tools.insert(Tool {
201            description: format!(
202                "Create a directed relation between two entities by key. Relations: {}.",
203                self.allowed_relations.join(", ")
204            )
205            .into(),
206            ..tool::Relate::as_tool()
207        });
208        tools.insert(tool::Connections::as_tool());
209        tools.insert(tool::Compact::as_tool());
210        tools.insert(tool::Distill::as_tool());
211    }
212}