Skip to main content

stakpak_shared/hooks/
mod.rs

1use serde::Serialize;
2use std::{collections::HashMap, fmt::Debug, fmt::Display};
3use tokio::task::JoinSet;
4use uuid::Uuid;
5
6/// Hook errors
7#[derive(Debug, thiserror::Error)]
8pub enum HookError {
9    #[error("Database error: {0}")]
10    DatabaseError(String),
11    #[error("Serialization error: {0}")]
12    SerializationError(String),
13    #[error("Hook execution failed: {0}")]
14    ExecutionError(String),
15}
16
17#[derive(Debug, Serialize)]
18pub struct HookContext<State: Clone + Serialize> {
19    pub session_id: Option<Uuid>,
20    pub new_checkpoint_id: Option<Uuid>,
21    pub request_id: Uuid,
22    pub state: State,
23
24    #[serde(skip)]
25    background_tasks: JoinSet<Result<(), HookError>>,
26}
27
28impl<State: Clone + Serialize> Clone for HookContext<State> {
29    fn clone(&self) -> Self {
30        Self {
31            session_id: self.session_id,
32            new_checkpoint_id: self.new_checkpoint_id,
33            request_id: self.request_id,
34            state: self.state.clone(),
35            background_tasks: JoinSet::new(),
36        }
37    }
38}
39
40impl<State: Clone + Serialize> HookContext<State> {
41    pub fn new(session_id: Option<Uuid>, state: State) -> Self {
42        Self {
43            session_id,
44            new_checkpoint_id: None,
45            request_id: Uuid::new_v4(),
46            state,
47            background_tasks: JoinSet::new(),
48        }
49    }
50
51    pub fn set_session_id(&mut self, session_id: Uuid) {
52        self.session_id = Some(session_id);
53    }
54
55    pub fn set_new_checkpoint_id(&mut self, new_checkpoint_id: Uuid) {
56        self.new_checkpoint_id = Some(new_checkpoint_id);
57    }
58}
59
60impl<State: Clone + Serialize> HookContext<State> {
61    pub fn spawn_task<F>(&mut self, task: F)
62    where
63        F: Future<Output = Result<(), HookError>> + Send + 'static,
64    {
65        self.background_tasks.spawn(task);
66    }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Hash)]
70pub enum LifecycleEvent {
71    // Request lifecycle
72    BeforeRequest,
73    AfterRequest,
74
75    // LLM interaction
76    BeforeInference,
77    AfterInference,
78
79    // Tool lifecycle
80    ToolCallRequested,
81    BeforeToolExecution,
82    AfterToolExecution,
83    ToolCallAborted,
84
85    // Errors
86    Error,
87}
88
89impl Display for LifecycleEvent {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        write!(f, "{:?}", self)
92    }
93}
94
95/// Control flow decisions from hooks
96#[derive(Debug, Default)]
97pub enum HookAction {
98    #[default]
99    Continue,
100    /// Skip remaining hooks for this event
101    Skip,
102    /// Abort the current operation
103    Abort {
104        name: Option<String>,
105        reason: String,
106    },
107}
108
109impl HookAction {
110    /// Convert hook action to Err on Abort
111    pub fn ok(self) -> Result<(), String> {
112        match self {
113            HookAction::Abort { name, reason } => Err(format!(
114                "[{}:hook_abort] {}",
115                name.unwrap_or_default(),
116                reason
117            )),
118            _ => Ok(()),
119        }
120    }
121}
122
123#[async_trait::async_trait]
124pub trait Hook<State: Clone + Serialize>: Send + Sync {
125    fn name(&self) -> &str;
126
127    /// Execution priority (lower = earlier execution)
128    fn priority(&self) -> u8 {
129        50
130    }
131
132    async fn execute(
133        &self,
134        ctx: &mut HookContext<State>,
135        event: &LifecycleEvent,
136    ) -> Result<HookAction, HookError>;
137}
138
139#[derive(Default)]
140pub struct HookRegistry<State> {
141    hooks: HashMap<LifecycleEvent, Vec<Box<dyn Hook<State>>>>,
142}
143impl<State: Clone + Serialize> std::fmt::Debug for HookRegistry<State> {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        let mut map = f.debug_map();
146        for (event, hooks) in &self.hooks {
147            let hook_names: Vec<&str> = hooks.iter().map(|hook| hook.name()).collect();
148            map.entry(event, &hook_names);
149        }
150        map.finish()
151    }
152}
153
154impl<State: Clone + Serialize> HookRegistry<State> {
155    pub fn register(&mut self, event: LifecycleEvent, hook: Box<dyn Hook<State>>) {
156        let hooks = self.hooks.entry(event).or_default();
157        hooks.push(hook);
158
159        // Sort by priority (lower = earlier execution)
160        hooks.sort_by_key(|h| h.priority());
161    }
162
163    pub async fn execute_hooks(
164        &self,
165        ctx: &mut HookContext<State>,
166        event: &LifecycleEvent,
167    ) -> Result<HookAction, HookError> {
168        let Some(hooks) = self.hooks.get(event) else {
169            return Ok(HookAction::Continue);
170        };
171
172        for hook in hooks {
173            match hook.execute(ctx, event).await? {
174                HookAction::Continue => continue,
175                HookAction::Skip => return Ok(HookAction::Skip),
176                HookAction::Abort { name, reason } => {
177                    return Ok(HookAction::Abort {
178                        name: Some(name.unwrap_or(hook.name().to_string())),
179                        reason,
180                    });
181                }
182            }
183        }
184
185        Ok(HookAction::Continue)
186    }
187}
188
189/**
190Usage Example
191
192```rust
193use stakpak_shared::define_hook;
194use stakpak_shared::hooks::{HookAction, HookContext, HookError, LifecycleEvent, Hook};
195use chrono::{DateTime, Local};
196use tokio::fs::OpenOptions;
197use tokio::io::AsyncWriteExt;
198use serde::Serialize;
199use std::fmt::Debug;
200
201#[derive(Debug, Clone, Serialize)]
202pub struct State;
203
204pub struct LoggerHook;
205
206impl LoggerHook {
207    pub fn new() -> Self {
208        Self
209    }
210}
211
212define_hook!(
213    LoggerHook,
214    "logger",
215    async |&self, ctx: &mut HookContext<State>, event: &LifecycleEvent| {
216        let timestamp: DateTime<Local> = Local::now();
217        let log_message = format!(
218            "[{}] LoggerHook event: {:?}, {}\n",
219            timestamp.format("%Y-%m-%d %H:%M:%S%.3f"),
220            event,
221            serde_json::to_string(&ctx).unwrap_or_default(),
222        );
223
224        let mut file = OpenOptions::new()
225            .create(true)
226            .append(true)
227            .open("hook_events.log")
228            .await
229            .map_err(|e| HookError::ExecutionError(e.to_string()))?;
230
231        file.write_all(log_message.as_bytes())
232            .await
233            .map_err(|e| HookError::ExecutionError(e.to_string()))?;
234
235        Ok(HookAction::Continue)
236    }
237);
238```
239*/
240#[macro_export]
241macro_rules! define_hook {
242    ($name:ident, $hook_name:expr, async |&$self:ident, $ctx:ident: &mut HookContext<$state:ty>, $event:ident: &LifecycleEvent| $body:block) => {
243        #[async_trait::async_trait]
244        impl Hook<$state> for $name {
245            fn name(&self) -> &str {
246                $hook_name
247            }
248            async fn execute(
249                &$self,
250                $ctx: &mut HookContext<$state>,
251                $event: &LifecycleEvent,
252            ) -> Result<HookAction, HookError> {
253                $body
254            }
255        }
256    };
257}