1use std::collections::HashMap;
76use std::hash::BuildHasher;
77use std::time::Duration;
78
79use serde::Serialize;
80use thiserror::Error;
81use tokio::io::AsyncWriteExt as _;
82use tokio::process::Command;
83use tokio::time::timeout;
84
85pub use zeph_config::{HookAction, HookDef, HookMatcher, SubagentHooks};
86
87#[derive(Debug, Default)]
94pub struct HookOutput {
95 pub updated_tool_output: Option<String>,
99}
100
101#[derive(Debug, Default)]
105pub struct HookRunResult {
106 pub output: HookOutput,
109}
110
111#[derive(Debug, Serialize)]
119pub struct PostToolUseHookInput<'a> {
120 pub tool_name: &'a str,
122 pub tool_args: &'a serde_json::Value,
124 #[serde(skip_serializing_if = "Option::is_none")]
126 pub session_id: Option<&'a str>,
127 pub duration_ms: u64,
129 #[serde(skip_serializing_if = "Option::is_none")]
131 pub tool_output: Option<&'a str>,
132 #[serde(skip_serializing_if = "Option::is_none")]
134 pub tool_error: Option<&'a str>,
135}
136
137const HOOK_STDOUT_CAP: usize = 1024 * 1024; pub trait McpDispatch: Send + Sync {
153 fn call_tool<'a>(
155 &'a self,
156 server: &'a str,
157 tool: &'a str,
158 args: serde_json::Value,
159 ) -> std::pin::Pin<
160 Box<dyn std::future::Future<Output = Result<serde_json::Value, String>> + Send + 'a>,
161 >;
162}
163
164#[derive(Debug, Error)]
168pub enum HookError {
169 #[error("hook command failed (exit code {code}): {command}")]
171 NonZeroExit { command: String, code: i32 },
172
173 #[error("hook command timed out after {timeout_secs}s: {command}")]
175 Timeout { command: String, timeout_secs: u64 },
176
177 #[error("hook I/O error for command '{command}': {source}")]
179 Io {
180 command: String,
181 #[source]
182 source: std::io::Error,
183 },
184
185 #[error(
187 "mcp_tool hook requires an MCP manager but none was provided (server={server}, tool={tool})"
188 )]
189 McpUnavailable { server: String, tool: String },
190
191 #[error("mcp_tool hook failed (server={server}, tool={tool}): {reason}")]
193 McpToolFailed {
194 server: String,
195 tool: String,
196 reason: String,
197 },
198}
199
200#[must_use]
221pub fn matching_hooks<'a>(matchers: &'a [HookMatcher], tool_name: &str) -> Vec<&'a HookDef> {
222 let mut result = Vec::new();
223 for m in matchers {
224 let matched = m
225 .matcher
226 .split('|')
227 .filter(|token| !token.is_empty())
228 .any(|token| tool_name.contains(token));
229 if matched {
230 result.extend(m.hooks.iter());
231 }
232 }
233 result
234}
235
236pub const TOOL_ARGS_JSON_LIMIT: usize = 64 * 1024;
242
243#[must_use]
262pub fn make_base_hook_env(
263 tool_name: &str,
264 tool_input: &serde_json::Value,
265) -> HashMap<String, String> {
266 let mut env = HashMap::new();
267 env.insert("ZEPH_TOOL_NAME".to_owned(), tool_name.to_owned());
268
269 let raw = serde_json::to_string(tool_input).unwrap_or_default();
270 let args_json = if raw.len() > TOOL_ARGS_JSON_LIMIT {
271 tracing::warn!(
272 tool = tool_name,
273 len = raw.len(),
274 limit = TOOL_ARGS_JSON_LIMIT,
275 "ZEPH_TOOL_ARGS_JSON truncated for hook dispatch"
276 );
277 let limit = raw.floor_char_boundary(TOOL_ARGS_JSON_LIMIT);
278 format!("{}…", &raw[..limit])
279 } else {
280 raw
281 };
282 env.insert("ZEPH_TOOL_ARGS_JSON".to_owned(), args_json);
283
284 env
285}
286
287pub async fn fire_hooks<S: BuildHasher>(
314 hooks: &[HookDef],
315 env: &HashMap<String, String, S>,
316 mcp: Option<&dyn McpDispatch>,
317 stdin_json: Option<&[u8]>,
318) -> Result<HookRunResult, HookError> {
319 let mut run_result = HookRunResult::default();
320 for hook in hooks {
321 let effective_stdin = run_result
324 .output
325 .updated_tool_output
326 .as_deref()
327 .map(str::as_bytes)
328 .or(stdin_json);
329 let result = fire_single_hook(hook, env, mcp, effective_stdin).await;
330 match result {
331 Ok(hook_output) => {
332 if hook_output.updated_tool_output.is_some() {
333 run_result.output.updated_tool_output = hook_output.updated_tool_output;
334 }
335 }
336 Err(e) if hook.fail_closed => {
337 tracing::error!(
338 error = %e,
339 "fail-closed hook failed — aborting"
340 );
341 return Err(e);
342 }
343 Err(e) => {
344 tracing::warn!(
345 error = %e,
346 "hook failed (fail_open) — continuing"
347 );
348 }
349 }
350 }
351 Ok(run_result)
352}
353
354async fn fire_single_hook<S: BuildHasher>(
355 hook: &HookDef,
356 env: &HashMap<String, String, S>,
357 mcp: Option<&dyn McpDispatch>,
358 stdin_json: Option<&[u8]>,
359) -> Result<HookOutput, HookError> {
360 match &hook.action {
361 HookAction::Command { command } => {
362 fire_shell_hook(command, hook.timeout_secs, env, stdin_json).await
363 }
364 HookAction::McpTool { server, tool, args } => {
365 let dispatcher = mcp.ok_or_else(|| HookError::McpUnavailable {
366 server: server.clone(),
367 tool: tool.clone(),
368 })?;
369 let call_fut = dispatcher.call_tool(server, tool, args.clone());
370 match timeout(Duration::from_secs(hook.timeout_secs), call_fut).await {
371 Ok(Ok(_)) => {
372 Ok(HookOutput::default())
374 }
375 Ok(Err(reason)) => Err(HookError::McpToolFailed {
376 server: server.clone(),
377 tool: tool.clone(),
378 reason,
379 }),
380 Err(_) => Err(HookError::Timeout {
381 command: format!("mcp_tool:{server}/{tool}"),
382 timeout_secs: hook.timeout_secs,
383 }),
384 }
385 }
386 }
387}
388
389async fn fire_shell_hook<S: BuildHasher>(
390 command: &str,
391 timeout_secs: u64,
392 env: &HashMap<String, String, S>,
393 stdin_json: Option<&[u8]>,
394) -> Result<HookOutput, HookError> {
395 use std::process::Stdio;
396 use tokio::io::AsyncReadExt as _;
397
398 let mut cmd = Command::new("sh");
399 cmd.arg("-c").arg(command);
400 cmd.env_clear();
402 if let Ok(path) = std::env::var("PATH") {
404 cmd.env("PATH", path);
405 }
406 for (k, v) in env {
407 cmd.env(k, v);
408 }
409 cmd.stdin(if stdin_json.is_some() {
410 Stdio::piped()
411 } else {
412 Stdio::null()
413 });
414 cmd.stdout(Stdio::piped());
416 cmd.stderr(Stdio::null());
417
418 let mut child = cmd.spawn().map_err(|e| HookError::Io {
419 command: command.to_owned(),
420 source: e,
421 })?;
422
423 if let Some(bytes) = stdin_json
426 && let Some(mut stdin_handle) = child.stdin.take()
427 && let Err(e) = stdin_handle.write_all(bytes).await
428 {
429 tracing::warn!(
430 command,
431 error = %e,
432 "failed to write stdin to hook — continuing without stdin data"
433 );
434 }
435
436 let stdout_handle = child.stdout.take();
439 match timeout(Duration::from_secs(timeout_secs), child.wait()).await {
440 Ok(Ok(status)) => {
441 let mut stdout_bytes = Vec::new();
442 if let Some(handle) = stdout_handle {
443 let mut limited = handle.take(HOOK_STDOUT_CAP as u64 + 1);
444 let _ = limited.read_to_end(&mut stdout_bytes).await;
445 }
446 if status.success() {
447 Ok(parse_hook_stdout(command, &stdout_bytes))
448 } else {
449 Err(HookError::NonZeroExit {
450 command: command.to_owned(),
451 code: status.code().unwrap_or(-1),
452 })
453 }
454 }
455 Ok(Err(e)) => Err(HookError::Io {
456 command: command.to_owned(),
457 source: e,
458 }),
459 Err(_) => {
460 let _ = child.kill().await;
462 Err(HookError::Timeout {
463 command: command.to_owned(),
464 timeout_secs,
465 })
466 }
467 }
468}
469
470fn parse_hook_stdout(command: &str, bytes: &[u8]) -> HookOutput {
475 if bytes.is_empty() {
476 return HookOutput::default();
477 }
478 if bytes.len() > HOOK_STDOUT_CAP {
479 tracing::warn!(
480 command,
481 bytes = bytes.len(),
482 cap = HOOK_STDOUT_CAP,
483 "hook stdout exceeds 1 MiB cap — treating as no substitution"
484 );
485 return HookOutput::default();
486 }
487 let Ok(text) = std::str::from_utf8(bytes) else {
488 tracing::warn!(command, "hook stdout is not valid UTF-8 — no substitution");
489 return HookOutput::default();
490 };
491 let Ok(json) = serde_json::from_str::<serde_json::Value>(text) else {
493 return HookOutput::default();
494 };
495 let updated = json
496 .get("hookSpecificOutput")
497 .and_then(|h| h.get("updatedToolOutput"));
498
499 match updated {
500 None | Some(serde_json::Value::Null) => HookOutput::default(),
501 Some(serde_json::Value::String(s)) => HookOutput {
502 updated_tool_output: Some(s.clone()),
503 },
504 Some(other) => {
505 tracing::warn!(
506 command,
507 kind = other
508 .is_object()
509 .then_some("object")
510 .or_else(|| other.is_array().then_some("array"))
511 .or_else(|| other.is_number().then_some("number"))
512 .or_else(|| other.is_boolean().then_some("boolean"))
513 .unwrap_or("unknown"),
514 "hookSpecificOutput.updatedToolOutput has unexpected type — no substitution"
515 );
516 HookOutput::default()
517 }
518 }
519}
520
521#[cfg(test)]
524mod tests {
525 use super::*;
526
527 fn cmd_hook(command: &str, fail_closed: bool, timeout_secs: u64) -> HookDef {
528 HookDef {
529 action: HookAction::Command {
530 command: command.to_owned(),
531 },
532 timeout_secs,
533 fail_closed,
534 }
535 }
536
537 fn make_matcher(matcher: &str, hooks: Vec<HookDef>) -> HookMatcher {
538 HookMatcher {
539 matcher: matcher.to_owned(),
540 hooks,
541 }
542 }
543
544 #[test]
547 fn matching_hooks_exact_name() {
548 let hook = cmd_hook("echo hi", false, 30);
549 let matchers = vec![make_matcher("Edit", vec![hook.clone()])];
550 let result = matching_hooks(&matchers, "Edit");
551 assert_eq!(result.len(), 1);
552 assert!(
553 matches!(&result[0].action, HookAction::Command { command } if command == "echo hi")
554 );
555 }
556
557 #[test]
558 fn matching_hooks_substring() {
559 let hook = cmd_hook("echo sub", false, 30);
560 let matchers = vec![make_matcher("Edit", vec![hook.clone()])];
561 let result = matching_hooks(&matchers, "EditFile");
562 assert_eq!(result.len(), 1);
563 }
564
565 #[test]
566 fn matching_hooks_pipe_separated() {
567 let h1 = cmd_hook("echo e", false, 30);
568 let h2 = cmd_hook("echo w", false, 30);
569 let matchers = vec![
570 make_matcher("Edit|Write", vec![h1.clone()]),
571 make_matcher("Shell", vec![h2.clone()]),
572 ];
573 let result_edit = matching_hooks(&matchers, "Edit");
574 assert_eq!(result_edit.len(), 1);
575
576 let result_shell = matching_hooks(&matchers, "Shell");
577 assert_eq!(result_shell.len(), 1);
578
579 let result_none = matching_hooks(&matchers, "Read");
580 assert!(result_none.is_empty());
581 }
582
583 #[test]
584 fn matching_hooks_no_match() {
585 let hook = cmd_hook("echo nope", false, 30);
586 let matchers = vec![make_matcher("Edit", vec![hook])];
587 let result = matching_hooks(&matchers, "Shell");
588 assert!(result.is_empty());
589 }
590
591 #[test]
592 fn matching_hooks_empty_token_ignored() {
593 let hook = cmd_hook("echo empty", false, 30);
594 let matchers = vec![make_matcher("|Edit|", vec![hook])];
595 let result = matching_hooks(&matchers, "Edit");
596 assert_eq!(result.len(), 1);
597 }
598
599 #[test]
600 fn matching_hooks_multiple_matchers_both_match() {
601 let h1 = cmd_hook("echo 1", false, 30);
602 let h2 = cmd_hook("echo 2", false, 30);
603 let matchers = vec![
604 make_matcher("Shell", vec![h1]),
605 make_matcher("Shell", vec![h2]),
606 ];
607 let result = matching_hooks(&matchers, "Shell");
608 assert_eq!(result.len(), 2);
609 }
610
611 #[tokio::test]
614 async fn fire_hooks_success() {
615 let hooks = vec![cmd_hook("true", false, 5)];
616 let env = HashMap::new();
617 assert!(fire_hooks(&hooks, &env, None, None).await.is_ok());
618 }
619
620 #[tokio::test]
621 async fn fire_hooks_fail_open_continues() {
622 let hooks = vec![
623 cmd_hook("false", false, 5), cmd_hook("true", false, 5), ];
626 let env = HashMap::new();
627 assert!(fire_hooks(&hooks, &env, None, None).await.is_ok());
628 }
629
630 #[tokio::test]
631 async fn fire_hooks_fail_closed_returns_err() {
632 let hooks = vec![cmd_hook("false", true, 5)];
633 let env = HashMap::new();
634 let result = fire_hooks(&hooks, &env, None, None).await;
635 assert!(result.is_err());
636 let err = result.unwrap_err();
637 assert!(matches!(err, HookError::NonZeroExit { .. }));
638 }
639
640 #[tokio::test]
641 async fn fire_hooks_timeout() {
642 let hooks = vec![cmd_hook("sleep 10", true, 1)];
643 let env = HashMap::new();
644 let result = fire_hooks(&hooks, &env, None, None).await;
645 assert!(result.is_err());
646 let err = result.unwrap_err();
647 assert!(matches!(err, HookError::Timeout { .. }));
648 }
649
650 #[tokio::test]
651 async fn fire_hooks_env_passed() {
652 let hooks = vec![cmd_hook(r#"test "$ZEPH_TEST_VAR" = "hello""#, true, 5)];
653 let mut env = HashMap::new();
654 env.insert("ZEPH_TEST_VAR".to_owned(), "hello".to_owned());
655 assert!(fire_hooks(&hooks, &env, None, None).await.is_ok());
656 }
657
658 #[tokio::test]
659 async fn fire_hooks_empty_list_ok() {
660 let env = HashMap::new();
661 assert!(fire_hooks(&[], &env, None, None).await.is_ok());
662 }
663
664 #[tokio::test]
665 async fn fire_hooks_mcp_unavailable_fail_open() {
666 let hooks = vec![HookDef {
667 action: HookAction::McpTool {
668 server: "srv".into(),
669 tool: "t".into(),
670 args: serde_json::Value::Null,
671 },
672 timeout_secs: 5,
673 fail_closed: false,
674 }];
675 let env = HashMap::new();
676 assert!(fire_hooks(&hooks, &env, None, None).await.is_ok());
678 }
679
680 #[tokio::test]
681 async fn fire_hooks_mcp_unavailable_fail_closed() {
682 let hooks = vec![HookDef {
683 action: HookAction::McpTool {
684 server: "srv".into(),
685 tool: "t".into(),
686 args: serde_json::Value::Null,
687 },
688 timeout_secs: 5,
689 fail_closed: true,
690 }];
691 let env = HashMap::new();
692 let result = fire_hooks(&hooks, &env, None, None).await;
693 assert!(matches!(result, Err(HookError::McpUnavailable { .. })));
694 }
695
696 struct CountingDispatch(std::sync::Arc<std::sync::atomic::AtomicU32>);
700
701 impl McpDispatch for CountingDispatch {
702 fn call_tool<'a>(
703 &'a self,
704 _server: &'a str,
705 _tool: &'a str,
706 _args: serde_json::Value,
707 ) -> std::pin::Pin<
708 Box<dyn std::future::Future<Output = Result<serde_json::Value, String>> + Send + 'a>,
709 > {
710 self.0.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
711 Box::pin(std::future::ready(Ok(serde_json::Value::Null)))
712 }
713 }
714
715 #[tokio::test]
716 async fn fire_hooks_mcp_dispatch_called_when_provided() {
717 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
718 let dispatch = CountingDispatch(std::sync::Arc::clone(&call_count));
719
720 let hooks = vec![HookDef {
721 action: HookAction::McpTool {
722 server: "srv".into(),
723 tool: "t".into(),
724 args: serde_json::Value::Null,
725 },
726 timeout_secs: 5,
727 fail_closed: true,
728 }];
729 let env = HashMap::new();
730 let result = fire_hooks(&hooks, &env, Some(&dispatch), None).await;
731 assert!(
732 result.is_ok(),
733 "fire_hooks should succeed with mcp dispatch"
734 );
735 assert_eq!(
736 call_count.load(std::sync::atomic::Ordering::SeqCst),
737 1,
738 "MCP dispatch should have been called exactly once"
739 );
740 }
741
742 #[tokio::test]
745 async fn fire_hooks_stdout_replacement_json() {
746 let cmd = r#"printf '{"hookSpecificOutput":{"updatedToolOutput":"replaced"}}'"#;
747 let hooks = vec![cmd_hook(cmd, true, 5)];
748 let env = HashMap::new();
749 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
750 assert_eq!(
751 result.output.updated_tool_output.as_deref(),
752 Some("replaced")
753 );
754 }
755
756 #[tokio::test]
757 async fn fire_hooks_stdout_empty_no_replacement() {
758 let hooks = vec![cmd_hook("true", true, 5)];
759 let env = HashMap::new();
760 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
761 assert!(result.output.updated_tool_output.is_none());
762 }
763
764 #[tokio::test]
765 async fn fire_hooks_stdout_non_json_no_replacement() {
766 let hooks = vec![cmd_hook("echo hello", true, 5)];
767 let env = HashMap::new();
768 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
769 assert!(result.output.updated_tool_output.is_none());
770 }
771
772 #[tokio::test]
773 async fn fire_hooks_stdout_null_updatedtooloutput_no_replacement() {
774 let cmd = r#"printf '{"hookSpecificOutput":{"updatedToolOutput":null}}'"#;
775 let hooks = vec![cmd_hook(cmd, true, 5)];
776 let env = HashMap::new();
777 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
778 assert!(result.output.updated_tool_output.is_none());
779 }
780
781 #[tokio::test]
782 async fn fire_hooks_stdin_passed_to_hook() {
783 let cmd = r#"python3 -c "import sys,json; d=json.load(sys.stdin); exit(0 if 'duration_ms' in d else 1)""#;
785 let hooks = vec![cmd_hook(cmd, true, 10)];
786 let env = HashMap::new();
787 let stdin = br#"{"tool_name":"Shell","tool_args":{},"duration_ms":42}"#;
788 let result = fire_hooks(&hooks, &env, None, Some(stdin)).await;
789 assert!(
790 result.is_ok(),
791 "hook should succeed when stdin has duration_ms"
792 );
793 }
794
795 #[tokio::test]
796 async fn fire_hooks_chaining_last_replacement_wins() {
797 let h1 = cmd_hook(
799 r#"printf '{"hookSpecificOutput":{"updatedToolOutput":"first"}}'"#,
800 false,
801 5,
802 );
803 let h2 = cmd_hook(
804 r#"printf '{"hookSpecificOutput":{"updatedToolOutput":"second"}}'"#,
805 false,
806 5,
807 );
808 let hooks = vec![h1, h2];
809 let env = HashMap::new();
810 let result = fire_hooks(&hooks, &env, None, None).await.unwrap();
811 assert_eq!(result.output.updated_tool_output.as_deref(), Some("second"));
812 }
813
814 #[test]
817 fn subagent_hooks_parses_from_yaml() {
818 let yaml = r#"
819PreToolUse:
820 - matcher: "Edit|Write"
821 hooks:
822 - type: command
823 command: "echo pre"
824 timeout_secs: 10
825 fail_closed: false
826PostToolUse:
827 - matcher: "Shell"
828 hooks:
829 - type: command
830 command: "echo post"
831"#;
832 let hooks: SubagentHooks = serde_norway::from_str(yaml).unwrap();
833 assert_eq!(hooks.pre_tool_use.len(), 1);
834 assert_eq!(hooks.pre_tool_use[0].matcher, "Edit|Write");
835 assert_eq!(hooks.pre_tool_use[0].hooks.len(), 1);
836 assert!(
837 matches!(&hooks.pre_tool_use[0].hooks[0].action, HookAction::Command { command } if command == "echo pre")
838 );
839 assert_eq!(hooks.post_tool_use.len(), 1);
840 }
841
842 #[test]
843 fn subagent_hooks_defaults_timeout() {
844 let yaml = r#"
845PreToolUse:
846 - matcher: "Edit"
847 hooks:
848 - type: command
849 command: "echo hi"
850"#;
851 let hooks: SubagentHooks = serde_norway::from_str(yaml).unwrap();
852 assert_eq!(hooks.pre_tool_use[0].hooks[0].timeout_secs, 30);
853 assert!(!hooks.pre_tool_use[0].hooks[0].fail_closed);
854 }
855
856 #[test]
857 fn subagent_hooks_empty_default() {
858 let hooks = SubagentHooks::default();
859 assert!(hooks.pre_tool_use.is_empty());
860 assert!(hooks.post_tool_use.is_empty());
861 }
862
863 #[tokio::test]
870 async fn fire_shell_hook_timeout_with_stdout_does_not_deadlock() {
871 let cmd = r#"echo "some output"; sleep 60"#;
874 let hooks = vec![cmd_hook(cmd, true, 1)];
875 let env = HashMap::new();
876
877 let result = tokio::time::timeout(
879 std::time::Duration::from_secs(5),
880 fire_hooks(&hooks, &env, None, None),
881 )
882 .await
883 .expect("fire_hooks must return within 5 s — deadlock regression #4011");
884
885 assert!(
886 matches!(result, Err(HookError::Timeout { .. })),
887 "expected HookError::Timeout, got: {result:?}"
888 );
889 }
890}