Skip to main content

walrus_daemon/hook/
mod.rs

1//! Stateful Hook implementation for the daemon.
2//!
3//! [`DaemonHook`] composes memory, skill, MCP, and OS sub-hooks.
4//! `on_build_agent` delegates to skills and memory; `on_register_tools`
5//! delegates to all sub-hooks in sequence. `dispatch_tool` routes every
6//! agent tool call by name — the single entry point from `event.rs`.
7
8use crate::{
9    ext::hub::DownloadRegistry,
10    hook::{
11        mcp::McpHandler, memory::MemoryHook, os::PermissionConfig, skill::SkillHandler,
12        task::TaskRegistry,
13    },
14};
15use compact_str::CompactString;
16use std::{collections::BTreeMap, sync::Arc};
17use tokio::sync::Mutex;
18use wcore::{AgentConfig, AgentEvent, Hook, ToolRegistry, model::Message};
19
20pub mod mcp;
21pub mod memory;
22pub mod os;
23pub mod search;
24pub mod skill;
25pub mod task;
26
27/// Stateful Hook implementation for the daemon.
28///
29/// Composes memory, skill, MCP, and OS sub-hooks. Each sub-hook
30/// self-registers its tools via `on_register_tools`. All tool dispatch
31/// is routed through `dispatch_tool`.
32/// Per-agent scope for dispatch enforcement. Empty vecs = unrestricted.
33#[derive(Default)]
34pub(crate) struct AgentScope {
35    pub(crate) tools: Vec<CompactString>,
36    pub(crate) members: Vec<String>,
37    pub(crate) skills: Vec<String>,
38    pub(crate) mcps: Vec<String>,
39}
40
41pub struct DaemonHook {
42    pub memory: MemoryHook,
43    pub skills: SkillHandler,
44    pub mcp: McpHandler,
45    pub tasks: Arc<Mutex<TaskRegistry>>,
46    pub downloads: Arc<Mutex<DownloadRegistry>>,
47    pub permissions: PermissionConfig,
48    /// Whether the daemon is running as the `walrus` OS user (sandbox active).
49    pub sandboxed: bool,
50    /// Per-agent scope maps, populated during load_agents.
51    pub(crate) scopes: BTreeMap<CompactString, AgentScope>,
52    pub(crate) aggregator: wsearch::aggregator::Aggregator,
53    pub(crate) fetch_client: reqwest::Client,
54}
55
56/// OS tool names — bypass permission check when running in sandbox mode.
57const OS_TOOLS: &[&str] = &["read", "write", "edit", "bash"];
58
59/// Base tools always included in every agent's whitelist (memory + OS).
60const BASE_TOOLS: &[&str] = &[
61    "remember",
62    "recall",
63    "relate",
64    "connections",
65    "compact",
66    "distill",
67    "__journal__",
68    "read",
69    "write",
70    "edit",
71    "bash",
72    "web_search",
73    "web_fetch",
74];
75
76/// Skill discovery/loading tools.
77const SKILL_TOOLS: &[&str] = &["search_skill", "load_skill"];
78
79/// MCP discovery/call tools.
80const MCP_TOOLS: &[&str] = &["search_mcp", "call_mcp_tool"];
81
82/// Task delegation tools.
83const TASK_TOOLS: &[&str] = &[
84    "spawn_task",
85    "check_tasks",
86    "create_task",
87    "ask_user",
88    "await_tasks",
89];
90
91impl DaemonHook {
92    /// Create a new DaemonHook with the given backends.
93    #[allow(clippy::too_many_arguments)]
94    pub fn new(
95        memory: MemoryHook,
96        skills: SkillHandler,
97        mcp: McpHandler,
98        tasks: Arc<Mutex<TaskRegistry>>,
99        downloads: Arc<Mutex<DownloadRegistry>>,
100        permissions: PermissionConfig,
101        sandboxed: bool,
102        aggregator: wsearch::aggregator::Aggregator,
103        fetch_client: reqwest::Client,
104    ) -> Self {
105        Self {
106            memory,
107            skills,
108            mcp,
109            tasks,
110            downloads,
111            permissions,
112            sandboxed,
113            scopes: BTreeMap::new(),
114            aggregator,
115            fetch_client,
116        }
117    }
118
119    /// Register an agent's scope for dispatch enforcement.
120    pub(crate) fn register_scope(&mut self, name: CompactString, config: &AgentConfig) {
121        self.scopes.insert(
122            name,
123            AgentScope {
124                tools: config.tools.clone(),
125                members: config.members.clone(),
126                skills: config.skills.clone(),
127                mcps: config.mcps.clone(),
128            },
129        );
130    }
131
132    /// Check tool permission. Returns `Some(denied_message)` if denied,
133    /// `None` if allowed.
134    async fn check_perm(
135        &self,
136        name: &str,
137        args: &str,
138        agent: &str,
139        task_id: Option<u64>,
140    ) -> Option<String> {
141        // OS tools bypass permission when running in sandbox mode.
142        if self.sandboxed && OS_TOOLS.contains(&name) {
143            return None;
144        }
145        use crate::hook::os::ToolPermission;
146        match self.permissions.resolve(agent, name) {
147            ToolPermission::Deny => Some(format!("permission denied: {name}")),
148            ToolPermission::Ask => {
149                if let Some(tid) = task_id {
150                    let summary = if args.len() > 200 {
151                        format!("{}…", &args[..200])
152                    } else {
153                        args.to_string()
154                    };
155                    let question = format!("{name}: {summary}");
156                    let rx = self.tasks.lock().await.block(tid, question);
157                    if let Some(rx) = rx {
158                        match rx.await {
159                            Ok(resp) if resp == "denied" => {
160                                return Some(format!("permission denied: {name}"));
161                            }
162                            Err(_) => {
163                                return Some(format!("permission denied: {name} (inbox dropped)"));
164                            }
165                            _ => {} // approved → proceed
166                        }
167                    }
168                }
169                // No task_id → can't block, treat as Allow.
170                None
171            }
172            ToolPermission::Allow => None,
173        }
174    }
175
176    /// Route a tool call by name to the appropriate handler.
177    ///
178    /// This is the single dispatch entry point — `event.rs` calls this
179    /// and never matches on tool names itself. Unrecognised names are
180    /// forwarded to the MCP bridge after a warn-level log.
181    pub async fn dispatch_tool(
182        &self,
183        name: &str,
184        args: &str,
185        agent: &str,
186        task_id: Option<u64>,
187        _sender: &str,
188    ) -> String {
189        if let Some(denied) = self.check_perm(name, args, agent, task_id).await {
190            return denied;
191        }
192        // Dispatch enforcement: reject tools not in the agent's whitelist.
193        if let Some(scope) = self.scopes.get(agent)
194            && !scope.tools.is_empty()
195            && !scope.tools.iter().any(|t| t.as_str() == name)
196        {
197            return format!("tool not available: {name}");
198        }
199        match name {
200            "remember" => self.memory.dispatch_remember(args).await,
201            "recall" => self.memory.dispatch_recall(args).await,
202            "relate" => self.memory.dispatch_relate(args).await,
203            "connections" => self.memory.dispatch_connections(args).await,
204            "compact" => self.memory.dispatch_compact(agent).await,
205            "__journal__" => self.memory.dispatch_journal(args, agent).await,
206            "distill" => self.memory.dispatch_distill(args, agent).await,
207            "search_mcp" => self.dispatch_search_mcp(args, agent).await,
208            "call_mcp_tool" => self.dispatch_call_mcp_tool(args, agent).await,
209            "search_skill" => self.dispatch_search_skill(args, agent).await,
210            "load_skill" => self.dispatch_load_skill(args, agent).await,
211            "read" => self.dispatch_read(args).await,
212            "write" => self.dispatch_write(args).await,
213            "edit" => self.dispatch_edit(args).await,
214            "bash" => self.dispatch_bash(args).await,
215            "spawn_task" => self.dispatch_spawn_task(args, agent, task_id).await,
216            "check_tasks" => self.dispatch_check_tasks(args).await,
217            "create_task" => self.dispatch_create_task(args, agent).await,
218            "ask_user" => self.dispatch_ask_user(args, task_id).await,
219            "await_tasks" => self.dispatch_await_tasks(args, task_id).await,
220            "web_search" => self.dispatch_web_search(args).await,
221            "web_fetch" => self.dispatch_web_fetch(args).await,
222            name => {
223                tracing::debug!(tool = name, "forwarding tool to MCP bridge");
224                let bridge = self.mcp.bridge().await;
225                bridge.call(name, args).await
226            }
227        }
228    }
229}
230
231impl Hook for DaemonHook {
232    fn on_build_agent(&self, config: AgentConfig) -> AgentConfig {
233        let mut config = self.memory.on_build_agent(config);
234
235        // Walrus agent (empty scoping) gets all tools, no scope injection.
236        let has_scoping =
237            !config.skills.is_empty() || !config.mcps.is_empty() || !config.members.is_empty();
238        if !has_scoping {
239            return config;
240        }
241
242        // Compute tool whitelist — base tools always included.
243        let mut whitelist: Vec<CompactString> =
244            BASE_TOOLS.iter().map(|&s| CompactString::from(s)).collect();
245        let mut scope_lines = Vec::new();
246
247        // Skill tools if skills non-empty.
248        if !config.skills.is_empty() {
249            for &t in SKILL_TOOLS {
250                whitelist.push(CompactString::from(t));
251            }
252            scope_lines.push(format!("skills: {}", config.skills.join(", ")));
253        }
254
255        // MCP tools if mcps non-empty.
256        if !config.mcps.is_empty() {
257            for &t in MCP_TOOLS {
258                whitelist.push(CompactString::from(t));
259            }
260            // Also include tools from named MCP servers.
261            let mcp_servers = tokio::task::block_in_place(|| {
262                tokio::runtime::Handle::current().block_on(self.mcp.list())
263            });
264            let mut mcp_info = Vec::new();
265            for (server_name, tool_names) in &mcp_servers {
266                if config.mcps.iter().any(|m| m == server_name.as_str()) {
267                    for tn in tool_names {
268                        whitelist.push(tn.clone());
269                    }
270                    mcp_info.push(format!(
271                        "  - {}: {}",
272                        server_name,
273                        tool_names
274                            .iter()
275                            .map(|t| t.as_str())
276                            .collect::<Vec<_>>()
277                            .join(", ")
278                    ));
279                }
280            }
281            if !mcp_info.is_empty() {
282                scope_lines.push(format!("mcp servers:\n{}", mcp_info.join("\n")));
283            }
284        }
285
286        // Task tools if members non-empty.
287        if !config.members.is_empty() {
288            for &t in TASK_TOOLS {
289                whitelist.push(CompactString::from(t));
290            }
291            scope_lines.push(format!("members: {}", config.members.join(", ")));
292        }
293
294        // Inject scope info into system prompt.
295        if !scope_lines.is_empty() {
296            let scope_block = format!("\n\n<scope>\n{}\n</scope>", scope_lines.join("\n"));
297            config.system_prompt.push_str(&scope_block);
298        }
299
300        config.tools = whitelist;
301        config
302    }
303
304    fn on_compact(&self, prompt: &mut String) {
305        self.memory.on_compact(prompt);
306    }
307
308    fn on_before_run(&self, agent: &str, history: &[Message]) -> Vec<Message> {
309        self.memory.on_before_run(agent, history)
310    }
311
312    async fn on_register_tools(&self, tools: &mut ToolRegistry) {
313        self.memory.on_register_tools(tools).await;
314        self.mcp.on_register_tools(tools).await;
315        tools.insert_all(os::tool::tools());
316        tools.insert_all(search::tool::tools());
317        tools.insert_all(skill::tool::tools());
318        tools.insert_all(task::tool::tools());
319    }
320
321    fn on_event(&self, agent: &str, event: &AgentEvent) {
322        match event {
323            AgentEvent::TextDelta(text) => {
324                tracing::trace!(%agent, text_len = text.len(), "agent text delta");
325            }
326            AgentEvent::ThinkingDelta(text) => {
327                tracing::trace!(%agent, text_len = text.len(), "agent thinking delta");
328            }
329            AgentEvent::ToolCallsStart(calls) => {
330                tracing::debug!(%agent, count = calls.len(), "agent tool calls started");
331            }
332            AgentEvent::ToolResult { call_id, .. } => {
333                tracing::debug!(%agent, %call_id, "agent tool result");
334            }
335            AgentEvent::ToolCallsComplete => {
336                tracing::debug!(%agent, "agent tool calls complete");
337            }
338            AgentEvent::Done(response) => {
339                tracing::info!(
340                    %agent,
341                    iterations = response.iterations,
342                    stop_reason = ?response.stop_reason,
343                    "agent run complete"
344                );
345                // Track token usage on the active task for this agent.
346                let (prompt, completion) = response.steps.iter().fold((0u64, 0u64), |(p, c), s| {
347                    (
348                        p + u64::from(s.response.usage.prompt_tokens),
349                        c + u64::from(s.response.usage.completion_tokens),
350                    )
351                });
352                if (prompt > 0 || completion > 0)
353                    && let Ok(mut registry) = self.tasks.try_lock()
354                {
355                    let tid = registry
356                        .list(Some(agent), Some(task::TaskStatus::InProgress), None)
357                        .first()
358                        .map(|t| t.id);
359                    if let Some(tid) = tid {
360                        registry.add_tokens(tid, prompt, completion);
361                    }
362                }
363            }
364        }
365    }
366}