swiftide_agents/
hooks.rs

1//! Hooks are functions that are called at specific points in the agent lifecycle.
2//!
3//!
4//! Since rust does not have async closures, hooks have to return a boxed, pinned async block
5//! themselves.
6//!
7//! # Example
8//!
9//! ```no_run
10//! # use swiftide_core::{AgentContext, chat_completion::ChatMessage};
11//! # use swiftide_agents::Agent;
12//! # fn test() {
13//! # let mut agent = swiftide_agents::Agent::builder();
14//! agent.before_all(move |agent: &Agent| {
15//!     Box::pin(async move {
16//!         agent.context().add_message(ChatMessage::new_user("Hello, world")).await;
17//!         Ok(())
18//!     })
19//! });
20//! # }
21//! ```
22//! Rust has a long outstanding issue where it captures outer lifetimes when returning an impl
23//! that also has lifetimes, see [this issue](https://github.com/rust-lang/rust/issues/42940)
24//!
25//! This can happen if you write a method like `fn return_hook(&self) -> impl HookFn`, where the
26//! owner also has a lifetime.
27//! The trick is to set an explicit lifetime on self, and hook, where self must outlive the hook.
28//!
29//! # Example
30//!
31//! ```no_run
32//! # use swiftide_core::{AgentContext};
33//! # use swiftide_agents::hooks::BeforeAllFn;
34//! # use swiftide_agents::Agent;
35//! struct SomeHook<'thing> {
36//!    thing: &'thing str
37//! }
38//!
39//! impl<'thing> SomeHook<'thing> {
40//!    fn return_hook<'tool>(&'thing self) -> impl BeforeAllFn + 'tool where 'thing: 'tool {
41//!     move |_: &Agent| {
42//!      Box::pin(async move {{ Ok(())}})
43//!     }
44//!   }
45//! }
46use anyhow::Result;
47use std::{future::Future, pin::Pin};
48
49use dyn_clone::DynClone;
50use swiftide_core::chat_completion::{
51    errors::ToolError, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall,
52    ToolOutput,
53};
54
55use crate::{errors::AgentError, state::StopReason, Agent};
56
57pub trait BeforeAllFn:
58    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
59    + Send
60    + Sync
61    + DynClone
62{
63}
64
65dyn_clone::clone_trait_object!(BeforeAllFn);
66
67pub trait AfterEachFn:
68    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
69    + Send
70    + Sync
71    + DynClone
72{
73}
74
75dyn_clone::clone_trait_object!(AfterEachFn);
76
77pub trait BeforeCompletionFn:
78    for<'a> Fn(
79        &'a Agent,
80        &mut ChatCompletionRequest,
81    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
82    + Send
83    + Sync
84    + DynClone
85{
86}
87
88dyn_clone::clone_trait_object!(BeforeCompletionFn);
89
90pub trait AfterCompletionFn:
91    for<'a> Fn(
92        &'a Agent,
93        &mut ChatCompletionResponse,
94    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
95    + Send
96    + Sync
97    + DynClone
98{
99}
100
101dyn_clone::clone_trait_object!(AfterCompletionFn);
102
103/// Hooks that are called after each tool
104pub trait AfterToolFn:
105    for<'tool> Fn(
106        &'tool Agent,
107        &ToolCall,
108        &'tool mut Result<ToolOutput, ToolError>,
109    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
110    + Send
111    + Sync
112    + DynClone
113{
114}
115
116dyn_clone::clone_trait_object!(AfterToolFn);
117
118/// Hooks that are called before each tool
119pub trait BeforeToolFn:
120    for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
121    + Send
122    + Sync
123    + DynClone
124{
125}
126
127dyn_clone::clone_trait_object!(BeforeToolFn);
128
129/// Hooks that are called when a new message is added to the `AgentContext`
130pub trait MessageHookFn:
131    for<'a> Fn(&'a Agent, &mut ChatMessage) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
132    + Send
133    + Sync
134    + DynClone
135{
136}
137
138dyn_clone::clone_trait_object!(MessageHookFn);
139
140/// Hooks that are called when the agent starts, either from pending or stopped
141pub trait OnStartFn:
142    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
143    + Send
144    + Sync
145    + DynClone
146{
147}
148
149dyn_clone::clone_trait_object!(OnStartFn);
150
151/// Hooks that are called when the agent stop
152pub trait OnStopFn:
153    for<'a> Fn(
154        &'a Agent,
155        StopReason,
156        Option<&AgentError>,
157    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
158    + Send
159    + Sync
160    + DynClone
161{
162}
163
164dyn_clone::clone_trait_object!(OnStopFn);
165
166/// Wrapper around the different types of hooks
167#[derive(Clone, strum_macros::EnumDiscriminants, strum_macros::Display)]
168#[strum_discriminants(name(HookTypes), derive(strum_macros::Display))]
169pub enum Hook {
170    /// Runs only once for the agent when it starts
171    BeforeAll(Box<dyn BeforeAllFn>),
172    /// Runs before every completion, yielding a mutable reference to the completion request
173    BeforeCompletion(Box<dyn BeforeCompletionFn>),
174    /// Runs after every completion, yielding a mutable reference to the completion response
175    AfterCompletion(Box<dyn AfterCompletionFn>),
176    /// Runs before every tool call, yielding a reference to the tool call
177    BeforeTool(Box<dyn BeforeToolFn>),
178    /// Runs after every tool call, yielding a reference to the tool call and a mutable result
179    AfterTool(Box<dyn AfterToolFn>),
180    /// Runs after all tools have completed and a single completion has been made
181    AfterEach(Box<dyn AfterEachFn>),
182    /// Runs when a new message is added to the `AgentContext`, yielding a mutable reference to the
183    /// message. This is only triggered when the message is added by the agent.
184    OnNewMessage(Box<dyn MessageHookFn>),
185    /// Runs when the agent starts, either from pending or stopped
186    OnStart(Box<dyn OnStartFn>),
187    /// Runs when the agent stops
188    OnStop(Box<dyn OnStopFn>),
189}
190
191impl<F> BeforeAllFn for F where
192    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
193        + Send
194        + Sync
195        + DynClone
196{
197}
198
199impl<F> AfterEachFn for F where
200    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
201        + Send
202        + Sync
203        + DynClone
204{
205}
206
207impl<F> BeforeCompletionFn for F where
208    F: for<'a> Fn(
209            &'a Agent,
210            &mut ChatCompletionRequest,
211        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
212        + Send
213        + Sync
214        + DynClone
215{
216}
217
218impl<F> AfterCompletionFn for F where
219    F: for<'a> Fn(
220            &'a Agent,
221            &mut ChatCompletionResponse,
222        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
223        + Send
224        + Sync
225        + DynClone
226{
227}
228
229impl<F> BeforeToolFn for F where
230    F: for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
231        + Send
232        + Sync
233        + DynClone
234{
235}
236impl<F> AfterToolFn for F where
237    F: for<'tool> Fn(
238            &'tool Agent,
239            &ToolCall,
240            &'tool mut Result<ToolOutput, ToolError>,
241        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
242        + Send
243        + Sync
244        + DynClone
245{
246}
247
248impl<F> MessageHookFn for F where
249    F: for<'a> Fn(
250            &'a Agent,
251            &mut ChatMessage,
252        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
253        + Send
254        + Sync
255        + DynClone
256{
257}
258
259impl<F> OnStartFn for F where
260    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
261        + Send
262        + Sync
263        + DynClone
264{
265}
266
267impl<F> OnStopFn for F where
268    F: for<'a> Fn(
269            &'a Agent,
270            StopReason,
271            Option<&AgentError>,
272        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
273        + Send
274        + Sync
275        + DynClone
276{
277}
278
279#[cfg(test)]
280mod tests {
281    use crate::Agent;
282
283    #[test]
284    fn test_hooks_compile_sync_and_async() {
285        Agent::builder()
286            .before_all(|_| Box::pin(async { Ok(()) }))
287            .on_start(|_| Box::pin(async { Ok(()) }))
288            .before_completion(|_, _| Box::pin(async { Ok(()) }))
289            .before_tool(|_, _| Box::pin(async { Ok(()) }))
290            .after_tool(|_, _, _| Box::pin(async { Ok(()) }))
291            .after_completion(|_, _| Box::pin(async { Ok(()) }));
292    }
293}