Skip to main content

rustant_plugins/
hooks.rs

1//! Hook system — 7 hook points for plugin interception.
2//!
3//! Hooks allow plugins to intercept and modify agent behavior at defined points.
4//! Each hook returns a `HookResult` indicating whether execution should continue or be blocked.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// The 7 hook points in the agent lifecycle.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub enum HookPoint {
12    /// Before a tool is executed.
13    BeforeToolExecution,
14    /// After a tool has executed.
15    AfterToolExecution,
16    /// Before an LLM request is sent.
17    BeforeLlmRequest,
18    /// After an LLM response is received.
19    AfterLlmResponse,
20    /// When a new session starts.
21    OnSessionStart,
22    /// When a session ends.
23    OnSessionEnd,
24    /// When an error occurs.
25    OnError,
26}
27
28/// Context passed to hooks.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct HookContext {
31    /// The hook point being triggered.
32    pub point: HookPoint,
33    /// Name of the relevant tool/provider (if applicable).
34    pub name: Option<String>,
35    /// Input data (tool args, LLM prompt, etc.).
36    pub input: Option<serde_json::Value>,
37    /// Output data (tool result, LLM response, etc.).
38    pub output: Option<serde_json::Value>,
39    /// Error message (for OnError).
40    pub error: Option<String>,
41    /// Session ID.
42    pub session_id: Option<String>,
43}
44
45impl HookContext {
46    /// Create a context for a tool execution hook.
47    pub fn tool(point: HookPoint, name: &str, args: serde_json::Value) -> Self {
48        Self {
49            point,
50            name: Some(name.into()),
51            input: Some(args),
52            output: None,
53            error: None,
54            session_id: None,
55        }
56    }
57
58    /// Create a context for an LLM request/response hook.
59    pub fn llm(point: HookPoint, provider: &str, data: serde_json::Value) -> Self {
60        let (input, output) = if point == HookPoint::BeforeLlmRequest {
61            (Some(data), None)
62        } else {
63            (None, Some(data))
64        };
65        Self {
66            point,
67            name: Some(provider.into()),
68            input,
69            output,
70            error: None,
71            session_id: None,
72        }
73    }
74
75    /// Create a context for session lifecycle hooks.
76    pub fn session(point: HookPoint, session_id: &str) -> Self {
77        Self {
78            point,
79            name: None,
80            input: None,
81            output: None,
82            error: None,
83            session_id: Some(session_id.into()),
84        }
85    }
86
87    /// Create a context for error hooks.
88    pub fn error(error_message: &str) -> Self {
89        Self {
90            point: HookPoint::OnError,
91            name: None,
92            input: None,
93            output: None,
94            error: Some(error_message.into()),
95            session_id: None,
96        }
97    }
98}
99
100/// Result of hook execution.
101#[derive(Debug, Clone, PartialEq, Eq)]
102pub enum HookResult {
103    /// Allow execution to continue.
104    Continue,
105    /// Block execution (with reason).
106    Block(String),
107    /// Continue but with modified context.
108    Modified,
109}
110
111/// Trait for hook implementations.
112pub trait Hook: Send + Sync {
113    /// Execute the hook and return a result.
114    fn execute(&self, context: &HookContext) -> HookResult;
115
116    /// Display name for logging.
117    fn name(&self) -> &str;
118}
119
120/// Manages hook registration and firing.
121pub struct HookManager {
122    hooks: HashMap<HookPoint, Vec<Box<dyn Hook>>>,
123}
124
125impl HookManager {
126    /// Create a new empty hook manager.
127    pub fn new() -> Self {
128        Self {
129            hooks: HashMap::new(),
130        }
131    }
132
133    /// Register a hook at a specific point.
134    pub fn register(&mut self, point: HookPoint, hook: Box<dyn Hook>) {
135        self.hooks.entry(point).or_default().push(hook);
136    }
137
138    /// Fire all hooks at a given point.
139    /// Returns `HookResult::Continue` if all hooks allow, or the first `Block` result.
140    pub fn fire(&self, context: &HookContext) -> HookResult {
141        let hooks = match self.hooks.get(&context.point) {
142            Some(hooks) => hooks,
143            None => return HookResult::Continue,
144        };
145
146        let mut result = HookResult::Continue;
147        for hook in hooks {
148            let hook_result = hook.execute(context);
149            match &hook_result {
150                HookResult::Block(_) => return hook_result,
151                HookResult::Modified => result = HookResult::Modified,
152                HookResult::Continue => {}
153            }
154        }
155        result
156    }
157
158    /// Number of hooks registered at a specific point.
159    pub fn count_at(&self, point: HookPoint) -> usize {
160        self.hooks.get(&point).map(|v| v.len()).unwrap_or(0)
161    }
162
163    /// Total number of registered hooks.
164    pub fn total_hooks(&self) -> usize {
165        self.hooks.values().map(|v| v.len()).sum()
166    }
167
168    /// Clear all hooks.
169    pub fn clear(&mut self) {
170        self.hooks.clear();
171    }
172}
173
174impl Default for HookManager {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    struct AllowHook;
185    impl Hook for AllowHook {
186        fn execute(&self, _context: &HookContext) -> HookResult {
187            HookResult::Continue
188        }
189        fn name(&self) -> &str {
190            "allow"
191        }
192    }
193
194    struct BlockHook {
195        reason: String,
196    }
197    impl BlockHook {
198        fn new(reason: &str) -> Self {
199            Self {
200                reason: reason.into(),
201            }
202        }
203    }
204    impl Hook for BlockHook {
205        fn execute(&self, _context: &HookContext) -> HookResult {
206            HookResult::Block(self.reason.clone())
207        }
208        fn name(&self) -> &str {
209            "block"
210        }
211    }
212
213    struct ModifyHook;
214    impl Hook for ModifyHook {
215        fn execute(&self, _context: &HookContext) -> HookResult {
216            HookResult::Modified
217        }
218        fn name(&self) -> &str {
219            "modify"
220        }
221    }
222
223    struct CountingHook {
224        name: String,
225    }
226    impl CountingHook {
227        fn new(name: &str) -> Self {
228            Self { name: name.into() }
229        }
230    }
231    impl Hook for CountingHook {
232        fn execute(&self, _context: &HookContext) -> HookResult {
233            HookResult::Continue
234        }
235        fn name(&self) -> &str {
236            &self.name
237        }
238    }
239
240    #[test]
241    fn test_hook_manager_register_and_fire() {
242        let mut mgr = HookManager::new();
243        mgr.register(HookPoint::BeforeToolExecution, Box::new(AllowHook));
244
245        let ctx = HookContext::tool(
246            HookPoint::BeforeToolExecution,
247            "shell_exec",
248            serde_json::json!({"cmd": "ls"}),
249        );
250        let result = mgr.fire(&ctx);
251        assert_eq!(result, HookResult::Continue);
252    }
253
254    #[test]
255    fn test_hook_manager_block() {
256        let mut mgr = HookManager::new();
257        mgr.register(
258            HookPoint::BeforeToolExecution,
259            Box::new(BlockHook::new("dangerous")),
260        );
261
262        let ctx = HookContext::tool(HookPoint::BeforeToolExecution, "rm", serde_json::json!({}));
263        let result = mgr.fire(&ctx);
264        assert_eq!(result, HookResult::Block("dangerous".into()));
265    }
266
267    #[test]
268    fn test_hook_ordering_block_stops_chain() {
269        let mut mgr = HookManager::new();
270        mgr.register(HookPoint::BeforeToolExecution, Box::new(AllowHook));
271        mgr.register(
272            HookPoint::BeforeToolExecution,
273            Box::new(BlockHook::new("blocked")),
274        );
275        mgr.register(HookPoint::BeforeToolExecution, Box::new(AllowHook));
276
277        let ctx = HookContext::tool(
278            HookPoint::BeforeToolExecution,
279            "test",
280            serde_json::json!({}),
281        );
282        let result = mgr.fire(&ctx);
283        assert_eq!(result, HookResult::Block("blocked".into()));
284    }
285
286    #[test]
287    fn test_hook_modified_result() {
288        let mut mgr = HookManager::new();
289        mgr.register(HookPoint::AfterLlmResponse, Box::new(ModifyHook));
290        mgr.register(HookPoint::AfterLlmResponse, Box::new(AllowHook));
291
292        let ctx = HookContext::llm(
293            HookPoint::AfterLlmResponse,
294            "openai",
295            serde_json::json!({"text": "hello"}),
296        );
297        let result = mgr.fire(&ctx);
298        assert_eq!(result, HookResult::Modified);
299    }
300
301    #[test]
302    fn test_hook_fire_no_hooks() {
303        let mgr = HookManager::new();
304        let ctx = HookContext::session(HookPoint::OnSessionStart, "session-1");
305        let result = mgr.fire(&ctx);
306        assert_eq!(result, HookResult::Continue);
307    }
308
309    #[test]
310    fn test_hook_manager_count() {
311        let mut mgr = HookManager::new();
312        mgr.register(HookPoint::BeforeToolExecution, Box::new(AllowHook));
313        mgr.register(HookPoint::BeforeToolExecution, Box::new(AllowHook));
314        mgr.register(HookPoint::OnError, Box::new(AllowHook));
315
316        assert_eq!(mgr.count_at(HookPoint::BeforeToolExecution), 2);
317        assert_eq!(mgr.count_at(HookPoint::OnError), 1);
318        assert_eq!(mgr.count_at(HookPoint::OnSessionEnd), 0);
319        assert_eq!(mgr.total_hooks(), 3);
320    }
321
322    #[test]
323    fn test_hook_manager_clear() {
324        let mut mgr = HookManager::new();
325        mgr.register(HookPoint::BeforeToolExecution, Box::new(AllowHook));
326        mgr.register(HookPoint::OnError, Box::new(AllowHook));
327        assert_eq!(mgr.total_hooks(), 2);
328
329        mgr.clear();
330        assert_eq!(mgr.total_hooks(), 0);
331    }
332
333    #[test]
334    fn test_hook_context_tool() {
335        let ctx = HookContext::tool(
336            HookPoint::BeforeToolExecution,
337            "shell_exec",
338            serde_json::json!({"cmd": "ls"}),
339        );
340        assert_eq!(ctx.point, HookPoint::BeforeToolExecution);
341        assert_eq!(ctx.name.as_deref(), Some("shell_exec"));
342        assert!(ctx.input.is_some());
343    }
344
345    #[test]
346    fn test_hook_context_error() {
347        let ctx = HookContext::error("something went wrong");
348        assert_eq!(ctx.point, HookPoint::OnError);
349        assert_eq!(ctx.error.as_deref(), Some("something went wrong"));
350    }
351
352    #[test]
353    fn test_multiple_hooks_fire_in_order() {
354        let mut mgr = HookManager::new();
355        mgr.register(
356            HookPoint::BeforeToolExecution,
357            Box::new(CountingHook::new("first")),
358        );
359        mgr.register(
360            HookPoint::BeforeToolExecution,
361            Box::new(CountingHook::new("second")),
362        );
363        mgr.register(
364            HookPoint::BeforeToolExecution,
365            Box::new(CountingHook::new("third")),
366        );
367
368        // All continue, so result should be Continue
369        let ctx = HookContext::tool(
370            HookPoint::BeforeToolExecution,
371            "test",
372            serde_json::json!({}),
373        );
374        assert_eq!(mgr.fire(&ctx), HookResult::Continue);
375        assert_eq!(mgr.count_at(HookPoint::BeforeToolExecution), 3);
376    }
377
378    #[test]
379    fn test_hook_point_serialization() {
380        let point = HookPoint::BeforeToolExecution;
381        let json = serde_json::to_string(&point).unwrap();
382        let restored: HookPoint = serde_json::from_str(&json).unwrap();
383        assert_eq!(restored, HookPoint::BeforeToolExecution);
384    }
385
386    #[test]
387    fn test_all_seven_hook_points() {
388        let points = vec![
389            HookPoint::BeforeToolExecution,
390            HookPoint::AfterToolExecution,
391            HookPoint::BeforeLlmRequest,
392            HookPoint::AfterLlmResponse,
393            HookPoint::OnSessionStart,
394            HookPoint::OnSessionEnd,
395            HookPoint::OnError,
396        ];
397        assert_eq!(points.len(), 7);
398
399        // Each can be used as a hash key
400        let mut mgr = HookManager::new();
401        for point in &points {
402            mgr.register(*point, Box::new(AllowHook));
403        }
404        assert_eq!(mgr.total_hooks(), 7);
405    }
406}