1use anyhow::Result;
47use std::{future::Future, pin::Pin};
48
49use dyn_clone::DynClone;
50use swiftide_core::chat_completion::{
51    ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall, ToolOutput,
52    errors::ToolError,
53};
54
55use crate::{Agent, errors::AgentError, state::StopReason};
56
57pub trait BeforeAllFn:
58    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
59    + Send
60    + Sync
61    + DynClone
62{
63}
64
65dyn_clone::clone_trait_object!(BeforeAllFn);
66
67pub trait AfterEachFn:
68    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
69    + Send
70    + Sync
71    + DynClone
72{
73}
74
75dyn_clone::clone_trait_object!(AfterEachFn);
76
77pub trait BeforeCompletionFn:
78    for<'a> Fn(
79        &'a Agent,
80        &mut ChatCompletionRequest,
81    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
82    + Send
83    + Sync
84    + DynClone
85{
86}
87
88dyn_clone::clone_trait_object!(BeforeCompletionFn);
89
90pub trait AfterCompletionFn:
91    for<'a> Fn(
92        &'a Agent,
93        &mut ChatCompletionResponse,
94    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
95    + Send
96    + Sync
97    + DynClone
98{
99}
100
101dyn_clone::clone_trait_object!(AfterCompletionFn);
102
103pub trait AfterToolFn:
105    for<'tool> Fn(
106        &'tool Agent,
107        &ToolCall,
108        &'tool mut Result<ToolOutput, ToolError>,
109    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
110    + Send
111    + Sync
112    + DynClone
113{
114}
115
116dyn_clone::clone_trait_object!(AfterToolFn);
117
118pub trait BeforeToolFn:
120    for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
121    + Send
122    + Sync
123    + DynClone
124{
125}
126
127dyn_clone::clone_trait_object!(BeforeToolFn);
128
129pub trait MessageHookFn:
131    for<'a> Fn(&'a Agent, &mut ChatMessage) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
132    + Send
133    + Sync
134    + DynClone
135{
136}
137
138dyn_clone::clone_trait_object!(MessageHookFn);
139
140pub trait OnStartFn:
142    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
143    + Send
144    + Sync
145    + DynClone
146{
147}
148
149dyn_clone::clone_trait_object!(OnStartFn);
150
151pub trait OnStopFn:
153    for<'a> Fn(
154        &'a Agent,
155        StopReason,
156        Option<&AgentError>,
157    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
158    + Send
159    + Sync
160    + DynClone
161{
162}
163
164dyn_clone::clone_trait_object!(OnStopFn);
165
166pub trait OnStreamFn:
167    for<'a> Fn(
168        &'a Agent,
169        &ChatCompletionResponse,
170    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
171    + Send
172    + Sync
173    + DynClone
174{
175}
176
177dyn_clone::clone_trait_object!(OnStreamFn);
178
179#[derive(Clone, strum_macros::EnumDiscriminants, strum_macros::Display)]
181#[strum_discriminants(name(HookTypes), derive(strum_macros::Display))]
182pub enum Hook {
183    BeforeAll(Box<dyn BeforeAllFn>),
185    BeforeCompletion(Box<dyn BeforeCompletionFn>),
187    AfterCompletion(Box<dyn AfterCompletionFn>),
189    BeforeTool(Box<dyn BeforeToolFn>),
191    AfterTool(Box<dyn AfterToolFn>),
193    AfterEach(Box<dyn AfterEachFn>),
195    OnNewMessage(Box<dyn MessageHookFn>),
198    OnStart(Box<dyn OnStartFn>),
200    OnStop(Box<dyn OnStopFn>),
202    OnStream(Box<dyn OnStreamFn>),
204}
205
206impl<F> BeforeAllFn for F where
207    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
208        + Send
209        + Sync
210        + DynClone
211{
212}
213
214impl<F> AfterEachFn for F where
215    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
216        + Send
217        + Sync
218        + DynClone
219{
220}
221
222impl<F> BeforeCompletionFn for F where
223    F: for<'a> Fn(
224            &'a Agent,
225            &mut ChatCompletionRequest,
226        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
227        + Send
228        + Sync
229        + DynClone
230{
231}
232
233impl<F> AfterCompletionFn for F where
234    F: for<'a> Fn(
235            &'a Agent,
236            &mut ChatCompletionResponse,
237        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
238        + Send
239        + Sync
240        + DynClone
241{
242}
243
244impl<F> BeforeToolFn for F where
245    F: for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
246        + Send
247        + Sync
248        + DynClone
249{
250}
251impl<F> AfterToolFn for F where
252    F: for<'tool> Fn(
253            &'tool Agent,
254            &ToolCall,
255            &'tool mut Result<ToolOutput, ToolError>,
256        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
257        + Send
258        + Sync
259        + DynClone
260{
261}
262
263impl<F> MessageHookFn for F where
264    F: for<'a> Fn(
265            &'a Agent,
266            &mut ChatMessage,
267        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
268        + Send
269        + Sync
270        + DynClone
271{
272}
273
274impl<F> OnStartFn for F where
275    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
276        + Send
277        + Sync
278        + DynClone
279{
280}
281
282impl<F> OnStopFn for F where
283    F: for<'a> Fn(
284            &'a Agent,
285            StopReason,
286            Option<&AgentError>,
287        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
288        + Send
289        + Sync
290        + DynClone
291{
292}
293
294impl<F> OnStreamFn for F where
295    F: for<'a> Fn(
296            &'a Agent,
297            &ChatCompletionResponse,
298        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
299        + Send
300        + Sync
301        + DynClone
302{
303}
304
305#[cfg(test)]
306mod tests {
307    use crate::Agent;
308
309    #[test]
310    fn test_hooks_compile_sync_and_async() {
311        Agent::builder()
312            .before_all(|_| Box::pin(async { Ok(()) }))
313            .on_start(|_| Box::pin(async { Ok(()) }))
314            .before_completion(|_, _| Box::pin(async { Ok(()) }))
315            .before_tool(|_, _| Box::pin(async { Ok(()) }))
316            .after_tool(|_, _, _| Box::pin(async { Ok(()) }))
317            .after_completion(|_, _| Box::pin(async { Ok(()) }));
318    }
319}