1use serde::{Deserialize, Serialize};
12use serde_json::Value;
13
14use crate::extensions::permissions::Permission;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum HookKind {
26 BeforeToolCall,
28 AfterToolCall,
30 BeforeMessage,
32 OnMessageComplete,
34 OnCompaction,
36 OnSessionStart,
38 OnSessionEnd,
40}
41
42impl HookKind {
43 pub fn as_str(&self) -> &'static str {
46 match self {
47 Self::BeforeToolCall => "before_tool_call",
48 Self::AfterToolCall => "after_tool_call",
49 Self::BeforeMessage => "before_message",
50 Self::OnMessageComplete => "on_message_complete",
51 Self::OnCompaction => "on_compaction",
52 Self::OnSessionStart => "on_session_start",
53 Self::OnSessionEnd => "on_session_end",
54 }
55 }
56
57 pub fn from_str(s: &str) -> Option<Self> {
62 match s {
63 "before_tool_call" => Some(Self::BeforeToolCall),
64 "after_tool_call" => Some(Self::AfterToolCall),
65 "before_message" => Some(Self::BeforeMessage),
66 "on_message_complete" => Some(Self::OnMessageComplete),
67 "on_compaction" => Some(Self::OnCompaction),
68 "on_session_start" => Some(Self::OnSessionStart),
69 "on_session_end" => Some(Self::OnSessionEnd),
70 _ => None,
71 }
72 }
73
74 pub fn allowed_action_names(&self) -> &'static [&'static str] {
76 match self {
77 Self::BeforeToolCall => &["continue", "block", "confirm", "modify"],
78 Self::AfterToolCall => &["continue"],
79 Self::BeforeMessage => &["continue", "inject"],
80 Self::OnMessageComplete | Self::OnCompaction | Self::OnSessionStart | Self::OnSessionEnd => &["continue"],
81 }
82 }
83
84 pub fn allows_tool_filter(&self) -> bool {
86 matches!(self, Self::BeforeToolCall | Self::AfterToolCall)
87 }
88
89 pub fn allows_result(&self, result: &HookResult) -> bool {
91 match (self, result) {
92 (_, HookResult::Continue) => true,
93 (Self::BeforeToolCall, HookResult::Block { .. }) => true,
94 (Self::BeforeToolCall, HookResult::Confirm { .. }) => true,
95 (Self::BeforeToolCall, HookResult::Modify { .. }) => true,
96 (Self::BeforeMessage, HookResult::Inject { .. }) => true,
97 _ => false,
98 }
99 }
100
101 pub fn required_permission(&self) -> Permission {
107 match self {
108 Self::BeforeToolCall | Self::AfterToolCall => Permission::ToolsIntercept,
109 Self::BeforeMessage | Self::OnMessageComplete | Self::OnCompaction => Permission::LlmContent,
110 Self::OnSessionStart | Self::OnSessionEnd => Permission::SessionLifecycle,
111 }
112 }
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct HookEvent {
135 pub kind: HookKind,
137 pub tool_name: Option<String>,
140 #[serde(default)]
143 pub tool_runtime_name: Option<String>,
144 pub tool_input: Option<Value>,
146 pub tool_output: Option<String>,
148 pub message: Option<String>,
150 pub session_id: Option<String>,
152 #[serde(default)]
156 pub transcript: Option<Vec<Value>>,
157 pub data: Value,
159}
160
161impl HookEvent {
162 pub fn before_tool_call(tool_name: &str, input: Value) -> Self {
164 Self {
165 kind: HookKind::BeforeToolCall,
166 tool_name: Some(tool_name.to_string()),
167 tool_input: Some(input),
168 tool_output: None,
169 message: None,
170 session_id: None,
171 tool_runtime_name: None,
172 transcript: None,
173 data: Value::Null,
174 }
175 }
176
177 pub fn after_tool_call(tool_name: &str, input: Value, output: String) -> Self {
181 const MAX_HOOK_OUTPUT: usize = 32 * 1024; let truncated_output = if output.len() > MAX_HOOK_OUTPUT {
183 let boundary = output
184 .char_indices()
185 .map(|(idx, _)| idx)
186 .take_while(|idx| *idx <= MAX_HOOK_OUTPUT)
187 .last()
188 .unwrap_or(0);
189 format!(
190 "{}…[truncated, {} total bytes]",
191 &output[..boundary],
192 output.len()
193 )
194 } else {
195 output
196 };
197 Self {
198 kind: HookKind::AfterToolCall,
199 tool_name: Some(tool_name.to_string()),
200 tool_input: Some(input),
201 tool_output: Some(truncated_output),
202 message: None,
203 session_id: None,
204 tool_runtime_name: None,
205 transcript: None,
206 data: Value::Null,
207 }
208 }
209
210 pub fn before_message(message: &str) -> Self {
212 Self {
213 kind: HookKind::BeforeMessage,
214 tool_name: None,
215 tool_input: None,
216 tool_output: None,
217 message: Some(message.to_string()),
218 session_id: None,
219 tool_runtime_name: None,
220 transcript: None,
221 data: Value::Null,
222 }
223 }
224
225 pub fn on_message_complete(message: &str, data: Value) -> Self {
227 Self {
228 kind: HookKind::OnMessageComplete,
229 tool_name: None,
230 tool_input: None,
231 tool_output: None,
232 message: Some(message.to_string()),
233 session_id: None,
234 tool_runtime_name: None,
235 transcript: None,
236 data,
237 }
238 }
239
240 pub fn on_compaction(
242 old_session_id: &str,
243 new_session_id: &str,
244 summary: &str,
245 message_count: usize,
246 mut data: Value,
247 ) -> Self {
248 if !data.is_object() {
249 data = Value::Object(Default::default());
250 }
251 if let Some(object) = data.as_object_mut() {
252 object.insert("old_session_id".to_string(), Value::String(old_session_id.to_string()));
253 object.insert("new_session_id".to_string(), Value::String(new_session_id.to_string()));
254 object.insert("message_count".to_string(), Value::Number(message_count.into()));
255 }
256 Self {
257 kind: HookKind::OnCompaction,
258 tool_name: None,
259 tool_input: None,
260 tool_output: None,
261 message: Some(summary.to_string()),
262 session_id: Some(new_session_id.to_string()),
263 tool_runtime_name: None,
264 transcript: None,
265 data,
266 }
267 }
268
269 pub fn on_session_start(session_id: &str) -> Self {
271 Self {
272 kind: HookKind::OnSessionStart,
273 tool_name: None,
274 tool_input: None,
275 tool_output: None,
276 message: None,
277 session_id: Some(session_id.to_string()),
278 tool_runtime_name: None,
279 transcript: None,
280 data: Value::Null,
281 }
282 }
283
284 pub fn on_session_end(session_id: &str, transcript: Option<Vec<Value>>) -> Self {
286 Self {
287 kind: HookKind::OnSessionEnd,
288 tool_name: None,
289 tool_input: None,
290 tool_output: None,
291 message: None,
292 session_id: Some(session_id.to_string()),
293 tool_runtime_name: None,
294 transcript,
295 data: Value::Null,
296 }
297 }
298}
299
300#[derive(Debug, Clone, Serialize, Deserialize)]
309#[serde(tag = "action", rename_all = "snake_case")]
310pub enum HookResult {
311 Continue,
313 Block { reason: String },
315 Inject { content: String },
318 Confirm { message: String },
321 Modify { input: Value },
323}
324
325impl Default for HookResult {
326 fn default() -> Self {
327 Self::Continue
328 }
329}
330
331#[cfg(test)]
334mod tests {
335 use super::*;
336 use serde_json::json;
337
338 #[test]
342 fn hook_kind_as_str_roundtrip() {
343 let all = [
344 HookKind::BeforeToolCall,
345 HookKind::AfterToolCall,
346 HookKind::BeforeMessage,
347 HookKind::OnMessageComplete,
348 HookKind::OnCompaction,
349 HookKind::OnSessionStart,
350 HookKind::OnSessionEnd,
351 ];
352 for kind in all {
353 let s = kind.as_str();
354 assert_eq!(
355 HookKind::from_str(s),
356 Some(kind),
357 "round-trip failed for {s}"
358 );
359 }
360 }
361
362 #[test]
364 fn hook_kind_from_str_unknown_returns_none() {
365 assert_eq!(HookKind::from_str(""), None);
366 assert_eq!(HookKind::from_str("BeforeToolCall"), None); assert_eq!(HookKind::from_str("on_crash"), None);
368 }
369
370 #[test]
372 fn hook_kind_serde_snake_case() {
373 let serialized = serde_json::to_string(&HookKind::BeforeToolCall).unwrap();
374 assert_eq!(serialized, r#""before_tool_call""#);
375
376 let back: HookKind = serde_json::from_str(r#""on_session_end""#).unwrap();
377 assert_eq!(back, HookKind::OnSessionEnd);
378 }
379
380 #[test]
382 fn hook_kind_required_permission() {
383 assert_eq!(
384 HookKind::BeforeToolCall.required_permission(),
385 Permission::ToolsIntercept
386 );
387 assert_eq!(
388 HookKind::AfterToolCall.required_permission(),
389 Permission::ToolsIntercept
390 );
391 assert_eq!(
392 HookKind::BeforeMessage.required_permission(),
393 Permission::LlmContent
394 );
395 assert_eq!(
396 HookKind::OnMessageComplete.required_permission(),
397 Permission::LlmContent
398 );
399 assert_eq!(
400 HookKind::OnCompaction.required_permission(),
401 Permission::LlmContent
402 );
403 assert_eq!(
404 HookKind::OnSessionStart.required_permission(),
405 Permission::SessionLifecycle
406 );
407 assert_eq!(
408 HookKind::OnSessionEnd.required_permission(),
409 Permission::SessionLifecycle
410 );
411 }
412
413 #[test]
416 fn hook_event_before_tool_call() {
417 let input = json!({"path": "/tmp/foo"});
418 let ev = HookEvent::before_tool_call("read_file", input.clone());
419
420 assert_eq!(ev.kind, HookKind::BeforeToolCall);
421 assert_eq!(ev.tool_name.as_deref(), Some("read_file"));
422 assert_eq!(ev.tool_input.as_ref(), Some(&input));
423 assert!(ev.tool_output.is_none());
424 assert!(ev.message.is_none());
425 assert!(ev.session_id.is_none());
426 assert_eq!(ev.data, Value::Null);
427 }
428
429 #[test]
430 fn hook_event_after_tool_call() {
431 let input = json!({"query": "select 1"});
432 let ev =
433 HookEvent::after_tool_call("sql_query", input.clone(), "1 row".to_string());
434
435 assert_eq!(ev.kind, HookKind::AfterToolCall);
436 assert_eq!(ev.tool_name.as_deref(), Some("sql_query"));
437 assert_eq!(ev.tool_input.as_ref(), Some(&input));
438 assert_eq!(ev.tool_output.as_deref(), Some("1 row"));
439 assert!(ev.message.is_none());
440 assert!(ev.session_id.is_none());
441 }
442
443 #[test]
444 fn hook_event_before_message() {
445 let ev = HookEvent::before_message("Hello, LLM");
446
447 assert_eq!(ev.kind, HookKind::BeforeMessage);
448 assert!(ev.tool_name.is_none());
449 assert!(ev.tool_input.is_none());
450 assert!(ev.tool_output.is_none());
451 assert_eq!(ev.message.as_deref(), Some("Hello, LLM"));
452 assert!(ev.session_id.is_none());
453 }
454
455 #[test]
456 fn hook_event_on_message_complete() {
457 let ev = HookEvent::on_message_complete("Done", json!({"content_block_count": 1}));
458
459 assert_eq!(ev.kind, HookKind::OnMessageComplete);
460 assert!(ev.tool_name.is_none());
461 assert!(ev.tool_input.is_none());
462 assert!(ev.tool_output.is_none());
463 assert_eq!(ev.message.as_deref(), Some("Done"));
464 assert_eq!(ev.data["content_block_count"], 1);
465 assert!(ev.session_id.is_none());
466 }
467
468 #[test]
469 fn hook_event_on_compaction() {
470 let ev = HookEvent::on_compaction(
471 "old-session",
472 "new-session",
473 "Summary",
474 7,
475 json!({"source": "manual"}),
476 );
477
478 assert_eq!(ev.kind, HookKind::OnCompaction);
479 assert_eq!(ev.message.as_deref(), Some("Summary"));
480 assert_eq!(ev.session_id.as_deref(), Some("new-session"));
481 assert_eq!(ev.data["old_session_id"], "old-session");
482 assert_eq!(ev.data["new_session_id"], "new-session");
483 assert_eq!(ev.data["message_count"], 7);
484 assert_eq!(ev.data["source"], "manual");
485 assert!(ev.transcript.is_none());
486 }
487
488 #[test]
489 fn hook_event_on_session_start() {
490 let ev = HookEvent::on_session_start("sess-abc-123");
491
492 assert_eq!(ev.kind, HookKind::OnSessionStart);
493 assert_eq!(ev.session_id.as_deref(), Some("sess-abc-123"));
494 assert!(ev.tool_name.is_none());
495 assert!(ev.message.is_none());
496 }
497
498 #[test]
499 fn hook_event_on_session_end() {
500 let ev = HookEvent::on_session_end("sess-abc-123", None);
501
502 assert_eq!(ev.kind, HookKind::OnSessionEnd);
503 assert_eq!(ev.session_id.as_deref(), Some("sess-abc-123"));
504 assert!(ev.tool_name.is_none());
505 assert!(ev.message.is_none());
506 }
507
508 #[test]
510 fn hook_event_serde_roundtrip() {
511 let ev = HookEvent::before_tool_call("bash", json!({"cmd": "ls"}));
512 let json = serde_json::to_string(&ev).unwrap();
513 let back: HookEvent = serde_json::from_str(&json).unwrap();
514
515 assert_eq!(back.kind, ev.kind);
516 assert_eq!(back.tool_name, ev.tool_name);
517 assert_eq!(back.tool_input, ev.tool_input);
518 }
519
520 #[test]
524 fn hook_result_default_is_continue() {
525 assert!(matches!(HookResult::default(), HookResult::Continue));
526 }
527
528 #[test]
530 fn hook_result_block_serde() {
531 let r = HookResult::Block {
532 reason: "denied by policy".to_string(),
533 };
534 let json = serde_json::to_string(&r).unwrap();
535 assert!(json.contains(r#""action":"block""#));
537 assert!(json.contains("denied by policy"));
538
539 let back: HookResult = serde_json::from_str(&json).unwrap();
540 assert!(matches!(back, HookResult::Block { reason } if reason == "denied by policy"));
541 }
542
543 #[test]
545 fn hook_result_confirm_serde() {
546 let r = HookResult::Confirm {
547 message: "Run this command?".to_string(),
548 };
549 let json = serde_json::to_string(&r).unwrap();
550 assert_eq!(json, r#"{"action":"confirm","message":"Run this command?"}"#);
551
552 let back: HookResult = serde_json::from_str(&json).unwrap();
553 assert!(matches!(back, HookResult::Confirm { message } if message == "Run this command?"));
554 }
555
556 #[test]
558 fn hook_result_modify_serde() {
559 let r = HookResult::Modify { input: json!({"command": "echo safe"}) };
560 let json = serde_json::to_string(&r).unwrap();
561 assert_eq!(json, r#"{"action":"modify","input":{"command":"echo safe"}}"#);
562
563 let back: HookResult = serde_json::from_str(&json).unwrap();
564 assert!(matches!(back, HookResult::Modify { input } if input == json!({"command": "echo safe"})));
565 }
566
567 #[test]
569 fn hook_result_continue_serde() {
570 let json = serde_json::to_string(&HookResult::Continue).unwrap();
571 assert_eq!(json, r#"{"action":"continue"}"#);
572 }
573}
574
575impl HookEvent {
576 pub fn with_runtime_name(mut self, name: &str) -> Self {
578 self.tool_runtime_name = Some(name.to_string());
579 self
580 }
581}