Skip to main content

systemprompt_models/services/
hooks.rs

1use std::fmt;
2use std::str::FromStr;
3
4use anyhow::anyhow;
5use serde::{Deserialize, Serialize};
6
7pub const HOOK_CONFIG_FILENAME: &str = "config.yaml";
8
9const fn default_true() -> bool {
10    true
11}
12
13fn default_version() -> String {
14    "1.0.0".to_string()
15}
16
17fn default_matcher() -> String {
18    "*".to_string()
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
22#[serde(rename_all = "PascalCase")]
23pub enum HookEvent {
24    PreToolUse,
25    PostToolUse,
26    PostToolUseFailure,
27    SessionStart,
28    SessionEnd,
29    UserPromptSubmit,
30    Notification,
31    Stop,
32    SubagentStart,
33    SubagentStop,
34}
35
36impl HookEvent {
37    pub const ALL_VARIANTS: &'static [Self] = &[
38        Self::PreToolUse,
39        Self::PostToolUse,
40        Self::PostToolUseFailure,
41        Self::SessionStart,
42        Self::SessionEnd,
43        Self::UserPromptSubmit,
44        Self::Notification,
45        Self::Stop,
46        Self::SubagentStart,
47        Self::SubagentStop,
48    ];
49
50    pub const fn as_str(&self) -> &'static str {
51        match self {
52            Self::PreToolUse => "PreToolUse",
53            Self::PostToolUse => "PostToolUse",
54            Self::PostToolUseFailure => "PostToolUseFailure",
55            Self::SessionStart => "SessionStart",
56            Self::SessionEnd => "SessionEnd",
57            Self::UserPromptSubmit => "UserPromptSubmit",
58            Self::Notification => "Notification",
59            Self::Stop => "Stop",
60            Self::SubagentStart => "SubagentStart",
61            Self::SubagentStop => "SubagentStop",
62        }
63    }
64}
65
66impl fmt::Display for HookEvent {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        write!(f, "{}", self.as_str())
69    }
70}
71
72impl FromStr for HookEvent {
73    type Err = anyhow::Error;
74
75    fn from_str(s: &str) -> anyhow::Result<Self> {
76        match s {
77            "PreToolUse" => Ok(Self::PreToolUse),
78            "PostToolUse" => Ok(Self::PostToolUse),
79            "PostToolUseFailure" => Ok(Self::PostToolUseFailure),
80            "SessionStart" => Ok(Self::SessionStart),
81            "SessionEnd" => Ok(Self::SessionEnd),
82            "UserPromptSubmit" => Ok(Self::UserPromptSubmit),
83            "Notification" => Ok(Self::Notification),
84            "Stop" => Ok(Self::Stop),
85            "SubagentStart" => Ok(Self::SubagentStart),
86            "SubagentStop" => Ok(Self::SubagentStop),
87            _ => Err(anyhow!("Invalid hook event: {s}")),
88        }
89    }
90}
91
92#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
93#[serde(rename_all = "lowercase")]
94pub enum HookCategory {
95    System,
96    #[default]
97    Custom,
98}
99
100impl HookCategory {
101    pub const fn as_str(&self) -> &'static str {
102        match self {
103            Self::System => "system",
104            Self::Custom => "custom",
105        }
106    }
107}
108
109impl fmt::Display for HookCategory {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        write!(f, "{}", self.as_str())
112    }
113}
114
115impl FromStr for HookCategory {
116    type Err = anyhow::Error;
117
118    fn from_str(s: &str) -> anyhow::Result<Self> {
119        match s {
120            "system" => Ok(Self::System),
121            "custom" => Ok(Self::Custom),
122            _ => Err(anyhow!("Invalid hook category: {s}")),
123        }
124    }
125}
126
127#[derive(Debug, Clone, Deserialize)]
128pub struct DiskHookConfig {
129    #[serde(default)]
130    pub id: String,
131    #[serde(default)]
132    pub name: String,
133    #[serde(default)]
134    pub description: String,
135    #[serde(default = "default_version")]
136    pub version: String,
137    #[serde(default = "default_true")]
138    pub enabled: bool,
139    pub event: HookEvent,
140    #[serde(default = "default_matcher")]
141    pub matcher: String,
142    #[serde(default)]
143    pub command: String,
144    #[serde(default, rename = "async")]
145    pub is_async: bool,
146    #[serde(default)]
147    pub category: HookCategory,
148    #[serde(default)]
149    pub tags: Vec<String>,
150    #[serde(default)]
151    pub visible_to: Vec<String>,
152}
153
154#[derive(Debug, Clone, Default, Serialize, Deserialize)]
155#[serde(rename_all = "PascalCase")]
156pub struct HookEventsConfig {
157    #[serde(default, skip_serializing_if = "Vec::is_empty")]
158    pub pre_tool_use: Vec<HookMatcher>,
159    #[serde(default, skip_serializing_if = "Vec::is_empty")]
160    pub post_tool_use: Vec<HookMatcher>,
161    #[serde(default, skip_serializing_if = "Vec::is_empty")]
162    pub post_tool_use_failure: Vec<HookMatcher>,
163    #[serde(default, skip_serializing_if = "Vec::is_empty")]
164    pub session_start: Vec<HookMatcher>,
165    #[serde(default, skip_serializing_if = "Vec::is_empty")]
166    pub session_end: Vec<HookMatcher>,
167    #[serde(default, skip_serializing_if = "Vec::is_empty")]
168    pub user_prompt_submit: Vec<HookMatcher>,
169    #[serde(default, skip_serializing_if = "Vec::is_empty")]
170    pub notification: Vec<HookMatcher>,
171    #[serde(default, skip_serializing_if = "Vec::is_empty")]
172    pub stop: Vec<HookMatcher>,
173    #[serde(default, skip_serializing_if = "Vec::is_empty")]
174    pub subagent_start: Vec<HookMatcher>,
175    #[serde(default, skip_serializing_if = "Vec::is_empty")]
176    pub subagent_stop: Vec<HookMatcher>,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct HookMatcher {
181    pub matcher: String,
182    pub hooks: Vec<HookAction>,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct HookAction {
187    #[serde(rename = "type")]
188    pub hook_type: HookType,
189    #[serde(skip_serializing_if = "Option::is_none")]
190    pub command: Option<String>,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub prompt: Option<String>,
193    #[serde(default, rename = "async")]
194    pub r#async: bool,
195    #[serde(skip_serializing_if = "Option::is_none")]
196    pub timeout: Option<u32>,
197    #[serde(skip_serializing_if = "Option::is_none", rename = "statusMessage")]
198    pub status_message: Option<String>,
199}
200
201#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
202#[serde(rename_all = "lowercase")]
203pub enum HookType {
204    Command,
205    Prompt,
206    Agent,
207}
208
209impl HookEventsConfig {
210    pub fn is_empty(&self) -> bool {
211        self.pre_tool_use.is_empty()
212            && self.post_tool_use.is_empty()
213            && self.post_tool_use_failure.is_empty()
214            && self.session_start.is_empty()
215            && self.session_end.is_empty()
216            && self.user_prompt_submit.is_empty()
217            && self.notification.is_empty()
218            && self.stop.is_empty()
219            && self.subagent_start.is_empty()
220            && self.subagent_stop.is_empty()
221    }
222
223    pub fn matchers_for_event(&self, event: HookEvent) -> &[HookMatcher] {
224        match event {
225            HookEvent::PreToolUse => &self.pre_tool_use,
226            HookEvent::PostToolUse => &self.post_tool_use,
227            HookEvent::PostToolUseFailure => &self.post_tool_use_failure,
228            HookEvent::SessionStart => &self.session_start,
229            HookEvent::SessionEnd => &self.session_end,
230            HookEvent::UserPromptSubmit => &self.user_prompt_submit,
231            HookEvent::Notification => &self.notification,
232            HookEvent::Stop => &self.stop,
233            HookEvent::SubagentStart => &self.subagent_start,
234            HookEvent::SubagentStop => &self.subagent_stop,
235        }
236    }
237
238    pub fn validate(&self) -> anyhow::Result<()> {
239        for event in HookEvent::ALL_VARIANTS {
240            for matcher in self.matchers_for_event(*event) {
241                for action in &matcher.hooks {
242                    match action.hook_type {
243                        HookType::Command => {
244                            if action.command.is_none() {
245                                anyhow::bail!(
246                                    "Hook matcher '{}': command hook requires a 'command' field",
247                                    matcher.matcher
248                                );
249                            }
250                        },
251                        HookType::Prompt => {
252                            if action.prompt.is_none() {
253                                anyhow::bail!(
254                                    "Hook matcher '{}': prompt hook requires a 'prompt' field",
255                                    matcher.matcher
256                                );
257                            }
258                        },
259                        HookType::Agent => {},
260                    }
261                }
262            }
263        }
264
265        Ok(())
266    }
267}