1use anyhow::Result;
47use std::{future::Future, pin::Pin};
48
49use dyn_clone::DynClone;
50use swiftide_core::chat_completion::{
51 ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall, ToolOutput,
52 errors::ToolError,
53};
54
55use crate::{Agent, errors::AgentError, state::StopReason};
56
57pub trait BeforeAllFn:
58 for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
59 + Send
60 + Sync
61 + DynClone
62{
63}
64
65dyn_clone::clone_trait_object!(BeforeAllFn);
66
67pub trait AfterEachFn:
68 for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
69 + Send
70 + Sync
71 + DynClone
72{
73}
74
75dyn_clone::clone_trait_object!(AfterEachFn);
76
77pub trait BeforeCompletionFn:
78 for<'a> Fn(
79 &'a Agent,
80 &mut ChatCompletionRequest,
81 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
82 + Send
83 + Sync
84 + DynClone
85{
86}
87
88dyn_clone::clone_trait_object!(BeforeCompletionFn);
89
90pub trait AfterCompletionFn:
91 for<'a> Fn(
92 &'a Agent,
93 &mut ChatCompletionResponse,
94 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
95 + Send
96 + Sync
97 + DynClone
98{
99}
100
101dyn_clone::clone_trait_object!(AfterCompletionFn);
102
103pub trait AfterToolFn:
105 for<'tool> Fn(
106 &'tool Agent,
107 &ToolCall,
108 &'tool mut Result<ToolOutput, ToolError>,
109 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
110 + Send
111 + Sync
112 + DynClone
113{
114}
115
116dyn_clone::clone_trait_object!(AfterToolFn);
117
118pub trait BeforeToolFn:
120 for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
121 + Send
122 + Sync
123 + DynClone
124{
125}
126
127dyn_clone::clone_trait_object!(BeforeToolFn);
128
129pub trait MessageHookFn:
131 for<'a> Fn(&'a Agent, &mut ChatMessage) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
132 + Send
133 + Sync
134 + DynClone
135{
136}
137
138dyn_clone::clone_trait_object!(MessageHookFn);
139
140pub trait OnStartFn:
142 for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
143 + Send
144 + Sync
145 + DynClone
146{
147}
148
149dyn_clone::clone_trait_object!(OnStartFn);
150
151pub trait OnStopFn:
153 for<'a> Fn(
154 &'a Agent,
155 StopReason,
156 Option<&AgentError>,
157 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
158 + Send
159 + Sync
160 + DynClone
161{
162}
163
164dyn_clone::clone_trait_object!(OnStopFn);
165
166pub trait OnStreamFn:
167 for<'a> Fn(
168 &'a Agent,
169 &ChatCompletionResponse,
170 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
171 + Send
172 + Sync
173 + DynClone
174{
175}
176
177dyn_clone::clone_trait_object!(OnStreamFn);
178
179#[derive(Clone, strum_macros::EnumDiscriminants, strum_macros::Display)]
181#[strum_discriminants(name(HookTypes), derive(strum_macros::Display))]
182pub enum Hook {
183 BeforeAll(Box<dyn BeforeAllFn>),
185 BeforeCompletion(Box<dyn BeforeCompletionFn>),
187 AfterCompletion(Box<dyn AfterCompletionFn>),
189 BeforeTool(Box<dyn BeforeToolFn>),
191 AfterTool(Box<dyn AfterToolFn>),
193 AfterEach(Box<dyn AfterEachFn>),
195 OnNewMessage(Box<dyn MessageHookFn>),
198 OnStart(Box<dyn OnStartFn>),
200 OnStop(Box<dyn OnStopFn>),
202 OnStream(Box<dyn OnStreamFn>),
204}
205
206impl<F> BeforeAllFn for F where
207 F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
208 + Send
209 + Sync
210 + DynClone
211{
212}
213
214impl<F> AfterEachFn for F where
215 F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
216 + Send
217 + Sync
218 + DynClone
219{
220}
221
222impl<F> BeforeCompletionFn for F where
223 F: for<'a> Fn(
224 &'a Agent,
225 &mut ChatCompletionRequest,
226 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
227 + Send
228 + Sync
229 + DynClone
230{
231}
232
233impl<F> AfterCompletionFn for F where
234 F: for<'a> Fn(
235 &'a Agent,
236 &mut ChatCompletionResponse,
237 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
238 + Send
239 + Sync
240 + DynClone
241{
242}
243
244impl<F> BeforeToolFn for F where
245 F: for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
246 + Send
247 + Sync
248 + DynClone
249{
250}
251impl<F> AfterToolFn for F where
252 F: for<'tool> Fn(
253 &'tool Agent,
254 &ToolCall,
255 &'tool mut Result<ToolOutput, ToolError>,
256 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
257 + Send
258 + Sync
259 + DynClone
260{
261}
262
263impl<F> MessageHookFn for F where
264 F: for<'a> Fn(
265 &'a Agent,
266 &mut ChatMessage,
267 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
268 + Send
269 + Sync
270 + DynClone
271{
272}
273
274impl<F> OnStartFn for F where
275 F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
276 + Send
277 + Sync
278 + DynClone
279{
280}
281
282impl<F> OnStopFn for F where
283 F: for<'a> Fn(
284 &'a Agent,
285 StopReason,
286 Option<&AgentError>,
287 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
288 + Send
289 + Sync
290 + DynClone
291{
292}
293
294impl<F> OnStreamFn for F where
295 F: for<'a> Fn(
296 &'a Agent,
297 &ChatCompletionResponse,
298 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
299 + Send
300 + Sync
301 + DynClone
302{
303}
304
305#[cfg(test)]
306mod tests {
307 use crate::Agent;
308
309 #[test]
310 fn test_hooks_compile_sync_and_async() {
311 Agent::builder()
312 .before_all(|_| Box::pin(async { Ok(()) }))
313 .on_start(|_| Box::pin(async { Ok(()) }))
314 .before_completion(|_, _| Box::pin(async { Ok(()) }))
315 .before_tool(|_, _| Box::pin(async { Ok(()) }))
316 .after_tool(|_, _, _| Box::pin(async { Ok(()) }))
317 .after_completion(|_, _| Box::pin(async { Ok(()) }));
318 }
319}