1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub enum HookPoint {
12 BeforeToolExecution,
14 AfterToolExecution,
16 BeforeLlmRequest,
18 AfterLlmResponse,
20 OnSessionStart,
22 OnSessionEnd,
24 OnError,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct HookContext {
31 pub point: HookPoint,
33 pub name: Option<String>,
35 pub input: Option<serde_json::Value>,
37 pub output: Option<serde_json::Value>,
39 pub error: Option<String>,
41 pub session_id: Option<String>,
43}
44
45impl HookContext {
46 pub fn tool(point: HookPoint, name: &str, args: serde_json::Value) -> Self {
48 Self {
49 point,
50 name: Some(name.into()),
51 input: Some(args),
52 output: None,
53 error: None,
54 session_id: None,
55 }
56 }
57
58 pub fn llm(point: HookPoint, provider: &str, data: serde_json::Value) -> Self {
60 let (input, output) = if point == HookPoint::BeforeLlmRequest {
61 (Some(data), None)
62 } else {
63 (None, Some(data))
64 };
65 Self {
66 point,
67 name: Some(provider.into()),
68 input,
69 output,
70 error: None,
71 session_id: None,
72 }
73 }
74
75 pub fn session(point: HookPoint, session_id: &str) -> Self {
77 Self {
78 point,
79 name: None,
80 input: None,
81 output: None,
82 error: None,
83 session_id: Some(session_id.into()),
84 }
85 }
86
87 pub fn error(error_message: &str) -> Self {
89 Self {
90 point: HookPoint::OnError,
91 name: None,
92 input: None,
93 output: None,
94 error: Some(error_message.into()),
95 session_id: None,
96 }
97 }
98}
99
100#[derive(Debug, Clone, PartialEq, Eq)]
102pub enum HookResult {
103 Continue,
105 Block(String),
107 Modified,
109}
110
111pub trait Hook: Send + Sync {
113 fn execute(&self, context: &HookContext) -> HookResult;
115
116 fn name(&self) -> &str;
118}
119
120pub struct HookManager {
122 hooks: HashMap<HookPoint, Vec<Box<dyn Hook>>>,
123}
124
125impl HookManager {
126 pub fn new() -> Self {
128 Self {
129 hooks: HashMap::new(),
130 }
131 }
132
133 pub fn register(&mut self, point: HookPoint, hook: Box<dyn Hook>) {
135 self.hooks.entry(point).or_default().push(hook);
136 }
137
138 pub fn fire(&self, context: &HookContext) -> HookResult {
141 let hooks = match self.hooks.get(&context.point) {
142 Some(hooks) => hooks,
143 None => return HookResult::Continue,
144 };
145
146 let mut result = HookResult::Continue;
147 for hook in hooks {
148 let hook_result = hook.execute(context);
149 match &hook_result {
150 HookResult::Block(_) => return hook_result,
151 HookResult::Modified => result = HookResult::Modified,
152 HookResult::Continue => {}
153 }
154 }
155 result
156 }
157
158 pub fn count_at(&self, point: HookPoint) -> usize {
160 self.hooks.get(&point).map(|v| v.len()).unwrap_or(0)
161 }
162
163 pub fn total_hooks(&self) -> usize {
165 self.hooks.values().map(|v| v.len()).sum()
166 }
167
168 pub fn clear(&mut self) {
170 self.hooks.clear();
171 }
172}
173
174impl Default for HookManager {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 struct AllowHook;
185 impl Hook for AllowHook {
186 fn execute(&self, _context: &HookContext) -> HookResult {
187 HookResult::Continue
188 }
189 fn name(&self) -> &str {
190 "allow"
191 }
192 }
193
194 struct BlockHook {
195 reason: String,
196 }
197 impl BlockHook {
198 fn new(reason: &str) -> Self {
199 Self {
200 reason: reason.into(),
201 }
202 }
203 }
204 impl Hook for BlockHook {
205 fn execute(&self, _context: &HookContext) -> HookResult {
206 HookResult::Block(self.reason.clone())
207 }
208 fn name(&self) -> &str {
209 "block"
210 }
211 }
212
213 struct ModifyHook;
214 impl Hook for ModifyHook {
215 fn execute(&self, _context: &HookContext) -> HookResult {
216 HookResult::Modified
217 }
218 fn name(&self) -> &str {
219 "modify"
220 }
221 }
222
223 struct CountingHook {
224 name: String,
225 }
226 impl CountingHook {
227 fn new(name: &str) -> Self {
228 Self { name: name.into() }
229 }
230 }
231 impl Hook for CountingHook {
232 fn execute(&self, _context: &HookContext) -> HookResult {
233 HookResult::Continue
234 }
235 fn name(&self) -> &str {
236 &self.name
237 }
238 }
239
240 #[test]
241 fn test_hook_manager_register_and_fire() {
242 let mut mgr = HookManager::new();
243 mgr.register(HookPoint::BeforeToolExecution, Box::new(AllowHook));
244
245 let ctx = HookContext::tool(
246 HookPoint::BeforeToolExecution,
247 "shell_exec",
248 serde_json::json!({"cmd": "ls"}),
249 );
250 let result = mgr.fire(&ctx);
251 assert_eq!(result, HookResult::Continue);
252 }
253
254 #[test]
255 fn test_hook_manager_block() {
256 let mut mgr = HookManager::new();
257 mgr.register(
258 HookPoint::BeforeToolExecution,
259 Box::new(BlockHook::new("dangerous")),
260 );
261
262 let ctx = HookContext::tool(HookPoint::BeforeToolExecution, "rm", serde_json::json!({}));
263 let result = mgr.fire(&ctx);
264 assert_eq!(result, HookResult::Block("dangerous".into()));
265 }
266
267 #[test]
268 fn test_hook_ordering_block_stops_chain() {
269 let mut mgr = HookManager::new();
270 mgr.register(HookPoint::BeforeToolExecution, Box::new(AllowHook));
271 mgr.register(
272 HookPoint::BeforeToolExecution,
273 Box::new(BlockHook::new("blocked")),
274 );
275 mgr.register(HookPoint::BeforeToolExecution, Box::new(AllowHook));
276
277 let ctx = HookContext::tool(
278 HookPoint::BeforeToolExecution,
279 "test",
280 serde_json::json!({}),
281 );
282 let result = mgr.fire(&ctx);
283 assert_eq!(result, HookResult::Block("blocked".into()));
284 }
285
286 #[test]
287 fn test_hook_modified_result() {
288 let mut mgr = HookManager::new();
289 mgr.register(HookPoint::AfterLlmResponse, Box::new(ModifyHook));
290 mgr.register(HookPoint::AfterLlmResponse, Box::new(AllowHook));
291
292 let ctx = HookContext::llm(
293 HookPoint::AfterLlmResponse,
294 "openai",
295 serde_json::json!({"text": "hello"}),
296 );
297 let result = mgr.fire(&ctx);
298 assert_eq!(result, HookResult::Modified);
299 }
300
301 #[test]
302 fn test_hook_fire_no_hooks() {
303 let mgr = HookManager::new();
304 let ctx = HookContext::session(HookPoint::OnSessionStart, "session-1");
305 let result = mgr.fire(&ctx);
306 assert_eq!(result, HookResult::Continue);
307 }
308
309 #[test]
310 fn test_hook_manager_count() {
311 let mut mgr = HookManager::new();
312 mgr.register(HookPoint::BeforeToolExecution, Box::new(AllowHook));
313 mgr.register(HookPoint::BeforeToolExecution, Box::new(AllowHook));
314 mgr.register(HookPoint::OnError, Box::new(AllowHook));
315
316 assert_eq!(mgr.count_at(HookPoint::BeforeToolExecution), 2);
317 assert_eq!(mgr.count_at(HookPoint::OnError), 1);
318 assert_eq!(mgr.count_at(HookPoint::OnSessionEnd), 0);
319 assert_eq!(mgr.total_hooks(), 3);
320 }
321
322 #[test]
323 fn test_hook_manager_clear() {
324 let mut mgr = HookManager::new();
325 mgr.register(HookPoint::BeforeToolExecution, Box::new(AllowHook));
326 mgr.register(HookPoint::OnError, Box::new(AllowHook));
327 assert_eq!(mgr.total_hooks(), 2);
328
329 mgr.clear();
330 assert_eq!(mgr.total_hooks(), 0);
331 }
332
333 #[test]
334 fn test_hook_context_tool() {
335 let ctx = HookContext::tool(
336 HookPoint::BeforeToolExecution,
337 "shell_exec",
338 serde_json::json!({"cmd": "ls"}),
339 );
340 assert_eq!(ctx.point, HookPoint::BeforeToolExecution);
341 assert_eq!(ctx.name.as_deref(), Some("shell_exec"));
342 assert!(ctx.input.is_some());
343 }
344
345 #[test]
346 fn test_hook_context_error() {
347 let ctx = HookContext::error("something went wrong");
348 assert_eq!(ctx.point, HookPoint::OnError);
349 assert_eq!(ctx.error.as_deref(), Some("something went wrong"));
350 }
351
352 #[test]
353 fn test_multiple_hooks_fire_in_order() {
354 let mut mgr = HookManager::new();
355 mgr.register(
356 HookPoint::BeforeToolExecution,
357 Box::new(CountingHook::new("first")),
358 );
359 mgr.register(
360 HookPoint::BeforeToolExecution,
361 Box::new(CountingHook::new("second")),
362 );
363 mgr.register(
364 HookPoint::BeforeToolExecution,
365 Box::new(CountingHook::new("third")),
366 );
367
368 let ctx = HookContext::tool(
370 HookPoint::BeforeToolExecution,
371 "test",
372 serde_json::json!({}),
373 );
374 assert_eq!(mgr.fire(&ctx), HookResult::Continue);
375 assert_eq!(mgr.count_at(HookPoint::BeforeToolExecution), 3);
376 }
377
378 #[test]
379 fn test_hook_point_serialization() {
380 let point = HookPoint::BeforeToolExecution;
381 let json = serde_json::to_string(&point).unwrap();
382 let restored: HookPoint = serde_json::from_str(&json).unwrap();
383 assert_eq!(restored, HookPoint::BeforeToolExecution);
384 }
385
386 #[test]
387 fn test_all_seven_hook_points() {
388 let points = vec![
389 HookPoint::BeforeToolExecution,
390 HookPoint::AfterToolExecution,
391 HookPoint::BeforeLlmRequest,
392 HookPoint::AfterLlmResponse,
393 HookPoint::OnSessionStart,
394 HookPoint::OnSessionEnd,
395 HookPoint::OnError,
396 ];
397 assert_eq!(points.len(), 7);
398
399 let mut mgr = HookManager::new();
401 for point in &points {
402 mgr.register(*point, Box::new(AllowHook));
403 }
404 assert_eq!(mgr.total_hooks(), 7);
405 }
406}