Skip to main content

walrus_daemon/hook/
mod.rs

1//! Stateful Hook implementation for the daemon.
2//!
3//! [`DaemonHook`] composes memory, skill, MCP, and OS sub-hooks.
4//! `on_build_agent` delegates to skills and memory; `on_register_tools`
5//! delegates to all sub-hooks in sequence. `dispatch_tool` routes every
6//! agent tool call by name — the single entry point from `event.rs`.
7
8use crate::hook::{
9    mcp::McpHandler, memory::MemoryHook, os::PermissionConfig, skill::SkillHandler,
10    task::TaskRegistry,
11};
12use compact_str::CompactString;
13use std::{collections::BTreeMap, sync::Arc};
14use tokio::sync::Mutex;
15use wcore::{AgentConfig, AgentEvent, Hook, ToolRegistry};
16
17pub mod mcp;
18pub mod memory;
19pub mod os;
20pub mod skill;
21pub mod task;
22
23/// Stateful Hook implementation for the daemon.
24///
25/// Composes memory, skill, MCP, and OS sub-hooks. Each sub-hook
26/// self-registers its tools via `on_register_tools`. All tool dispatch
27/// is routed through `dispatch_tool`.
28/// Per-agent scope for dispatch enforcement. Empty vecs = unrestricted.
29#[derive(Default)]
30pub(crate) struct AgentScope {
31    pub(crate) tools: Vec<CompactString>,
32    pub(crate) members: Vec<String>,
33    pub(crate) skills: Vec<String>,
34    pub(crate) mcps: Vec<String>,
35}
36
37pub struct DaemonHook {
38    pub memory: MemoryHook,
39    pub skills: SkillHandler,
40    pub mcp: McpHandler,
41    pub tasks: Arc<Mutex<TaskRegistry>>,
42    pub permissions: PermissionConfig,
43    /// Whether the daemon is running as the `walrus` OS user (sandbox active).
44    pub sandboxed: bool,
45    /// Per-agent scope maps, populated during load_agents.
46    pub(crate) scopes: BTreeMap<CompactString, AgentScope>,
47}
48
49/// OS tool names — bypass permission check when running in sandbox mode.
50const OS_TOOLS: &[&str] = &["read", "write", "edit", "bash"];
51
52/// Base tools always included in every agent's whitelist (memory + OS).
53const BASE_TOOLS: &[&str] = &[
54    "remember",
55    "recall",
56    "relate",
57    "connections",
58    "compact",
59    "distill",
60    "__journal__",
61    "read",
62    "write",
63    "edit",
64    "bash",
65];
66
67/// Skill discovery/loading tools.
68const SKILL_TOOLS: &[&str] = &["search_skill", "load_skill"];
69
70/// MCP discovery/call tools.
71const MCP_TOOLS: &[&str] = &["search_mcp", "call_mcp_tool"];
72
73/// Task delegation tools.
74const TASK_TOOLS: &[&str] = &[
75    "spawn_task",
76    "check_tasks",
77    "create_task",
78    "ask_user",
79    "await_tasks",
80];
81
82impl DaemonHook {
83    /// Create a new DaemonHook with the given backends.
84    pub fn new(
85        memory: MemoryHook,
86        skills: SkillHandler,
87        mcp: McpHandler,
88        tasks: Arc<Mutex<TaskRegistry>>,
89        permissions: PermissionConfig,
90        sandboxed: bool,
91    ) -> Self {
92        Self {
93            memory,
94            skills,
95            mcp,
96            tasks,
97            permissions,
98            sandboxed,
99            scopes: BTreeMap::new(),
100        }
101    }
102
103    /// Register an agent's scope for dispatch enforcement.
104    pub(crate) fn register_scope(&mut self, name: CompactString, config: &AgentConfig) {
105        self.scopes.insert(
106            name,
107            AgentScope {
108                tools: config.tools.clone(),
109                members: config.members.clone(),
110                skills: config.skills.clone(),
111                mcps: config.mcps.clone(),
112            },
113        );
114    }
115
116    /// Check tool permission. Returns `Some(denied_message)` if denied,
117    /// `None` if allowed.
118    async fn check_perm(
119        &self,
120        name: &str,
121        args: &str,
122        agent: &str,
123        task_id: Option<u64>,
124    ) -> Option<String> {
125        // OS tools bypass permission when running in sandbox mode.
126        if self.sandboxed && OS_TOOLS.contains(&name) {
127            return None;
128        }
129        use crate::hook::os::ToolPermission;
130        match self.permissions.resolve(agent, name) {
131            ToolPermission::Deny => Some(format!("permission denied: {name}")),
132            ToolPermission::Ask => {
133                if let Some(tid) = task_id {
134                    let summary = if args.len() > 200 {
135                        format!("{}…", &args[..200])
136                    } else {
137                        args.to_string()
138                    };
139                    let question = format!("{name}: {summary}");
140                    let rx = self.tasks.lock().await.block(tid, question);
141                    if let Some(rx) = rx {
142                        match rx.await {
143                            Ok(resp) if resp == "denied" => {
144                                return Some(format!("permission denied: {name}"));
145                            }
146                            Err(_) => {
147                                return Some(format!("permission denied: {name} (inbox dropped)"));
148                            }
149                            _ => {} // approved → proceed
150                        }
151                    }
152                }
153                // No task_id → can't block, treat as Allow.
154                None
155            }
156            ToolPermission::Allow => None,
157        }
158    }
159
160    /// Route a tool call by name to the appropriate handler.
161    ///
162    /// This is the single dispatch entry point — `event.rs` calls this
163    /// and never matches on tool names itself. Unrecognised names are
164    /// forwarded to the MCP bridge after a warn-level log.
165    pub async fn dispatch_tool(
166        &self,
167        name: &str,
168        args: &str,
169        agent: &str,
170        task_id: Option<u64>,
171        sender: &str,
172    ) -> String {
173        if let Some(denied) = self.check_perm(name, args, agent, task_id).await {
174            return denied;
175        }
176        // Dispatch enforcement: reject tools not in the agent's whitelist.
177        if let Some(scope) = self.scopes.get(agent)
178            && !scope.tools.is_empty()
179            && !scope.tools.iter().any(|t| t.as_str() == name)
180        {
181            return format!("tool not available: {name}");
182        }
183        match name {
184            "remember" => self.memory.dispatch_remember(args, agent, sender).await,
185            "recall" => self.memory.dispatch_recall(args, agent, sender).await,
186            "relate" => self.memory.dispatch_relate(args, agent, sender).await,
187            "connections" => self.memory.dispatch_connections(args, agent, sender).await,
188            "compact" => self.memory.dispatch_compact(agent).await,
189            "__journal__" => self.memory.dispatch_journal(args, agent).await,
190            "distill" => self.memory.dispatch_distill(args, agent).await,
191            "search_mcp" => self.dispatch_search_mcp(args, agent).await,
192            "call_mcp_tool" => self.dispatch_call_mcp_tool(args, agent).await,
193            "search_skill" => self.dispatch_search_skill(args, agent).await,
194            "load_skill" => self.dispatch_load_skill(args, agent).await,
195            "read" => self.dispatch_read(args).await,
196            "write" => self.dispatch_write(args).await,
197            "edit" => self.dispatch_edit(args).await,
198            "bash" => self.dispatch_bash(args).await,
199            "spawn_task" => self.dispatch_spawn_task(args, agent, task_id).await,
200            "check_tasks" => self.dispatch_check_tasks(args).await,
201            "create_task" => self.dispatch_create_task(args, agent).await,
202            "ask_user" => self.dispatch_ask_user(args, task_id).await,
203            "await_tasks" => self.dispatch_await_tasks(args, task_id).await,
204            name => {
205                tracing::debug!(tool = name, "forwarding tool to MCP bridge");
206                let bridge = self.mcp.bridge().await;
207                bridge.call(name, args).await
208            }
209        }
210    }
211}
212
213impl Hook for DaemonHook {
214    fn on_build_agent(&self, config: AgentConfig) -> AgentConfig {
215        let mut config = self.memory.on_build_agent(config);
216
217        // Walrus agent (empty scoping) gets all tools, no scope injection.
218        let has_scoping =
219            !config.skills.is_empty() || !config.mcps.is_empty() || !config.members.is_empty();
220        if !has_scoping {
221            return config;
222        }
223
224        // Compute tool whitelist — base tools always included.
225        let mut whitelist: Vec<CompactString> =
226            BASE_TOOLS.iter().map(|&s| CompactString::from(s)).collect();
227        let mut scope_lines = Vec::new();
228
229        // Skill tools if skills non-empty.
230        if !config.skills.is_empty() {
231            for &t in SKILL_TOOLS {
232                whitelist.push(CompactString::from(t));
233            }
234            scope_lines.push(format!("skills: {}", config.skills.join(", ")));
235        }
236
237        // MCP tools if mcps non-empty.
238        if !config.mcps.is_empty() {
239            for &t in MCP_TOOLS {
240                whitelist.push(CompactString::from(t));
241            }
242            // Also include tools from named MCP servers.
243            let mcp_servers = tokio::task::block_in_place(|| {
244                tokio::runtime::Handle::current().block_on(self.mcp.list())
245            });
246            let mut mcp_info = Vec::new();
247            for (server_name, tool_names) in &mcp_servers {
248                if config.mcps.iter().any(|m| m == server_name.as_str()) {
249                    for tn in tool_names {
250                        whitelist.push(tn.clone());
251                    }
252                    mcp_info.push(format!(
253                        "  - {}: {}",
254                        server_name,
255                        tool_names
256                            .iter()
257                            .map(|t| t.as_str())
258                            .collect::<Vec<_>>()
259                            .join(", ")
260                    ));
261                }
262            }
263            if !mcp_info.is_empty() {
264                scope_lines.push(format!("mcp servers:\n{}", mcp_info.join("\n")));
265            }
266        }
267
268        // Task tools if members non-empty.
269        if !config.members.is_empty() {
270            for &t in TASK_TOOLS {
271                whitelist.push(CompactString::from(t));
272            }
273            scope_lines.push(format!("members: {}", config.members.join(", ")));
274        }
275
276        // Inject scope info into system prompt.
277        if !scope_lines.is_empty() {
278            let scope_block = format!("\n\n<scope>\n{}\n</scope>", scope_lines.join("\n"));
279            config.system_prompt.push_str(&scope_block);
280        }
281
282        config.tools = whitelist;
283        config
284    }
285
286    fn on_compact(&self, prompt: &mut String) {
287        self.memory.on_compact(prompt);
288    }
289
290    async fn on_register_tools(&self, tools: &mut ToolRegistry) {
291        self.memory.on_register_tools(tools).await;
292        self.mcp.on_register_tools(tools).await;
293        tools.insert_all(os::tool::tools());
294        tools.insert_all(skill::tool::tools());
295        tools.insert_all(task::tool::tools());
296    }
297
298    fn on_event(&self, agent: &str, event: &AgentEvent) {
299        match event {
300            AgentEvent::TextDelta(text) => {
301                tracing::trace!(%agent, text_len = text.len(), "agent text delta");
302            }
303            AgentEvent::ThinkingDelta(text) => {
304                tracing::trace!(%agent, text_len = text.len(), "agent thinking delta");
305            }
306            AgentEvent::ToolCallsStart(calls) => {
307                tracing::debug!(%agent, count = calls.len(), "agent tool calls started");
308            }
309            AgentEvent::ToolResult { call_id, .. } => {
310                tracing::debug!(%agent, %call_id, "agent tool result");
311            }
312            AgentEvent::ToolCallsComplete => {
313                tracing::debug!(%agent, "agent tool calls complete");
314            }
315            AgentEvent::Done(response) => {
316                tracing::info!(
317                    %agent,
318                    iterations = response.iterations,
319                    stop_reason = ?response.stop_reason,
320                    "agent run complete"
321                );
322                // Track token usage on the active task for this agent.
323                let (prompt, completion) = response.steps.iter().fold((0u64, 0u64), |(p, c), s| {
324                    (
325                        p + u64::from(s.response.usage.prompt_tokens),
326                        c + u64::from(s.response.usage.completion_tokens),
327                    )
328                });
329                if (prompt > 0 || completion > 0)
330                    && let Ok(mut registry) = self.tasks.try_lock()
331                {
332                    let tid = registry
333                        .list(Some(agent), Some(task::TaskStatus::InProgress), None)
334                        .first()
335                        .map(|t| t.id);
336                    if let Some(tid) = tid {
337                        registry.add_tokens(tid, prompt, completion);
338                    }
339                }
340            }
341        }
342    }
343}