1use anyhow::Result;
47use std::{future::Future, pin::Pin};
48
49use dyn_clone::DynClone;
50use swiftide_core::chat_completion::{
51 errors::ToolError, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall,
52 ToolOutput,
53};
54
55use crate::{errors::AgentError, state::StopReason, Agent};
56
57pub trait BeforeAllFn:
58 for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
59 + Send
60 + Sync
61 + DynClone
62{
63}
64
65dyn_clone::clone_trait_object!(BeforeAllFn);
66
67pub trait AfterEachFn:
68 for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
69 + Send
70 + Sync
71 + DynClone
72{
73}
74
75dyn_clone::clone_trait_object!(AfterEachFn);
76
77pub trait BeforeCompletionFn:
78 for<'a> Fn(
79 &'a Agent,
80 &mut ChatCompletionRequest,
81 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
82 + Send
83 + Sync
84 + DynClone
85{
86}
87
88dyn_clone::clone_trait_object!(BeforeCompletionFn);
89
90pub trait AfterCompletionFn:
91 for<'a> Fn(
92 &'a Agent,
93 &mut ChatCompletionResponse,
94 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
95 + Send
96 + Sync
97 + DynClone
98{
99}
100
101dyn_clone::clone_trait_object!(AfterCompletionFn);
102
103pub trait AfterToolFn:
105 for<'tool> Fn(
106 &'tool Agent,
107 &ToolCall,
108 &'tool mut Result<ToolOutput, ToolError>,
109 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
110 + Send
111 + Sync
112 + DynClone
113{
114}
115
116dyn_clone::clone_trait_object!(AfterToolFn);
117
118pub trait BeforeToolFn:
120 for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
121 + Send
122 + Sync
123 + DynClone
124{
125}
126
127dyn_clone::clone_trait_object!(BeforeToolFn);
128
129pub trait MessageHookFn:
131 for<'a> Fn(&'a Agent, &mut ChatMessage) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
132 + Send
133 + Sync
134 + DynClone
135{
136}
137
138dyn_clone::clone_trait_object!(MessageHookFn);
139
140pub trait OnStartFn:
142 for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
143 + Send
144 + Sync
145 + DynClone
146{
147}
148
149dyn_clone::clone_trait_object!(OnStartFn);
150
151pub trait OnStopFn:
153 for<'a> Fn(
154 &'a Agent,
155 StopReason,
156 Option<&AgentError>,
157 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
158 + Send
159 + Sync
160 + DynClone
161{
162}
163
164dyn_clone::clone_trait_object!(OnStopFn);
165
166#[derive(Clone, strum_macros::EnumDiscriminants, strum_macros::Display)]
168#[strum_discriminants(name(HookTypes), derive(strum_macros::Display))]
169pub enum Hook {
170 BeforeAll(Box<dyn BeforeAllFn>),
172 BeforeCompletion(Box<dyn BeforeCompletionFn>),
174 AfterCompletion(Box<dyn AfterCompletionFn>),
176 BeforeTool(Box<dyn BeforeToolFn>),
178 AfterTool(Box<dyn AfterToolFn>),
180 AfterEach(Box<dyn AfterEachFn>),
182 OnNewMessage(Box<dyn MessageHookFn>),
185 OnStart(Box<dyn OnStartFn>),
187 OnStop(Box<dyn OnStopFn>),
189}
190
191impl<F> BeforeAllFn for F where
192 F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
193 + Send
194 + Sync
195 + DynClone
196{
197}
198
199impl<F> AfterEachFn for F where
200 F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
201 + Send
202 + Sync
203 + DynClone
204{
205}
206
207impl<F> BeforeCompletionFn for F where
208 F: for<'a> Fn(
209 &'a Agent,
210 &mut ChatCompletionRequest,
211 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
212 + Send
213 + Sync
214 + DynClone
215{
216}
217
218impl<F> AfterCompletionFn for F where
219 F: for<'a> Fn(
220 &'a Agent,
221 &mut ChatCompletionResponse,
222 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
223 + Send
224 + Sync
225 + DynClone
226{
227}
228
229impl<F> BeforeToolFn for F where
230 F: for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
231 + Send
232 + Sync
233 + DynClone
234{
235}
236impl<F> AfterToolFn for F where
237 F: for<'tool> Fn(
238 &'tool Agent,
239 &ToolCall,
240 &'tool mut Result<ToolOutput, ToolError>,
241 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
242 + Send
243 + Sync
244 + DynClone
245{
246}
247
248impl<F> MessageHookFn for F where
249 F: for<'a> Fn(
250 &'a Agent,
251 &mut ChatMessage,
252 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
253 + Send
254 + Sync
255 + DynClone
256{
257}
258
259impl<F> OnStartFn for F where
260 F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
261 + Send
262 + Sync
263 + DynClone
264{
265}
266
267impl<F> OnStopFn for F where
268 F: for<'a> Fn(
269 &'a Agent,
270 StopReason,
271 Option<&AgentError>,
272 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
273 + Send
274 + Sync
275 + DynClone
276{
277}
278
279#[cfg(test)]
280mod tests {
281 use crate::Agent;
282
283 #[test]
284 fn test_hooks_compile_sync_and_async() {
285 Agent::builder()
286 .before_all(|_| Box::pin(async { Ok(()) }))
287 .on_start(|_| Box::pin(async { Ok(()) }))
288 .before_completion(|_, _| Box::pin(async { Ok(()) }))
289 .before_tool(|_, _| Box::pin(async { Ok(()) }))
290 .after_tool(|_, _, _| Box::pin(async { Ok(()) }))
291 .after_completion(|_, _| Box::pin(async { Ok(()) }));
292 }
293}