swiftide_agents/
agent.rs

1#![allow(dead_code)]
2use crate::{
3    default_context::DefaultContext,
4    errors::AgentError,
5    hooks::{
6        AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7        Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn,
8    },
9    invoke_hooks,
10    state::{self, StopReason},
11    system_prompt::SystemPrompt,
12    tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15    collections::{HashMap, HashSet},
16    hash::{DefaultHasher, Hash as _, Hasher as _},
17    sync::Arc,
18};
19
20use derive_builder::Builder;
21use swiftide_core::{
22    chat_completion::{
23        ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
24    },
25    prompt::Prompt,
26    AgentContext, ToolBox,
27};
28use tracing::{debug, Instrument};
29
30/// Agents are the main interface for building agentic systems.
31///
32/// Construct agents by calling the builder, setting an llm, configure hooks, tools and other
33/// customizations.
34///
35/// # Important defaults
36///
37/// - The default context is the `DefaultContext`, executing tools locally with the `LocalExecutor`.
38/// - A default `stop` tool is provided for agents to explicitly stop if needed
39/// - The default `SystemPrompt` instructs the agent with chain of thought and some common
40///   safeguards, but is otherwise quite bare. In a lot of cases this can be sufficient.
41#[derive(Clone, Builder)]
42pub struct Agent {
43    /// Hooks are functions that are called at specific points in the agent's lifecycle.
44    #[builder(default, setter(into))]
45    pub(crate) hooks: Vec<Hook>,
46    /// The context in which the agent operates, by default this is the `DefaultContext`.
47    #[builder(
48        setter(custom),
49        default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
50    )]
51    pub(crate) context: Arc<dyn AgentContext>,
52    /// Tools the agent can use
53    #[builder(default = Agent::default_tools(), setter(custom))]
54    pub(crate) tools: HashSet<Box<dyn Tool>>,
55
56    /// Toolboxes are collections of tools that can be added to the agent.
57    ///
58    /// Toolboxes make their tools available to the agent at runtime.
59    #[builder(default)]
60    pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
61
62    /// The language model that the agent uses for completion.
63    #[builder(setter(custom))]
64    pub(crate) llm: Box<dyn ChatCompletion>,
65
66    /// System prompt for the agent when it starts
67    ///
68    /// Some agents profit significantly from a tailored prompt. But it is not always needed.
69    ///
70    /// See [`SystemPrompt`] for an opiniated, customizable system prompt.
71    ///
72    /// Swiftide provides a default system prompt for all agents.
73    ///
74    /// # Example
75    ///
76    /// ```no_run
77    /// # use swiftide_agents::system_prompt::SystemPrompt;
78    /// # use swiftide_agents::Agent;
79    /// Agent::builder()
80    ///     .system_prompt(
81    ///         SystemPrompt::builder().role("You are an expert engineer")
82    ///         .build().unwrap())
83    ///     .build().unwrap();
84    /// ```
85    #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
86    pub(crate) system_prompt: Option<Prompt>,
87
88    /// Initial state of the agent
89    #[builder(private, default = state::State::default())]
90    pub(crate) state: state::State,
91
92    /// Optional limit on the amount of loops the agent can run.
93    /// The counter is reset when the agent is stopped.
94    #[builder(default, setter(strip_option))]
95    pub(crate) limit: Option<usize>,
96
97    /// The maximum amount of times the failed output of a tool will be send
98    /// to an LLM before the agent stops. Defaults to 3.
99    ///
100    /// LLMs sometimes send missing arguments, or a tool might actually fail, but retrying could be
101    /// worth while. If the limit is not reached, the agent will send the formatted error back to
102    /// the LLM.
103    ///
104    /// The limit is hashed based on the tool call name and arguments, so the limit is per tool
105    /// call.
106    ///
107    /// This limit is _not_ reset when the agent is stopped.
108    #[builder(default = 3)]
109    pub(crate) tool_retry_limit: usize,
110
111    /// Internally tracks the amount of times a tool has been retried. The key is a hash based on
112    /// the name and args of the tool.
113    #[builder(private, default)]
114    pub(crate) tool_retries_counter: HashMap<u64, usize>,
115
116    /// Tools loaded from toolboxes
117    #[builder(private, default)]
118    pub(crate) toolbox_tools: HashSet<Box<dyn Tool>>,
119}
120
121impl std::fmt::Debug for Agent {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        f.debug_struct("Agent")
124            // display hooks as a list of type: number of hooks
125            .field(
126                "hooks",
127                &self
128                    .hooks
129                    .iter()
130                    .map(std::string::ToString::to_string)
131                    .collect::<Vec<_>>(),
132            )
133            .field(
134                "tools",
135                &self
136                    .tools
137                    .iter()
138                    .map(swiftide_core::Tool::name)
139                    .collect::<Vec<_>>(),
140            )
141            .field("llm", &"Box<dyn ChatCompletion>")
142            .field("state", &self.state)
143            .finish()
144    }
145}
146
147impl AgentBuilder {
148    /// The context in which the agent operates, by default this is the `DefaultContext`.
149    pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
150    where
151        Self: Clone,
152    {
153        self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
154        self
155    }
156
157    /// Disable the system prompt.
158    pub fn no_system_prompt(&mut self) -> &mut Self {
159        self.system_prompt = Some(None);
160
161        self
162    }
163
164    /// Add a hook to the agent.
165    pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
166        let hooks = self.hooks.get_or_insert_with(Vec::new);
167        hooks.push(hook);
168
169        self
170    }
171
172    /// Add a hook that runs once, before all completions. Even if the agent is paused and resumed,
173    /// before all will not trigger again.
174    pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
175        self.add_hook(Hook::BeforeAll(Box::new(hook)))
176    }
177
178    /// Add a hook that runs once, when the agent starts. This hook also runs if the agent stopped
179    /// and then starts again. The hook runs after any `before_all` hooks and before the
180    /// `before_completion` hooks.
181    pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
182        self.add_hook(Hook::OnStart(Box::new(hook)))
183    }
184
185    /// Add a hook that runs before each completion.
186    pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
187        self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
188    }
189
190    /// Add a hook that runs after each tool. The `Result<ToolOutput, ToolError>` is provided
191    /// as mut, so the tool output can be fully modified.
192    ///
193    /// The `ToolOutput` also references the original `ToolCall`, allowing you to match at runtime
194    /// what tool to interact with.
195    pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
196        self.add_hook(Hook::AfterTool(Box::new(hook)))
197    }
198
199    /// Add a hook that runs before each tool. Yields an immutable reference to the `ToolCall`.
200    pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
201        self.add_hook(Hook::BeforeTool(Box::new(hook)))
202    }
203
204    /// Add a hook that runs after each completion, before tool invocation and/or new messages.
205    pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
206        self.add_hook(Hook::AfterCompletion(Box::new(hook)))
207    }
208
209    /// Add a hook that runs after each completion, after tool invocations, right before a new loop
210    /// might start
211    pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
212        self.add_hook(Hook::AfterEach(Box::new(hook)))
213    }
214
215    /// Add a hook that runs when a new message is added to the context. Note that each tool adds a
216    /// separate message.
217    pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
218        self.add_hook(Hook::OnNewMessage(Box::new(hook)))
219    }
220
221    pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
222        self.add_hook(Hook::OnStop(Box::new(hook)))
223    }
224
225    /// Set the LLM for the agent. An LLM must implement the `ChatCompletion` trait.
226    pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
227        let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
228
229        self.llm = Some(boxed);
230        self
231    }
232
233    /// Define the available tools for the agent. Tools must implement the `Tool` trait.
234    ///
235    /// See the [tool attribute macro](`swiftide_macros::tool`) and the [tool derive
236    /// macro](`swiftide_macros::Tool`) for easy ways to create (many) tools.
237    pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
238    where
239        TOOL: Into<Box<dyn Tool>>,
240    {
241        self.tools = Some(
242            tools
243                .into_iter()
244                .map(Into::into)
245                .chain(Agent::default_tools())
246                .collect(),
247        );
248        self
249    }
250
251    /// Add a toolbox to the agent. Toolboxes are collections of tools that can be added to the
252    /// to the agent. Available tools are evaluated at runtime, when the agent starts for the first
253    /// time.
254    ///
255    /// Agents can have many toolboxes.
256    pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
257        self.toolboxes.get_or_insert_with(Vec::new);
258
259        self.toolboxes.as_mut().unwrap().push(Box::new(toolbox));
260        self
261    }
262}
263
264impl Agent {
265    /// Build a new agent
266    pub fn builder() -> AgentBuilder {
267        AgentBuilder::default()
268    }
269}
270
271impl Agent {
272    /// Default tools for the agent that it always includes
273    fn default_tools() -> HashSet<Box<dyn Tool>> {
274        HashSet::from([Box::new(Stop::default()) as Box<dyn Tool>])
275    }
276
277    /// Run the agent with a user message. The agent will loop completions, make tool calls, until
278    /// no new messages are available.
279    #[tracing::instrument(skip_all, name = "agent.query")]
280    pub async fn query(
281        &mut self,
282        query: impl Into<String> + std::fmt::Debug,
283    ) -> Result<(), AgentError> {
284        self.run_agent(Some(query.into()), false).await
285    }
286
287    /// Run the agent with a user message once.
288    #[tracing::instrument(skip_all, name = "agent.query_once")]
289    pub async fn query_once(
290        &mut self,
291        query: impl Into<String> + std::fmt::Debug,
292    ) -> Result<(), AgentError> {
293        self.run_agent(Some(query.into()), true).await
294    }
295
296    /// Run the agent with without user message. The agent will loop completions, make tool calls,
297    /// until no new messages are available.
298    #[tracing::instrument(skip_all, name = "agent.run")]
299    pub async fn run(&mut self) -> Result<(), AgentError> {
300        self.run_agent(None, false).await
301    }
302
303    /// Run the agent with without user message. The agent will loop completions, make tool calls,
304    /// until
305    #[tracing::instrument(skip_all, name = "agent.run_once")]
306    pub async fn run_once(&mut self) -> Result<(), AgentError> {
307        self.run_agent(None, true).await
308    }
309
310    /// Retrieve the message history of the agent
311    pub async fn history(&self) -> Vec<ChatMessage> {
312        self.context.history().await
313    }
314
315    async fn run_agent(
316        &mut self,
317        maybe_query: Option<String>,
318        just_once: bool,
319    ) -> Result<(), AgentError> {
320        if self.state.is_running() {
321            return Err(AgentError::AlreadyRunning);
322        }
323
324        if self.state.is_pending() {
325            if let Some(system_prompt) = &self.system_prompt {
326                self.context
327                    .add_messages(vec![ChatMessage::System(
328                        system_prompt
329                            .render()
330                            .map_err(AgentError::FailedToRenderSystemPrompt)?,
331                    )])
332                    .await;
333            }
334
335            invoke_hooks!(BeforeAll, self);
336
337            self.load_toolboxes().await?;
338        }
339
340        invoke_hooks!(OnStart, self);
341
342        self.state = state::State::Running;
343
344        if let Some(query) = maybe_query {
345            self.context.add_message(ChatMessage::User(query)).await;
346        }
347
348        let mut loop_counter = 0;
349
350        while let Some(messages) = self.context.next_completion().await {
351            if let Some(limit) = self.limit {
352                if loop_counter >= limit {
353                    tracing::warn!("Agent loop limit reached");
354                    break;
355                }
356            }
357            let result = self.run_completions(&messages).await;
358
359            if let Err(err) = result {
360                self.stop_with_error(&err).await;
361                tracing::error!(error = ?err, "Agent stopped with error {err}");
362                return Err(err);
363            }
364
365            if just_once || self.state.is_stopped() {
366                break;
367            }
368            loop_counter += 1;
369        }
370
371        // If there are no new messages, ensure we update our state
372        self.stop(StopReason::NoNewMessages).await;
373
374        Ok(())
375    }
376
377    #[tracing::instrument(skip_all, err)]
378    async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<(), AgentError> {
379        debug!(
380            "Running completion for agent with {} messages",
381            messages.len()
382        );
383
384        let mut chat_completion_request = ChatCompletionRequest::builder()
385            .messages(messages)
386            .tools_spec(
387                self.tools
388                    .iter()
389                    .map(swiftide_core::Tool::tool_spec)
390                    .collect::<HashSet<_>>(),
391            )
392            .build()
393            .map_err(AgentError::FailedToBuildRequest)?;
394
395        invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
396
397        debug!(
398            "Calling LLM with the following new messages:\n {}",
399            self.context
400                .current_new_messages()
401                .await
402                .iter()
403                .map(ToString::to_string)
404                .collect::<Vec<_>>()
405                .join(",\n")
406        );
407
408        let mut response = self
409            .llm
410            .complete(&chat_completion_request)
411            .await
412            .map_err(AgentError::CompletionsFailed)?;
413
414        invoke_hooks!(AfterCompletion, self, &mut response);
415
416        self.add_message(ChatMessage::Assistant(
417            response.message,
418            response.tool_calls.clone(),
419        ))
420        .await?;
421
422        if let Some(tool_calls) = response.tool_calls {
423            self.invoke_tools(tool_calls).await?;
424        }
425
426        invoke_hooks!(AfterEach, self);
427
428        Ok(())
429    }
430
431    async fn invoke_tools(&mut self, tool_calls: Vec<ToolCall>) -> Result<(), AgentError> {
432        debug!("LLM returned tool calls: {:?}", tool_calls);
433
434        let mut handles = vec![];
435        for tool_call in tool_calls {
436            let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
437                tracing::warn!("Tool {} not found", tool_call.name());
438                continue;
439            };
440            tracing::info!("Calling tool `{}`", tool_call.name());
441
442            let tool_args = tool_call.args().map(String::from);
443            let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
444
445            invoke_hooks!(BeforeTool, self, &tool_call);
446
447            let tool_span = tracing::info_span!(
448                "tool",
449                "otel.name" = format!("tool.{}", tool.name().as_ref())
450            );
451
452            let handle = tokio::spawn(async move {
453                    let tool_args = ArgPreprocessor::preprocess(tool_args.as_deref());
454                    let output = tool.invoke(&*context, tool_args.as_deref()).await.map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
455
456                    tracing::debug!(output = output.to_string(), args = ?tool_args, tool_name = tool.name().as_ref(), "Completed tool call");
457
458                    Ok(output)
459                }.instrument(tool_span.or_current()));
460
461            handles.push((handle, tool_call));
462        }
463
464        for (handle, tool_call) in handles {
465            let mut output = handle.await.map_err(AgentError::ToolFailedToJoin)?;
466
467            invoke_hooks!(AfterTool, self, &tool_call, &mut output);
468
469            if let Err(error) = output {
470                let stop = self.tool_calls_over_limit(&tool_call);
471                if stop {
472                    tracing::error!(
473                        ?error,
474                        "Tool call failed, retry limit reached, stopping agent: {error}",
475                    );
476                } else {
477                    tracing::warn!(
478                        ?error,
479                        tool_call = ?tool_call,
480                        "Tool call failed, retrying",
481                    );
482                }
483                self.add_message(ChatMessage::ToolOutput(
484                    tool_call.clone(),
485                    ToolOutput::Fail(error.to_string()),
486                ))
487                .await?;
488                if stop {
489                    self.stop(StopReason::ToolCallsOverLimit(tool_call)).await;
490                    return Err(error.into());
491                }
492                continue;
493            }
494
495            let output = output?;
496            self.handle_control_tools(&tool_call, &output).await;
497            self.add_message(ChatMessage::ToolOutput(tool_call, output))
498                .await?;
499        }
500
501        Ok(())
502    }
503
504    fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
505        self.hooks
506            .iter()
507            .filter(|h| hook_type == (*h).into())
508            .collect()
509    }
510
511    fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
512        self.tools
513            .iter()
514            .find(|tool| tool.name() == tool_name)
515            .cloned()
516    }
517
518    // Handle any tool specific output (e.g. stop)
519    async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
520        if let ToolOutput::Stop = output {
521            tracing::warn!("Stop tool called, stopping agent");
522            self.stop(StopReason::RequestedByTool(tool_call.clone()))
523                .await;
524        }
525    }
526
527    fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
528        let mut s = DefaultHasher::new();
529        tool_call.hash(&mut s);
530        let hash = s.finish();
531
532        if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
533            let val = *retries >= self.tool_retry_limit;
534            *retries += 1;
535            val
536        } else {
537            self.tool_retries_counter.insert(hash, 1);
538            false
539        }
540    }
541
542    /// Add a message to the agent's context
543    ///
544    /// This will trigger a `OnNewMessage` hook if its present.
545    ///
546    /// If you want to add a message without triggering the hook, use the context directly.
547    #[tracing::instrument(skip_all, fields(message = message.to_string()))]
548    pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
549        invoke_hooks!(OnNewMessage, self, &mut message);
550
551        self.context.add_message(message).await;
552        Ok(())
553    }
554
555    /// Tell the agent to stop. It will finish it's current loop and then stop.
556    pub async fn stop(&mut self, reason: impl Into<StopReason>) {
557        if self.state.is_stopped() {
558            return;
559        }
560        let reason = reason.into();
561        invoke_hooks!(OnStop, self, reason.clone(), None);
562
563        self.state = state::State::Stopped(reason);
564    }
565
566    pub async fn stop_with_error(&mut self, error: &AgentError) {
567        if self.state.is_stopped() {
568            return;
569        }
570        invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
571
572        self.state = state::State::Stopped(StopReason::Error);
573    }
574
575    /// Access the agent's context
576    pub fn context(&self) -> &dyn AgentContext {
577        &self.context
578    }
579
580    /// The agent is still running
581    pub fn is_running(&self) -> bool {
582        self.state.is_running()
583    }
584
585    /// The agent stopped
586    pub fn is_stopped(&self) -> bool {
587        self.state.is_stopped()
588    }
589
590    /// The agent has not (ever) started
591    pub fn is_pending(&self) -> bool {
592        self.state.is_pending()
593    }
594
595    /// Get a list of tools available to the agent
596    fn tools(&self) -> &HashSet<Box<dyn Tool>> {
597        &self.tools
598    }
599
600    async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
601        for toolbox in &self.toolboxes {
602            let tools = toolbox
603                .available_tools()
604                .await
605                .map_err(AgentError::ToolBoxFailedToLoad)?;
606            self.toolbox_tools.extend(tools);
607        }
608
609        self.tools.extend(self.toolbox_tools.clone());
610
611        Ok(())
612    }
613}
614
615#[cfg(test)]
616mod tests {
617
618    use serde::ser::Error;
619    use swiftide_core::chat_completion::errors::ToolError;
620    use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
621    use swiftide_core::test_utils::MockChatCompletion;
622
623    use super::*;
624    use crate::{
625        assistant, chat_request, chat_response, summary, system, tool_failed, tool_output, user,
626    };
627
628    use crate::test_utils::{MockHook, MockTool};
629
630    #[test_log::test(tokio::test)]
631    async fn test_agent_builder_defaults() {
632        // Create a prompt
633        let mock_llm = MockChatCompletion::new();
634
635        // Build the agent
636        let agent = Agent::builder().llm(&mock_llm).build().unwrap();
637
638        // Check that the context is the default context
639
640        // Check that the default tools are added
641        assert!(agent.find_tool_by_name("stop").is_some());
642
643        // Check it does not allow duplicates
644        let agent = Agent::builder()
645            .tools([Stop::default(), Stop::default()])
646            .llm(&mock_llm)
647            .build()
648            .unwrap();
649
650        assert_eq!(agent.tools.len(), 1);
651
652        // It should include the default tool if a different tool is provided
653        let agent = Agent::builder()
654            .tools([MockTool::new("mock_tool")])
655            .llm(&mock_llm)
656            .build()
657            .unwrap();
658
659        assert_eq!(agent.tools.len(), 2);
660        assert!(agent.find_tool_by_name("mock_tool").is_some());
661        assert!(agent.find_tool_by_name("stop").is_some());
662
663        assert!(agent.context().history().await.is_empty());
664    }
665
666    #[test_log::test(tokio::test)]
667    async fn test_agent_tool_calling_loop() {
668        let prompt = "Write a poem";
669        let mock_llm = MockChatCompletion::new();
670        let mock_tool = MockTool::new("mock_tool");
671
672        let chat_request = chat_request! {
673            user!("Write a poem");
674
675            tools = [mock_tool.clone()]
676        };
677
678        let mock_tool_response = chat_response! {
679            "Roses are red";
680            tool_calls = ["mock_tool"]
681
682        };
683
684        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
685
686        let chat_request = chat_request! {
687            user!("Write a poem"),
688            assistant!("Roses are red", ["mock_tool"]),
689            tool_output!("mock_tool", "Great!");
690
691            tools = [mock_tool.clone()]
692        };
693
694        let stop_response = chat_response! {
695            "Roses are red";
696            tool_calls = ["stop"]
697        };
698
699        mock_llm.expect_complete(chat_request, Ok(stop_response));
700        mock_tool.expect_invoke_ok("Great!".into(), None);
701
702        let mut agent = Agent::builder()
703            .tools([mock_tool])
704            .llm(&mock_llm)
705            .no_system_prompt()
706            .build()
707            .unwrap();
708
709        agent.query(prompt).await.unwrap();
710    }
711
712    #[test_log::test(tokio::test)]
713    async fn test_agent_tool_run_once() {
714        let prompt = "Write a poem";
715        let mock_llm = MockChatCompletion::new();
716        let mock_tool = MockTool::default();
717
718        let chat_request = chat_request! {
719            system!("My system prompt"),
720            user!("Write a poem");
721
722            tools = [mock_tool.clone()]
723        };
724
725        let mock_tool_response = chat_response! {
726            "Roses are red";
727            tool_calls = ["mock_tool"]
728
729        };
730
731        mock_tool.expect_invoke_ok("Great!".into(), None);
732        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
733
734        let mut agent = Agent::builder()
735            .tools([mock_tool])
736            .system_prompt("My system prompt")
737            .llm(&mock_llm)
738            .build()
739            .unwrap();
740
741        agent.query_once(prompt).await.unwrap();
742    }
743
744    #[test_log::test(tokio::test)]
745    async fn test_agent_tool_via_toolbox_run_once() {
746        let prompt = "Write a poem";
747        let mock_llm = MockChatCompletion::new();
748        let mock_tool = MockTool::default();
749
750        let chat_request = chat_request! {
751            system!("My system prompt"),
752            user!("Write a poem");
753
754            tools = [mock_tool.clone()]
755        };
756
757        let mock_tool_response = chat_response! {
758            "Roses are red";
759            tool_calls = ["mock_tool"]
760
761        };
762
763        mock_tool.expect_invoke_ok("Great!".into(), None);
764        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
765
766        let mut agent = Agent::builder()
767            .add_toolbox(vec![mock_tool.boxed()])
768            .system_prompt("My system prompt")
769            .llm(&mock_llm)
770            .build()
771            .unwrap();
772
773        agent.query_once(prompt).await.unwrap();
774    }
775
776    #[test_log::test(tokio::test(flavor = "multi_thread"))]
777    async fn test_multiple_tool_calls() {
778        let prompt = "Write a poem";
779        let mock_llm = MockChatCompletion::new();
780        let mock_tool = MockTool::new("mock_tool1");
781        let mock_tool2 = MockTool::new("mock_tool2");
782
783        let chat_request = chat_request! {
784            system!("My system prompt"),
785            user!("Write a poem");
786
787
788
789            tools = [mock_tool.clone(), mock_tool2.clone()]
790        };
791
792        let mock_tool_response = chat_response! {
793            "Roses are red";
794
795            tool_calls = ["mock_tool1", "mock_tool2"]
796
797        };
798
799        dbg!(&chat_request);
800        mock_tool.expect_invoke_ok("Great!".into(), None);
801        mock_tool2.expect_invoke_ok("Great!".into(), None);
802        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
803
804        let chat_request = chat_request! {
805            system!("My system prompt"),
806            user!("Write a poem"),
807            assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
808            tool_output!("mock_tool1", "Great!"),
809            tool_output!("mock_tool2", "Great!");
810
811            tools = [mock_tool.clone(), mock_tool2.clone()]
812        };
813
814        let mock_tool_response = chat_response! {
815            "Ok!";
816
817            tool_calls = ["stop"]
818
819        };
820
821        mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
822
823        let mut agent = Agent::builder()
824            .tools([mock_tool, mock_tool2])
825            .system_prompt("My system prompt")
826            .llm(&mock_llm)
827            .build()
828            .unwrap();
829
830        agent.query(prompt).await.unwrap();
831    }
832
833    #[test_log::test(tokio::test)]
834    async fn test_agent_state_machine() {
835        let prompt = "Write a poem";
836        let mock_llm = MockChatCompletion::new();
837
838        let chat_request = chat_request! {
839            user!("Write a poem");
840            tools = []
841        };
842        let mock_tool_response = chat_response! {
843            "Roses are red";
844            tool_calls = []
845        };
846
847        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
848        let mut agent = Agent::builder()
849            .llm(&mock_llm)
850            .no_system_prompt()
851            .build()
852            .unwrap();
853
854        // Agent has never run and is pending
855        assert!(agent.state.is_pending());
856        agent.query_once(prompt).await.unwrap();
857
858        // Agent is stopped, there might be more messages
859        assert!(agent.state.is_stopped());
860    }
861
862    #[test_log::test(tokio::test)]
863    async fn test_summary() {
864        let prompt = "Write a poem";
865        let mock_llm = MockChatCompletion::new();
866
867        let mock_tool_response = chat_response! {
868            "Roses are red";
869            tool_calls = []
870
871        };
872
873        let expected_chat_request = chat_request! {
874            system!("My system prompt"),
875            user!("Write a poem");
876
877            tools = []
878        };
879
880        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
881
882        let mut agent = Agent::builder()
883            .system_prompt("My system prompt")
884            .llm(&mock_llm)
885            .build()
886            .unwrap();
887
888        agent.query_once(prompt).await.unwrap();
889
890        agent
891            .context
892            .add_message(ChatMessage::new_summary("Summary"))
893            .await;
894
895        let expected_chat_request = chat_request! {
896            system!("My system prompt"),
897            summary!("Summary"),
898            user!("Write another poem");
899            tools = []
900        };
901        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
902
903        agent.query_once("Write another poem").await.unwrap();
904
905        agent
906            .context
907            .add_message(ChatMessage::new_summary("Summary 2"))
908            .await;
909
910        let expected_chat_request = chat_request! {
911            system!("My system prompt"),
912            summary!("Summary 2"),
913            user!("Write a third poem");
914            tools = []
915        };
916        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
917
918        agent.query_once("Write a third poem").await.unwrap();
919    }
920
921    #[test_log::test(tokio::test)]
922    async fn test_agent_hooks() {
923        let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
924        let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
925        let mock_before_completion = MockHook::new("before_completion")
926            .expect_calls(2)
927            .to_owned();
928        let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
929        let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
930        let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
931        let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
932
933        // Once for mock tool and once for stop
934        let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
935        let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
936
937        let prompt = "Write a poem";
938        let mock_llm = MockChatCompletion::new();
939        let mock_tool = MockTool::default();
940
941        let chat_request = chat_request! {
942            user!("Write a poem");
943
944            tools = [mock_tool.clone()]
945        };
946
947        let mock_tool_response = chat_response! {
948            "Roses are red";
949            tool_calls = ["mock_tool"]
950
951        };
952
953        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
954
955        let chat_request = chat_request! {
956            user!("Write a poem"),
957            assistant!("Roses are red", ["mock_tool"]),
958            tool_output!("mock_tool", "Great!");
959
960            tools = [mock_tool.clone()]
961        };
962
963        let stop_response = chat_response! {
964            "Roses are red";
965            tool_calls = ["stop"]
966        };
967
968        mock_llm.expect_complete(chat_request, Ok(stop_response));
969        mock_tool.expect_invoke_ok("Great!".into(), None);
970
971        let mut agent = Agent::builder()
972            .tools([mock_tool])
973            .llm(&mock_llm)
974            .no_system_prompt()
975            .before_all(mock_before_all.hook_fn())
976            .on_start(mock_on_start_fn.on_start_fn())
977            .before_completion(mock_before_completion.before_completion_fn())
978            .before_tool(mock_before_tool.before_tool_fn())
979            .after_completion(mock_after_completion.after_completion_fn())
980            .after_tool(mock_after_tool.after_tool_fn())
981            .after_each(mock_after_each.hook_fn())
982            .on_new_message(mock_on_message.message_hook_fn())
983            .on_stop(mock_on_stop.stop_hook_fn())
984            .build()
985            .unwrap();
986
987        agent.query(prompt).await.unwrap();
988    }
989
990    #[test_log::test(tokio::test)]
991    async fn test_agent_loop_limit() {
992        let prompt = "Generate content"; // Example prompt
993        let mock_llm = MockChatCompletion::new();
994        let mock_tool = MockTool::new("mock_tool");
995
996        let chat_request = chat_request! {
997            user!(prompt);
998            tools = [mock_tool.clone()]
999        };
1000        mock_tool.expect_invoke_ok("Great!".into(), None);
1001
1002        let mock_tool_response = chat_response! {
1003            "Some response";
1004            tool_calls = ["mock_tool"]
1005        };
1006
1007        // Set expectations for the mock LLM responses
1008        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1009
1010        // // Response for terminating the loop
1011        let stop_response = chat_response! {
1012            "Final response";
1013            tool_calls = ["stop"]
1014        };
1015
1016        mock_llm.expect_complete(chat_request, Ok(stop_response));
1017
1018        let mut agent = Agent::builder()
1019            .tools([mock_tool])
1020            .llm(&mock_llm)
1021            .no_system_prompt()
1022            .limit(1) // Setting the loop limit to 1
1023            .build()
1024            .unwrap();
1025
1026        // Run the agent
1027        agent.query(prompt).await.unwrap();
1028
1029        // Assert that the remaining message is still in the queue
1030        let remaining = mock_llm.expectations.lock().unwrap().pop();
1031        assert!(remaining.is_some());
1032
1033        // Assert that the agent is stopped after reaching the loop limit
1034        assert!(agent.is_stopped());
1035    }
1036
1037    #[test_log::test(tokio::test)]
1038    async fn test_tool_retry_mechanism() {
1039        let prompt = "Execute my tool";
1040        let mock_llm = MockChatCompletion::new();
1041        let mock_tool = MockTool::new("retry_tool");
1042
1043        // Configure mock tool to fail twice. First time is fed back to the LLM, second time is an
1044        // error
1045        mock_tool.expect_invoke(
1046            Err(ToolError::WrongArguments(serde_json::Error::custom(
1047                "missing `query`",
1048            ))),
1049            None,
1050        );
1051        mock_tool.expect_invoke(
1052            Err(ToolError::WrongArguments(serde_json::Error::custom(
1053                "missing `query`",
1054            ))),
1055            None,
1056        );
1057
1058        let chat_request = chat_request! {
1059            user!(prompt);
1060            tools = [mock_tool.clone()]
1061        };
1062        let retry_response = chat_response! {
1063            "First failing attempt";
1064            tool_calls = ["retry_tool"]
1065        };
1066        mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1067
1068        let chat_request = chat_request! {
1069            user!(prompt),
1070            assistant!("First failing attempt", ["retry_tool"]),
1071            tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1072
1073            tools = [mock_tool.clone()]
1074        };
1075        let will_fail_response = chat_response! {
1076            "Finished execution";
1077            tool_calls = ["retry_tool"]
1078        };
1079        mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1080
1081        let mut agent = Agent::builder()
1082            .tools([mock_tool])
1083            .llm(&mock_llm)
1084            .no_system_prompt()
1085            .tool_retry_limit(1) // The test relies on a limit of 2 retries.
1086            .build()
1087            .unwrap();
1088
1089        // Run the agent
1090        let result = agent.query(prompt).await;
1091
1092        assert!(result.is_err());
1093        assert!(result.unwrap_err().to_string().contains("missing `query`"));
1094        assert!(agent.is_stopped());
1095    }
1096}