swiftide_agents/
agent.rs

1#![allow(dead_code)]
2use crate::{
3    default_context::DefaultContext,
4    hooks::{
5        AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
6        Hook, HookTypes, MessageHookFn, OnStartFn,
7    },
8    state,
9    system_prompt::SystemPrompt,
10    tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
11};
12use std::{
13    collections::{HashMap, HashSet},
14    hash::{DefaultHasher, Hash as _, Hasher as _},
15    sync::Arc,
16};
17
18use anyhow::Result;
19use derive_builder::Builder;
20use swiftide_core::{
21    chat_completion::{
22        ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
23    },
24    prompt::Prompt,
25    AgentContext,
26};
27use tracing::{debug, Instrument};
28
29/// Agents are the main interface for building agentic systems.
30///
31/// Construct agents by calling the builder, setting an llm, configure hooks, tools and other
32/// customizations.
33///
34/// # Important defaults
35///
36/// - The default context is the `DefaultContext`, executing tools locally with the `LocalExecutor`.
37/// - A default `stop` tool is provided for agents to explicitly stop if needed
38/// - The default `SystemPrompt` instructs the agent with chain of thought and some common
39///   safeguards, but is otherwise quite bare. In a lot of cases this can be sufficient.
40#[derive(Clone, Builder)]
41pub struct Agent {
42    /// Hooks are functions that are called at specific points in the agent's lifecycle.
43    #[builder(default, setter(into))]
44    pub(crate) hooks: Vec<Hook>,
45    /// The context in which the agent operates, by default this is the `DefaultContext`.
46    #[builder(
47        setter(custom),
48        default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
49    )]
50    pub(crate) context: Arc<dyn AgentContext>,
51    /// Tools the agent can use
52    #[builder(default = Agent::default_tools(), setter(custom))]
53    pub(crate) tools: HashSet<Box<dyn Tool>>,
54
55    /// The language model that the agent uses for completion.
56    #[builder(setter(custom))]
57    pub(crate) llm: Box<dyn ChatCompletion>,
58
59    /// System prompt for the agent when it starts
60    ///
61    /// Some agents profit significantly from a tailored prompt. But it is not always needed.
62    ///
63    /// See [`SystemPrompt`] for an opiniated, customizable system prompt.
64    ///
65    /// Swiftide provides a default system prompt for all agents.
66    ///
67    /// # Example
68    ///
69    /// ```no_run
70    /// # use swiftide_agents::system_prompt::SystemPrompt;
71    /// # use swiftide_agents::Agent;
72    /// Agent::builder()
73    ///     .system_prompt(
74    ///         SystemPrompt::builder().role("You are an expert engineer")
75    ///         .build().unwrap())
76    ///     .build().unwrap();
77    /// ```
78    #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
79    pub(crate) system_prompt: Option<Prompt>,
80
81    /// Initial state of the agent
82    #[builder(private, default = state::State::default())]
83    pub(crate) state: state::State,
84
85    /// Optional limit on the amount of loops the agent can run.
86    /// The counter is reset when the agent is stopped.
87    #[builder(default, setter(strip_option))]
88    pub(crate) limit: Option<usize>,
89
90    /// The maximum amount of times the failed output of a tool will be send
91    /// to an LLM before the agent stops. Defaults to 3.
92    ///
93    /// LLMs sometimes send missing arguments, or a tool might actually fail, but retrying could be
94    /// worth while. If the limit is not reached, the agent will send the formatted error back to
95    /// the LLM.
96    ///
97    /// The limit is hashed based on the tool call name and arguments, so the limit is per tool
98    /// call.
99    ///
100    /// This limit is _not_ reset when the agent is stopped.
101    #[builder(default = 3)]
102    pub(crate) tool_retry_limit: usize,
103
104    /// Internally tracks the amount of times a tool has been retried. The key is a hash based on
105    /// the name and args of the tool.
106    #[builder(private, default)]
107    pub(crate) tool_retries_counter: HashMap<u64, usize>,
108}
109
110impl std::fmt::Debug for Agent {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        f.debug_struct("Agent")
113            // display hooks as a list of type: number of hooks
114            .field(
115                "hooks",
116                &self
117                    .hooks
118                    .iter()
119                    .map(std::string::ToString::to_string)
120                    .collect::<Vec<_>>(),
121            )
122            .field(
123                "tools",
124                &self
125                    .tools
126                    .iter()
127                    .map(swiftide_core::Tool::name)
128                    .collect::<Vec<_>>(),
129            )
130            .field("llm", &"Box<dyn ChatCompletion>")
131            .field("state", &self.state)
132            .finish()
133    }
134}
135
136impl AgentBuilder {
137    /// The context in which the agent operates, by default this is the `DefaultContext`.
138    pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
139    where
140        Self: Clone,
141    {
142        self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
143        self
144    }
145
146    /// Disable the system prompt.
147    pub fn no_system_prompt(&mut self) -> &mut Self {
148        self.system_prompt = Some(None);
149
150        self
151    }
152
153    /// Add a hook to the agent.
154    pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
155        let hooks = self.hooks.get_or_insert_with(Vec::new);
156        hooks.push(hook);
157
158        self
159    }
160
161    /// Add a hook that runs once, before all completions. Even if the agent is paused and resumed,
162    /// before all will not trigger again.
163    pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
164        self.add_hook(Hook::BeforeAll(Box::new(hook)))
165    }
166
167    /// Add a hook that runs once, when the agent starts. This hook also runs if the agent stopped
168    /// and then starts again. The hook runs after any `before_all` hooks and before the
169    /// `before_completion` hooks.
170    pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
171        self.add_hook(Hook::OnStart(Box::new(hook)))
172    }
173
174    /// Add a hook that runs before each completion.
175    pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
176        self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
177    }
178
179    /// Add a hook that runs after each tool. The `Result<ToolOutput, ToolError>` is provided
180    /// as mut, so the tool output can be fully modified.
181    ///
182    /// The `ToolOutput` also references the original `ToolCall`, allowing you to match at runtime
183    /// what tool to interact with.
184    pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
185        self.add_hook(Hook::AfterTool(Box::new(hook)))
186    }
187
188    /// Add a hook that runs before each tool. Yields an immutable reference to the `ToolCall`.
189    pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
190        self.add_hook(Hook::BeforeTool(Box::new(hook)))
191    }
192
193    /// Add a hook that runs after each completion, before tool invocation and/or new messages.
194    pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
195        self.add_hook(Hook::AfterCompletion(Box::new(hook)))
196    }
197
198    /// Add a hook that runs after each completion, after tool invocations, right before a new loop
199    /// might start
200    pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
201        self.add_hook(Hook::AfterEach(Box::new(hook)))
202    }
203
204    /// Add a hook that runs when a new message is added to the context. Note that each tool adds a
205    /// separate message.
206    pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
207        self.add_hook(Hook::OnNewMessage(Box::new(hook)))
208    }
209
210    /// Set the LLM for the agent. An LLM must implement the `ChatCompletion` trait.
211    pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
212        let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
213
214        self.llm = Some(boxed);
215        self
216    }
217
218    /// Define the available tools for the agent. Tools must implement the `Tool` trait.
219    ///
220    /// See the [tool attribute macro](`swiftide_macros::tool`) and the [tool derive
221    /// macro](`swiftide_macros::Tool`) for easy ways to create (many) tools.
222    pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
223    where
224        TOOL: Into<Box<dyn Tool>>,
225    {
226        self.tools = Some(
227            tools
228                .into_iter()
229                .map(Into::into)
230                .chain(Agent::default_tools())
231                .collect(),
232        );
233        self
234    }
235}
236
237impl Agent {
238    /// Build a new agent
239    pub fn builder() -> AgentBuilder {
240        AgentBuilder::default()
241    }
242}
243
244impl Agent {
245    /// Default tools for the agent that it always includes
246    fn default_tools() -> HashSet<Box<dyn Tool>> {
247        HashSet::from([Box::new(Stop::default()) as Box<dyn Tool>])
248    }
249
250    /// Run the agent with a user message. The agent will loop completions, make tool calls, until
251    /// no new messages are available.
252    #[tracing::instrument(skip_all, name = "agent.query")]
253    pub async fn query(&mut self, query: impl Into<String> + std::fmt::Debug) -> Result<()> {
254        self.run_agent(Some(query.into()), false).await
255    }
256
257    /// Run the agent with a user message once.
258    #[tracing::instrument(skip_all, name = "agent.query_once")]
259    pub async fn query_once(&mut self, query: impl Into<String> + std::fmt::Debug) -> Result<()> {
260        self.run_agent(Some(query.into()), true).await
261    }
262
263    /// Run the agent with without user message. The agent will loop completions, make tool calls,
264    /// until no new messages are available.
265    #[tracing::instrument(skip_all, name = "agent.run")]
266    pub async fn run(&mut self) -> Result<()> {
267        self.run_agent(None, false).await
268    }
269
270    /// Run the agent with without user message. The agent will loop completions, make tool calls,
271    /// until
272    #[tracing::instrument(skip_all, name = "agent.run_once")]
273    pub async fn run_once(&mut self) -> Result<()> {
274        self.run_agent(None, true).await
275    }
276
277    /// Retrieve the message history of the agent
278    pub async fn history(&self) -> Vec<ChatMessage> {
279        self.context.history().await
280    }
281
282    async fn run_agent(&mut self, maybe_query: Option<String>, just_once: bool) -> Result<()> {
283        if self.state.is_running() {
284            anyhow::bail!("Agent is already running");
285        }
286
287        if self.state.is_pending() {
288            if let Some(system_prompt) = &self.system_prompt {
289                self.context
290                    .add_messages(vec![ChatMessage::System(system_prompt.render().await?)])
291                    .await;
292            }
293            for hook in self.hooks_by_type(HookTypes::BeforeAll) {
294                if let Hook::BeforeAll(hook) = hook {
295                    let span = tracing::info_span!(
296                        "hook",
297                        "otel.name" = format!("hook.{}", HookTypes::AfterTool)
298                    );
299                    tracing::info!("Calling {} hook", HookTypes::AfterTool);
300                    hook(self).instrument(span.or_current()).await?;
301                }
302            }
303        }
304
305        for hook in self.hooks_by_type(HookTypes::OnStart) {
306            if let Hook::OnStart(hook) = hook {
307                let span = tracing::info_span!(
308                    "hook",
309                    "otel.name" = format!("hook.{}", HookTypes::OnStart)
310                );
311                tracing::info!("Calling {} hook", HookTypes::OnStart);
312                hook(self).instrument(span.or_current()).await?;
313            }
314        }
315
316        self.state = state::State::Running;
317
318        if let Some(query) = maybe_query {
319            self.context.add_message(ChatMessage::User(query)).await;
320        }
321
322        let mut loop_counter = 0;
323
324        while let Some(messages) = self.context.next_completion().await {
325            if let Some(limit) = self.limit {
326                if loop_counter >= limit {
327                    tracing::warn!("Agent loop limit reached");
328                    break;
329                }
330            }
331            let result = self.run_completions(&messages).await;
332
333            if let Err(err) = result {
334                self.stop();
335                tracing::error!(error = ?err, "Agent stopped with error {err}");
336                return Err(err);
337            }
338
339            if just_once || self.state.is_stopped() {
340                break;
341            }
342            loop_counter += 1;
343        }
344
345        // If there are no new messages, ensure we update our state
346        self.stop();
347
348        Ok(())
349    }
350
351    #[tracing::instrument(skip_all, err)]
352    async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<()> {
353        debug!(
354            "Running completion for agent with {} messages",
355            messages.len()
356        );
357
358        let mut chat_completion_request = ChatCompletionRequest::builder()
359            .messages(messages)
360            .tools_spec(
361                self.tools
362                    .iter()
363                    .map(swiftide_core::Tool::tool_spec)
364                    .collect::<HashSet<_>>(),
365            )
366            .build()?;
367
368        for hook in self.hooks_by_type(HookTypes::BeforeCompletion) {
369            if let Hook::BeforeCompletion(hook) = hook {
370                let span = tracing::info_span!(
371                    "hook",
372                    "otel.name" = format!("hook.{}", HookTypes::BeforeCompletion)
373                );
374                tracing::info!("Calling {} hook", HookTypes::BeforeCompletion);
375                hook(&*self, &mut chat_completion_request)
376                    .instrument(span.or_current())
377                    .await?;
378            }
379        }
380
381        debug!(
382            "Calling LLM with the following new messages:\n {}",
383            self.context
384                .current_new_messages()
385                .await
386                .iter()
387                .map(ToString::to_string)
388                .collect::<Vec<_>>()
389                .join(",\n")
390        );
391
392        let mut response = self.llm.complete(&chat_completion_request).await?;
393
394        for hook in self.hooks_by_type(HookTypes::AfterCompletion) {
395            if let Hook::AfterCompletion(hook) = hook {
396                let span = tracing::info_span!(
397                    "hook",
398                    "otel.name" = format!("hook.{}", HookTypes::AfterCompletion)
399                );
400                tracing::info!("Calling {} hook", HookTypes::AfterCompletion);
401                hook(&*self, &mut response)
402                    .instrument(span.or_current())
403                    .await?;
404            }
405        }
406        self.add_message(ChatMessage::Assistant(
407            response.message,
408            response.tool_calls.clone(),
409        ))
410        .await?;
411
412        if let Some(tool_calls) = response.tool_calls {
413            self.invoke_tools(tool_calls).await?;
414        }
415
416        for hook in self.hooks_by_type(HookTypes::AfterEach) {
417            if let Hook::AfterEach(hook) = hook {
418                let span = tracing::info_span!(
419                    "hook",
420                    "otel.name" = format!("hook.{}", HookTypes::AfterEach)
421                );
422                tracing::info!("Calling {} hook", HookTypes::AfterEach);
423                hook(&*self).instrument(span.or_current()).await?;
424            }
425        }
426
427        Ok(())
428    }
429
430    async fn invoke_tools(&mut self, tool_calls: Vec<ToolCall>) -> Result<()> {
431        debug!("LLM returned tool calls: {:?}", tool_calls);
432
433        let mut handles = vec![];
434        for tool_call in tool_calls {
435            let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
436                tracing::warn!("Tool {} not found", tool_call.name());
437                continue;
438            };
439            tracing::info!("Calling tool `{}`", tool_call.name());
440
441            let tool_args = tool_call.args().map(String::from);
442            let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
443
444            for hook in self.hooks_by_type(HookTypes::BeforeTool) {
445                if let Hook::BeforeTool(hook) = hook {
446                    let span = tracing::info_span!(
447                        "hook",
448                        "otel.name" = format!("hook.{}", HookTypes::BeforeTool)
449                    );
450                    tracing::info!("Calling {} hook", HookTypes::BeforeTool);
451                    hook(&*self, &tool_call)
452                        .instrument(span.or_current())
453                        .await?;
454                }
455            }
456
457            let tool_span = tracing::info_span!(
458                "tool",
459                "otel.name" = format!("tool.{}", tool.name().as_ref())
460            );
461
462            let handle = tokio::spawn(async move {
463                    let tool_args = ArgPreprocessor::preprocess(tool_args.as_deref());
464                    let output = tool.invoke(&*context, tool_args.as_deref()).await.map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
465
466                    tracing::debug!(output = output.to_string(), args = ?tool_args, tool_name = tool.name().as_ref(), "Completed tool call");
467
468                    Ok(output)
469                }.instrument(tool_span.or_current()));
470
471            handles.push((handle, tool_call));
472        }
473
474        for (handle, tool_call) in handles {
475            let mut output = handle.await?;
476
477            // Invoking hooks feels too verbose and repetitive
478            for hook in self.hooks_by_type(HookTypes::AfterTool) {
479                if let Hook::AfterTool(hook) = hook {
480                    let span = tracing::info_span!(
481                        "hook",
482                        "otel.name" = format!("hook.{}", HookTypes::AfterTool)
483                    );
484                    tracing::info!("Calling {} hook", HookTypes::AfterTool);
485                    hook(&*self, &tool_call, &mut output)
486                        .instrument(span.or_current())
487                        .await?;
488                }
489            }
490
491            if let Err(error) = output {
492                let stop = self.tool_calls_over_limit(&tool_call);
493                if stop {
494                    tracing::error!(
495                        "Tool call failed, retry limit reached, stopping agent: {err}",
496                        err = error
497                    );
498                } else {
499                    tracing::warn!(
500                        error = error.to_string(),
501                        tool_call = ?tool_call,
502                        "Tool call failed, retrying",
503                    );
504                }
505                self.add_message(ChatMessage::ToolOutput(
506                    tool_call,
507                    ToolOutput::Fail(error.to_string()),
508                ))
509                .await?;
510                if stop {
511                    self.stop();
512                    return Err(error.into());
513                }
514                continue;
515            }
516
517            let output = output?;
518            self.handle_control_tools(&output);
519            self.add_message(ChatMessage::ToolOutput(tool_call, output))
520                .await?;
521        }
522
523        Ok(())
524    }
525
526    fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
527        self.hooks
528            .iter()
529            .filter(|h| hook_type == (*h).into())
530            .collect()
531    }
532
533    fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
534        self.tools
535            .iter()
536            .find(|tool| tool.name() == tool_name)
537            .cloned()
538    }
539
540    // Handle any tool specific output (e.g. stop)
541    fn handle_control_tools(&mut self, output: &ToolOutput) {
542        if let ToolOutput::Stop = output {
543            tracing::warn!("Stop tool called, stopping agent");
544            self.stop();
545        }
546    }
547
548    fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
549        let mut s = DefaultHasher::new();
550        tool_call.hash(&mut s);
551        let hash = s.finish();
552
553        if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
554            let val = *retries >= self.tool_retry_limit;
555            *retries += 1;
556            val
557        } else {
558            self.tool_retries_counter.insert(hash, 1);
559            false
560        }
561    }
562
563    /// Add a message to the agent's context
564    ///
565    /// This will trigger a `OnNewMessage` hook if its present.
566    ///
567    /// If you want to add a message without triggering the hook, use the context directly.
568    #[tracing::instrument(skip_all, fields(message = message.to_string()))]
569    pub async fn add_message(&self, mut message: ChatMessage) -> Result<()> {
570        for hook in self.hooks_by_type(HookTypes::OnNewMessage) {
571            if let Hook::OnNewMessage(hook) = hook {
572                let span = tracing::info_span!(
573                    "hook",
574                    "otel.name" = format!("hook.{}", HookTypes::OnNewMessage)
575                );
576                if let Err(err) = hook(self, &mut message).instrument(span.or_current()).await {
577                    tracing::error!(
578                        "Error in {hooktype} hook: {err}",
579                        hooktype = HookTypes::OnNewMessage,
580                    );
581                }
582            }
583        }
584        self.context.add_message(message).await;
585        Ok(())
586    }
587
588    /// Tell the agent to stop. It will finish it's current loop and then stop.
589    pub fn stop(&mut self) {
590        self.state = state::State::Stopped;
591    }
592
593    /// Access the agent's context
594    pub fn context(&self) -> &dyn AgentContext {
595        &self.context
596    }
597
598    /// The agent is still running
599    pub fn is_running(&self) -> bool {
600        self.state.is_running()
601    }
602
603    /// The agent stopped
604    pub fn is_stopped(&self) -> bool {
605        self.state.is_stopped()
606    }
607
608    /// The agent has not (ever) started
609    pub fn is_pending(&self) -> bool {
610        self.state.is_pending()
611    }
612}
613
614#[cfg(test)]
615mod tests {
616
617    use serde::ser::Error;
618    use swiftide_core::chat_completion::errors::ToolError;
619    use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
620    use swiftide_core::test_utils::MockChatCompletion;
621
622    use super::*;
623    use crate::{
624        assistant, chat_request, chat_response, summary, system, tool_failed, tool_output, user,
625    };
626
627    use crate::test_utils::{MockHook, MockTool};
628
629    #[test_log::test(tokio::test)]
630    async fn test_agent_builder_defaults() {
631        // Create a prompt
632        let mock_llm = MockChatCompletion::new();
633
634        // Build the agent
635        let agent = Agent::builder().llm(&mock_llm).build().unwrap();
636
637        // Check that the context is the default context
638
639        // Check that the default tools are added
640        assert!(agent.find_tool_by_name("stop").is_some());
641
642        // Check it does not allow duplicates
643        let agent = Agent::builder()
644            .tools([Stop::default(), Stop::default()])
645            .llm(&mock_llm)
646            .build()
647            .unwrap();
648
649        assert_eq!(agent.tools.len(), 1);
650
651        // It should include the default tool if a different tool is provided
652        let agent = Agent::builder()
653            .tools([MockTool::new("mock_tool")])
654            .llm(&mock_llm)
655            .build()
656            .unwrap();
657
658        assert_eq!(agent.tools.len(), 2);
659        assert!(agent.find_tool_by_name("mock_tool").is_some());
660        assert!(agent.find_tool_by_name("stop").is_some());
661
662        assert!(agent.context().history().await.is_empty());
663    }
664
665    #[test_log::test(tokio::test)]
666    async fn test_agent_tool_calling_loop() {
667        let prompt = "Write a poem";
668        let mock_llm = MockChatCompletion::new();
669        let mock_tool = MockTool::new("mock_tool");
670
671        let chat_request = chat_request! {
672            user!("Write a poem");
673
674            tools = [mock_tool.clone()]
675        };
676
677        let mock_tool_response = chat_response! {
678            "Roses are red";
679            tool_calls = ["mock_tool"]
680
681        };
682
683        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
684
685        let chat_request = chat_request! {
686            user!("Write a poem"),
687            assistant!("Roses are red", ["mock_tool"]),
688            tool_output!("mock_tool", "Great!");
689
690            tools = [mock_tool.clone()]
691        };
692
693        let stop_response = chat_response! {
694            "Roses are red";
695            tool_calls = ["stop"]
696        };
697
698        mock_llm.expect_complete(chat_request, Ok(stop_response));
699        mock_tool.expect_invoke_ok("Great!".into(), None);
700
701        let mut agent = Agent::builder()
702            .tools([mock_tool])
703            .llm(&mock_llm)
704            .no_system_prompt()
705            .build()
706            .unwrap();
707
708        agent.query(prompt).await.unwrap();
709    }
710
711    #[test_log::test(tokio::test)]
712    async fn test_agent_tool_run_once() {
713        let prompt = "Write a poem";
714        let mock_llm = MockChatCompletion::new();
715        let mock_tool = MockTool::default();
716
717        let chat_request = chat_request! {
718            system!("My system prompt"),
719            user!("Write a poem");
720
721            tools = [mock_tool.clone()]
722        };
723
724        let mock_tool_response = chat_response! {
725            "Roses are red";
726            tool_calls = ["mock_tool"]
727
728        };
729
730        mock_tool.expect_invoke_ok("Great!".into(), None);
731        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
732
733        let mut agent = Agent::builder()
734            .tools([mock_tool])
735            .system_prompt("My system prompt")
736            .llm(&mock_llm)
737            .build()
738            .unwrap();
739
740        agent.query_once(prompt).await.unwrap();
741    }
742
743    #[test_log::test(tokio::test(flavor = "multi_thread"))]
744    async fn test_multiple_tool_calls() {
745        let prompt = "Write a poem";
746        let mock_llm = MockChatCompletion::new();
747        let mock_tool = MockTool::new("mock_tool1");
748        let mock_tool2 = MockTool::new("mock_tool2");
749
750        let chat_request = chat_request! {
751            system!("My system prompt"),
752            user!("Write a poem");
753
754
755
756            tools = [mock_tool.clone(), mock_tool2.clone()]
757        };
758
759        let mock_tool_response = chat_response! {
760            "Roses are red";
761
762            tool_calls = ["mock_tool1", "mock_tool2"]
763
764        };
765
766        dbg!(&chat_request);
767        mock_tool.expect_invoke_ok("Great!".into(), None);
768        mock_tool2.expect_invoke_ok("Great!".into(), None);
769        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
770
771        let chat_request = chat_request! {
772            system!("My system prompt"),
773            user!("Write a poem"),
774            assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
775            tool_output!("mock_tool1", "Great!"),
776            tool_output!("mock_tool2", "Great!");
777
778            tools = [mock_tool.clone(), mock_tool2.clone()]
779        };
780
781        let mock_tool_response = chat_response! {
782            "Ok!";
783
784            tool_calls = ["stop"]
785
786        };
787
788        mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
789
790        let mut agent = Agent::builder()
791            .tools([mock_tool, mock_tool2])
792            .system_prompt("My system prompt")
793            .llm(&mock_llm)
794            .build()
795            .unwrap();
796
797        agent.query(prompt).await.unwrap();
798    }
799
800    #[test_log::test(tokio::test)]
801    async fn test_agent_state_machine() {
802        let prompt = "Write a poem";
803        let mock_llm = MockChatCompletion::new();
804
805        let chat_request = chat_request! {
806            user!("Write a poem");
807            tools = []
808        };
809        let mock_tool_response = chat_response! {
810            "Roses are red";
811            tool_calls = []
812        };
813
814        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
815        let mut agent = Agent::builder()
816            .llm(&mock_llm)
817            .no_system_prompt()
818            .build()
819            .unwrap();
820
821        // Agent has never run and is pending
822        assert!(agent.state.is_pending());
823        agent.query_once(prompt).await.unwrap();
824
825        // Agent is stopped, there might be more messages
826        assert!(agent.state.is_stopped());
827    }
828
829    #[test_log::test(tokio::test)]
830    async fn test_summary() {
831        let prompt = "Write a poem";
832        let mock_llm = MockChatCompletion::new();
833
834        let mock_tool_response = chat_response! {
835            "Roses are red";
836            tool_calls = []
837
838        };
839
840        let expected_chat_request = chat_request! {
841            system!("My system prompt"),
842            user!("Write a poem");
843
844            tools = []
845        };
846
847        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
848
849        let mut agent = Agent::builder()
850            .system_prompt("My system prompt")
851            .llm(&mock_llm)
852            .build()
853            .unwrap();
854
855        agent.query_once(prompt).await.unwrap();
856
857        agent
858            .context
859            .add_message(ChatMessage::new_summary("Summary"))
860            .await;
861
862        let expected_chat_request = chat_request! {
863            system!("My system prompt"),
864            summary!("Summary"),
865            user!("Write another poem");
866            tools = []
867        };
868        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
869
870        agent.query_once("Write another poem").await.unwrap();
871
872        agent
873            .context
874            .add_message(ChatMessage::new_summary("Summary 2"))
875            .await;
876
877        let expected_chat_request = chat_request! {
878            system!("My system prompt"),
879            summary!("Summary 2"),
880            user!("Write a third poem");
881            tools = []
882        };
883        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
884
885        agent.query_once("Write a third poem").await.unwrap();
886    }
887
888    #[test_log::test(tokio::test)]
889    async fn test_agent_hooks() {
890        let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
891        let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
892        let mock_before_completion = MockHook::new("before_completion")
893            .expect_calls(2)
894            .to_owned();
895        let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
896        let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
897        let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
898
899        // Once for mock tool and once for stop
900        let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
901        let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
902
903        let prompt = "Write a poem";
904        let mock_llm = MockChatCompletion::new();
905        let mock_tool = MockTool::default();
906
907        let chat_request = chat_request! {
908            user!("Write a poem");
909
910            tools = [mock_tool.clone()]
911        };
912
913        let mock_tool_response = chat_response! {
914            "Roses are red";
915            tool_calls = ["mock_tool"]
916
917        };
918
919        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
920
921        let chat_request = chat_request! {
922            user!("Write a poem"),
923            assistant!("Roses are red", ["mock_tool"]),
924            tool_output!("mock_tool", "Great!");
925
926            tools = [mock_tool.clone()]
927        };
928
929        let stop_response = chat_response! {
930            "Roses are red";
931            tool_calls = ["stop"]
932        };
933
934        mock_llm.expect_complete(chat_request, Ok(stop_response));
935        mock_tool.expect_invoke_ok("Great!".into(), None);
936
937        let mut agent = Agent::builder()
938            .tools([mock_tool])
939            .llm(&mock_llm)
940            .no_system_prompt()
941            .before_all(mock_before_all.hook_fn())
942            .on_start(mock_on_start_fn.on_start_fn())
943            .before_completion(mock_before_completion.before_completion_fn())
944            .before_tool(mock_before_tool.before_tool_fn())
945            .after_completion(mock_after_completion.after_completion_fn())
946            .after_tool(mock_after_tool.after_tool_fn())
947            .after_each(mock_after_each.hook_fn())
948            .on_new_message(mock_on_message.message_hook_fn())
949            .build()
950            .unwrap();
951
952        agent.query(prompt).await.unwrap();
953    }
954
955    #[test_log::test(tokio::test)]
956    async fn test_agent_loop_limit() {
957        let prompt = "Generate content"; // Example prompt
958        let mock_llm = MockChatCompletion::new();
959        let mock_tool = MockTool::new("mock_tool");
960
961        let chat_request = chat_request! {
962            user!(prompt);
963            tools = [mock_tool.clone()]
964        };
965        mock_tool.expect_invoke_ok("Great!".into(), None);
966
967        let mock_tool_response = chat_response! {
968            "Some response";
969            tool_calls = ["mock_tool"]
970        };
971
972        // Set expectations for the mock LLM responses
973        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
974
975        // // Response for terminating the loop
976        let stop_response = chat_response! {
977            "Final response";
978            tool_calls = ["stop"]
979        };
980
981        mock_llm.expect_complete(chat_request, Ok(stop_response));
982
983        let mut agent = Agent::builder()
984            .tools([mock_tool])
985            .llm(&mock_llm)
986            .no_system_prompt()
987            .limit(1) // Setting the loop limit to 1
988            .build()
989            .unwrap();
990
991        // Run the agent
992        agent.query(prompt).await.unwrap();
993
994        // Assert that the remaining message is still in the queue
995        let remaining = mock_llm.expectations.lock().unwrap().pop();
996        assert!(remaining.is_some());
997
998        // Assert that the agent is stopped after reaching the loop limit
999        assert!(agent.is_stopped());
1000    }
1001
1002    #[test_log::test(tokio::test)]
1003    async fn test_tool_retry_mechanism() {
1004        let prompt = "Execute my tool";
1005        let mock_llm = MockChatCompletion::new();
1006        let mock_tool = MockTool::new("retry_tool");
1007
1008        // Configure mock tool to fail twice. First time is fed back to the LLM, second time is an
1009        // error
1010        mock_tool.expect_invoke(
1011            Err(ToolError::WrongArguments(serde_json::Error::custom(
1012                "missing `query`",
1013            ))),
1014            None,
1015        );
1016        mock_tool.expect_invoke(
1017            Err(ToolError::WrongArguments(serde_json::Error::custom(
1018                "missing `query`",
1019            ))),
1020            None,
1021        );
1022
1023        let chat_request = chat_request! {
1024            user!(prompt);
1025            tools = [mock_tool.clone()]
1026        };
1027        let retry_response = chat_response! {
1028            "First failing attempt";
1029            tool_calls = ["retry_tool"]
1030        };
1031        mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1032
1033        let chat_request = chat_request! {
1034            user!(prompt),
1035            assistant!("First failing attempt", ["retry_tool"]),
1036            tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1037
1038            tools = [mock_tool.clone()]
1039        };
1040        let will_fail_response = chat_response! {
1041            "Finished execution";
1042            tool_calls = ["retry_tool"]
1043        };
1044        mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1045
1046        let mut agent = Agent::builder()
1047            .tools([mock_tool])
1048            .llm(&mock_llm)
1049            .no_system_prompt()
1050            .tool_retry_limit(1) // The test relies on a limit of 2 retries.
1051            .build()
1052            .unwrap();
1053
1054        // Run the agent
1055        let result = agent.query(prompt).await;
1056
1057        assert!(result.is_err());
1058        assert!(result.unwrap_err().to_string().contains("missing `query`"));
1059        assert!(agent.is_stopped());
1060    }
1061}