Skip to main content

systemprompt_models/services/
hooks.rs

1use std::fmt;
2use std::str::FromStr;
3
4use serde::{Deserialize, Serialize};
5use systemprompt_identifiers::HookId;
6
7use crate::errors::{ConfigValidationError, ParseEnumError};
8
9pub const HOOK_CONFIG_FILENAME: &str = "config.yaml";
10
11const fn default_true() -> bool {
12    true
13}
14
15fn default_version() -> String {
16    "1.0.0".to_string()
17}
18
19fn default_matcher() -> String {
20    "*".to_string()
21}
22
23fn default_hook_id() -> HookId {
24    HookId::new("")
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
28#[serde(rename_all = "PascalCase")]
29pub enum HookEvent {
30    PreToolUse,
31    PostToolUse,
32    PostToolUseFailure,
33    SessionStart,
34    SessionEnd,
35    UserPromptSubmit,
36    Notification,
37    Stop,
38    SubagentStart,
39    SubagentStop,
40}
41
42impl HookEvent {
43    pub const ALL_VARIANTS: &'static [Self] = &[
44        Self::PreToolUse,
45        Self::PostToolUse,
46        Self::PostToolUseFailure,
47        Self::SessionStart,
48        Self::SessionEnd,
49        Self::UserPromptSubmit,
50        Self::Notification,
51        Self::Stop,
52        Self::SubagentStart,
53        Self::SubagentStop,
54    ];
55
56    pub const fn as_str(&self) -> &'static str {
57        match self {
58            Self::PreToolUse => "PreToolUse",
59            Self::PostToolUse => "PostToolUse",
60            Self::PostToolUseFailure => "PostToolUseFailure",
61            Self::SessionStart => "SessionStart",
62            Self::SessionEnd => "SessionEnd",
63            Self::UserPromptSubmit => "UserPromptSubmit",
64            Self::Notification => "Notification",
65            Self::Stop => "Stop",
66            Self::SubagentStart => "SubagentStart",
67            Self::SubagentStop => "SubagentStop",
68        }
69    }
70}
71
72impl fmt::Display for HookEvent {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        write!(f, "{}", self.as_str())
75    }
76}
77
78impl FromStr for HookEvent {
79    type Err = ParseEnumError;
80
81    fn from_str(s: &str) -> Result<Self, Self::Err> {
82        match s {
83            "PreToolUse" => Ok(Self::PreToolUse),
84            "PostToolUse" => Ok(Self::PostToolUse),
85            "PostToolUseFailure" => Ok(Self::PostToolUseFailure),
86            "SessionStart" => Ok(Self::SessionStart),
87            "SessionEnd" => Ok(Self::SessionEnd),
88            "UserPromptSubmit" => Ok(Self::UserPromptSubmit),
89            "Notification" => Ok(Self::Notification),
90            "Stop" => Ok(Self::Stop),
91            "SubagentStart" => Ok(Self::SubagentStart),
92            "SubagentStop" => Ok(Self::SubagentStop),
93            _ => Err(ParseEnumError::new("hook_event", s)),
94        }
95    }
96}
97
98#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
99#[serde(rename_all = "lowercase")]
100pub enum HookCategory {
101    System,
102    #[default]
103    Custom,
104}
105
106impl HookCategory {
107    pub const fn as_str(&self) -> &'static str {
108        match self {
109            Self::System => "system",
110            Self::Custom => "custom",
111        }
112    }
113}
114
115impl fmt::Display for HookCategory {
116    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117        write!(f, "{}", self.as_str())
118    }
119}
120
121impl FromStr for HookCategory {
122    type Err = ParseEnumError;
123
124    fn from_str(s: &str) -> Result<Self, Self::Err> {
125        match s {
126            "system" => Ok(Self::System),
127            "custom" => Ok(Self::Custom),
128            _ => Err(ParseEnumError::new("hook_category", s)),
129        }
130    }
131}
132
133#[derive(Debug, Clone, Deserialize)]
134pub struct DiskHookConfig {
135    #[serde(default = "default_hook_id")]
136    pub id: HookId,
137    #[serde(default)]
138    pub name: String,
139    #[serde(default)]
140    pub description: String,
141    #[serde(default = "default_version")]
142    pub version: String,
143    #[serde(default = "default_true")]
144    pub enabled: bool,
145    pub event: HookEvent,
146    #[serde(default = "default_matcher")]
147    pub matcher: String,
148    #[serde(default)]
149    pub command: String,
150    #[serde(default, rename = "async")]
151    pub is_async: bool,
152    #[serde(default)]
153    pub category: HookCategory,
154    #[serde(default)]
155    pub tags: Vec<String>,
156    #[serde(default)]
157    pub visible_to: Vec<String>,
158}
159
160#[derive(Debug, Clone, Default, Serialize, Deserialize)]
161#[serde(rename_all = "PascalCase")]
162pub struct HookEventsConfig {
163    #[serde(default, skip_serializing_if = "Vec::is_empty")]
164    pub pre_tool_use: Vec<HookMatcher>,
165    #[serde(default, skip_serializing_if = "Vec::is_empty")]
166    pub post_tool_use: Vec<HookMatcher>,
167    #[serde(default, skip_serializing_if = "Vec::is_empty")]
168    pub post_tool_use_failure: Vec<HookMatcher>,
169    #[serde(default, skip_serializing_if = "Vec::is_empty")]
170    pub session_start: Vec<HookMatcher>,
171    #[serde(default, skip_serializing_if = "Vec::is_empty")]
172    pub session_end: Vec<HookMatcher>,
173    #[serde(default, skip_serializing_if = "Vec::is_empty")]
174    pub user_prompt_submit: Vec<HookMatcher>,
175    #[serde(default, skip_serializing_if = "Vec::is_empty")]
176    pub notification: Vec<HookMatcher>,
177    #[serde(default, skip_serializing_if = "Vec::is_empty")]
178    pub stop: Vec<HookMatcher>,
179    #[serde(default, skip_serializing_if = "Vec::is_empty")]
180    pub subagent_start: Vec<HookMatcher>,
181    #[serde(default, skip_serializing_if = "Vec::is_empty")]
182    pub subagent_stop: Vec<HookMatcher>,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct HookMatcher {
187    pub matcher: String,
188    pub hooks: Vec<HookAction>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct HookAction {
193    #[serde(rename = "type")]
194    pub hook_type: HookType,
195    #[serde(skip_serializing_if = "Option::is_none")]
196    pub command: Option<String>,
197    #[serde(skip_serializing_if = "Option::is_none")]
198    pub prompt: Option<String>,
199    #[serde(default, rename = "async")]
200    pub r#async: bool,
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub timeout: Option<u32>,
203    #[serde(skip_serializing_if = "Option::is_none", rename = "statusMessage")]
204    pub status_message: Option<String>,
205}
206
207#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
208#[serde(rename_all = "lowercase")]
209pub enum HookType {
210    Command,
211    Prompt,
212    Agent,
213}
214
215impl HookEventsConfig {
216    pub fn is_empty(&self) -> bool {
217        self.pre_tool_use.is_empty()
218            && self.post_tool_use.is_empty()
219            && self.post_tool_use_failure.is_empty()
220            && self.session_start.is_empty()
221            && self.session_end.is_empty()
222            && self.user_prompt_submit.is_empty()
223            && self.notification.is_empty()
224            && self.stop.is_empty()
225            && self.subagent_start.is_empty()
226            && self.subagent_stop.is_empty()
227    }
228
229    pub fn matchers_for_event(&self, event: HookEvent) -> &[HookMatcher] {
230        match event {
231            HookEvent::PreToolUse => &self.pre_tool_use,
232            HookEvent::PostToolUse => &self.post_tool_use,
233            HookEvent::PostToolUseFailure => &self.post_tool_use_failure,
234            HookEvent::SessionStart => &self.session_start,
235            HookEvent::SessionEnd => &self.session_end,
236            HookEvent::UserPromptSubmit => &self.user_prompt_submit,
237            HookEvent::Notification => &self.notification,
238            HookEvent::Stop => &self.stop,
239            HookEvent::SubagentStart => &self.subagent_start,
240            HookEvent::SubagentStop => &self.subagent_stop,
241        }
242    }
243
244    pub fn validate(&self) -> Result<(), ConfigValidationError> {
245        for event in HookEvent::ALL_VARIANTS {
246            for matcher in self.matchers_for_event(*event) {
247                for action in &matcher.hooks {
248                    match action.hook_type {
249                        HookType::Command => {
250                            if action.command.is_none() {
251                                return Err(ConfigValidationError::required(format!(
252                                    "Hook matcher '{}': command hook requires a 'command' field",
253                                    matcher.matcher
254                                )));
255                            }
256                        },
257                        HookType::Prompt => {
258                            if action.prompt.is_none() {
259                                return Err(ConfigValidationError::required(format!(
260                                    "Hook matcher '{}': prompt hook requires a 'prompt' field",
261                                    matcher.matcher
262                                )));
263                            }
264                        },
265                        HookType::Agent => {},
266                    }
267                }
268            }
269        }
270
271        Ok(())
272    }
273}