walrus_daemon/hook/memory/
mod.rs1pub use config::MemoryConfig;
7use embedder::Embedder;
8use lance::LanceStore;
9use std::path::Path;
10use std::sync::Mutex;
11use wcore::{AgentConfig, Hook, ToolRegistry, agent::AsTool, model::Tool, paths::CONFIG_DIR};
12
13pub mod config;
14pub(crate) mod dispatch;
15pub(crate) mod embedder;
16pub(crate) mod lance;
17pub(crate) mod tool;
18
19const MEMORY_PROMPT: &str = include_str!("../../../prompts/memory.md");
20
21const DEFAULT_ENTITIES: &[&str] = &[
23 "fact",
24 "preference",
25 "person",
26 "event",
27 "concept",
28 "identity",
29 "profile",
30];
31
32const DEFAULT_RELATIONS: &[&str] = &[
34 "knows",
35 "prefers",
36 "related_to",
37 "caused_by",
38 "part_of",
39 "depends_on",
40 "tagged_with",
41];
42
43pub struct MemoryHook {
45 pub(crate) lance: LanceStore,
46 pub(crate) embedder: Mutex<Embedder>,
47 pub(crate) allowed_entities: Vec<String>,
48 pub(crate) allowed_relations: Vec<String>,
49 pub(crate) connection_limit: usize,
50}
51
52impl MemoryHook {
53 pub async fn open(memory_dir: impl AsRef<Path>, config: &MemoryConfig) -> anyhow::Result<Self> {
55 let memory_dir = memory_dir.as_ref();
56 tokio::fs::create_dir_all(memory_dir).await?;
57 let lance_dir = memory_dir.join("lance");
58 let lance = LanceStore::open(&lance_dir).await?;
59
60 let cache_dir = CONFIG_DIR.join(".cache").join("huggingface");
61 let embedder = tokio::task::spawn_blocking(move || Embedder::load(&cache_dir)).await??;
62
63 let allowed_entities = merge_defaults(DEFAULT_ENTITIES, &config.entities);
64 let allowed_relations = merge_defaults(DEFAULT_RELATIONS, &config.relations);
65 let connection_limit = config.connections.clamp(1, 100);
66
67 Ok(Self {
68 lance,
69 embedder: Mutex::new(embedder),
70 allowed_entities,
71 allowed_relations,
72 connection_limit,
73 })
74 }
75
76 pub(crate) fn is_valid_entity(&self, entity_type: &str) -> bool {
78 self.allowed_entities.iter().any(|t| t == entity_type)
79 }
80
81 pub(crate) fn is_valid_relation(&self, relation: &str) -> bool {
83 self.allowed_relations.iter().any(|r| r == relation)
84 }
85
86 pub(crate) async fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
88 let text = text.to_owned();
89 tokio::task::block_in_place(|| {
90 let mut embedder = self
91 .embedder
92 .lock()
93 .map_err(|e| anyhow::anyhow!("embedder lock poisoned: {e}"))?;
94 embedder.embed(&text)
95 })
96 }
97}
98
99fn merge_defaults(defaults: &[&str], extras: &[String]) -> Vec<String> {
100 let mut merged: Vec<String> = defaults.iter().map(|s| (*s).to_owned()).collect();
101 for t in extras {
102 if !merged.contains(t) {
103 merged.push(t.clone());
104 }
105 }
106 merged
107}
108
109impl Hook for MemoryHook {
110 fn on_build_agent(&self, mut config: AgentConfig) -> AgentConfig {
111 let agent_name = config.name.to_string();
115 let lance = &self.lance;
116
117 let mut self_block = String::from("\n\n<self>\n");
119 self_block.push_str(&format!("name: {}\n", config.name));
120 if !config.description.is_empty() {
121 self_block.push_str(&format!("description: {}\n", config.description));
122 }
123 self_block.push_str("</self>");
124
125 let extra = tokio::task::block_in_place(|| {
126 tokio::runtime::Handle::current().block_on(async {
127 let mut buf = self_block;
128
129 if let Ok(identities) = lance.query_by_type(&agent_name, "identity", 50).await
131 && !identities.is_empty()
132 {
133 buf.push_str("\n\n<identity>\n");
134 for e in &identities {
135 buf.push_str(&format!("- **{}**: {}\n", e.key, e.value));
136 }
137 buf.push_str("</identity>");
138 }
139
140 if let Ok(profiles) = lance.query_by_type(&agent_name, "profile", 50).await
142 && !profiles.is_empty()
143 {
144 buf.push_str("\n\n<profile>\n");
145 for e in &profiles {
146 buf.push_str(&format!("- **{}**: {}\n", e.key, e.value));
147 }
148 buf.push_str("</profile>");
149 }
150
151 if let Ok(journals) = lance.recent_journals(&agent_name, 3).await
153 && !journals.is_empty()
154 {
155 buf.push_str("\n\n<journal>\n");
156 for j in &journals {
157 let ts = chrono::DateTime::from_timestamp(j.created_at as i64, 0)
158 .map(|dt| dt.format("%Y-%m-%d %H:%M").to_string())
159 .unwrap_or_else(|| j.created_at.to_string());
160 let summary = if j.summary.len() > 500 {
162 format!("{}...", &j.summary[..500])
163 } else {
164 j.summary.clone()
165 };
166 buf.push_str(&format!("- **{ts}**: {summary}\n"));
167 }
168 buf.push_str("</journal>");
169 }
170
171 buf
172 })
173 });
174
175 if !extra.is_empty() {
176 config.system_prompt = format!("{}{extra}", config.system_prompt);
177 }
178 config.system_prompt = format!("{}\n\n{MEMORY_PROMPT}", config.system_prompt);
179 config
180 }
181
182 fn on_compact(&self, _prompt: &mut String) {
183 }
188
189 async fn on_register_tools(&self, tools: &mut ToolRegistry) {
190 tools.insert(Tool {
192 description: format!(
193 "Store a memory entity. Types: {}.",
194 self.allowed_entities.join(", ")
195 )
196 .into(),
197 ..tool::Remember::as_tool()
198 });
199 tools.insert(tool::Recall::as_tool());
200 tools.insert(Tool {
201 description: format!(
202 "Create a directed relation between two entities by key. Relations: {}.",
203 self.allowed_relations.join(", ")
204 )
205 .into(),
206 ..tool::Relate::as_tool()
207 });
208 tools.insert(tool::Connections::as_tool());
209 tools.insert(tool::Compact::as_tool());
210 tools.insert(tool::Distill::as_tool());
211 }
212}