swiftide_agents/
agent.rs

1#![allow(dead_code)]
2use crate::{
3    default_context::DefaultContext,
4    hooks::{
5        AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
6        Hook, HookTypes, MessageHookFn, OnStartFn,
7    },
8    state,
9    system_prompt::SystemPrompt,
10    tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
11};
12use std::{
13    collections::{HashMap, HashSet},
14    hash::{DefaultHasher, Hash as _, Hasher as _},
15    sync::Arc,
16};
17
18use anyhow::Result;
19use derive_builder::Builder;
20use swiftide_core::{
21    chat_completion::{
22        ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
23    },
24    prompt::Prompt,
25    AgentContext,
26};
27use tracing::{debug, Instrument};
28
29/// Agents are the main interface for building agentic systems.
30///
31/// Construct agents by calling the builder, setting an llm, configure hooks, tools and other
32/// customizations.
33///
34/// # Important defaults
35///
36/// - The default context is the `DefaultContext`, executing tools locally with the `LocalExecutor`.
37/// - A default `stop` tool is provided for agents to explicitly stop if needed
38/// - The default `SystemPrompt` instructs the agent with chain of thought and some common safeguards, but is otherwise quite bare. In a lot of cases this can be sufficient.
39#[derive(Clone, Builder)]
40pub struct Agent {
41    /// Hooks are functions that are called at specific points in the agent's lifecycle.
42    #[builder(default, setter(into))]
43    pub(crate) hooks: Vec<Hook>,
44    /// The context in which the agent operates, by default this is the `DefaultContext`.
45    #[builder(
46        setter(custom),
47        default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
48    )]
49    pub(crate) context: Arc<dyn AgentContext>,
50    /// Tools the agent can use
51    #[builder(default = Agent::default_tools(), setter(custom))]
52    pub(crate) tools: HashSet<Box<dyn Tool>>,
53
54    /// The language model that the agent uses for completion.
55    #[builder(setter(custom))]
56    pub(crate) llm: Box<dyn ChatCompletion>,
57
58    /// System prompt for the agent when it starts
59    ///
60    /// Some agents profit significantly from a tailored prompt. But it is not always needed.
61    ///
62    /// See [`SystemPrompt`] for an opiniated, customizable system prompt.
63    ///
64    /// Swiftide provides a default system prompt for all agents.
65    ///
66    /// # Example
67    ///
68    /// ```no_run
69    /// # use swiftide_agents::system_prompt::SystemPrompt;
70    /// # use swiftide_agents::Agent;
71    /// Agent::builder()
72    ///     .system_prompt(
73    ///         SystemPrompt::builder().role("You are an expert engineer")
74    ///         .build().unwrap())
75    ///     .build().unwrap();
76    /// ```
77    #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
78    pub(crate) system_prompt: Option<Prompt>,
79
80    /// Initial state of the agent
81    #[builder(private, default = state::State::default())]
82    pub(crate) state: state::State,
83
84    /// Optional limit on the amount of loops the agent can run.
85    /// The counter is reset when the agent is stopped.
86    #[builder(default, setter(strip_option))]
87    pub(crate) limit: Option<usize>,
88
89    /// The maximum amount of times the failed output of a tool will be send
90    /// to an LLM before the agent stops. Defaults to 3.
91    ///
92    /// LLMs sometimes send missing arguments, or a tool might actually fail, but retrying could be
93    /// worth while. If the limit is not reached, the agent will send the formatted error back to
94    /// the LLM.
95    ///
96    /// The limit is hashed based on the tool call name and arguments, so the limit is per tool call.
97    ///
98    /// This limit is _not_ reset when the agent is stopped.
99    #[builder(default = 3)]
100    pub(crate) tool_retry_limit: usize,
101
102    /// Internally tracks the amount of times a tool has been retried. The key is a hash based on
103    /// the name and args of the tool.
104    #[builder(private, default)]
105    pub(crate) tool_retries_counter: HashMap<u64, usize>,
106}
107
108impl std::fmt::Debug for Agent {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        f.debug_struct("Agent")
111            //display hooks as a list of type: number of hooks
112            .field(
113                "hooks",
114                &self
115                    .hooks
116                    .iter()
117                    .map(std::string::ToString::to_string)
118                    .collect::<Vec<_>>(),
119            )
120            .field(
121                "tools",
122                &self
123                    .tools
124                    .iter()
125                    .map(swiftide_core::Tool::name)
126                    .collect::<Vec<_>>(),
127            )
128            .field("llm", &"Box<dyn ChatCompletion>")
129            .field("state", &self.state)
130            .finish()
131    }
132}
133
134impl AgentBuilder {
135    /// The context in which the agent operates, by default this is the `DefaultContext`.
136    pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
137    where
138        Self: Clone,
139    {
140        self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
141        self
142    }
143
144    /// Disable the system prompt.
145    pub fn no_system_prompt(&mut self) -> &mut Self {
146        self.system_prompt = Some(None);
147
148        self
149    }
150
151    /// Add a hook to the agent.
152    pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
153        let hooks = self.hooks.get_or_insert_with(Vec::new);
154        hooks.push(hook);
155
156        self
157    }
158
159    /// Add a hook that runs once, before all completions. Even if the agent is paused and resumed,
160    /// before all will not trigger again.
161    pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
162        self.add_hook(Hook::BeforeAll(Box::new(hook)))
163    }
164
165    /// Add a hook that runs once, when the agent starts. This hook also runs if the agent stopped
166    /// and then starts again. The hook runs after any `before_all` hooks and before the
167    /// `before_completion` hooks.
168    pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
169        self.add_hook(Hook::OnStart(Box::new(hook)))
170    }
171
172    /// Add a hook that runs before each completion.
173    pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
174        self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
175    }
176
177    /// Add a hook that runs after each tool. The `Result<ToolOutput, ToolError>` is provided
178    /// as mut, so the tool output can be fully modified.
179    ///
180    /// The `ToolOutput` also references the original `ToolCall`, allowing you to match at runtime
181    /// what tool to interact with.
182    pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
183        self.add_hook(Hook::AfterTool(Box::new(hook)))
184    }
185
186    /// Add a hook that runs before each tool. Yields an immutable reference to the `ToolCall`.
187    pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
188        self.add_hook(Hook::BeforeTool(Box::new(hook)))
189    }
190
191    /// Add a hook that runs after each completion, before tool invocation and/or new messages.
192    pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
193        self.add_hook(Hook::AfterCompletion(Box::new(hook)))
194    }
195
196    /// Add a hook that runs after each completion, after tool invocations, right before a new loop
197    /// might start
198    pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
199        self.add_hook(Hook::AfterEach(Box::new(hook)))
200    }
201
202    /// Add a hook that runs when a new message is added to the context. Note that each tool adds a
203    /// separate message.
204    pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
205        self.add_hook(Hook::OnNewMessage(Box::new(hook)))
206    }
207
208    /// Set the LLM for the agent. An LLM must implement the `ChatCompletion` trait.
209    pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
210        let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
211
212        self.llm = Some(boxed);
213        self
214    }
215
216    /// Define the available tools for the agent. Tools must implement the `Tool` trait.
217    ///
218    /// See the [tool attribute macro](`swiftide_macros::tool`) and the [tool derive macro](`swiftide_macros::Tool`)
219    /// for easy ways to create (many) tools.
220    pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
221    where
222        TOOL: Into<Box<dyn Tool>>,
223    {
224        self.tools = Some(
225            tools
226                .into_iter()
227                .map(Into::into)
228                .chain(Agent::default_tools())
229                .collect(),
230        );
231        self
232    }
233}
234
235impl Agent {
236    /// Build a new agent
237    pub fn builder() -> AgentBuilder {
238        AgentBuilder::default()
239    }
240}
241
242impl Agent {
243    /// Default tools for the agent that it always includes
244    fn default_tools() -> HashSet<Box<dyn Tool>> {
245        HashSet::from([Box::new(Stop::default()) as Box<dyn Tool>])
246    }
247
248    /// Run the agent with a user message. The agent will loop completions, make tool calls, until
249    /// no new messages are available.
250    #[tracing::instrument(skip_all, name = "agent.query")]
251    pub async fn query(&mut self, query: impl Into<String> + std::fmt::Debug) -> Result<()> {
252        self.run_agent(Some(query.into()), false).await
253    }
254
255    /// Run the agent with a user message once.
256    #[tracing::instrument(skip_all, name = "agent.query_once")]
257    pub async fn query_once(&mut self, query: impl Into<String> + std::fmt::Debug) -> Result<()> {
258        self.run_agent(Some(query.into()), true).await
259    }
260
261    /// Run the agent with without user message. The agent will loop completions, make tool calls, until
262    /// no new messages are available.
263    #[tracing::instrument(skip_all, name = "agent.run")]
264    pub async fn run(&mut self) -> Result<()> {
265        self.run_agent(None, false).await
266    }
267
268    /// Run the agent with without user message. The agent will loop completions, make tool calls, until
269    #[tracing::instrument(skip_all, name = "agent.run_once")]
270    pub async fn run_once(&mut self) -> Result<()> {
271        self.run_agent(None, true).await
272    }
273
274    /// Retrieve the message history of the agent
275    pub async fn history(&self) -> Vec<ChatMessage> {
276        self.context.history().await
277    }
278
279    async fn run_agent(&mut self, maybe_query: Option<String>, just_once: bool) -> Result<()> {
280        if self.state.is_running() {
281            anyhow::bail!("Agent is already running");
282        }
283
284        if self.state.is_pending() {
285            if let Some(system_prompt) = &self.system_prompt {
286                self.context
287                    .add_messages(vec![ChatMessage::System(system_prompt.render().await?)])
288                    .await;
289            }
290            for hook in self.hooks_by_type(HookTypes::BeforeAll) {
291                if let Hook::BeforeAll(hook) = hook {
292                    let span = tracing::info_span!(
293                        "hook",
294                        "otel.name" = format!("hook.{}", HookTypes::AfterTool)
295                    );
296                    tracing::info!("Calling {} hook", HookTypes::AfterTool);
297                    hook(self).instrument(span.or_current()).await?;
298                }
299            }
300        }
301
302        for hook in self.hooks_by_type(HookTypes::OnStart) {
303            if let Hook::OnStart(hook) = hook {
304                let span = tracing::info_span!(
305                    "hook",
306                    "otel.name" = format!("hook.{}", HookTypes::OnStart)
307                );
308                tracing::info!("Calling {} hook", HookTypes::OnStart);
309                hook(self).instrument(span.or_current()).await?;
310            }
311        }
312
313        self.state = state::State::Running;
314
315        if let Some(query) = maybe_query {
316            self.context.add_message(ChatMessage::User(query)).await;
317        }
318
319        let mut loop_counter = 0;
320
321        while let Some(messages) = self.context.next_completion().await {
322            if let Some(limit) = self.limit {
323                if loop_counter >= limit {
324                    tracing::warn!("Agent loop limit reached");
325                    break;
326                }
327            }
328            let result = self.run_completions(&messages).await;
329
330            if let Err(err) = result {
331                self.stop();
332                tracing::error!(error = ?err, "Agent stopped with error {err}");
333                return Err(err);
334            }
335
336            if just_once || self.state.is_stopped() {
337                break;
338            }
339            loop_counter += 1;
340        }
341
342        // If there are no new messages, ensure we update our state
343        self.stop();
344
345        Ok(())
346    }
347
348    #[tracing::instrument(skip_all, err)]
349    async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<()> {
350        debug!(
351            "Running completion for agent with {} messages",
352            messages.len()
353        );
354
355        let mut chat_completion_request = ChatCompletionRequest::builder()
356            .messages(messages)
357            .tools_spec(
358                self.tools
359                    .iter()
360                    .map(swiftide_core::Tool::tool_spec)
361                    .collect::<HashSet<_>>(),
362            )
363            .build()?;
364
365        for hook in self.hooks_by_type(HookTypes::BeforeCompletion) {
366            if let Hook::BeforeCompletion(hook) = hook {
367                let span = tracing::info_span!(
368                    "hook",
369                    "otel.name" = format!("hook.{}", HookTypes::BeforeCompletion)
370                );
371                tracing::info!("Calling {} hook", HookTypes::BeforeCompletion);
372                hook(&*self, &mut chat_completion_request)
373                    .instrument(span.or_current())
374                    .await?;
375            }
376        }
377
378        debug!(
379            "Calling LLM with the following new messages:\n {}",
380            self.context
381                .current_new_messages()
382                .await
383                .iter()
384                .map(ToString::to_string)
385                .collect::<Vec<_>>()
386                .join(",\n")
387        );
388
389        let mut response = self.llm.complete(&chat_completion_request).await?;
390
391        for hook in self.hooks_by_type(HookTypes::AfterCompletion) {
392            if let Hook::AfterCompletion(hook) = hook {
393                let span = tracing::info_span!(
394                    "hook",
395                    "otel.name" = format!("hook.{}", HookTypes::AfterCompletion)
396                );
397                tracing::info!("Calling {} hook", HookTypes::AfterCompletion);
398                hook(&*self, &mut response)
399                    .instrument(span.or_current())
400                    .await?;
401            }
402        }
403        self.add_message(ChatMessage::Assistant(
404            response.message,
405            response.tool_calls.clone(),
406        ))
407        .await?;
408
409        if let Some(tool_calls) = response.tool_calls {
410            self.invoke_tools(tool_calls).await?;
411        };
412
413        for hook in self.hooks_by_type(HookTypes::AfterEach) {
414            if let Hook::AfterEach(hook) = hook {
415                let span = tracing::info_span!(
416                    "hook",
417                    "otel.name" = format!("hook.{}", HookTypes::AfterEach)
418                );
419                tracing::info!("Calling {} hook", HookTypes::AfterEach);
420                hook(&*self).instrument(span.or_current()).await?;
421            }
422        }
423
424        Ok(())
425    }
426
427    async fn invoke_tools(&mut self, tool_calls: Vec<ToolCall>) -> Result<()> {
428        debug!("LLM returned tool calls: {:?}", tool_calls);
429
430        let mut handles = vec![];
431        for tool_call in tool_calls {
432            let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
433                tracing::warn!("Tool {} not found", tool_call.name());
434                continue;
435            };
436            tracing::info!("Calling tool `{}`", tool_call.name());
437
438            let tool_args = tool_call.args().map(String::from);
439            let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
440
441            for hook in self.hooks_by_type(HookTypes::BeforeTool) {
442                if let Hook::BeforeTool(hook) = hook {
443                    let span = tracing::info_span!(
444                        "hook",
445                        "otel.name" = format!("hook.{}", HookTypes::BeforeTool)
446                    );
447                    tracing::info!("Calling {} hook", HookTypes::BeforeTool);
448                    hook(&*self, &tool_call)
449                        .instrument(span.or_current())
450                        .await?;
451                }
452            }
453
454            let tool_span = tracing::info_span!(
455                "tool",
456                "otel.name" = format!("tool.{}", tool.name().as_ref())
457            );
458
459            let handle = tokio::spawn(async move {
460                    let tool_args = ArgPreprocessor::preprocess(tool_args.as_deref());
461                    let output = tool.invoke(&*context, tool_args.as_deref()).await.map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
462
463                    tracing::debug!(output = output.to_string(), args = ?tool_args, tool_name = tool.name().as_ref(), "Completed tool call");
464
465                    Ok(output)
466                }.instrument(tool_span.or_current()));
467
468            handles.push((handle, tool_call));
469        }
470
471        for (handle, tool_call) in handles {
472            let mut output = handle.await?;
473
474            // Invoking hooks feels too verbose and repetitive
475            for hook in self.hooks_by_type(HookTypes::AfterTool) {
476                if let Hook::AfterTool(hook) = hook {
477                    let span = tracing::info_span!(
478                        "hook",
479                        "otel.name" = format!("hook.{}", HookTypes::AfterTool)
480                    );
481                    tracing::info!("Calling {} hook", HookTypes::AfterTool);
482                    hook(&*self, &tool_call, &mut output)
483                        .instrument(span.or_current())
484                        .await?;
485                }
486            }
487
488            if let Err(error) = output {
489                let stop = self.tool_calls_over_limit(&tool_call);
490                if stop {
491                    tracing::error!(
492                        "Tool call failed, retry limit reached, stopping agent: {err}",
493                        err = error
494                    );
495                } else {
496                    tracing::warn!(
497                        error = error.to_string(),
498                        tool_call = ?tool_call,
499                        "Tool call failed, retrying",
500                    );
501                }
502                self.add_message(ChatMessage::ToolOutput(
503                    tool_call,
504                    ToolOutput::Fail(error.to_string()),
505                ))
506                .await?;
507                if stop {
508                    self.stop();
509                    return Err(error.into());
510                }
511                continue;
512            }
513
514            let output = output?;
515            self.handle_control_tools(&output);
516            self.add_message(ChatMessage::ToolOutput(tool_call, output))
517                .await?;
518        }
519
520        Ok(())
521    }
522
523    fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
524        self.hooks
525            .iter()
526            .filter(|h| hook_type == (*h).into())
527            .collect()
528    }
529
530    fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
531        self.tools
532            .iter()
533            .find(|tool| tool.name() == tool_name)
534            .cloned()
535    }
536
537    // Handle any tool specific output (e.g. stop)
538    fn handle_control_tools(&mut self, output: &ToolOutput) {
539        if let ToolOutput::Stop = output {
540            tracing::warn!("Stop tool called, stopping agent");
541            self.stop();
542        }
543    }
544
545    fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
546        let mut s = DefaultHasher::new();
547        tool_call.hash(&mut s);
548        let hash = s.finish();
549
550        if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
551            let val = *retries >= self.tool_retry_limit;
552            *retries += 1;
553            val
554        } else {
555            self.tool_retries_counter.insert(hash, 1);
556            false
557        }
558    }
559
560    #[tracing::instrument(skip_all, fields(message = message.to_string()))]
561    async fn add_message(&self, mut message: ChatMessage) -> Result<()> {
562        for hook in self.hooks_by_type(HookTypes::OnNewMessage) {
563            if let Hook::OnNewMessage(hook) = hook {
564                let span = tracing::info_span!(
565                    "hook",
566                    "otel.name" = format!("hook.{}", HookTypes::OnNewMessage)
567                );
568                if let Err(err) = hook(self, &mut message).instrument(span.or_current()).await {
569                    tracing::error!(
570                        "Error in {hooktype} hook: {err}",
571                        hooktype = HookTypes::OnNewMessage,
572                    );
573                }
574            }
575        }
576        self.context.add_message(message).await;
577        Ok(())
578    }
579
580    /// Tell the agent to stop. It will finish it's current loop and then stop.
581    pub fn stop(&mut self) {
582        self.state = state::State::Stopped;
583    }
584
585    /// Access the agent's context
586    pub fn context(&self) -> &dyn AgentContext {
587        &self.context
588    }
589
590    /// The agent is still running
591    pub fn is_running(&self) -> bool {
592        self.state.is_running()
593    }
594
595    /// The agent stopped
596    pub fn is_stopped(&self) -> bool {
597        self.state.is_stopped()
598    }
599
600    /// The agent has not (ever) started
601    pub fn is_pending(&self) -> bool {
602        self.state.is_pending()
603    }
604}
605
606#[cfg(test)]
607mod tests {
608
609    use serde::ser::Error;
610    use swiftide_core::chat_completion::errors::ToolError;
611    use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
612    use swiftide_core::test_utils::MockChatCompletion;
613
614    use super::*;
615    use crate::{
616        assistant, chat_request, chat_response, summary, system, tool_failed, tool_output, user,
617    };
618
619    use crate::test_utils::{MockHook, MockTool};
620
621    #[test_log::test(tokio::test)]
622    async fn test_agent_builder_defaults() {
623        // Create a prompt
624        let mock_llm = MockChatCompletion::new();
625
626        // Build the agent
627        let agent = Agent::builder().llm(&mock_llm).build().unwrap();
628
629        // Check that the context is the default context
630
631        // Check that the default tools are added
632        assert!(agent.find_tool_by_name("stop").is_some());
633
634        // Check it does not allow duplicates
635        let agent = Agent::builder()
636            .tools([Stop::default(), Stop::default()])
637            .llm(&mock_llm)
638            .build()
639            .unwrap();
640
641        assert_eq!(agent.tools.len(), 1);
642
643        // It should include the default tool if a different tool is provided
644        let agent = Agent::builder()
645            .tools([MockTool::new("mock_tool")])
646            .llm(&mock_llm)
647            .build()
648            .unwrap();
649
650        assert_eq!(agent.tools.len(), 2);
651        assert!(agent.find_tool_by_name("mock_tool").is_some());
652        assert!(agent.find_tool_by_name("stop").is_some());
653
654        assert!(agent.context().history().await.is_empty());
655    }
656
657    #[test_log::test(tokio::test)]
658    async fn test_agent_tool_calling_loop() {
659        let prompt = "Write a poem";
660        let mock_llm = MockChatCompletion::new();
661        let mock_tool = MockTool::new("mock_tool");
662
663        let chat_request = chat_request! {
664            user!("Write a poem");
665
666            tools = [mock_tool.clone()]
667        };
668
669        let mock_tool_response = chat_response! {
670            "Roses are red";
671            tool_calls = ["mock_tool"]
672
673        };
674
675        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
676
677        let chat_request = chat_request! {
678            user!("Write a poem"),
679            assistant!("Roses are red", ["mock_tool"]),
680            tool_output!("mock_tool", "Great!");
681
682            tools = [mock_tool.clone()]
683        };
684
685        let stop_response = chat_response! {
686            "Roses are red";
687            tool_calls = ["stop"]
688        };
689
690        mock_llm.expect_complete(chat_request, Ok(stop_response));
691        mock_tool.expect_invoke_ok("Great!".into(), None);
692
693        let mut agent = Agent::builder()
694            .tools([mock_tool])
695            .llm(&mock_llm)
696            .no_system_prompt()
697            .build()
698            .unwrap();
699
700        agent.query(prompt).await.unwrap();
701    }
702
703    #[test_log::test(tokio::test)]
704    async fn test_agent_tool_run_once() {
705        let prompt = "Write a poem";
706        let mock_llm = MockChatCompletion::new();
707        let mock_tool = MockTool::default();
708
709        let chat_request = chat_request! {
710            system!("My system prompt"),
711            user!("Write a poem");
712
713            tools = [mock_tool.clone()]
714        };
715
716        let mock_tool_response = chat_response! {
717            "Roses are red";
718            tool_calls = ["mock_tool"]
719
720        };
721
722        mock_tool.expect_invoke_ok("Great!".into(), None);
723        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
724
725        let mut agent = Agent::builder()
726            .tools([mock_tool])
727            .system_prompt("My system prompt")
728            .llm(&mock_llm)
729            .build()
730            .unwrap();
731
732        agent.query_once(prompt).await.unwrap();
733    }
734
735    #[test_log::test(tokio::test(flavor = "multi_thread"))]
736    async fn test_multiple_tool_calls() {
737        let prompt = "Write a poem";
738        let mock_llm = MockChatCompletion::new();
739        let mock_tool = MockTool::new("mock_tool1");
740        let mock_tool2 = MockTool::new("mock_tool2");
741
742        let chat_request = chat_request! {
743            system!("My system prompt"),
744            user!("Write a poem");
745
746
747
748            tools = [mock_tool.clone(), mock_tool2.clone()]
749        };
750
751        let mock_tool_response = chat_response! {
752            "Roses are red";
753
754            tool_calls = ["mock_tool1", "mock_tool2"]
755
756        };
757
758        dbg!(&chat_request);
759        mock_tool.expect_invoke_ok("Great!".into(), None);
760        mock_tool2.expect_invoke_ok("Great!".into(), None);
761        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
762
763        let chat_request = chat_request! {
764            system!("My system prompt"),
765            user!("Write a poem"),
766            assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
767            tool_output!("mock_tool1", "Great!"),
768            tool_output!("mock_tool2", "Great!");
769
770            tools = [mock_tool.clone(), mock_tool2.clone()]
771        };
772
773        let mock_tool_response = chat_response! {
774            "Ok!";
775
776            tool_calls = ["stop"]
777
778        };
779
780        mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
781
782        let mut agent = Agent::builder()
783            .tools([mock_tool, mock_tool2])
784            .system_prompt("My system prompt")
785            .llm(&mock_llm)
786            .build()
787            .unwrap();
788
789        agent.query(prompt).await.unwrap();
790    }
791
792    #[test_log::test(tokio::test)]
793    async fn test_agent_state_machine() {
794        let prompt = "Write a poem";
795        let mock_llm = MockChatCompletion::new();
796
797        let chat_request = chat_request! {
798            user!("Write a poem");
799            tools = []
800        };
801        let mock_tool_response = chat_response! {
802            "Roses are red";
803            tool_calls = []
804        };
805
806        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
807        let mut agent = Agent::builder()
808            .llm(&mock_llm)
809            .no_system_prompt()
810            .build()
811            .unwrap();
812
813        // Agent has never run and is pending
814        assert!(agent.state.is_pending());
815        agent.query_once(prompt).await.unwrap();
816
817        // Agent is stopped, there might be more messages
818        assert!(agent.state.is_stopped());
819    }
820
821    #[test_log::test(tokio::test)]
822    async fn test_summary() {
823        let prompt = "Write a poem";
824        let mock_llm = MockChatCompletion::new();
825
826        let mock_tool_response = chat_response! {
827            "Roses are red";
828            tool_calls = []
829
830        };
831
832        let expected_chat_request = chat_request! {
833            system!("My system prompt"),
834            user!("Write a poem");
835
836            tools = []
837        };
838
839        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
840
841        let mut agent = Agent::builder()
842            .system_prompt("My system prompt")
843            .llm(&mock_llm)
844            .build()
845            .unwrap();
846
847        agent.query_once(prompt).await.unwrap();
848
849        agent
850            .context
851            .add_message(ChatMessage::new_summary("Summary"))
852            .await;
853
854        let expected_chat_request = chat_request! {
855            system!("My system prompt"),
856            summary!("Summary"),
857            user!("Write another poem");
858            tools = []
859        };
860        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
861
862        agent.query_once("Write another poem").await.unwrap();
863
864        agent
865            .context
866            .add_message(ChatMessage::new_summary("Summary 2"))
867            .await;
868
869        let expected_chat_request = chat_request! {
870            system!("My system prompt"),
871            summary!("Summary 2"),
872            user!("Write a third poem");
873            tools = []
874        };
875        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
876
877        agent.query_once("Write a third poem").await.unwrap();
878    }
879
880    #[test_log::test(tokio::test)]
881    async fn test_agent_hooks() {
882        let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
883        let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
884        let mock_before_completion = MockHook::new("before_completion")
885            .expect_calls(2)
886            .to_owned();
887        let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
888        let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
889        let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
890
891        // Once for mock tool and once for stop
892        let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
893        let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
894
895        let prompt = "Write a poem";
896        let mock_llm = MockChatCompletion::new();
897        let mock_tool = MockTool::default();
898
899        let chat_request = chat_request! {
900            user!("Write a poem");
901
902            tools = [mock_tool.clone()]
903        };
904
905        let mock_tool_response = chat_response! {
906            "Roses are red";
907            tool_calls = ["mock_tool"]
908
909        };
910
911        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
912
913        let chat_request = chat_request! {
914            user!("Write a poem"),
915            assistant!("Roses are red", ["mock_tool"]),
916            tool_output!("mock_tool", "Great!");
917
918            tools = [mock_tool.clone()]
919        };
920
921        let stop_response = chat_response! {
922            "Roses are red";
923            tool_calls = ["stop"]
924        };
925
926        mock_llm.expect_complete(chat_request, Ok(stop_response));
927        mock_tool.expect_invoke_ok("Great!".into(), None);
928
929        let mut agent = Agent::builder()
930            .tools([mock_tool])
931            .llm(&mock_llm)
932            .no_system_prompt()
933            .before_all(mock_before_all.hook_fn())
934            .on_start(mock_on_start_fn.on_start_fn())
935            .before_completion(mock_before_completion.before_completion_fn())
936            .before_tool(mock_before_tool.before_tool_fn())
937            .after_completion(mock_after_completion.after_completion_fn())
938            .after_tool(mock_after_tool.after_tool_fn())
939            .after_each(mock_after_each.hook_fn())
940            .on_new_message(mock_on_message.message_hook_fn())
941            .build()
942            .unwrap();
943
944        agent.query(prompt).await.unwrap();
945    }
946
947    #[test_log::test(tokio::test)]
948    async fn test_agent_loop_limit() {
949        let prompt = "Generate content"; // Example prompt
950        let mock_llm = MockChatCompletion::new();
951        let mock_tool = MockTool::new("mock_tool");
952
953        let chat_request = chat_request! {
954            user!(prompt);
955            tools = [mock_tool.clone()]
956        };
957        mock_tool.expect_invoke_ok("Great!".into(), None);
958
959        let mock_tool_response = chat_response! {
960            "Some response";
961            tool_calls = ["mock_tool"]
962        };
963
964        // Set expectations for the mock LLM responses
965        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
966
967        // // Response for terminating the loop
968        let stop_response = chat_response! {
969            "Final response";
970            tool_calls = ["stop"]
971        };
972
973        mock_llm.expect_complete(chat_request, Ok(stop_response));
974
975        let mut agent = Agent::builder()
976            .tools([mock_tool])
977            .llm(&mock_llm)
978            .no_system_prompt()
979            .limit(1) // Setting the loop limit to 1
980            .build()
981            .unwrap();
982
983        // Run the agent
984        agent.query(prompt).await.unwrap();
985
986        // Assert that the remaining message is still in the queue
987        let remaining = mock_llm.expectations.lock().unwrap().pop();
988        assert!(remaining.is_some());
989
990        // Assert that the agent is stopped after reaching the loop limit
991        assert!(agent.is_stopped());
992    }
993
994    #[test_log::test(tokio::test)]
995    async fn test_tool_retry_mechanism() {
996        let prompt = "Execute my tool";
997        let mock_llm = MockChatCompletion::new();
998        let mock_tool = MockTool::new("retry_tool");
999
1000        // Configure mock tool to fail twice. First time is fed back to the LLM, second time is an
1001        // error
1002        mock_tool.expect_invoke(
1003            Err(ToolError::WrongArguments(serde_json::Error::custom(
1004                "missing `query`",
1005            ))),
1006            None,
1007        );
1008        mock_tool.expect_invoke(
1009            Err(ToolError::WrongArguments(serde_json::Error::custom(
1010                "missing `query`",
1011            ))),
1012            None,
1013        );
1014
1015        let chat_request = chat_request! {
1016            user!(prompt);
1017            tools = [mock_tool.clone()]
1018        };
1019        let retry_response = chat_response! {
1020            "First failing attempt";
1021            tool_calls = ["retry_tool"]
1022        };
1023        mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1024
1025        let chat_request = chat_request! {
1026            user!(prompt),
1027            assistant!("First failing attempt", ["retry_tool"]),
1028            tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1029
1030            tools = [mock_tool.clone()]
1031        };
1032        let will_fail_response = chat_response! {
1033            "Finished execution";
1034            tool_calls = ["retry_tool"]
1035        };
1036        mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1037
1038        let mut agent = Agent::builder()
1039            .tools([mock_tool])
1040            .llm(&mock_llm)
1041            .no_system_prompt()
1042            .tool_retry_limit(1) // The test relies on a limit of 2 retries.
1043            .build()
1044            .unwrap();
1045
1046        // Run the agent
1047        let result = agent.query(prompt).await;
1048
1049        assert!(result.is_err());
1050        assert!(result.unwrap_err().to_string().contains("missing `query`"));
1051        assert!(agent.is_stopped());
1052    }
1053}