Skip to main content

wraith_runtime/
hooks.rs

1use std::ffi::OsStr;
2use std::process::Command;
3
4use serde_json::json;
5
6use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum HookEvent {
10    PreToolUse,
11    PostToolUse,
12}
13
14impl HookEvent {
15    fn as_str(self) -> &'static str {
16        match self {
17            Self::PreToolUse => "PreToolUse",
18            Self::PostToolUse => "PostToolUse",
19        }
20    }
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct HookRunResult {
25    denied: bool,
26    messages: Vec<String>,
27}
28
29impl HookRunResult {
30    #[must_use]
31    pub fn allow(messages: Vec<String>) -> Self {
32        Self {
33            denied: false,
34            messages,
35        }
36    }
37
38    #[must_use]
39    pub fn is_denied(&self) -> bool {
40        self.denied
41    }
42
43    #[must_use]
44    pub fn messages(&self) -> &[String] {
45        &self.messages
46    }
47}
48
49#[derive(Debug, Clone, PartialEq, Eq, Default)]
50pub struct HookRunner {
51    config: RuntimeHookConfig,
52}
53
54#[derive(Debug, Clone, Copy)]
55struct HookCommandRequest<'a> {
56    event: HookEvent,
57    tool_name: &'a str,
58    tool_input: &'a str,
59    tool_output: Option<&'a str>,
60    is_error: bool,
61    payload: &'a str,
62}
63
64impl HookRunner {
65    #[must_use]
66    pub fn new(config: RuntimeHookConfig) -> Self {
67        Self { config }
68    }
69
70    #[must_use]
71    pub fn from_feature_config(feature_config: &RuntimeFeatureConfig) -> Self {
72        Self::new(feature_config.hooks().clone())
73    }
74
75    #[must_use]
76    pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
77        Self::run_commands(
78            HookEvent::PreToolUse,
79            self.config.pre_tool_use(),
80            tool_name,
81            tool_input,
82            None,
83            false,
84        )
85    }
86
87    #[must_use]
88    pub fn run_post_tool_use(
89        &self,
90        tool_name: &str,
91        tool_input: &str,
92        tool_output: &str,
93        is_error: bool,
94    ) -> HookRunResult {
95        Self::run_commands(
96            HookEvent::PostToolUse,
97            self.config.post_tool_use(),
98            tool_name,
99            tool_input,
100            Some(tool_output),
101            is_error,
102        )
103    }
104
105    fn run_commands(
106        event: HookEvent,
107        commands: &[String],
108        tool_name: &str,
109        tool_input: &str,
110        tool_output: Option<&str>,
111        is_error: bool,
112    ) -> HookRunResult {
113        if commands.is_empty() {
114            return HookRunResult::allow(Vec::new());
115        }
116
117        let payload = json!({
118            "hook_event_name": event.as_str(),
119            "tool_name": tool_name,
120            "tool_input": parse_tool_input(tool_input),
121            "tool_input_json": tool_input,
122            "tool_output": tool_output,
123            "tool_result_is_error": is_error,
124        })
125        .to_string();
126
127        let mut messages = Vec::new();
128
129        for command in commands {
130            match Self::run_command(
131                command,
132                HookCommandRequest {
133                    event,
134                    tool_name,
135                    tool_input,
136                    tool_output,
137                    is_error,
138                    payload: &payload,
139                },
140            ) {
141                HookCommandOutcome::Allow { message } => {
142                    if let Some(message) = message {
143                        messages.push(message);
144                    }
145                }
146                HookCommandOutcome::Deny { message } => {
147                    let message = message.unwrap_or_else(|| {
148                        format!("{} hook denied tool `{tool_name}`", event.as_str())
149                    });
150                    messages.push(message);
151                    return HookRunResult {
152                        denied: true,
153                        messages,
154                    };
155                }
156                HookCommandOutcome::Warn { message } => messages.push(message),
157            }
158        }
159
160        HookRunResult::allow(messages)
161    }
162
163    fn run_command(command: &str, request: HookCommandRequest<'_>) -> HookCommandOutcome {
164        let mut child = shell_command(command);
165        child.stdin(std::process::Stdio::piped());
166        child.stdout(std::process::Stdio::piped());
167        child.stderr(std::process::Stdio::piped());
168        child.env("HOOK_EVENT", request.event.as_str());
169        child.env("HOOK_TOOL_NAME", request.tool_name);
170        child.env("HOOK_TOOL_INPUT", request.tool_input);
171        child.env(
172            "HOOK_TOOL_IS_ERROR",
173            if request.is_error { "1" } else { "0" },
174        );
175        if let Some(tool_output) = request.tool_output {
176            child.env("HOOK_TOOL_OUTPUT", tool_output);
177        }
178
179        match child.output_with_stdin(request.payload.as_bytes()) {
180            Ok(output) => {
181                let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
182                let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
183                let message = (!stdout.is_empty()).then_some(stdout);
184                match output.status.code() {
185                    Some(0) => HookCommandOutcome::Allow { message },
186                    Some(2) => HookCommandOutcome::Deny { message },
187                    Some(code) => HookCommandOutcome::Warn {
188                        message: format_hook_warning(
189                            command,
190                            code,
191                            message.as_deref(),
192                            stderr.as_str(),
193                        ),
194                    },
195                    None => HookCommandOutcome::Warn {
196                        message: format!(
197                            "{} hook `{command}` terminated by signal while handling `{}`",
198                            request.event.as_str(),
199                            request.tool_name
200                        ),
201                    },
202                }
203            }
204            Err(error) => HookCommandOutcome::Warn {
205                message: format!(
206                    "{} hook `{command}` failed to start for `{}`: {error}",
207                    request.event.as_str(),
208                    request.tool_name
209                ),
210            },
211        }
212    }
213}
214
215enum HookCommandOutcome {
216    Allow { message: Option<String> },
217    Deny { message: Option<String> },
218    Warn { message: String },
219}
220
221fn parse_tool_input(tool_input: &str) -> serde_json::Value {
222    serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input }))
223}
224
225fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String {
226    let mut message =
227        format!("Hook `{command}` exited with status {code}; allowing tool execution to continue");
228    if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) {
229        message.push_str(": ");
230        message.push_str(stdout);
231    } else if !stderr.is_empty() {
232        message.push_str(": ");
233        message.push_str(stderr);
234    }
235    message
236}
237
238fn shell_command(command: &str) -> CommandWithStdin {
239    #[cfg(windows)]
240    let mut command_builder = {
241        let mut command_builder = Command::new("cmd");
242        command_builder.arg("/C").arg(command);
243        CommandWithStdin::new(command_builder)
244    };
245
246    #[cfg(not(windows))]
247    let command_builder = {
248        let mut command_builder = Command::new("sh");
249        command_builder.arg("-lc").arg(command);
250        CommandWithStdin::new(command_builder)
251    };
252
253    command_builder
254}
255
256struct CommandWithStdin {
257    command: Command,
258}
259
260impl CommandWithStdin {
261    fn new(command: Command) -> Self {
262        Self { command }
263    }
264
265    fn stdin(&mut self, cfg: std::process::Stdio) -> &mut Self {
266        self.command.stdin(cfg);
267        self
268    }
269
270    fn stdout(&mut self, cfg: std::process::Stdio) -> &mut Self {
271        self.command.stdout(cfg);
272        self
273    }
274
275    fn stderr(&mut self, cfg: std::process::Stdio) -> &mut Self {
276        self.command.stderr(cfg);
277        self
278    }
279
280    fn env<K, V>(&mut self, key: K, value: V) -> &mut Self
281    where
282        K: AsRef<OsStr>,
283        V: AsRef<OsStr>,
284    {
285        self.command.env(key, value);
286        self
287    }
288
289    fn output_with_stdin(&mut self, stdin: &[u8]) -> std::io::Result<std::process::Output> {
290        let mut child = self.command.spawn()?;
291        if let Some(mut child_stdin) = child.stdin.take() {
292            use std::io::Write;
293            child_stdin.write_all(stdin)?;
294        }
295        child.wait_with_output()
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::{HookRunResult, HookRunner};
302    use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
303
304    #[test]
305    fn allows_exit_code_zero_and_captures_stdout() {
306        let runner = HookRunner::new(RuntimeHookConfig::new(
307            vec![shell_snippet("printf 'pre ok'")],
308            Vec::new(),
309        ));
310
311        let result = runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#);
312
313        assert_eq!(result, HookRunResult::allow(vec!["pre ok".to_string()]));
314    }
315
316    #[test]
317    fn denies_exit_code_two() {
318        let runner = HookRunner::new(RuntimeHookConfig::new(
319            vec![shell_snippet("printf 'blocked by hook'; exit 2")],
320            Vec::new(),
321        ));
322
323        let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#);
324
325        assert!(result.is_denied());
326        assert_eq!(result.messages(), &["blocked by hook".to_string()]);
327    }
328
329    #[test]
330    fn warns_for_other_non_zero_statuses() {
331        let runner = HookRunner::from_feature_config(&RuntimeFeatureConfig::default().with_hooks(
332            RuntimeHookConfig::new(
333                vec![shell_snippet("printf 'warning hook'; exit 1")],
334                Vec::new(),
335            ),
336        ));
337
338        let result = runner.run_pre_tool_use("Edit", r#"{"file":"src/lib.rs"}"#);
339
340        assert!(!result.is_denied());
341        assert!(result
342            .messages()
343            .iter()
344            .any(|message| message.contains("allowing tool execution to continue")));
345    }
346
347    #[cfg(windows)]
348    fn shell_snippet(script: &str) -> String {
349        script.replace('\'', "\"")
350    }
351
352    #[cfg(not(windows))]
353    fn shell_snippet(script: &str) -> String {
354        script.to_string()
355    }
356}