stakpak_shared/hooks/
mod.rs

1use serde::Serialize;
2use std::{collections::HashMap, fmt::Debug, fmt::Display};
3use tokio::task::JoinSet;
4use uuid::Uuid;
5
6/// Hook errors
7#[derive(Debug, thiserror::Error)]
8pub enum HookError {
9    #[error("Database error: {0}")]
10    DatabaseError(String),
11    #[error("Serialization error: {0}")]
12    SerializationError(String),
13    #[error("Hook execution failed: {0}")]
14    ExecutionError(String),
15}
16
17#[derive(Debug, Serialize)]
18pub struct HookContext<State: Clone + Serialize> {
19    pub session_id: Option<Uuid>,
20    pub new_checkpoint_id: Option<Uuid>,
21    pub request_id: Uuid,
22    pub state: State,
23
24    #[serde(skip)]
25    background_tasks: JoinSet<Result<(), HookError>>,
26}
27
28impl<State: Clone + Serialize> Clone for HookContext<State> {
29    fn clone(&self) -> Self {
30        Self {
31            session_id: self.session_id,
32            new_checkpoint_id: self.new_checkpoint_id,
33            request_id: self.request_id,
34            state: self.state.clone(),
35            background_tasks: JoinSet::new(),
36        }
37    }
38}
39
40impl<State: Clone + Serialize> HookContext<State> {
41    pub fn new(session_id: Option<Uuid>, state: State) -> Self {
42        Self {
43            session_id,
44            new_checkpoint_id: None,
45            request_id: Uuid::new_v4(),
46            state,
47            background_tasks: JoinSet::new(),
48        }
49    }
50
51    pub fn set_session_id(&mut self, session_id: Uuid) {
52        self.session_id = Some(session_id);
53    }
54
55    pub fn set_new_checkpoint_id(&mut self, new_checkpoint_id: Uuid) {
56        self.new_checkpoint_id = Some(new_checkpoint_id);
57    }
58}
59
60impl<State: Clone + Serialize> HookContext<State> {
61    pub fn spawn_task<F>(&mut self, task: F)
62    where
63        F: Future<Output = Result<(), HookError>> + Send + 'static,
64    {
65        self.background_tasks.spawn(task);
66    }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Hash)]
70pub enum LifecycleEvent {
71    // Request lifecycle
72    BeforeRequest,
73    AfterRequest,
74
75    // LLM interaction
76    BeforeInference,
77    AfterInference,
78
79    // Tool lifecycle
80    ToolCallRequested,
81    BeforeToolExecution,
82    AfterToolExecution,
83    ToolCallAborted,
84
85    // Errors
86    Error,
87}
88
89impl Display for LifecycleEvent {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        write!(f, "{:?}", self)
92    }
93}
94
95/// Control flow decisions from hooks
96#[derive(Debug, Default)]
97pub enum HookAction {
98    #[default]
99    Continue,
100    /// Skip remaining hooks for this event
101    Skip,
102    /// Abort the current operation
103    Abort { reason: String },
104}
105
106impl HookAction {
107    /// Convert hook action to Err on Abort
108    pub fn ok(self) -> Result<(), String> {
109        match self {
110            HookAction::Abort { reason } => Err(reason),
111            _ => Ok(()),
112        }
113    }
114}
115
116#[async_trait::async_trait]
117pub trait Hook<State: Clone + Serialize>: Send + Sync {
118    fn name(&self) -> &str;
119
120    /// Execution priority (lower = earlier execution)
121    fn priority(&self) -> u8 {
122        50
123    }
124
125    async fn execute(
126        &self,
127        ctx: &mut HookContext<State>,
128        event: &LifecycleEvent,
129    ) -> Result<HookAction, HookError>;
130}
131
132#[derive(Default)]
133pub struct HookRegistry<State> {
134    hooks: HashMap<LifecycleEvent, Vec<Box<dyn Hook<State>>>>,
135}
136impl<State: Clone + Serialize> std::fmt::Debug for HookRegistry<State> {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        let mut map = f.debug_map();
139        for (event, hooks) in &self.hooks {
140            let hook_names: Vec<&str> = hooks.iter().map(|hook| hook.name()).collect();
141            map.entry(event, &hook_names);
142        }
143        map.finish()
144    }
145}
146
147impl<State: Clone + Serialize> HookRegistry<State> {
148    pub fn register(&mut self, event: LifecycleEvent, hook: Box<dyn Hook<State>>) {
149        let hooks = self.hooks.entry(event).or_default();
150        hooks.push(hook);
151
152        // Sort by priority (lower = earlier execution)
153        hooks.sort_by_key(|h| h.priority());
154    }
155
156    pub async fn execute_hooks(
157        &self,
158        ctx: &mut HookContext<State>,
159        event: &LifecycleEvent,
160    ) -> Result<HookAction, HookError> {
161        let Some(hooks) = self.hooks.get(event) else {
162            return Ok(HookAction::Continue);
163        };
164
165        for hook in hooks {
166            match hook.execute(ctx, event).await? {
167                HookAction::Continue => continue,
168                HookAction::Skip => return Ok(HookAction::Skip),
169                HookAction::Abort { reason } => {
170                    return Ok(HookAction::Abort { reason });
171                }
172            }
173        }
174
175        Ok(HookAction::Continue)
176    }
177}
178
179/**
180Usage Example
181
182```rust
183use stakpak_shared::define_hook;
184use stakpak_shared::hooks::{HookAction, HookContext, HookError, LifecycleEvent, Hook};
185use chrono::{DateTime, Local};
186use tokio::fs::OpenOptions;
187use tokio::io::AsyncWriteExt;
188use serde::Serialize;
189use std::fmt::Debug;
190
191#[derive(Debug, Clone, Serialize)]
192pub struct State;
193
194pub struct LoggerHook;
195
196impl LoggerHook {
197    pub fn new() -> Self {
198        Self
199    }
200}
201
202define_hook!(
203    LoggerHook,
204    "logger",
205    async |&self, ctx: &mut HookContext<State>, event: &LifecycleEvent| {
206        let timestamp: DateTime<Local> = Local::now();
207        let log_message = format!(
208            "[{}] LoggerHook event: {:?}, {}\n",
209            timestamp.format("%Y-%m-%d %H:%M:%S%.3f"),
210            event,
211            serde_json::to_string(&ctx).unwrap_or_default(),
212        );
213
214        let mut file = OpenOptions::new()
215            .create(true)
216            .append(true)
217            .open("hook_events.log")
218            .await
219            .map_err(|e| HookError::ExecutionError(e.to_string()))?;
220
221        file.write_all(log_message.as_bytes())
222            .await
223            .map_err(|e| HookError::ExecutionError(e.to_string()))?;
224
225        Ok(HookAction::Continue)
226    }
227);
228```
229*/
230#[macro_export]
231macro_rules! define_hook {
232    ($name:ident, $hook_name:expr, async |&$self:ident, $ctx:ident: &mut HookContext<$state:ty>, $event:ident: &LifecycleEvent| $body:block) => {
233        #[async_trait::async_trait]
234        impl Hook<$state> for $name {
235            fn name(&self) -> &str {
236                $hook_name
237            }
238            async fn execute(
239                &$self,
240                $ctx: &mut HookContext<$state>,
241                $event: &LifecycleEvent,
242            ) -> Result<HookAction, HookError> {
243                $body
244            }
245        }
246    };
247}