swiftide_agents/
hooks.rs

1//! Hooks are functions that are called at specific points in the agent lifecycle.
2//!
3//!
4//! Since rust does not have async closures, hooks have to return a boxed, pinned async block
5//! themselves.
6//!
7//! # Example
8//!
9//! ```no_run
10//! # use swiftide_core::{AgentContext, chat_completion::ChatMessage};
11//! # use swiftide_agents::Agent;
12//! # fn test() {
13//! # let mut agent = swiftide_agents::Agent::builder();
14//! agent.before_all(move |agent: &Agent| {
15//!     Box::pin(async move {
16//!         agent.context().add_message(ChatMessage::new_user("Hello, world")).await;
17//!         Ok(())
18//!     })
19//! });
20//! # }
21//! ```
22//! Rust has a long outstanding issue where it captures outer lifetimes when returning an impl
23//! that also has lifetimes, see [this issue](https://github.com/rust-lang/rust/issues/42940)
24//!
25//! This can happen if you write a method like `fn return_hook(&self) -> impl HookFn`, where the
26//! owner also has a lifetime.
27//! The trick is to set an explicit lifetime on self, and hook, where self must outlive the hook.
28//!
29//! # Example
30//!
31//! ```no_run
32//! # use swiftide_core::{AgentContext};
33//! # use swiftide_agents::hooks::BeforeAllFn;
34//! # use swiftide_agents::Agent;
35//! struct SomeHook<'thing> {
36//!    thing: &'thing str
37//! }
38//!
39//! impl<'thing> SomeHook<'thing> {
40//!    fn return_hook<'tool>(&'thing self) -> impl BeforeAllFn + 'tool where 'thing: 'tool {
41//!     move |_: &Agent| {
42//!      Box::pin(async move {{ Ok(())}})
43//!     }
44//!   }
45//! }
46use anyhow::Result;
47use std::{future::Future, pin::Pin};
48
49use dyn_clone::DynClone;
50use swiftide_core::chat_completion::{
51    ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall, ToolOutput,
52    errors::ToolError,
53};
54
55use crate::{Agent, errors::AgentError, state::StopReason};
56
57pub trait BeforeAllFn:
58    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
59    + Send
60    + Sync
61    + DynClone
62{
63}
64
65dyn_clone::clone_trait_object!(BeforeAllFn);
66
67pub trait AfterEachFn:
68    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
69    + Send
70    + Sync
71    + DynClone
72{
73}
74
75dyn_clone::clone_trait_object!(AfterEachFn);
76
77pub trait BeforeCompletionFn:
78    for<'a> Fn(
79        &'a Agent,
80        &mut ChatCompletionRequest,
81    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
82    + Send
83    + Sync
84    + DynClone
85{
86}
87
88dyn_clone::clone_trait_object!(BeforeCompletionFn);
89
90pub trait AfterCompletionFn:
91    for<'a> Fn(
92        &'a Agent,
93        &mut ChatCompletionResponse,
94    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
95    + Send
96    + Sync
97    + DynClone
98{
99}
100
101dyn_clone::clone_trait_object!(AfterCompletionFn);
102
103/// Hooks that are called after each tool
104pub trait AfterToolFn:
105    for<'tool> Fn(
106        &'tool Agent,
107        &ToolCall,
108        &'tool mut Result<ToolOutput, ToolError>,
109    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
110    + Send
111    + Sync
112    + DynClone
113{
114}
115
116dyn_clone::clone_trait_object!(AfterToolFn);
117
118/// Hooks that are called before each tool
119pub trait BeforeToolFn:
120    for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
121    + Send
122    + Sync
123    + DynClone
124{
125}
126
127dyn_clone::clone_trait_object!(BeforeToolFn);
128
129/// Hooks that are called when a new message is added to the `AgentContext`
130pub trait MessageHookFn:
131    for<'a> Fn(&'a Agent, &mut ChatMessage) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
132    + Send
133    + Sync
134    + DynClone
135{
136}
137
138dyn_clone::clone_trait_object!(MessageHookFn);
139
140/// Hooks that are called when the agent starts, either from pending or stopped
141pub trait OnStartFn:
142    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
143    + Send
144    + Sync
145    + DynClone
146{
147}
148
149dyn_clone::clone_trait_object!(OnStartFn);
150
151/// Hooks that are called when the agent stop
152pub trait OnStopFn:
153    for<'a> Fn(
154        &'a Agent,
155        StopReason,
156        Option<&AgentError>,
157    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
158    + Send
159    + Sync
160    + DynClone
161{
162}
163
164dyn_clone::clone_trait_object!(OnStopFn);
165
166pub trait OnStreamFn:
167    for<'a> Fn(
168        &'a Agent,
169        &ChatCompletionResponse,
170    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
171    + Send
172    + Sync
173    + DynClone
174{
175}
176
177dyn_clone::clone_trait_object!(OnStreamFn);
178
179/// Wrapper around the different types of hooks
180#[derive(Clone, strum_macros::EnumDiscriminants, strum_macros::Display)]
181#[strum_discriminants(name(HookTypes), derive(strum_macros::Display))]
182pub enum Hook {
183    /// Runs only once for the agent when it starts
184    BeforeAll(Box<dyn BeforeAllFn>),
185    /// Runs before every completion, yielding a mutable reference to the completion request
186    BeforeCompletion(Box<dyn BeforeCompletionFn>),
187    /// Runs after every completion, yielding a mutable reference to the completion response
188    AfterCompletion(Box<dyn AfterCompletionFn>),
189    /// Runs before every tool call, yielding a reference to the tool call
190    BeforeTool(Box<dyn BeforeToolFn>),
191    /// Runs after every tool call, yielding a reference to the tool call and a mutable result
192    AfterTool(Box<dyn AfterToolFn>),
193    /// Runs after all tools have completed and a single completion has been made
194    AfterEach(Box<dyn AfterEachFn>),
195    /// Runs when a new message is added to the `AgentContext`, yielding a mutable reference to the
196    /// message. This is only triggered when the message is added by the agent.
197    OnNewMessage(Box<dyn MessageHookFn>),
198    /// Runs when the agent starts, either from pending or stopped
199    OnStart(Box<dyn OnStartFn>),
200    /// Runs when the agent stops
201    OnStop(Box<dyn OnStopFn>),
202    /// Runs when the agent streams a response
203    OnStream(Box<dyn OnStreamFn>),
204}
205
206impl<F> BeforeAllFn for F where
207    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
208        + Send
209        + Sync
210        + DynClone
211{
212}
213
214impl<F> AfterEachFn for F where
215    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
216        + Send
217        + Sync
218        + DynClone
219{
220}
221
222impl<F> BeforeCompletionFn for F where
223    F: for<'a> Fn(
224            &'a Agent,
225            &mut ChatCompletionRequest,
226        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
227        + Send
228        + Sync
229        + DynClone
230{
231}
232
233impl<F> AfterCompletionFn for F where
234    F: for<'a> Fn(
235            &'a Agent,
236            &mut ChatCompletionResponse,
237        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
238        + Send
239        + Sync
240        + DynClone
241{
242}
243
244impl<F> BeforeToolFn for F where
245    F: for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
246        + Send
247        + Sync
248        + DynClone
249{
250}
251impl<F> AfterToolFn for F where
252    F: for<'tool> Fn(
253            &'tool Agent,
254            &ToolCall,
255            &'tool mut Result<ToolOutput, ToolError>,
256        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
257        + Send
258        + Sync
259        + DynClone
260{
261}
262
263impl<F> MessageHookFn for F where
264    F: for<'a> Fn(
265            &'a Agent,
266            &mut ChatMessage,
267        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
268        + Send
269        + Sync
270        + DynClone
271{
272}
273
274impl<F> OnStartFn for F where
275    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
276        + Send
277        + Sync
278        + DynClone
279{
280}
281
282impl<F> OnStopFn for F where
283    F: for<'a> Fn(
284            &'a Agent,
285            StopReason,
286            Option<&AgentError>,
287        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
288        + Send
289        + Sync
290        + DynClone
291{
292}
293
294impl<F> OnStreamFn for F where
295    F: for<'a> Fn(
296            &'a Agent,
297            &ChatCompletionResponse,
298        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
299        + Send
300        + Sync
301        + DynClone
302{
303}
304
305#[cfg(test)]
306mod tests {
307    use crate::Agent;
308
309    #[test]
310    fn test_hooks_compile_sync_and_async() {
311        Agent::builder()
312            .before_all(|_| Box::pin(async { Ok(()) }))
313            .on_start(|_| Box::pin(async { Ok(()) }))
314            .before_completion(|_, _| Box::pin(async { Ok(()) }))
315            .before_tool(|_, _| Box::pin(async { Ok(()) }))
316            .after_tool(|_, _, _| Box::pin(async { Ok(()) }))
317            .after_completion(|_, _| Box::pin(async { Ok(()) }));
318    }
319}