swiftide_agents/
hooks.rs

1//! Hooks are functions that are called at specific points in the agent lifecycle.
2//!
3//!
4//! Since rust does not have async closures, hooks have to return a boxed, pinned async block
5//! themselves.
6//!
7//! # Example
8//!
9//! ```no_run
10//! # use swiftide_core::{AgentContext, chat_completion::ChatMessage};
11//! # use swiftide_agents::Agent;
12//! # fn test() {
13//! # let mut agent = swiftide_agents::Agent::builder();
14//! agent.before_all(move |agent: &Agent| {
15//!     Box::pin(async move {
16//!         agent.context().add_message(ChatMessage::new_user("Hello, world")).await;
17//!         Ok(())
18//!     })
19//!});
20//!# }
21//!```
22//! Rust has a long outstanding issue where it captures outer lifetimes when returning an impl
23//! that also has lifetimes, see [this issue](https://github.com/rust-lang/rust/issues/42940)
24//!
25//! This can happen if you write a method like `fn return_hook(&self) -> impl HookFn`, where the
26//! owner also has a lifetime.
27//! The trick is to set an explicit lifetime on self, and hook, where self must outlive the hook.
28//!
29//! # Example
30//!
31//! ```no_run
32//! # use swiftide_core::{AgentContext};
33//! # use swiftide_agents::hooks::BeforeAllFn;
34//! # use swiftide_agents::Agent;
35//! struct SomeHook<'thing> {
36//!    thing: &'thing str
37//! }
38//!
39//! impl<'thing> SomeHook<'thing> {
40//!    fn return_hook<'tool>(&'thing self) -> impl BeforeAllFn + 'tool where 'thing: 'tool {
41//!     move |_: &Agent| {
42//!      Box::pin(async move {{ Ok(())}})
43//!     }
44//!   }
45//!}
46use anyhow::Result;
47use std::{future::Future, pin::Pin};
48
49use dyn_clone::DynClone;
50use swiftide_core::chat_completion::{
51    errors::ToolError, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall,
52    ToolOutput,
53};
54
55use crate::Agent;
56
57pub trait BeforeAllFn:
58    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
59    + Send
60    + Sync
61    + DynClone
62{
63}
64
65dyn_clone::clone_trait_object!(BeforeAllFn);
66
67pub trait AfterEachFn:
68    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
69    + Send
70    + Sync
71    + DynClone
72{
73}
74
75dyn_clone::clone_trait_object!(AfterEachFn);
76
77pub trait BeforeCompletionFn:
78    for<'a> Fn(
79        &'a Agent,
80        &mut ChatCompletionRequest,
81    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
82    + Send
83    + Sync
84    + DynClone
85{
86}
87
88dyn_clone::clone_trait_object!(BeforeCompletionFn);
89
90pub trait AfterCompletionFn:
91    for<'a> Fn(
92        &'a Agent,
93        &mut ChatCompletionResponse,
94    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
95    + Send
96    + Sync
97    + DynClone
98{
99}
100
101dyn_clone::clone_trait_object!(AfterCompletionFn);
102
103/// Hooks that are called after each tool
104///
105pub trait AfterToolFn:
106    for<'tool> Fn(
107        &'tool Agent,
108        &ToolCall,
109        &'tool mut Result<ToolOutput, ToolError>,
110    ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
111    + Send
112    + Sync
113    + DynClone
114{
115}
116
117dyn_clone::clone_trait_object!(AfterToolFn);
118
119/// Hooks that are called before each tool
120pub trait BeforeToolFn:
121    for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
122    + Send
123    + Sync
124    + DynClone
125{
126}
127
128dyn_clone::clone_trait_object!(BeforeToolFn);
129
130/// Hooks that are called when a new message is added to the `AgentContext`
131pub trait MessageHookFn:
132    for<'a> Fn(&'a Agent, &mut ChatMessage) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
133    + Send
134    + Sync
135    + DynClone
136{
137}
138
139dyn_clone::clone_trait_object!(MessageHookFn);
140
141/// Hooks that are called when the agent starts, either from pending or stopped
142pub trait OnStartFn:
143    for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
144    + Send
145    + Sync
146    + DynClone
147{
148}
149
150dyn_clone::clone_trait_object!(OnStartFn);
151
152/// Wrapper around the different types of hooks
153#[derive(Clone, strum_macros::EnumDiscriminants, strum_macros::Display)]
154#[strum_discriminants(name(HookTypes), derive(strum_macros::Display))]
155pub enum Hook {
156    /// Runs only once for the agent when it starts
157    BeforeAll(Box<dyn BeforeAllFn>),
158    /// Runs before every completion, yielding a mutable reference to the completion request
159    BeforeCompletion(Box<dyn BeforeCompletionFn>),
160    /// Runs after every completion, yielding a mutable reference to the completion response
161    AfterCompletion(Box<dyn AfterCompletionFn>),
162    /// Runs before every tool call, yielding a reference to the tool call
163    BeforeTool(Box<dyn BeforeToolFn>),
164    /// Runs after every tool call, yielding a reference to the tool call and a mutable result
165    AfterTool(Box<dyn AfterToolFn>),
166    /// Runs after all tools have completed and a single completion has been made
167    AfterEach(Box<dyn AfterEachFn>),
168    /// Runs when a new message is added to the `AgentContext`, yielding a mutable reference to the
169    /// message. This is only triggered when the message is added by the agent.
170    OnNewMessage(Box<dyn MessageHookFn>),
171    /// Runs when the agent starts, either from pending or stopped
172    OnStart(Box<dyn OnStartFn>),
173}
174
175impl<F> BeforeAllFn for F where
176    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
177        + Send
178        + Sync
179        + DynClone
180{
181}
182
183impl<F> AfterEachFn for F where
184    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
185        + Send
186        + Sync
187        + DynClone
188{
189}
190
191impl<F> BeforeCompletionFn for F where
192    F: for<'a> Fn(
193            &'a Agent,
194            &mut ChatCompletionRequest,
195        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
196        + Send
197        + Sync
198        + DynClone
199{
200}
201
202impl<F> AfterCompletionFn for F where
203    F: for<'a> Fn(
204            &'a Agent,
205            &mut ChatCompletionResponse,
206        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
207        + Send
208        + Sync
209        + DynClone
210{
211}
212
213impl<F> BeforeToolFn for F where
214    F: for<'a> Fn(&'a Agent, &ToolCall) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
215        + Send
216        + Sync
217        + DynClone
218{
219}
220impl<F> AfterToolFn for F where
221    F: for<'tool> Fn(
222            &'tool Agent,
223            &ToolCall,
224            &'tool mut Result<ToolOutput, ToolError>,
225        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'tool>>
226        + Send
227        + Sync
228        + DynClone
229{
230}
231
232impl<F> MessageHookFn for F where
233    F: for<'a> Fn(
234            &'a Agent,
235            &mut ChatMessage,
236        ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
237        + Send
238        + Sync
239        + DynClone
240{
241}
242
243impl<F> OnStartFn for F where
244    F: for<'a> Fn(&'a Agent) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
245        + Send
246        + Sync
247        + DynClone
248{
249}
250
251#[cfg(test)]
252mod tests {
253    use crate::Agent;
254
255    #[test]
256    fn test_hooks_compile_sync_and_async() {
257        Agent::builder()
258            .before_all(|_| Box::pin(async { Ok(()) }))
259            .on_start(|_| Box::pin(async { Ok(()) }))
260            .before_completion(|_, _| Box::pin(async { Ok(()) }))
261            .before_tool(|_, _| Box::pin(async { Ok(()) }))
262            .after_tool(|_, _, _| Box::pin(async { Ok(()) }))
263            .after_completion(|_, _| Box::pin(async { Ok(()) }));
264    }
265}