1use async_trait::async_trait;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::types::{HookEvent, Layer2Result, SessionId};
11
12pub type HookCallback = Arc<dyn Fn(&HookContext) -> Layer2Result<()> + Send + Sync>;
14
15#[derive(Debug, Clone)]
17pub struct HookContext {
18 pub session_id: SessionId,
19 pub event: HookEvent,
20 pub timestamp: chrono::DateTime<chrono::Utc>,
21 pub data: serde_json::Value,
22 pub metadata: HashMap<String, String>,
23}
24
25impl HookContext {
26 pub fn new(session_id: SessionId, event: HookEvent) -> Self {
27 Self {
28 session_id,
29 event,
30 timestamp: chrono::Utc::now(),
31 data: serde_json::Value::Null,
32 metadata: HashMap::new(),
33 }
34 }
35
36 pub fn with_data(mut self, data: serde_json::Value) -> Self {
37 self.data = data;
38 self
39 }
40
41 pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
42 self.metadata.insert(key.to_string(), value.to_string());
43 self
44 }
45}
46
47#[async_trait]
49pub trait HookSystemTrait: Send + Sync {
50 fn on_before(&self, event: HookEvent, callback: HookCallback);
52
53 fn on_after(&self, event: HookEvent, callback: HookCallback);
55
56 async fn trigger(&self, context: &HookContext) -> Layer2Result<()>;
58
59 fn remove(&self, event: HookEvent, is_before: bool);
61
62 fn clear(&self);
64
65 fn count(&self, event: HookEvent) -> (usize, usize);
67}
68
69type HookRegistry = HashMap<HookEvent, Vec<HookCallback>>;
71
72pub struct HookSystem {
74 before_hooks: RwLock<HookRegistry>,
75 after_hooks: RwLock<HookRegistry>,
76}
77
78impl HookSystem {
79 pub fn new() -> Self {
80 Self {
81 before_hooks: RwLock::new(HashMap::new()),
82 after_hooks: RwLock::new(HashMap::new()),
83 }
84 }
85}
86
87impl Default for HookSystem {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93#[async_trait]
94impl HookSystemTrait for HookSystem {
95 fn on_before(&self, event: HookEvent, callback: HookCallback) {
96 let mut hooks = self.before_hooks.write();
97 hooks.entry(event).or_default().push(callback);
98 }
99
100 fn on_after(&self, event: HookEvent, callback: HookCallback) {
101 let mut hooks = self.after_hooks.write();
102 hooks.entry(event).or_default().push(callback);
103 }
104
105 async fn trigger(&self, context: &HookContext) -> Layer2Result<()> {
106 {
108 let hooks = self.before_hooks.read();
109 if let Some(callbacks) = hooks.get(&context.event) {
110 for callback in callbacks {
111 callback(context)?;
112 }
113 }
114 }
115
116 {
118 let hooks = self.after_hooks.read();
119 if let Some(callbacks) = hooks.get(&context.event) {
120 for callback in callbacks {
121 callback(context)?;
122 }
123 }
124 }
125
126 Ok(())
127 }
128
129 fn remove(&self, event: HookEvent, is_before: bool) {
130 let hooks = if is_before {
131 &self.before_hooks
132 } else {
133 &self.after_hooks
134 };
135
136 let mut hooks = hooks.write();
137 hooks.remove(&event);
138 }
139
140 fn clear(&self) {
141 self.before_hooks.write().clear();
142 self.after_hooks.write().clear();
143 }
144
145 fn count(&self, event: HookEvent) -> (usize, usize) {
146 let before = self
147 .before_hooks
148 .read()
149 .get(&event)
150 .map(|v| v.len())
151 .unwrap_or(0);
152 let after = self
153 .after_hooks
154 .read()
155 .get(&event)
156 .map(|v| v.len())
157 .unwrap_or(0);
158 (before, after)
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn test_hook_system_creation() {
168 let hooks = HookSystem::new();
169 let (before, after) = hooks.count(HookEvent::BeforeAgentStart);
170 assert_eq!(before, 0);
171 assert_eq!(after, 0);
172 }
173
174 #[test]
175 fn test_hook_registration() {
176 let hooks = HookSystem::new();
177 let callback: HookCallback = Arc::new(|_| Ok(()));
178
179 hooks.on_before(HookEvent::BeforeAgentStart, callback);
180
181 let (before, _) = hooks.count(HookEvent::BeforeAgentStart);
182 assert_eq!(before, 1);
183 }
184}