1pub mod events;
9
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14use futures::future::join_all;
15use tokio::sync::RwLock;
16
17use self::events::{HookEvent, HookKind, HookResult};
18use crate::extensions::manifest::HookMatcher;
19use crate::extensions::permissions::PermissionSet;
20
21const HANDLER_TIMEOUT: Duration = Duration::from_secs(5);
23
24fn extensions_trace_enabled() -> bool {
25 std::env::var("SYNAPS_EXTENSIONS_TRACE")
26 .map(|value| {
27 let normalized = value.trim().to_ascii_lowercase();
28 matches!(normalized.as_str(), "1" | "true" | "yes" | "on")
29 })
30 .unwrap_or(false)
31}
32
33fn hook_result_action(result: &HookResult) -> &'static str {
34 match result {
35 HookResult::Continue => "continue",
36 HookResult::Block { .. } => "block",
37 HookResult::Inject { .. } => "inject",
38 HookResult::Confirm { .. } => "confirm",
39 HookResult::Modify { .. } => "modify",
40 }
41}
42
43#[derive(Clone)]
45pub struct HandlerRegistration {
46 pub handler: Arc<dyn crate::extensions::runtime::ExtensionHandler>,
48 pub tool_filter: Option<String>,
50 pub matcher: Option<HookMatcher>,
52 pub permissions: PermissionSet,
54}
55
56pub struct HookBus {
61 handlers: RwLock<HashMap<HookKind, Vec<HandlerRegistration>>>,
62}
63
64impl HookBus {
65 pub fn new() -> Self {
67 Self {
68 handlers: RwLock::new(HashMap::new()),
69 }
70 }
71
72 pub async fn subscribe(
77 &self,
78 kind: HookKind,
79 handler: Arc<dyn crate::extensions::runtime::ExtensionHandler>,
80 tool_filter: Option<String>,
81 matcher: Option<HookMatcher>,
82 permissions: PermissionSet,
83 ) -> Result<(), String> {
84 if !permissions.allows_hook(kind) {
86 return Err(format!(
87 "Extension '{}' lacks permission '{}' required for hook '{}'",
88 handler.id(),
89 kind.required_permission().as_str(),
90 kind.as_str(),
91 ));
92 }
93
94 let reg = HandlerRegistration {
95 handler,
96 tool_filter,
97 matcher,
98 permissions,
99 };
100
101 let mut handlers = self.handlers.write().await;
102 handlers.entry(kind).or_default().push(reg);
103 Ok(())
104 }
105
106 pub async fn emit(&self, event: &HookEvent) -> HookResult {
114 let registrations = {
119 let handlers = self.handlers.read().await;
120 match handlers.get(&event.kind) {
121 Some(regs) if !regs.is_empty() => regs.clone(),
122 _ => return HookResult::Continue, }
124 }; let mut injections: Vec<String> = Vec::new();
128
129 for reg in ®istrations {
130 if let Some(ref filter) = reg.tool_filter {
134 let matches = match (&event.tool_name, &event.tool_runtime_name) {
135 (Some(api), Some(runtime)) => filter == api || filter == runtime,
136 (Some(api), None) => filter == api,
137 (None, Some(runtime)) => filter == runtime,
138 (None, None) => false,
139 };
140 if !matches {
141 continue;
142 }
143 }
144
145 if let Some(ref matcher) = reg.matcher {
146 if !matcher.matches(event) {
147 continue;
148 }
149 }
150
151 let handler = reg.handler.clone();
153 let event_clone = event.clone();
154 let trace_enabled = extensions_trace_enabled();
155 let started_at = trace_enabled.then(Instant::now);
156 let result = tokio::time::timeout(
157 HANDLER_TIMEOUT,
158 handler.handle(&event_clone),
159 )
160 .await;
161
162 if trace_enabled {
163 let health = reg.handler.health().await;
164 let health = health.as_str();
165 let restart_count = reg.handler.restart_count().await;
166 let duration_ms = started_at
167 .map(|start| start.elapsed().as_millis() as u64)
168 .unwrap_or(0);
169 match &result {
170 Ok(hook_result) => {
171 let action = hook_result_action(hook_result);
172 tracing::info!(
173 extension_trace = true,
174 hook = %event.kind.as_str(),
175 extension = %reg.handler.id(),
176 action = action,
177 duration_ms = duration_ms,
178 health = health,
179 restart_count = restart_count,
180 "Extension hook trace"
181 );
182 }
183 Err(_) => {
184 tracing::warn!(
185 extension_trace = true,
186 hook = %event.kind.as_str(),
187 extension = %reg.handler.id(),
188 action = "timeout",
189 duration_ms = duration_ms,
190 timeout_secs = HANDLER_TIMEOUT.as_secs(),
191 health = health,
192 restart_count = restart_count,
193 "Extension hook trace"
194 );
195 }
196 }
197 }
198
199 match result {
200 Ok(result) if !event.kind.allows_result(&result) => {
201 tracing::warn!(
202 hook = %event.kind.as_str(),
203 extension = %reg.handler.id(),
204 action = hook_result_action(&result),
205 "Extension returned action not allowed for hook — ignoring"
206 );
207 continue;
208 }
209 Ok(HookResult::Block { reason }) => {
210 tracing::info!(
211 hook = %event.kind.as_str(),
212 extension = %reg.handler.id(),
213 reason = %reason,
214 "Hook blocked by extension"
215 );
216 return HookResult::Block { reason };
217 }
218 Ok(HookResult::Continue) => {}
219 Ok(HookResult::Inject { content }) => {
220 tracing::debug!(
221 hook = %event.kind.as_str(),
222 extension = %reg.handler.id(),
223 len = content.len(),
224 "Extension injected context"
225 );
226 injections.push(content);
228 }
229 Ok(HookResult::Modify { input }) => {
230 tracing::info!(
231 hook = %event.kind.as_str(),
232 extension = %reg.handler.id(),
233 "Hook modified tool input by extension"
234 );
235 return HookResult::Modify { input };
236 }
237 Ok(HookResult::Confirm { message }) => {
238 tracing::info!(
239 hook = %event.kind.as_str(),
240 extension = %reg.handler.id(),
241 "Hook requested confirmation by extension"
242 );
243 return HookResult::Confirm { message };
244 }
245 Err(_timeout) => {
246 tracing::warn!(
247 hook = %event.kind.as_str(),
248 extension = %reg.handler.id(),
249 timeout_secs = HANDLER_TIMEOUT.as_secs(),
250 "Hook handler timed out — skipping"
251 );
252 }
254 }
255 }
256
257 if !injections.is_empty() {
259 HookResult::Inject {
260 content: injections.join("\n\n"),
261 }
262 } else {
263 HookResult::Continue
264 }
265 }
266
267 pub async fn emit_concurrent(&self, event: &HookEvent) -> HookResult {
289 let registrations = {
291 let handlers = self.handlers.read().await;
292 match handlers.get(&event.kind) {
293 Some(regs) if !regs.is_empty() => regs.clone(),
294 _ => return HookResult::Continue, }
296 };
297
298 let futures: Vec<_> = registrations
300 .iter()
301 .filter(|reg| {
302 if let Some(ref filter) = reg.tool_filter {
304 match (&event.tool_name, &event.tool_runtime_name) {
305 (Some(api), Some(runtime)) => filter == api || filter == runtime,
306 (Some(api), None) => filter == api,
307 (None, Some(runtime)) => filter == runtime,
308 (None, None) => false,
309 }
310 } else {
311 true
312 }
313 })
314 .filter(|reg| {
315 reg.matcher.as_ref().map_or(true, |m| m.matches(event))
316 })
317 .map(|reg| {
318 let handler = reg.handler.clone();
319 let event_clone = event.clone();
320 async move {
321 tokio::time::timeout(HANDLER_TIMEOUT, handler.handle(&event_clone)).await
322 }
323 })
324 .collect();
325
326 let results = join_all(futures).await;
327
328 let mut injections: Vec<String> = Vec::new();
329 for result in results {
330 match result {
331 Ok(HookResult::Continue) => {}
332 Ok(HookResult::Block { reason }) => {
333 return HookResult::Block { reason };
334 }
335 Ok(HookResult::Inject { content }) => {
336 injections.push(content);
337 }
338 Ok(HookResult::Modify { input }) => {
339 return HookResult::Modify { input };
340 }
341 Ok(HookResult::Confirm { message }) => {
342 return HookResult::Confirm { message };
343 }
344 Err(_timeout) => {
345 tracing::warn!(
346 hook = %event.kind.as_str(),
347 timeout_secs = HANDLER_TIMEOUT.as_secs(),
348 "Hook handler timed out in concurrent emit — skipping"
349 );
350 }
351 }
352 }
353
354 if !injections.is_empty() {
355 HookResult::Inject {
356 content: injections.join("\n\n"),
357 }
358 } else {
359 HookResult::Continue
360 }
361 }
362
363 pub async fn unsubscribe_all(&self, extension_id: &str) {
365 let mut handlers = self.handlers.write().await;
366 for regs in handlers.values_mut() {
367 regs.retain(|r| r.handler.id() != extension_id);
368 }
369 }
370
371 pub async fn handler_count(&self) -> usize {
373 let handlers = self.handlers.read().await;
374 handlers.values().map(|v| v.len()).sum()
375 }
376
377 pub async fn is_empty(&self) -> bool {
379 let handlers = self.handlers.read().await;
380 handlers.values().all(|v| v.is_empty())
381 }
382
383 pub async fn subscriptions_for(&self, extension_id: &str) -> Vec<(HookKind, Option<String>)> {
386 let handlers = self.handlers.read().await;
387 let mut out: Vec<(HookKind, Option<String>)> = Vec::new();
388 for (kind, regs) in handlers.iter() {
389 for reg in regs {
390 if reg.handler.id() == extension_id {
391 out.push((*kind, reg.tool_filter.clone()));
392 }
393 }
394 }
395 out.sort_by(|a, b| {
396 a.0.as_str()
397 .cmp(b.0.as_str())
398 .then_with(|| a.1.cmp(&b.1))
399 });
400 out
401 }
402}
403
404impl Default for HookBus {
405 fn default() -> Self {
406 Self::new()
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use crate::extensions::hooks::events::HookEvent;
414 use crate::extensions::permissions::Permission;
415 use async_trait::async_trait;
416 use std::sync::atomic::{AtomicUsize, Ordering};
417
418 struct TestHandler {
420 id: String,
421 call_count: AtomicUsize,
422 result: HookResult,
423 }
424
425 impl TestHandler {
426 fn new(id: &str, result: HookResult) -> Arc<Self> {
427 Arc::new(Self {
428 id: id.to_string(),
429 call_count: AtomicUsize::new(0),
430 result,
431 })
432 }
433
434 fn calls(&self) -> usize {
435 self.call_count.load(Ordering::Relaxed)
436 }
437 }
438
439 #[async_trait]
440 impl crate::extensions::runtime::ExtensionHandler for TestHandler {
441 fn id(&self) -> &str {
442 &self.id
443 }
444
445 async fn handle(&self, _event: &HookEvent) -> HookResult {
446 self.call_count.fetch_add(1, Ordering::Relaxed);
447 self.result.clone()
448 }
449
450 async fn shutdown(&self) {}
451 }
452
453 fn perms_with(perms: &[Permission]) -> PermissionSet {
454 let mut set = PermissionSet::new();
455 for p in perms {
456 set.grant(*p);
457 }
458 set
459 }
460
461 #[test]
462 fn trace_env_value_parser_accepts_common_truthy_values() {
463 for value in ["1", "true", "TRUE", "yes", "on"] {
464 std::env::set_var("SYNAPS_EXTENSIONS_TRACE", value);
465 assert!(extensions_trace_enabled(), "{value} should enable trace mode");
466 }
467
468 for value in ["", "0", "false", "off", "no"] {
469 std::env::set_var("SYNAPS_EXTENSIONS_TRACE", value);
470 assert!(!extensions_trace_enabled(), "{value:?} should not enable trace mode");
471 }
472 std::env::remove_var("SYNAPS_EXTENSIONS_TRACE");
473 }
474
475 #[tokio::test]
476 async fn matcher_skips_handler_when_input_does_not_contain_value() {
477 let bus = HookBus::new();
478 let handler = TestHandler::new("matcher", HookResult::Block { reason: "matched".into() });
479 let mut perms = PermissionSet::new();
480 perms.grant(Permission::ToolsIntercept);
481 bus.subscribe(
482 HookKind::BeforeToolCall,
483 handler.clone(),
484 None,
485 Some(HookMatcher {
486 input_contains: Some("danger".to_string()),
487 input_equals: None,
488 }),
489 perms,
490 ).await.unwrap();
491
492 let safe = HookEvent::before_tool_call("bash", serde_json::json!({"command": "echo safe"}));
493 assert!(matches!(bus.emit(&safe).await, HookResult::Continue));
494
495 let danger = HookEvent::before_tool_call("bash", serde_json::json!({"command": "echo danger"}));
496 assert!(matches!(bus.emit(&danger).await, HookResult::Block { .. }));
497 }
498
499 #[test]
500 fn hook_result_action_names_are_stable_for_trace_logs() {
501 assert_eq!(hook_result_action(&HookResult::Continue), "continue");
502 assert_eq!(
503 hook_result_action(&HookResult::Block {
504 reason: "stop".into(),
505 }),
506 "block"
507 );
508 assert_eq!(
509 hook_result_action(&HookResult::Inject {
510 content: "context".into(),
511 }),
512 "inject"
513 );
514 assert_eq!(
515 hook_result_action(&HookResult::Confirm {
516 message: "Proceed?".into(),
517 }),
518 "confirm"
519 );
520 assert_eq!(
521 hook_result_action(&HookResult::Modify {
522 input: serde_json::json!({"command": "echo safe"}),
523 }),
524 "modify"
525 );
526 }
527
528 #[tokio::test]
529 async fn empty_bus_returns_continue() {
530 let bus = HookBus::new();
531 let event = HookEvent::before_tool_call("bash", serde_json::json!({}));
532 let result = bus.emit(&event).await;
533 assert!(matches!(result, HookResult::Continue));
534 }
535
536 #[tokio::test]
537 async fn handler_receives_events() {
538 let bus = HookBus::new();
539 let handler = TestHandler::new("test-ext", HookResult::Continue);
540 let perms = perms_with(&[Permission::ToolsIntercept]);
541
542 bus.subscribe(HookKind::BeforeToolCall, handler.clone(), None, None, perms)
543 .await
544 .unwrap();
545
546 let event = HookEvent::before_tool_call("bash", serde_json::json!({"command": "ls"}));
547 bus.emit(&event).await;
548
549 assert_eq!(handler.calls(), 1);
550 }
551
552 #[tokio::test]
553 async fn confirm_stops_chain_for_before_tool_call() {
554 let bus = HookBus::new();
555 let confirmer = TestHandler::new("confirmer", HookResult::Confirm {
556 message: "Run this command?".into(),
557 });
558 let after = TestHandler::new("after", HookResult::Continue);
559 let perms = perms_with(&[Permission::ToolsIntercept]);
560
561 bus.subscribe(HookKind::BeforeToolCall, confirmer.clone(), None, None, perms.clone())
562 .await
563 .unwrap();
564 bus.subscribe(HookKind::BeforeToolCall, after.clone(), None, None, perms)
565 .await
566 .unwrap();
567
568 let event = HookEvent::before_tool_call("bash", serde_json::json!({}));
569 let result = bus.emit(&event).await;
570
571 assert!(matches!(result, HookResult::Confirm { .. }));
572 assert_eq!(confirmer.calls(), 1);
573 assert_eq!(after.calls(), 0);
574 }
575
576 #[tokio::test]
577 async fn confirm_is_ignored_for_non_tool_hooks() {
578 let bus = HookBus::new();
579 let confirmer = TestHandler::new("confirmer", HookResult::Confirm {
580 message: "Not allowed here".into(),
581 });
582 let perms = perms_with(&[Permission::LlmContent]);
583
584 bus.subscribe(HookKind::BeforeMessage, confirmer.clone(), None, None, perms)
585 .await
586 .unwrap();
587
588 let event = HookEvent::before_message("hello");
589 let result = bus.emit(&event).await;
590
591 assert!(matches!(result, HookResult::Continue));
592 assert_eq!(confirmer.calls(), 1);
593 }
594
595 #[tokio::test]
596 async fn block_stops_chain() {
597 let bus = HookBus::new();
598 let blocker = TestHandler::new("blocker", HookResult::Block {
599 reason: "dangerous".into(),
600 });
601 let after = TestHandler::new("after", HookResult::Continue);
602 let perms = perms_with(&[Permission::ToolsIntercept]);
603
604 bus.subscribe(HookKind::BeforeToolCall, blocker.clone(), None, None, perms.clone())
605 .await
606 .unwrap();
607 bus.subscribe(HookKind::BeforeToolCall, after.clone(), None, None, perms)
608 .await
609 .unwrap();
610
611 let event = HookEvent::before_tool_call("bash", serde_json::json!({}));
612 let result = bus.emit(&event).await;
613
614 assert!(matches!(result, HookResult::Block { .. }));
615 assert_eq!(blocker.calls(), 1);
616 assert_eq!(after.calls(), 0); }
618
619 #[tokio::test]
620 async fn modify_stops_chain() {
621 let bus = HookBus::new();
622 let modifier = TestHandler::new("modifier", HookResult::Modify {
623 input: serde_json::json!({"command": "echo safe"}),
624 });
625 let after = TestHandler::new("after", HookResult::Block {
626 reason: "should not run".into(),
627 });
628 let perms = perms_with(&[Permission::ToolsIntercept]);
629
630 bus.subscribe(HookKind::BeforeToolCall, modifier.clone(), None, None, perms.clone())
631 .await
632 .unwrap();
633 bus.subscribe(HookKind::BeforeToolCall, after.clone(), None, None, perms)
634 .await
635 .unwrap();
636
637 let event = HookEvent::before_tool_call("bash", serde_json::json!({"command": "rm -rf /"}));
638 let result = bus.emit(&event).await;
639
640 assert!(matches!(result, HookResult::Modify { input } if input == serde_json::json!({"command": "echo safe"})));
641 assert_eq!(modifier.calls(), 1);
642 assert_eq!(after.calls(), 0); }
644
645 #[tokio::test]
646 async fn tool_filter_only_matches_specified_tool() {
647 let bus = HookBus::new();
648 let handler = TestHandler::new("bash-only", HookResult::Continue);
649 let perms = perms_with(&[Permission::ToolsIntercept]);
650
651 bus.subscribe(
652 HookKind::AfterToolCall,
653 handler.clone(),
654 Some("bash".into()),
655 None,
656 perms,
657 )
658 .await
659 .unwrap();
660
661 let event = HookEvent::after_tool_call("read", serde_json::json!({}), "content".into());
663 bus.emit(&event).await;
664 assert_eq!(handler.calls(), 0);
665
666 let event = HookEvent::after_tool_call("bash", serde_json::json!({}), "output".into());
668 bus.emit(&event).await;
669 assert_eq!(handler.calls(), 1);
670 }
671
672 #[tokio::test]
673 async fn permission_denied_rejects_subscribe() {
674 let bus = HookBus::new();
675 let handler = TestHandler::new("no-perms", HookResult::Continue);
676 let perms = PermissionSet::new(); let result = bus
679 .subscribe(HookKind::BeforeToolCall, handler, None, None, perms)
680 .await;
681
682 assert!(result.is_err());
683 assert!(result.unwrap_err().contains("lacks permission"));
684 }
685
686 #[tokio::test]
687 async fn unsubscribe_removes_handlers() {
688 let bus = HookBus::new();
689 let handler = TestHandler::new("removable", HookResult::Continue);
690 let perms = perms_with(&[Permission::ToolsIntercept]);
691
692 bus.subscribe(HookKind::BeforeToolCall, handler.clone(), None, None, perms)
693 .await
694 .unwrap();
695 assert_eq!(bus.handler_count().await, 1);
696
697 bus.unsubscribe_all("removable").await;
698 assert_eq!(bus.handler_count().await, 0);
699 }
700
701 #[tokio::test]
702 async fn subscriptions_for_lists_only_matching_extension() {
703 let bus = HookBus::new();
704 let alpha = TestHandler::new("alpha", HookResult::Continue);
705 let beta = TestHandler::new("beta", HookResult::Continue);
706 let perms = perms_with(&[Permission::ToolsIntercept]);
707
708 bus.subscribe(HookKind::BeforeToolCall, alpha.clone(), Some("bash".into()), None, perms.clone())
709 .await
710 .unwrap();
711 bus.subscribe(HookKind::AfterToolCall, alpha.clone(), None, None, perms.clone())
712 .await
713 .unwrap();
714 bus.subscribe(HookKind::BeforeToolCall, beta.clone(), None, None, perms)
715 .await
716 .unwrap();
717
718 let alpha_subs = bus.subscriptions_for("alpha").await;
719 assert_eq!(alpha_subs.len(), 2);
720 assert_eq!(alpha_subs[0].0, HookKind::AfterToolCall);
722 assert_eq!(alpha_subs[0].1, None);
723 assert_eq!(alpha_subs[1].0, HookKind::BeforeToolCall);
724 assert_eq!(alpha_subs[1].1, Some("bash".to_string()));
725
726 let beta_subs = bus.subscriptions_for("beta").await;
727 assert_eq!(beta_subs, vec![(HookKind::BeforeToolCall, None)]);
728
729 let none_subs = bus.subscriptions_for("ghost").await;
730 assert!(none_subs.is_empty());
731 }
732
733 #[tokio::test]
734 async fn is_empty_reflects_state() {
735 let bus = HookBus::new();
736 assert!(bus.is_empty().await);
737
738 let handler = TestHandler::new("ext", HookResult::Continue);
739 let perms = perms_with(&[Permission::ToolsIntercept]);
740 bus.subscribe(HookKind::BeforeToolCall, handler, None, None, perms)
741 .await
742 .unwrap();
743 assert!(!bus.is_empty().await);
744 }
745}