walrus_daemon/hook/memory/
mod.rs1pub use config::MemoryConfig;
7use embedder::Embedder;
8use lance::LanceStore;
9use std::path::Path;
10use std::sync::Mutex;
11use wcore::{
12 AgentConfig, Hook, ToolRegistry,
13 agent::AsTool,
14 model::{Message, Role, Tool},
15 paths::CONFIG_DIR,
16};
17
18pub mod config;
19pub(crate) mod dispatch;
20pub(crate) mod embedder;
21pub(crate) mod lance;
22pub(crate) mod tool;
23
24const MEMORY_PROMPT: &str = include_str!("../../../prompts/memory.md");
25
26const DEFAULT_ENTITIES: &[&str] = &[
28 "fact",
29 "preference",
30 "person",
31 "event",
32 "concept",
33 "identity",
34 "profile",
35];
36
37const DEFAULT_RELATIONS: &[&str] = &[
39 "knows",
40 "prefers",
41 "related_to",
42 "caused_by",
43 "part_of",
44 "depends_on",
45 "tagged_with",
46];
47
48pub struct MemoryHook {
50 pub(crate) lance: LanceStore,
51 pub(crate) embedder: Mutex<Embedder>,
52 pub(crate) allowed_entities: Vec<String>,
53 pub(crate) allowed_relations: Vec<String>,
54 pub(crate) connection_limit: usize,
55 pub(crate) auto_recall: bool,
56}
57
58impl MemoryHook {
59 pub async fn open(memory_dir: impl AsRef<Path>, config: &MemoryConfig) -> anyhow::Result<Self> {
61 let memory_dir = memory_dir.as_ref();
62 tokio::fs::create_dir_all(memory_dir).await?;
63
64 let cache_dir = CONFIG_DIR.join(".cache").join("huggingface");
66 let embedder = tokio::task::spawn_blocking(move || Embedder::load(&cache_dir)).await??;
67
68 let lance_dir = memory_dir.join("lance");
69 let embed_mutex = Mutex::new(embedder);
70 let lance = LanceStore::open(&lance_dir, |text| {
71 let mut emb = embed_mutex
72 .lock()
73 .map_err(|e| anyhow::anyhow!("embedder lock poisoned: {e}"))?;
74 emb.embed(text)
75 })
76 .await?;
77
78 let allowed_entities = merge_defaults(DEFAULT_ENTITIES, &config.entities);
79 let allowed_relations = merge_defaults(DEFAULT_RELATIONS, &config.relations);
80 let connection_limit = config.connections.clamp(1, 100);
81
82 Ok(Self {
83 lance,
84 embedder: embed_mutex,
85 allowed_entities,
86 allowed_relations,
87 connection_limit,
88 auto_recall: config.auto_recall,
89 })
90 }
91
92 pub(crate) fn is_valid_entity(&self, entity_type: &str) -> bool {
94 self.allowed_entities.iter().any(|t| t == entity_type)
95 }
96
97 pub(crate) fn is_valid_relation(&self, relation: &str) -> bool {
99 self.allowed_relations.iter().any(|r| r == relation)
100 }
101
102 pub(crate) async fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
104 let text = text.to_owned();
105 tokio::task::block_in_place(|| {
106 let mut embedder = self
107 .embedder
108 .lock()
109 .map_err(|e| anyhow::anyhow!("embedder lock poisoned: {e}"))?;
110 embedder.embed(&text)
111 })
112 }
113}
114
115fn truncate_utf8(s: &str, max_bytes: usize) -> String {
117 if s.len() <= max_bytes {
118 return s.to_owned();
119 }
120 let mut end = max_bytes;
122 while end > 0 && !s.is_char_boundary(end) {
123 end -= 1;
124 }
125 format!("{}...", &s[..end])
126}
127
128fn merge_defaults(defaults: &[&str], extras: &[String]) -> Vec<String> {
129 let mut merged: Vec<String> = defaults.iter().map(|s| (*s).to_owned()).collect();
130 for t in extras {
131 if !merged.contains(t) {
132 merged.push(t.clone());
133 }
134 }
135 merged
136}
137
138impl Hook for MemoryHook {
139 fn on_build_agent(&self, mut config: AgentConfig) -> AgentConfig {
140 let agent_name = config.name.to_string();
144 let lance = &self.lance;
145
146 let mut self_block = String::from("\n\n<self>\n");
148 self_block.push_str(&format!("name: {}\n", config.name));
149 if !config.description.is_empty() {
150 self_block.push_str(&format!("description: {}\n", config.description));
151 }
152 self_block.push_str("</self>");
153
154 let extra = tokio::task::block_in_place(|| {
155 tokio::runtime::Handle::current().block_on(async {
156 let mut buf = self_block;
157
158 if let Ok(identities) = lance.query_by_type("identity", 50).await
160 && !identities.is_empty()
161 {
162 buf.push_str("\n\n<identity>\n");
163 for e in &identities {
164 buf.push_str(&format!("- **{}**: {}\n", e.key, e.value));
165 }
166 buf.push_str("</identity>");
167 }
168
169 if let Ok(profiles) = lance.query_by_type("profile", 50).await
171 && !profiles.is_empty()
172 {
173 buf.push_str("\n\n<profile>\n");
174 for e in &profiles {
175 buf.push_str(&format!("- **{}**: {}\n", e.key, e.value));
176 }
177 buf.push_str("</profile>");
178 }
179
180 if let Ok(journals) = lance.recent_journals(&agent_name, 3).await
182 && !journals.is_empty()
183 {
184 buf.push_str("\n\n<journal>\n");
185 for j in &journals {
186 let ts = chrono::DateTime::from_timestamp(j.created_at as i64, 0)
187 .map(|dt| dt.format("%Y-%m-%d %H:%M").to_string())
188 .unwrap_or_else(|| j.created_at.to_string());
189 let summary = truncate_utf8(&j.summary, 500);
191 buf.push_str(&format!("- **{ts}**: {summary}\n"));
192 }
193 buf.push_str("</journal>");
194 }
195
196 buf
197 })
198 });
199
200 if !extra.is_empty() {
201 config.system_prompt = format!("{}{extra}", config.system_prompt);
202 }
203 config.system_prompt = format!("{}\n\n{MEMORY_PROMPT}", config.system_prompt);
204 config
205 }
206
207 fn on_before_run(&self, agent: &str, history: &[Message]) -> Vec<Message> {
208 if !self.auto_recall {
209 return Vec::new();
210 }
211
212 let query = match history.iter().rev().find(|m| m.role == Role::User) {
214 Some(m) if m.content.len() >= 10 => &m.content,
215 _ => return Vec::new(),
216 };
217
218 let lance = &self.lance;
219 let agent = agent.to_owned();
220 let query = query.clone();
221
222 tokio::task::block_in_place(|| {
223 tokio::runtime::Handle::current().block_on(async {
224 let mut lines = Vec::new();
225
226 let vector = match self.embed(&query).await {
228 Ok(v) => v,
229 Err(e) => {
230 tracing::warn!("auto-recall embed failed: {e}");
231 return Vec::new();
232 }
233 };
234
235 let entities = lance
237 .search_entities_semantic(&vector, None, 5)
238 .await
239 .unwrap_or_default();
240 for e in &entities {
241 lines.push(format!("[{}] {}: {}", e.entity_type, e.key, e.value));
242 }
243
244 for e in entities.iter().take(3) {
246 if let Ok(rels) = lance
247 .find_connections(&e.id, None, lance::Direction::Both, 5)
248 .await
249 {
250 for r in &rels {
251 let line = format!("{} -[{}]-> {}", r.source, r.relation, r.target);
252 if !lines.contains(&line) {
253 lines.push(line);
254 }
255 }
256 }
257 }
258
259 if let Ok(journals) = lance.search_journals(&vector, &agent, 2).await {
261 for j in &journals {
262 let ts = chrono::DateTime::from_timestamp(j.created_at as i64, 0)
263 .map(|dt| dt.format("%Y-%m-%d %H:%M").to_string())
264 .unwrap_or_else(|| j.created_at.to_string());
265 let summary = truncate_utf8(&j.summary, 300);
266 lines.push(format!("[journal {ts}] {summary}"));
267 }
268 }
269
270 if lines.is_empty() {
271 return Vec::new();
272 }
273
274 let block = format!("<recall>\n{}\n</recall>", lines.join("\n"));
275 vec![Message::user(block)]
276 })
277 })
278 }
279
280 fn on_compact(&self, _prompt: &mut String) {
281 }
286
287 async fn on_register_tools(&self, tools: &mut ToolRegistry) {
288 tools.insert(Tool {
290 description: format!(
291 "Store a memory entity. Types: {}.",
292 self.allowed_entities.join(", ")
293 )
294 .into(),
295 ..tool::Remember::as_tool()
296 });
297 tools.insert(tool::Recall::as_tool());
298 tools.insert(Tool {
299 description: format!(
300 "Create a directed relation between two entities by key. Relations: {}.",
301 self.allowed_relations.join(", ")
302 )
303 .into(),
304 ..tool::Relate::as_tool()
305 });
306 tools.insert(tool::Connections::as_tool());
307 tools.insert(tool::Compact::as_tool());
308 tools.insert(tool::Distill::as_tool());
309 }
310}