1use anyhow::Result;
47use std::{future::Future, pin::Pin};
48
49use dyn_clone::DynClone;
50use swiftide_core::chat_completion::{
51 errors::ToolError, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall,
52 ToolOutput,
53};
54
55use crate::Agent;
56
57pub trait BeforeAllFn:
58 for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
59 + Send
60 + Sync
61 + DynClone
62{
63}
64
65dyn_clone::clone_trait_object!(BeforeAllFn);
66
67pub trait AfterEachFn:
68 for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
69 + Send
70 + Sync
71 + DynClone
72{
73}
74
75dyn_clone::clone_trait_object!(AfterEachFn);
76
77pub trait BeforeCompletionFn:
78 for<'a> Fn(
79 &'a Agent,
80 &mut ChatCompletionRequest,
81 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
82 + Send
83 + Sync
84 + DynClone
85{
86}
87
88dyn_clone::clone_trait_object!(BeforeCompletionFn);
89
90pub trait AfterCompletionFn:
91 for<'a> Fn(
92 &'a Agent,
93 &mut ChatCompletionResponse,
94 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
95 + Send
96 + Sync
97 + DynClone
98{
99}
100
101dyn_clone::clone_trait_object!(AfterCompletionFn);
102
103pub trait AfterToolFn:
106 for<'tool> Fn(
107 &'tool Agent,
108 &ToolCall,
109 &'tool mut Result<ToolOutput, ToolError>,
110 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
111 + Send
112 + Sync
113 + DynClone
114{
115}
116
117dyn_clone::clone_trait_object!(AfterToolFn);
118
119pub trait BeforeToolFn:
121 for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
122 + Send
123 + Sync
124 + DynClone
125{
126}
127
128dyn_clone::clone_trait_object!(BeforeToolFn);
129
130pub trait MessageHookFn:
132 for<'a> Fn(&'a Agent, &mut ChatMessage) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
133 + Send
134 + Sync
135 + DynClone
136{
137}
138
139dyn_clone::clone_trait_object!(MessageHookFn);
140
141pub trait OnStartFn:
143 for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
144 + Send
145 + Sync
146 + DynClone
147{
148}
149
150dyn_clone::clone_trait_object!(OnStartFn);
151
152#[derive(Clone, strum_macros::EnumDiscriminants, strum_macros::Display)]
154#[strum_discriminants(name(HookTypes), derive(strum_macros::Display))]
155pub enum Hook {
156 BeforeAll(Box<dyn BeforeAllFn>),
158 BeforeCompletion(Box<dyn BeforeCompletionFn>),
160 AfterCompletion(Box<dyn AfterCompletionFn>),
162 BeforeTool(Box<dyn BeforeToolFn>),
164 AfterTool(Box<dyn AfterToolFn>),
166 AfterEach(Box<dyn AfterEachFn>),
168 OnNewMessage(Box<dyn MessageHookFn>),
171 OnStart(Box<dyn OnStartFn>),
173}
174
175impl<F> BeforeAllFn for F where
176 F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
177 + Send
178 + Sync
179 + DynClone
180{
181}
182
183impl<F> AfterEachFn for F where
184 F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
185 + Send
186 + Sync
187 + DynClone
188{
189}
190
191impl<F> BeforeCompletionFn for F where
192 F: for<'a> Fn(
193 &'a Agent,
194 &mut ChatCompletionRequest,
195 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
196 + Send
197 + Sync
198 + DynClone
199{
200}
201
202impl<F> AfterCompletionFn for F where
203 F: for<'a> Fn(
204 &'a Agent,
205 &mut ChatCompletionResponse,
206 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
207 + Send
208 + Sync
209 + DynClone
210{
211}
212
213impl<F> BeforeToolFn for F where
214 F: for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
215 + Send
216 + Sync
217 + DynClone
218{
219}
220impl<F> AfterToolFn for F where
221 F: for<'tool> Fn(
222 &'tool Agent,
223 &ToolCall,
224 &'tool mut Result<ToolOutput, ToolError>,
225 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
226 + Send
227 + Sync
228 + DynClone
229{
230}
231
232impl<F> MessageHookFn for F where
233 F: for<'a> Fn(
234 &'a Agent,
235 &mut ChatMessage,
236 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
237 + Send
238 + Sync
239 + DynClone
240{
241}
242
243impl<F> OnStartFn for F where
244 F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
245 + Send
246 + Sync
247 + DynClone
248{
249}
250
251#[cfg(test)]
252mod tests {
253 use crate::Agent;
254
255 #[test]
256 fn test_hooks_compile_sync_and_async() {
257 Agent::builder()
258 .before_all(|_| Box::pin(async { Ok(()) }))
259 .on_start(|_| Box::pin(async { Ok(()) }))
260 .before_completion(|_, _| Box::pin(async { Ok(()) }))
261 .before_tool(|_, _| Box::pin(async { Ok(()) }))
262 .after_tool(|_, _, _| Box::pin(async { Ok(()) }))
263 .after_completion(|_, _| Box::pin(async { Ok(()) }));
264 }
265}