swiftide_agents/
hooks.rs

1//! Hooks are functions that are called at specific points in the agent lifecycle.
2//!
3//!
4//! Since rust does not have async closures, hooks have to return a boxed, pinned async block
5//! themselves.
6//!
7//! # Example
8//!
9//! ```no_run
10//! # use swiftide_core::{AgentContext, chat_completion::ChatMessage};
11//! # use swiftide_agents::Agent;
12//! # fn test() {
13//! # let mut agent = swiftide_agents::Agent::builder();
14//! agent.before_all(move |agent: &Agent| {
15//!     Box::pin(async move {
16//!         agent.context().add_message(ChatMessage::new_user("Hello, world")).await;
17//!         Ok(())
18//!     })
19//! });
20//! # }
21//! ```
22//! Rust has a long outstanding issue where it captures outer lifetimes when returning an impl
23//! that also has lifetimes, see [this issue](https://github.com/rust-lang/rust/issues/42940)
24//!
25//! This can happen if you write a method like `fn return_hook(&self) -> impl HookFn`, where the
26//! owner also has a lifetime.
27//! The trick is to set an explicit lifetime on self, and hook, where self must outlive the hook.
28//!
29//! # Example
30//!
31//! ```no_run
32//! # use swiftide_core::{AgentContext};
33//! # use swiftide_agents::hooks::BeforeAllFn;
34//! # use swiftide_agents::Agent;
35//! struct SomeHook<'thing> {
36//!    thing: &'thing str
37//! }
38//!
39//! impl<'thing> SomeHook<'thing> {
40//!    fn return_hook<'tool>(&'thing self) -> impl BeforeAllFn + 'tool where 'thing: 'tool {
41//!     move |_: &Agent| {
42//!      Box::pin(async move {{ Ok(())}})
43//!     }
44//!   }
45//! }
46use anyhow::Result;
47use std::{future::Future, pin::Pin};
48
49use dyn_clone::DynClone;
50use swiftide_core::chat_completion::{
51    errors::ToolError, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall,
52    ToolOutput,
53};
54
55use crate::Agent;
56
57pub trait BeforeAllFn:
58    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
59    + Send
60    + Sync
61    + DynClone
62{
63}
64
65dyn_clone::clone_trait_object!(BeforeAllFn);
66
67pub trait AfterEachFn:
68    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
69    + Send
70    + Sync
71    + DynClone
72{
73}
74
75dyn_clone::clone_trait_object!(AfterEachFn);
76
77pub trait BeforeCompletionFn:
78    for<'a> Fn(
79        &'a Agent,
80        &mut ChatCompletionRequest,
81    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
82    + Send
83    + Sync
84    + DynClone
85{
86}
87
88dyn_clone::clone_trait_object!(BeforeCompletionFn);
89
90pub trait AfterCompletionFn:
91    for<'a> Fn(
92        &'a Agent,
93        &mut ChatCompletionResponse,
94    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
95    + Send
96    + Sync
97    + DynClone
98{
99}
100
101dyn_clone::clone_trait_object!(AfterCompletionFn);
102
103/// Hooks that are called after each tool
104pub trait AfterToolFn:
105    for<'tool> Fn(
106        &'tool Agent,
107        &ToolCall,
108        &'tool mut Result<ToolOutput, ToolError>,
109    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
110    + Send
111    + Sync
112    + DynClone
113{
114}
115
116dyn_clone::clone_trait_object!(AfterToolFn);
117
118/// Hooks that are called before each tool
119pub trait BeforeToolFn:
120    for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
121    + Send
122    + Sync
123    + DynClone
124{
125}
126
127dyn_clone::clone_trait_object!(BeforeToolFn);
128
129/// Hooks that are called when a new message is added to the `AgentContext`
130pub trait MessageHookFn:
131    for<'a> Fn(&'a Agent, &mut ChatMessage) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
132    + Send
133    + Sync
134    + DynClone
135{
136}
137
138dyn_clone::clone_trait_object!(MessageHookFn);
139
140/// Hooks that are called when the agent starts, either from pending or stopped
141pub trait OnStartFn:
142    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
143    + Send
144    + Sync
145    + DynClone
146{
147}
148
149dyn_clone::clone_trait_object!(OnStartFn);
150
151/// Wrapper around the different types of hooks
152#[derive(Clone, strum_macros::EnumDiscriminants, strum_macros::Display)]
153#[strum_discriminants(name(HookTypes), derive(strum_macros::Display))]
154pub enum Hook {
155    /// Runs only once for the agent when it starts
156    BeforeAll(Box<dyn BeforeAllFn>),
157    /// Runs before every completion, yielding a mutable reference to the completion request
158    BeforeCompletion(Box<dyn BeforeCompletionFn>),
159    /// Runs after every completion, yielding a mutable reference to the completion response
160    AfterCompletion(Box<dyn AfterCompletionFn>),
161    /// Runs before every tool call, yielding a reference to the tool call
162    BeforeTool(Box<dyn BeforeToolFn>),
163    /// Runs after every tool call, yielding a reference to the tool call and a mutable result
164    AfterTool(Box<dyn AfterToolFn>),
165    /// Runs after all tools have completed and a single completion has been made
166    AfterEach(Box<dyn AfterEachFn>),
167    /// Runs when a new message is added to the `AgentContext`, yielding a mutable reference to the
168    /// message. This is only triggered when the message is added by the agent.
169    OnNewMessage(Box<dyn MessageHookFn>),
170    /// Runs when the agent starts, either from pending or stopped
171    OnStart(Box<dyn OnStartFn>),
172}
173
174impl<F> BeforeAllFn for F where
175    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
176        + Send
177        + Sync
178        + DynClone
179{
180}
181
182impl<F> AfterEachFn for F where
183    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
184        + Send
185        + Sync
186        + DynClone
187{
188}
189
190impl<F> BeforeCompletionFn for F where
191    F: for<'a> Fn(
192            &'a Agent,
193            &mut ChatCompletionRequest,
194        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
195        + Send
196        + Sync
197        + DynClone
198{
199}
200
201impl<F> AfterCompletionFn for F where
202    F: for<'a> Fn(
203            &'a Agent,
204            &mut ChatCompletionResponse,
205        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
206        + Send
207        + Sync
208        + DynClone
209{
210}
211
212impl<F> BeforeToolFn for F where
213    F: for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
214        + Send
215        + Sync
216        + DynClone
217{
218}
219impl<F> AfterToolFn for F where
220    F: for<'tool> Fn(
221            &'tool Agent,
222            &ToolCall,
223            &'tool mut Result<ToolOutput, ToolError>,
224        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
225        + Send
226        + Sync
227        + DynClone
228{
229}
230
231impl<F> MessageHookFn for F where
232    F: for<'a> Fn(
233            &'a Agent,
234            &mut ChatMessage,
235        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
236        + Send
237        + Sync
238        + DynClone
239{
240}
241
242impl<F> OnStartFn for F where
243    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
244        + Send
245        + Sync
246        + DynClone
247{
248}
249
250#[cfg(test)]
251mod tests {
252    use crate::Agent;
253
254    #[test]
255    fn test_hooks_compile_sync_and_async() {
256        Agent::builder()
257            .before_all(|_| Box::pin(async { Ok(()) }))
258            .on_start(|_| Box::pin(async { Ok(()) }))
259            .before_completion(|_, _| Box::pin(async { Ok(()) }))
260            .before_tool(|_, _| Box::pin(async { Ok(()) }))
261            .after_tool(|_, _, _| Box::pin(async { Ok(()) }))
262            .after_completion(|_, _| Box::pin(async { Ok(()) }));
263    }
264}