Skip to main content

synaps_cli/extensions/hooks/
mod.rs

1//! HookBus — the central dispatcher for extension hooks.
2//!
3//! The HookBus holds registered handlers and dispatches typed events to them.
4//! Without any handlers, `emit()` is a no-op fast path (<1µs).
5//!
6//! Tool-specific hooks filter by tool name before dispatching.
7
8pub mod events;
9
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14use tokio::sync::RwLock;
15
16use self::events::{HookEvent, HookKind, HookResult};
17use crate::extensions::manifest::HookMatcher;
18use crate::extensions::permissions::PermissionSet;
19
20/// Default timeout for a single hook handler call.
21const HANDLER_TIMEOUT: Duration = Duration::from_secs(5);
22
23fn extensions_trace_enabled() -> bool {
24    std::env::var("SYNAPS_EXTENSIONS_TRACE")
25        .map(|value| {
26            let normalized = value.trim().to_ascii_lowercase();
27            matches!(normalized.as_str(), "1" | "true" | "yes" | "on")
28        })
29        .unwrap_or(false)
30}
31
32fn hook_result_action(result: &HookResult) -> &'static str {
33    match result {
34        HookResult::Continue => "continue",
35        HookResult::Block { .. } => "block",
36        HookResult::Inject { .. } => "inject",
37        HookResult::Confirm { .. } => "confirm",
38        HookResult::Modify { .. } => "modify",
39    }
40}
41
42/// A registered hook handler with its metadata.
43#[derive(Clone)]
44pub struct HandlerRegistration {
45    /// The extension handler.
46    pub handler: Arc<dyn crate::extensions::runtime::ExtensionHandler>,
47    /// Optional tool name filter (None = all tools).
48    pub tool_filter: Option<String>,
49    /// Optional matcher for event payloads.
50    pub matcher: Option<HookMatcher>,
51    /// Permissions granted to this handler's extension.
52    pub permissions: PermissionSet,
53}
54
55/// The central hook dispatcher.
56///
57/// Thread-safe: uses `RwLock` so multiple concurrent emitters can read
58/// the handler list, and registration takes a write lock only briefly.
59pub struct HookBus {
60    handlers: RwLock<HashMap<HookKind, Vec<HandlerRegistration>>>,
61}
62
63impl HookBus {
64    /// Create an empty HookBus with no handlers.
65    pub fn new() -> Self {
66        Self {
67            handlers: RwLock::new(HashMap::new()),
68        }
69    }
70
71    /// Register a handler for a specific hook kind.
72    ///
73    /// Returns an error if the handler's permissions don't allow
74    /// subscribing to this hook kind.
75    pub async fn subscribe(
76        &self,
77        kind: HookKind,
78        handler: Arc<dyn crate::extensions::runtime::ExtensionHandler>,
79        tool_filter: Option<String>,
80        matcher: Option<HookMatcher>,
81        permissions: PermissionSet,
82    ) -> Result<(), String> {
83        // Permission check
84        if !permissions.allows_hook(kind) {
85            return Err(format!(
86                "Extension '{}' lacks permission '{}' required for hook '{}'",
87                handler.id(),
88                kind.required_permission().as_str(),
89                kind.as_str(),
90            ));
91        }
92
93        let reg = HandlerRegistration {
94            handler,
95            tool_filter,
96            matcher,
97            permissions,
98        };
99
100        let mut handlers = self.handlers.write().await;
101        handlers.entry(kind).or_default().push(reg);
102        Ok(())
103    }
104
105    /// Emit a hook event to all registered handlers.
106    ///
107    /// Returns the first `Block` result if any handler blocks, otherwise
108    /// returns `Continue`. Handlers are called in registration order.
109    ///
110    /// If no handlers are registered for this hook, returns immediately
111    /// (the no-extensions fast path).
112    pub async fn emit(&self, event: &HookEvent) -> HookResult {
113        // Snapshot the handler list and drop the lock immediately.
114        // This prevents holding the RwLock across async handler calls
115        // (which could block subscribe/unsubscribe for the entire
116        // duration of IPC round-trips to extension processes).
117        let registrations = {
118            let handlers = self.handlers.read().await;
119            match handlers.get(&event.kind) {
120                Some(regs) if !regs.is_empty() => regs.clone(),
121                _ => return HookResult::Continue, // fast path: no handlers
122            }
123        }; // lock dropped here
124
125        // Collect injections from all handlers rather than returning on first
126        let mut injections: Vec<String> = Vec::new();
127
128        for reg in &registrations {
129            // Tool-specific filter: skip handlers that don't match.
130            // Check both API name and runtime name so MCP tools with
131            // sanitized names (slashes→underscores) still match.
132            if let Some(ref filter) = reg.tool_filter {
133                let matches = match (&event.tool_name, &event.tool_runtime_name) {
134                    (Some(api), Some(runtime)) => filter == api || filter == runtime,
135                    (Some(api), None) => filter == api,
136                    (None, Some(runtime)) => filter == runtime,
137                    (None, None) => false,
138                };
139                if !matches {
140                    continue;
141                }
142            }
143
144            if let Some(ref matcher) = reg.matcher {
145                if !matcher.matches(event) {
146                    continue;
147                }
148            }
149
150            // Call handler with timeout
151            let handler = reg.handler.clone();
152            let event_clone = event.clone();
153            let trace_enabled = extensions_trace_enabled();
154            let started_at = trace_enabled.then(Instant::now);
155            let result = tokio::time::timeout(
156                HANDLER_TIMEOUT,
157                handler.handle(&event_clone),
158            )
159            .await;
160
161            if trace_enabled {
162                let health = reg.handler.health().await;
163                let health = health.as_str();
164                let restart_count = reg.handler.restart_count().await;
165                let duration_ms = started_at
166                    .map(|start| start.elapsed().as_millis() as u64)
167                    .unwrap_or(0);
168                match &result {
169                    Ok(hook_result) => {
170                        let action = hook_result_action(hook_result);
171                        tracing::info!(
172                            extension_trace = true,
173                            hook = %event.kind.as_str(),
174                            extension = %reg.handler.id(),
175                            action = action,
176                            duration_ms = duration_ms,
177                            health = health,
178                            restart_count = restart_count,
179                            "Extension hook trace"
180                        );
181                    }
182                    Err(_) => {
183                        tracing::warn!(
184                            extension_trace = true,
185                            hook = %event.kind.as_str(),
186                            extension = %reg.handler.id(),
187                            action = "timeout",
188                            duration_ms = duration_ms,
189                            timeout_secs = HANDLER_TIMEOUT.as_secs(),
190                            health = health,
191                            restart_count = restart_count,
192                            "Extension hook trace"
193                        );
194                    }
195                }
196            }
197
198            match result {
199                Ok(result) if !event.kind.allows_result(&result) => {
200                    tracing::warn!(
201                        hook = %event.kind.as_str(),
202                        extension = %reg.handler.id(),
203                        action = hook_result_action(&result),
204                        "Extension returned action not allowed for hook — ignoring"
205                    );
206                    continue;
207                }
208                Ok(HookResult::Block { reason }) => {
209                    tracing::info!(
210                        hook = %event.kind.as_str(),
211                        extension = %reg.handler.id(),
212                        reason = %reason,
213                        "Hook blocked by extension"
214                    );
215                    return HookResult::Block { reason };
216                }
217                Ok(HookResult::Continue) => {}
218                Ok(HookResult::Inject { content }) => {
219                    tracing::debug!(
220                        hook = %event.kind.as_str(),
221                        extension = %reg.handler.id(),
222                        len = content.len(),
223                        "Extension injected context"
224                    );
225                    // Accumulate — don't early-return. Multiple extensions can inject.
226                    injections.push(content);
227                }
228                Ok(HookResult::Modify { input }) => {
229                    tracing::info!(
230                        hook = %event.kind.as_str(),
231                        extension = %reg.handler.id(),
232                        "Hook modified tool input by extension"
233                    );
234                    return HookResult::Modify { input };
235                }
236                Ok(HookResult::Confirm { message }) => {
237                    tracing::info!(
238                        hook = %event.kind.as_str(),
239                        extension = %reg.handler.id(),
240                        "Hook requested confirmation by extension"
241                    );
242                    return HookResult::Confirm { message };
243                }
244                Err(_timeout) => {
245                    tracing::warn!(
246                        hook = %event.kind.as_str(),
247                        extension = %reg.handler.id(),
248                        timeout_secs = HANDLER_TIMEOUT.as_secs(),
249                        "Hook handler timed out — skipping"
250                    );
251                    // Fail-open: timeout = continue
252                }
253            }
254        }
255
256        // Merge accumulated injections from all handlers
257        if !injections.is_empty() {
258            HookResult::Inject {
259                content: injections.join("\n\n"),
260            }
261        } else {
262            HookResult::Continue
263        }
264    }
265
266    /// Remove all handlers for a given extension ID.
267    pub async fn unsubscribe_all(&self, extension_id: &str) {
268        let mut handlers = self.handlers.write().await;
269        for regs in handlers.values_mut() {
270            regs.retain(|r| r.handler.id() != extension_id);
271        }
272    }
273
274    /// Number of registered handlers across all hooks.
275    pub async fn handler_count(&self) -> usize {
276        let handlers = self.handlers.read().await;
277        handlers.values().map(|v| v.len()).sum()
278    }
279
280    /// Check if any handlers are registered (for fast-path decisions).
281    pub async fn is_empty(&self) -> bool {
282        let handlers = self.handlers.read().await;
283        handlers.values().all(|v| v.is_empty())
284    }
285
286    /// Return all (kind, tool_filter) pairs subscribed by the given extension id.
287    /// Sorted by kind name, then by tool_filter (None first), for stable output.
288    pub async fn subscriptions_for(&self, extension_id: &str) -> Vec<(HookKind, Option<String>)> {
289        let handlers = self.handlers.read().await;
290        let mut out: Vec<(HookKind, Option<String>)> = Vec::new();
291        for (kind, regs) in handlers.iter() {
292            for reg in regs {
293                if reg.handler.id() == extension_id {
294                    out.push((*kind, reg.tool_filter.clone()));
295                }
296            }
297        }
298        out.sort_by(|a, b| {
299            a.0.as_str()
300                .cmp(b.0.as_str())
301                .then_with(|| a.1.cmp(&b.1))
302        });
303        out
304    }
305}
306
307impl Default for HookBus {
308    fn default() -> Self {
309        Self::new()
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use crate::extensions::hooks::events::HookEvent;
317    use crate::extensions::permissions::Permission;
318    use async_trait::async_trait;
319    use std::sync::atomic::{AtomicUsize, Ordering};
320
321    /// Test handler that counts calls and returns a configurable result.
322    struct TestHandler {
323        id: String,
324        call_count: AtomicUsize,
325        result: HookResult,
326    }
327
328    impl TestHandler {
329        fn new(id: &str, result: HookResult) -> Arc<Self> {
330            Arc::new(Self {
331                id: id.to_string(),
332                call_count: AtomicUsize::new(0),
333                result,
334            })
335        }
336
337        fn calls(&self) -> usize {
338            self.call_count.load(Ordering::Relaxed)
339        }
340    }
341
342    #[async_trait]
343    impl crate::extensions::runtime::ExtensionHandler for TestHandler {
344        fn id(&self) -> &str {
345            &self.id
346        }
347
348        async fn handle(&self, _event: &HookEvent) -> HookResult {
349            self.call_count.fetch_add(1, Ordering::Relaxed);
350            self.result.clone()
351        }
352
353        async fn shutdown(&self) {}
354    }
355
356    fn perms_with(perms: &[Permission]) -> PermissionSet {
357        let mut set = PermissionSet::new();
358        for p in perms {
359            set.grant(*p);
360        }
361        set
362    }
363
364    #[test]
365    fn trace_env_value_parser_accepts_common_truthy_values() {
366        for value in ["1", "true", "TRUE", "yes", "on"] {
367            std::env::set_var("SYNAPS_EXTENSIONS_TRACE", value);
368            assert!(extensions_trace_enabled(), "{value} should enable trace mode");
369        }
370
371        for value in ["", "0", "false", "off", "no"] {
372            std::env::set_var("SYNAPS_EXTENSIONS_TRACE", value);
373            assert!(!extensions_trace_enabled(), "{value:?} should not enable trace mode");
374        }
375        std::env::remove_var("SYNAPS_EXTENSIONS_TRACE");
376    }
377
378    #[tokio::test]
379    async fn matcher_skips_handler_when_input_does_not_contain_value() {
380        let bus = HookBus::new();
381        let handler = TestHandler::new("matcher", HookResult::Block { reason: "matched".into() });
382        let mut perms = PermissionSet::new();
383        perms.grant(Permission::ToolsIntercept);
384        bus.subscribe(
385            HookKind::BeforeToolCall,
386            handler.clone(),
387            None,
388            Some(HookMatcher {
389                input_contains: Some("danger".to_string()),
390                input_equals: None,
391            }),
392            perms,
393        ).await.unwrap();
394
395        let safe = HookEvent::before_tool_call("bash", serde_json::json!({"command": "echo safe"}));
396        assert!(matches!(bus.emit(&safe).await, HookResult::Continue));
397
398        let danger = HookEvent::before_tool_call("bash", serde_json::json!({"command": "echo danger"}));
399        assert!(matches!(bus.emit(&danger).await, HookResult::Block { .. }));
400    }
401
402    #[test]
403    fn hook_result_action_names_are_stable_for_trace_logs() {
404        assert_eq!(hook_result_action(&HookResult::Continue), "continue");
405        assert_eq!(
406            hook_result_action(&HookResult::Block {
407                reason: "stop".into(),
408            }),
409            "block"
410        );
411        assert_eq!(
412            hook_result_action(&HookResult::Inject {
413                content: "context".into(),
414            }),
415            "inject"
416        );
417        assert_eq!(
418            hook_result_action(&HookResult::Confirm {
419                message: "Proceed?".into(),
420            }),
421            "confirm"
422        );
423        assert_eq!(
424            hook_result_action(&HookResult::Modify {
425                input: serde_json::json!({"command": "echo safe"}),
426            }),
427            "modify"
428        );
429    }
430
431    #[tokio::test]
432    async fn empty_bus_returns_continue() {
433        let bus = HookBus::new();
434        let event = HookEvent::before_tool_call("bash", serde_json::json!({}));
435        let result = bus.emit(&event).await;
436        assert!(matches!(result, HookResult::Continue));
437    }
438
439    #[tokio::test]
440    async fn handler_receives_events() {
441        let bus = HookBus::new();
442        let handler = TestHandler::new("test-ext", HookResult::Continue);
443        let perms = perms_with(&[Permission::ToolsIntercept]);
444
445        bus.subscribe(HookKind::BeforeToolCall, handler.clone(), None, None, perms)
446            .await
447            .unwrap();
448
449        let event = HookEvent::before_tool_call("bash", serde_json::json!({"command": "ls"}));
450        bus.emit(&event).await;
451
452        assert_eq!(handler.calls(), 1);
453    }
454
455    #[tokio::test]
456    async fn confirm_stops_chain_for_before_tool_call() {
457        let bus = HookBus::new();
458        let confirmer = TestHandler::new("confirmer", HookResult::Confirm {
459            message: "Run this command?".into(),
460        });
461        let after = TestHandler::new("after", HookResult::Continue);
462        let perms = perms_with(&[Permission::ToolsIntercept]);
463
464        bus.subscribe(HookKind::BeforeToolCall, confirmer.clone(), None, None, perms.clone())
465            .await
466            .unwrap();
467        bus.subscribe(HookKind::BeforeToolCall, after.clone(), None, None, perms)
468            .await
469            .unwrap();
470
471        let event = HookEvent::before_tool_call("bash", serde_json::json!({}));
472        let result = bus.emit(&event).await;
473
474        assert!(matches!(result, HookResult::Confirm { .. }));
475        assert_eq!(confirmer.calls(), 1);
476        assert_eq!(after.calls(), 0);
477    }
478
479    #[tokio::test]
480    async fn confirm_is_ignored_for_non_tool_hooks() {
481        let bus = HookBus::new();
482        let confirmer = TestHandler::new("confirmer", HookResult::Confirm {
483            message: "Not allowed here".into(),
484        });
485        let perms = perms_with(&[Permission::LlmContent]);
486
487        bus.subscribe(HookKind::BeforeMessage, confirmer.clone(), None, None, perms)
488            .await
489            .unwrap();
490
491        let event = HookEvent::before_message("hello");
492        let result = bus.emit(&event).await;
493
494        assert!(matches!(result, HookResult::Continue));
495        assert_eq!(confirmer.calls(), 1);
496    }
497
498    #[tokio::test]
499    async fn block_stops_chain() {
500        let bus = HookBus::new();
501        let blocker = TestHandler::new("blocker", HookResult::Block {
502            reason: "dangerous".into(),
503        });
504        let after = TestHandler::new("after", HookResult::Continue);
505        let perms = perms_with(&[Permission::ToolsIntercept]);
506
507        bus.subscribe(HookKind::BeforeToolCall, blocker.clone(), None, None, perms.clone())
508            .await
509            .unwrap();
510        bus.subscribe(HookKind::BeforeToolCall, after.clone(), None, None, perms)
511            .await
512            .unwrap();
513
514        let event = HookEvent::before_tool_call("bash", serde_json::json!({}));
515        let result = bus.emit(&event).await;
516
517        assert!(matches!(result, HookResult::Block { .. }));
518        assert_eq!(blocker.calls(), 1);
519        assert_eq!(after.calls(), 0); // never reached
520    }
521
522    #[tokio::test]
523    async fn modify_stops_chain() {
524        let bus = HookBus::new();
525        let modifier = TestHandler::new("modifier", HookResult::Modify {
526            input: serde_json::json!({"command": "echo safe"}),
527        });
528        let after = TestHandler::new("after", HookResult::Block {
529            reason: "should not run".into(),
530        });
531        let perms = perms_with(&[Permission::ToolsIntercept]);
532
533        bus.subscribe(HookKind::BeforeToolCall, modifier.clone(), None, None, perms.clone())
534            .await
535            .unwrap();
536        bus.subscribe(HookKind::BeforeToolCall, after.clone(), None, None, perms)
537            .await
538            .unwrap();
539
540        let event = HookEvent::before_tool_call("bash", serde_json::json!({"command": "rm -rf /"}));
541        let result = bus.emit(&event).await;
542
543        assert!(matches!(result, HookResult::Modify { input } if input == serde_json::json!({"command": "echo safe"})));
544        assert_eq!(modifier.calls(), 1);
545        assert_eq!(after.calls(), 0); // never reached
546    }
547
548    #[tokio::test]
549    async fn tool_filter_only_matches_specified_tool() {
550        let bus = HookBus::new();
551        let handler = TestHandler::new("bash-only", HookResult::Continue);
552        let perms = perms_with(&[Permission::ToolsIntercept]);
553
554        bus.subscribe(
555            HookKind::AfterToolCall,
556            handler.clone(),
557            Some("bash".into()),
558            None,
559            perms,
560        )
561        .await
562        .unwrap();
563
564        // Should NOT fire for 'read' tool
565        let event = HookEvent::after_tool_call("read", serde_json::json!({}), "content".into());
566        bus.emit(&event).await;
567        assert_eq!(handler.calls(), 0);
568
569        // SHOULD fire for 'bash' tool
570        let event = HookEvent::after_tool_call("bash", serde_json::json!({}), "output".into());
571        bus.emit(&event).await;
572        assert_eq!(handler.calls(), 1);
573    }
574
575    #[tokio::test]
576    async fn permission_denied_rejects_subscribe() {
577        let bus = HookBus::new();
578        let handler = TestHandler::new("no-perms", HookResult::Continue);
579        let perms = PermissionSet::new(); // empty — no permissions
580
581        let result = bus
582            .subscribe(HookKind::BeforeToolCall, handler, None, None, perms)
583            .await;
584
585        assert!(result.is_err());
586        assert!(result.unwrap_err().contains("lacks permission"));
587    }
588
589    #[tokio::test]
590    async fn unsubscribe_removes_handlers() {
591        let bus = HookBus::new();
592        let handler = TestHandler::new("removable", HookResult::Continue);
593        let perms = perms_with(&[Permission::ToolsIntercept]);
594
595        bus.subscribe(HookKind::BeforeToolCall, handler.clone(), None, None, perms)
596            .await
597            .unwrap();
598        assert_eq!(bus.handler_count().await, 1);
599
600        bus.unsubscribe_all("removable").await;
601        assert_eq!(bus.handler_count().await, 0);
602    }
603
604    #[tokio::test]
605    async fn subscriptions_for_lists_only_matching_extension() {
606        let bus = HookBus::new();
607        let alpha = TestHandler::new("alpha", HookResult::Continue);
608        let beta = TestHandler::new("beta", HookResult::Continue);
609        let perms = perms_with(&[Permission::ToolsIntercept]);
610
611        bus.subscribe(HookKind::BeforeToolCall, alpha.clone(), Some("bash".into()), None, perms.clone())
612            .await
613            .unwrap();
614        bus.subscribe(HookKind::AfterToolCall, alpha.clone(), None, None, perms.clone())
615            .await
616            .unwrap();
617        bus.subscribe(HookKind::BeforeToolCall, beta.clone(), None, None, perms)
618            .await
619            .unwrap();
620
621        let alpha_subs = bus.subscriptions_for("alpha").await;
622        assert_eq!(alpha_subs.len(), 2);
623        // sorted by kind name then by tool_filter (None first)
624        assert_eq!(alpha_subs[0].0, HookKind::AfterToolCall);
625        assert_eq!(alpha_subs[0].1, None);
626        assert_eq!(alpha_subs[1].0, HookKind::BeforeToolCall);
627        assert_eq!(alpha_subs[1].1, Some("bash".to_string()));
628
629        let beta_subs = bus.subscriptions_for("beta").await;
630        assert_eq!(beta_subs, vec![(HookKind::BeforeToolCall, None)]);
631
632        let none_subs = bus.subscriptions_for("ghost").await;
633        assert!(none_subs.is_empty());
634    }
635
636    #[tokio::test]
637    async fn is_empty_reflects_state() {
638        let bus = HookBus::new();
639        assert!(bus.is_empty().await);
640
641        let handler = TestHandler::new("ext", HookResult::Continue);
642        let perms = perms_with(&[Permission::ToolsIntercept]);
643        bus.subscribe(HookKind::BeforeToolCall, handler, None, None, perms)
644            .await
645            .unwrap();
646        assert!(!bus.is_empty().await);
647    }
648}