1use anyhow::Result;
47use std::{future::Future, pin::Pin};
48
49use dyn_clone::DynClone;
50use swiftide_core::chat_completion::{
51 errors::ToolError, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall,
52 ToolOutput,
53};
54
55use crate::Agent;
56
57pub trait BeforeAllFn:
58 for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
59 + Send
60 + Sync
61 + DynClone
62{
63}
64
65dyn_clone::clone_trait_object!(BeforeAllFn);
66
67pub trait AfterEachFn:
68 for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
69 + Send
70 + Sync
71 + DynClone
72{
73}
74
75dyn_clone::clone_trait_object!(AfterEachFn);
76
77pub trait BeforeCompletionFn:
78 for<'a> Fn(
79 &'a Agent,
80 &mut ChatCompletionRequest,
81 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
82 + Send
83 + Sync
84 + DynClone
85{
86}
87
88dyn_clone::clone_trait_object!(BeforeCompletionFn);
89
90pub trait AfterCompletionFn:
91 for<'a> Fn(
92 &'a Agent,
93 &mut ChatCompletionResponse,
94 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
95 + Send
96 + Sync
97 + DynClone
98{
99}
100
101dyn_clone::clone_trait_object!(AfterCompletionFn);
102
103pub trait AfterToolFn:
105 for<'tool> Fn(
106 &'tool Agent,
107 &ToolCall,
108 &'tool mut Result<ToolOutput, ToolError>,
109 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
110 + Send
111 + Sync
112 + DynClone
113{
114}
115
116dyn_clone::clone_trait_object!(AfterToolFn);
117
118pub trait BeforeToolFn:
120 for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
121 + Send
122 + Sync
123 + DynClone
124{
125}
126
127dyn_clone::clone_trait_object!(BeforeToolFn);
128
129pub trait MessageHookFn:
131 for<'a> Fn(&'a Agent, &mut ChatMessage) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
132 + Send
133 + Sync
134 + DynClone
135{
136}
137
138dyn_clone::clone_trait_object!(MessageHookFn);
139
140pub trait OnStartFn:
142 for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
143 + Send
144 + Sync
145 + DynClone
146{
147}
148
149dyn_clone::clone_trait_object!(OnStartFn);
150
151#[derive(Clone, strum_macros::EnumDiscriminants, strum_macros::Display)]
153#[strum_discriminants(name(HookTypes), derive(strum_macros::Display))]
154pub enum Hook {
155 BeforeAll(Box<dyn BeforeAllFn>),
157 BeforeCompletion(Box<dyn BeforeCompletionFn>),
159 AfterCompletion(Box<dyn AfterCompletionFn>),
161 BeforeTool(Box<dyn BeforeToolFn>),
163 AfterTool(Box<dyn AfterToolFn>),
165 AfterEach(Box<dyn AfterEachFn>),
167 OnNewMessage(Box<dyn MessageHookFn>),
170 OnStart(Box<dyn OnStartFn>),
172}
173
174impl<F> BeforeAllFn for F where
175 F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
176 + Send
177 + Sync
178 + DynClone
179{
180}
181
182impl<F> AfterEachFn for F where
183 F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
184 + Send
185 + Sync
186 + DynClone
187{
188}
189
190impl<F> BeforeCompletionFn for F where
191 F: for<'a> Fn(
192 &'a Agent,
193 &mut ChatCompletionRequest,
194 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
195 + Send
196 + Sync
197 + DynClone
198{
199}
200
201impl<F> AfterCompletionFn for F where
202 F: for<'a> Fn(
203 &'a Agent,
204 &mut ChatCompletionResponse,
205 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
206 + Send
207 + Sync
208 + DynClone
209{
210}
211
212impl<F> BeforeToolFn for F where
213 F: for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
214 + Send
215 + Sync
216 + DynClone
217{
218}
219impl<F> AfterToolFn for F where
220 F: for<'tool> Fn(
221 &'tool Agent,
222 &ToolCall,
223 &'tool mut Result<ToolOutput, ToolError>,
224 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
225 + Send
226 + Sync
227 + DynClone
228{
229}
230
231impl<F> MessageHookFn for F where
232 F: for<'a> Fn(
233 &'a Agent,
234 &mut ChatMessage,
235 ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
236 + Send
237 + Sync
238 + DynClone
239{
240}
241
242impl<F> OnStartFn for F where
243 F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
244 + Send
245 + Sync
246 + DynClone
247{
248}
249
250#[cfg(test)]
251mod tests {
252 use crate::Agent;
253
254 #[test]
255 fn test_hooks_compile_sync_and_async() {
256 Agent::builder()
257 .before_all(|_| Box::pin(async { Ok(()) }))
258 .on_start(|_| Box::pin(async { Ok(()) }))
259 .before_completion(|_, _| Box::pin(async { Ok(()) }))
260 .before_tool(|_, _| Box::pin(async { Ok(()) }))
261 .after_tool(|_, _, _| Box::pin(async { Ok(()) }))
262 .after_completion(|_, _| Box::pin(async { Ok(()) }));
263 }
264}