swiftide_agents/
agent.rs

1#![allow(dead_code)]
2use crate::{
3    default_context::DefaultContext,
4    errors::AgentError,
5    hooks::{
6        AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7        Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn, OnStreamFn,
8    },
9    invoke_hooks,
10    state::{self, StopReason},
11    system_prompt::SystemPrompt,
12    tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15    collections::{HashMap, HashSet},
16    hash::{DefaultHasher, Hash as _, Hasher as _},
17    sync::Arc,
18};
19
20use derive_builder::Builder;
21use futures_util::stream::StreamExt;
22use swiftide_core::{
23    AgentContext, ToolBox,
24    chat_completion::{
25        ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
26    },
27    prompt::Prompt,
28};
29use tracing::{Instrument, debug};
30
31/// Agents are the main interface for building agentic systems.
32///
33/// Construct agents by calling the builder, setting an llm, configure hooks, tools and other
34/// customizations.
35///
36/// # Important defaults
37///
38/// - The default context is the `DefaultContext`, executing tools locally with the `LocalExecutor`.
39/// - A default `stop` tool is provided for agents to explicitly stop if needed
40/// - The default `SystemPrompt` instructs the agent with chain of thought and some common
41///   safeguards, but is otherwise quite bare. In a lot of cases this can be sufficient.
42#[derive(Clone, Builder)]
43pub struct Agent {
44    /// Hooks are functions that are called at specific points in the agent's lifecycle.
45    #[builder(default, setter(into))]
46    pub(crate) hooks: Vec<Hook>,
47    /// The context in which the agent operates, by default this is the `DefaultContext`.
48    #[builder(
49        setter(custom),
50        default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
51    )]
52    pub(crate) context: Arc<dyn AgentContext>,
53    /// Tools the agent can use
54    #[builder(default = Agent::default_tools(), setter(custom))]
55    pub(crate) tools: HashSet<Box<dyn Tool>>,
56
57    /// Toolboxes are collections of tools that can be added to the agent.
58    ///
59    /// Toolboxes make their tools available to the agent at runtime.
60    #[builder(default)]
61    pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
62
63    /// The language model that the agent uses for completion.
64    #[builder(setter(custom))]
65    pub(crate) llm: Box<dyn ChatCompletion>,
66
67    /// System prompt for the agent when it starts
68    ///
69    /// Some agents profit significantly from a tailored prompt. But it is not always needed.
70    ///
71    /// See [`SystemPrompt`] for an opiniated, customizable system prompt.
72    ///
73    /// Swiftide provides a default system prompt for all agents.
74    ///
75    /// # Example
76    ///
77    /// ```no_run
78    /// # use swiftide_agents::system_prompt::SystemPrompt;
79    /// # use swiftide_agents::Agent;
80    /// Agent::builder()
81    ///     .system_prompt(
82    ///         SystemPrompt::builder().role("You are an expert engineer")
83    ///         .build().unwrap())
84    ///     .build().unwrap();
85    /// ```
86    #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
87    pub(crate) system_prompt: Option<Prompt>,
88
89    /// Initial state of the agent
90    #[builder(private, default = state::State::default())]
91    pub(crate) state: state::State,
92
93    /// Optional limit on the amount of loops the agent can run.
94    /// The counter is reset when the agent is stopped.
95    #[builder(default, setter(strip_option))]
96    pub(crate) limit: Option<usize>,
97
98    /// The maximum amount of times the failed output of a tool will be send
99    /// to an LLM before the agent stops. Defaults to 3.
100    ///
101    /// LLMs sometimes send missing arguments, or a tool might actually fail, but retrying could be
102    /// worth while. If the limit is not reached, the agent will send the formatted error back to
103    /// the LLM.
104    ///
105    /// The limit is hashed based on the tool call name and arguments, so the limit is per tool
106    /// call.
107    ///
108    /// This limit is _not_ reset when the agent is stopped.
109    #[builder(default = 3)]
110    pub(crate) tool_retry_limit: usize,
111
112    /// Enables streaming the chat completion responses for the agent.
113    #[builder(default)]
114    pub(crate) streaming: bool,
115
116    /// Internally tracks the amount of times a tool has been retried. The key is a hash based on
117    /// the name and args of the tool.
118    #[builder(private, default)]
119    pub(crate) tool_retries_counter: HashMap<u64, usize>,
120
121    /// Tools loaded from toolboxes
122    #[builder(private, default)]
123    pub(crate) toolbox_tools: HashSet<Box<dyn Tool>>,
124}
125
126impl std::fmt::Debug for Agent {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        f.debug_struct("Agent")
129            // display hooks as a list of type: number of hooks
130            .field(
131                "hooks",
132                &self
133                    .hooks
134                    .iter()
135                    .map(std::string::ToString::to_string)
136                    .collect::<Vec<_>>(),
137            )
138            .field(
139                "tools",
140                &self
141                    .tools
142                    .iter()
143                    .map(swiftide_core::Tool::name)
144                    .collect::<Vec<_>>(),
145            )
146            .field("llm", &"Box<dyn ChatCompletion>")
147            .field("state", &self.state)
148            .finish()
149    }
150}
151
152impl AgentBuilder {
153    /// The context in which the agent operates, by default this is the `DefaultContext`.
154    pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
155    where
156        Self: Clone,
157    {
158        self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
159        self
160    }
161
162    /// Disable the system prompt.
163    pub fn no_system_prompt(&mut self) -> &mut Self {
164        self.system_prompt = Some(None);
165
166        self
167    }
168
169    /// Add a hook to the agent.
170    pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
171        let hooks = self.hooks.get_or_insert_with(Vec::new);
172        hooks.push(hook);
173
174        self
175    }
176
177    /// Add a hook that runs once, before all completions. Even if the agent is paused and resumed,
178    /// before all will not trigger again.
179    pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
180        self.add_hook(Hook::BeforeAll(Box::new(hook)))
181    }
182
183    /// Add a hook that runs once, when the agent starts. This hook also runs if the agent stopped
184    /// and then starts again. The hook runs after any `before_all` hooks and before the
185    /// `before_completion` hooks.
186    pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
187        self.add_hook(Hook::OnStart(Box::new(hook)))
188    }
189
190    /// Add a hook that runs when the agent receives a streaming response
191    ///
192    /// The response will always include both the current accumulated message and the delta
193    ///
194    /// This will set `self.streaming` to true, there is no need to set it manually for the default
195    /// behaviour.
196    pub fn on_stream(&mut self, hook: impl OnStreamFn + 'static) -> &mut Self {
197        self.streaming = Some(true);
198        self.add_hook(Hook::OnStream(Box::new(hook)))
199    }
200
201    /// Add a hook that runs before each completion.
202    pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
203        self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
204    }
205
206    /// Add a hook that runs after each tool. The `Result<ToolOutput, ToolError>` is provided
207    /// as mut, so the tool output can be fully modified.
208    ///
209    /// The `ToolOutput` also references the original `ToolCall`, allowing you to match at runtime
210    /// what tool to interact with.
211    pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
212        self.add_hook(Hook::AfterTool(Box::new(hook)))
213    }
214
215    /// Add a hook that runs before each tool. Yields an immutable reference to the `ToolCall`.
216    pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
217        self.add_hook(Hook::BeforeTool(Box::new(hook)))
218    }
219
220    /// Add a hook that runs after each completion, before tool invocation and/or new messages.
221    pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
222        self.add_hook(Hook::AfterCompletion(Box::new(hook)))
223    }
224
225    /// Add a hook that runs after each completion, after tool invocations, right before a new loop
226    /// might start
227    pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
228        self.add_hook(Hook::AfterEach(Box::new(hook)))
229    }
230
231    /// Add a hook that runs when a new message is added to the context. Note that each tool adds a
232    /// separate message.
233    pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
234        self.add_hook(Hook::OnNewMessage(Box::new(hook)))
235    }
236
237    pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
238        self.add_hook(Hook::OnStop(Box::new(hook)))
239    }
240
241    /// Set the LLM for the agent. An LLM must implement the `ChatCompletion` trait.
242    pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
243        let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
244
245        self.llm = Some(boxed);
246        self
247    }
248
249    /// Define the available tools for the agent. Tools must implement the `Tool` trait.
250    ///
251    /// See the [tool attribute macro](`swiftide_macros::tool`) and the [tool derive
252    /// macro](`swiftide_macros::Tool`) for easy ways to create (many) tools.
253    pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
254    where
255        TOOL: Into<Box<dyn Tool>>,
256    {
257        self.tools = Some(
258            tools
259                .into_iter()
260                .map(Into::into)
261                .chain(Agent::default_tools())
262                .collect(),
263        );
264        self
265    }
266
267    /// Add a toolbox to the agent. Toolboxes are collections of tools that can be added to the
268    /// to the agent. Available tools are evaluated at runtime, when the agent starts for the first
269    /// time.
270    ///
271    /// Agents can have many toolboxes.
272    pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
273        let toolboxes = self.toolboxes.get_or_insert_with(Vec::new);
274        toolboxes.push(Box::new(toolbox));
275
276        self
277    }
278}
279
280impl Agent {
281    /// Build a new agent
282    pub fn builder() -> AgentBuilder {
283        AgentBuilder::default()
284    }
285}
286
287impl Agent {
288    /// Default tools for the agent that it always includes
289    fn default_tools() -> HashSet<Box<dyn Tool>> {
290        HashSet::from([Box::new(Stop::default()) as Box<dyn Tool>])
291    }
292
293    /// Run the agent with a user message. The agent will loop completions, make tool calls, until
294    /// no new messages are available.
295    #[tracing::instrument(skip_all, name = "agent.query")]
296    pub async fn query(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
297        let query = query
298            .into()
299            .render()
300            .map_err(AgentError::FailedToRenderPrompt)?;
301        self.run_agent(Some(query), false).await
302    }
303
304    /// Run the agent with a user message once.
305    #[tracing::instrument(skip_all, name = "agent.query_once")]
306    pub async fn query_once(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
307        let query = query
308            .into()
309            .render()
310            .map_err(AgentError::FailedToRenderPrompt)?;
311        self.run_agent(Some(query), true).await
312    }
313
314    /// Run the agent with without user message. The agent will loop completions, make tool calls,
315    /// until no new messages are available.
316    #[tracing::instrument(skip_all, name = "agent.run")]
317    pub async fn run(&mut self) -> Result<(), AgentError> {
318        self.run_agent(None, false).await
319    }
320
321    /// Run the agent with without user message. The agent will loop completions, make tool calls,
322    /// until
323    #[tracing::instrument(skip_all, name = "agent.run_once")]
324    pub async fn run_once(&mut self) -> Result<(), AgentError> {
325        self.run_agent(None, true).await
326    }
327
328    /// Retrieve the message history of the agent
329    pub async fn history(&self) -> Vec<ChatMessage> {
330        self.context.history().await
331    }
332
333    async fn run_agent(
334        &mut self,
335        maybe_query: Option<String>,
336        just_once: bool,
337    ) -> Result<(), AgentError> {
338        if self.state.is_running() {
339            return Err(AgentError::AlreadyRunning);
340        }
341
342        if self.state.is_pending() {
343            if let Some(system_prompt) = &self.system_prompt {
344                self.context
345                    .add_messages(vec![ChatMessage::System(
346                        system_prompt
347                            .render()
348                            .map_err(AgentError::FailedToRenderSystemPrompt)?,
349                    )])
350                    .await;
351            }
352
353            invoke_hooks!(BeforeAll, self);
354
355            self.load_toolboxes().await?;
356        }
357
358        invoke_hooks!(OnStart, self);
359
360        self.state = state::State::Running;
361
362        if let Some(query) = maybe_query {
363            self.context.add_message(ChatMessage::User(query)).await;
364        }
365
366        let mut loop_counter = 0;
367
368        while let Some(messages) = self.context.next_completion().await {
369            if let Some(limit) = self.limit {
370                if loop_counter >= limit {
371                    tracing::warn!("Agent loop limit reached");
372                    break;
373                }
374            }
375            let result = self.run_completions(&messages).await;
376
377            if let Err(err) = result {
378                self.stop_with_error(&err).await;
379                tracing::error!(error = ?err, "Agent stopped with error {err}");
380                return Err(err);
381            }
382
383            if just_once || self.state.is_stopped() {
384                break;
385            }
386            loop_counter += 1;
387        }
388
389        // If there are no new messages, ensure we update our state
390        self.stop(StopReason::NoNewMessages).await;
391
392        Ok(())
393    }
394
395    #[tracing::instrument(skip_all, err)]
396    async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<(), AgentError> {
397        debug!(
398            "Running completion for agent with {} messages",
399            messages.len()
400        );
401
402        let mut chat_completion_request = ChatCompletionRequest::builder()
403            .messages(messages)
404            .tools_spec(
405                self.tools
406                    .iter()
407                    .map(swiftide_core::Tool::tool_spec)
408                    .collect::<HashSet<_>>(),
409            )
410            .build()
411            .map_err(AgentError::FailedToBuildRequest)?;
412
413        invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
414
415        debug!(
416            "Calling LLM with the following new messages:\n {}",
417            self.context
418                .current_new_messages()
419                .await
420                .iter()
421                .map(ToString::to_string)
422                .collect::<Vec<_>>()
423                .join(",\n")
424        );
425
426        let mut response = if self.streaming {
427            let mut last_response = None;
428            let mut stream = self.llm.complete_stream(&chat_completion_request).await;
429
430            while let Some(response) = stream.next().await {
431                let response = response.map_err(AgentError::CompletionsFailed)?;
432                invoke_hooks!(OnStream, self, &response);
433                last_response = Some(response);
434            }
435            tracing::trace!(?last_response, "Streaming completed");
436            last_response.ok_or(AgentError::EmptyStream)
437        } else {
438            self.llm
439                .complete(&chat_completion_request)
440                .await
441                .map_err(AgentError::CompletionsFailed)
442        }?;
443
444        invoke_hooks!(AfterCompletion, self, &mut response);
445
446        self.add_message(ChatMessage::Assistant(
447            response.message,
448            response.tool_calls.clone(),
449        ))
450        .await?;
451
452        if let Some(tool_calls) = response.tool_calls {
453            self.invoke_tools(tool_calls).await?;
454        }
455
456        invoke_hooks!(AfterEach, self);
457
458        Ok(())
459    }
460
461    async fn invoke_tools(&mut self, tool_calls: Vec<ToolCall>) -> Result<(), AgentError> {
462        debug!("LLM returned tool calls: {:?}", tool_calls);
463
464        let mut handles = vec![];
465        for tool_call in tool_calls {
466            let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
467                tracing::warn!("Tool {} not found", tool_call.name());
468                continue;
469            };
470            tracing::info!("Calling tool `{}`", tool_call.name());
471
472            let tool_args = tool_call.args().map(String::from);
473            let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
474
475            invoke_hooks!(BeforeTool, self, &tool_call);
476
477            let tool_span = tracing::info_span!(
478                "tool",
479                "otel.name" = format!("tool.{}", tool.name().as_ref())
480            );
481
482            let handle = tokio::spawn(async move {
483                    let tool_args = ArgPreprocessor::preprocess(tool_args.as_deref());
484                    let output = tool.invoke(&*context, tool_args.as_deref()).await.map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
485
486                    tracing::debug!(output = output.to_string(), args = ?tool_args, tool_name = tool.name().as_ref(), "Completed tool call");
487
488                    Ok(output)
489                }.instrument(tool_span.or_current()));
490
491            handles.push((handle, tool_call));
492        }
493
494        for (handle, tool_call) in handles {
495            let mut output = handle.await.map_err(AgentError::ToolFailedToJoin)?;
496
497            invoke_hooks!(AfterTool, self, &tool_call, &mut output);
498
499            if let Err(error) = output {
500                let stop = self.tool_calls_over_limit(&tool_call);
501                if stop {
502                    tracing::error!(
503                        ?error,
504                        "Tool call failed, retry limit reached, stopping agent: {error}",
505                    );
506                } else {
507                    tracing::warn!(
508                        ?error,
509                        tool_call = ?tool_call,
510                        "Tool call failed, retrying",
511                    );
512                }
513                self.add_message(ChatMessage::ToolOutput(
514                    tool_call.clone(),
515                    ToolOutput::Fail(error.to_string()),
516                ))
517                .await?;
518                if stop {
519                    self.stop(StopReason::ToolCallsOverLimit(tool_call)).await;
520                    return Err(error.into());
521                }
522                continue;
523            }
524
525            let output = output?;
526            self.handle_control_tools(&tool_call, &output).await;
527            self.add_message(ChatMessage::ToolOutput(tool_call, output))
528                .await?;
529        }
530
531        Ok(())
532    }
533
534    fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
535        self.hooks
536            .iter()
537            .filter(|h| hook_type == (*h).into())
538            .collect()
539    }
540
541    fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
542        self.tools
543            .iter()
544            .find(|tool| tool.name() == tool_name)
545            .cloned()
546    }
547
548    // Handle any tool specific output (e.g. stop)
549    async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
550        if let ToolOutput::Stop = output {
551            tracing::warn!("Stop tool called, stopping agent");
552            self.stop(StopReason::RequestedByTool(tool_call.clone()))
553                .await;
554        }
555    }
556
557    fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
558        let mut s = DefaultHasher::new();
559        tool_call.hash(&mut s);
560        let hash = s.finish();
561
562        if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
563            let val = *retries >= self.tool_retry_limit;
564            *retries += 1;
565            val
566        } else {
567            self.tool_retries_counter.insert(hash, 1);
568            false
569        }
570    }
571
572    /// Add a message to the agent's context
573    ///
574    /// This will trigger a `OnNewMessage` hook if its present.
575    ///
576    /// If you want to add a message without triggering the hook, use the context directly.
577    #[tracing::instrument(skip_all, fields(message = message.to_string()))]
578    pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
579        invoke_hooks!(OnNewMessage, self, &mut message);
580
581        self.context.add_message(message).await;
582        Ok(())
583    }
584
585    /// Tell the agent to stop. It will finish it's current loop and then stop.
586    pub async fn stop(&mut self, reason: impl Into<StopReason>) {
587        if self.state.is_stopped() {
588            return;
589        }
590        let reason = reason.into();
591        invoke_hooks!(OnStop, self, reason.clone(), None);
592
593        self.state = state::State::Stopped(reason);
594    }
595
596    pub async fn stop_with_error(&mut self, error: &AgentError) {
597        if self.state.is_stopped() {
598            return;
599        }
600        invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
601
602        self.state = state::State::Stopped(StopReason::Error);
603    }
604
605    /// Access the agent's context
606    pub fn context(&self) -> &dyn AgentContext {
607        &self.context
608    }
609
610    /// The agent is still running
611    pub fn is_running(&self) -> bool {
612        self.state.is_running()
613    }
614
615    /// The agent stopped
616    pub fn is_stopped(&self) -> bool {
617        self.state.is_stopped()
618    }
619
620    /// The agent has not (ever) started
621    pub fn is_pending(&self) -> bool {
622        self.state.is_pending()
623    }
624
625    /// Get a list of tools available to the agent
626    pub fn tools(&self) -> &HashSet<Box<dyn Tool>> {
627        &self.tools
628    }
629
630    async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
631        for toolbox in &self.toolboxes {
632            let tools = toolbox
633                .available_tools()
634                .await
635                .map_err(AgentError::ToolBoxFailedToLoad)?;
636            self.toolbox_tools.extend(tools);
637        }
638
639        self.tools.extend(self.toolbox_tools.clone());
640
641        Ok(())
642    }
643}
644
645#[cfg(test)]
646mod tests {
647
648    use serde::ser::Error;
649    use swiftide_core::chat_completion::errors::ToolError;
650    use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
651    use swiftide_core::test_utils::MockChatCompletion;
652
653    use super::*;
654    use crate::{
655        assistant, chat_request, chat_response, summary, system, tool_failed, tool_output, user,
656    };
657
658    use crate::test_utils::{MockHook, MockTool};
659
660    #[test_log::test(tokio::test)]
661    async fn test_agent_builder_defaults() {
662        // Create a prompt
663        let mock_llm = MockChatCompletion::new();
664
665        // Build the agent
666        let agent = Agent::builder().llm(&mock_llm).build().unwrap();
667
668        // Check that the context is the default context
669
670        // Check that the default tools are added
671        assert!(agent.find_tool_by_name("stop").is_some());
672
673        // Check it does not allow duplicates
674        let agent = Agent::builder()
675            .tools([Stop::default(), Stop::default()])
676            .llm(&mock_llm)
677            .build()
678            .unwrap();
679
680        assert_eq!(agent.tools.len(), 1);
681
682        // It should include the default tool if a different tool is provided
683        let agent = Agent::builder()
684            .tools([MockTool::new("mock_tool")])
685            .llm(&mock_llm)
686            .build()
687            .unwrap();
688
689        assert_eq!(agent.tools.len(), 2);
690        assert!(agent.find_tool_by_name("mock_tool").is_some());
691        assert!(agent.find_tool_by_name("stop").is_some());
692
693        assert!(agent.context().history().await.is_empty());
694    }
695
696    #[test_log::test(tokio::test)]
697    async fn test_agent_tool_calling_loop() {
698        let prompt = "Write a poem";
699        let mock_llm = MockChatCompletion::new();
700        let mock_tool = MockTool::new("mock_tool");
701
702        let chat_request = chat_request! {
703            user!("Write a poem");
704
705            tools = [mock_tool.clone()]
706        };
707
708        let mock_tool_response = chat_response! {
709            "Roses are red";
710            tool_calls = ["mock_tool"]
711
712        };
713
714        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
715
716        let chat_request = chat_request! {
717            user!("Write a poem"),
718            assistant!("Roses are red", ["mock_tool"]),
719            tool_output!("mock_tool", "Great!");
720
721            tools = [mock_tool.clone()]
722        };
723
724        let stop_response = chat_response! {
725            "Roses are red";
726            tool_calls = ["stop"]
727        };
728
729        mock_llm.expect_complete(chat_request, Ok(stop_response));
730        mock_tool.expect_invoke_ok("Great!".into(), None);
731
732        let mut agent = Agent::builder()
733            .tools([mock_tool])
734            .llm(&mock_llm)
735            .no_system_prompt()
736            .build()
737            .unwrap();
738
739        agent.query(prompt).await.unwrap();
740    }
741
742    #[test_log::test(tokio::test)]
743    async fn test_agent_tool_run_once() {
744        let prompt = "Write a poem";
745        let mock_llm = MockChatCompletion::new();
746        let mock_tool = MockTool::default();
747
748        let chat_request = chat_request! {
749            system!("My system prompt"),
750            user!("Write a poem");
751
752            tools = [mock_tool.clone()]
753        };
754
755        let mock_tool_response = chat_response! {
756            "Roses are red";
757            tool_calls = ["mock_tool"]
758
759        };
760
761        mock_tool.expect_invoke_ok("Great!".into(), None);
762        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
763
764        let mut agent = Agent::builder()
765            .tools([mock_tool])
766            .system_prompt("My system prompt")
767            .llm(&mock_llm)
768            .build()
769            .unwrap();
770
771        agent.query_once(prompt).await.unwrap();
772    }
773
774    #[test_log::test(tokio::test)]
775    async fn test_agent_tool_via_toolbox_run_once() {
776        let prompt = "Write a poem";
777        let mock_llm = MockChatCompletion::new();
778        let mock_tool = MockTool::default();
779
780        let chat_request = chat_request! {
781            system!("My system prompt"),
782            user!("Write a poem");
783
784            tools = [mock_tool.clone()]
785        };
786
787        let mock_tool_response = chat_response! {
788            "Roses are red";
789            tool_calls = ["mock_tool"]
790
791        };
792
793        mock_tool.expect_invoke_ok("Great!".into(), None);
794        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
795
796        let mut agent = Agent::builder()
797            .add_toolbox(vec![mock_tool.boxed()])
798            .system_prompt("My system prompt")
799            .llm(&mock_llm)
800            .build()
801            .unwrap();
802
803        agent.query_once(prompt).await.unwrap();
804    }
805
806    #[test_log::test(tokio::test(flavor = "multi_thread"))]
807    async fn test_multiple_tool_calls() {
808        let prompt = "Write a poem";
809        let mock_llm = MockChatCompletion::new();
810        let mock_tool = MockTool::new("mock_tool1");
811        let mock_tool2 = MockTool::new("mock_tool2");
812
813        let chat_request = chat_request! {
814            system!("My system prompt"),
815            user!("Write a poem");
816
817
818
819            tools = [mock_tool.clone(), mock_tool2.clone()]
820        };
821
822        let mock_tool_response = chat_response! {
823            "Roses are red";
824
825            tool_calls = ["mock_tool1", "mock_tool2"]
826
827        };
828
829        dbg!(&chat_request);
830        mock_tool.expect_invoke_ok("Great!".into(), None);
831        mock_tool2.expect_invoke_ok("Great!".into(), None);
832        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
833
834        let chat_request = chat_request! {
835            system!("My system prompt"),
836            user!("Write a poem"),
837            assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
838            tool_output!("mock_tool1", "Great!"),
839            tool_output!("mock_tool2", "Great!");
840
841            tools = [mock_tool.clone(), mock_tool2.clone()]
842        };
843
844        let mock_tool_response = chat_response! {
845            "Ok!";
846
847            tool_calls = ["stop"]
848
849        };
850
851        mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
852
853        let mut agent = Agent::builder()
854            .tools([mock_tool, mock_tool2])
855            .system_prompt("My system prompt")
856            .llm(&mock_llm)
857            .build()
858            .unwrap();
859
860        agent.query(prompt).await.unwrap();
861    }
862
863    #[test_log::test(tokio::test)]
864    async fn test_agent_state_machine() {
865        let prompt = "Write a poem";
866        let mock_llm = MockChatCompletion::new();
867
868        let chat_request = chat_request! {
869            user!("Write a poem");
870            tools = []
871        };
872        let mock_tool_response = chat_response! {
873            "Roses are red";
874            tool_calls = []
875        };
876
877        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
878        let mut agent = Agent::builder()
879            .llm(&mock_llm)
880            .no_system_prompt()
881            .build()
882            .unwrap();
883
884        // Agent has never run and is pending
885        assert!(agent.state.is_pending());
886        agent.query_once(prompt).await.unwrap();
887
888        // Agent is stopped, there might be more messages
889        assert!(agent.state.is_stopped());
890    }
891
892    #[test_log::test(tokio::test)]
893    async fn test_summary() {
894        let prompt = "Write a poem";
895        let mock_llm = MockChatCompletion::new();
896
897        let mock_tool_response = chat_response! {
898            "Roses are red";
899            tool_calls = []
900
901        };
902
903        let expected_chat_request = chat_request! {
904            system!("My system prompt"),
905            user!("Write a poem");
906
907            tools = []
908        };
909
910        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
911
912        let mut agent = Agent::builder()
913            .system_prompt("My system prompt")
914            .llm(&mock_llm)
915            .build()
916            .unwrap();
917
918        agent.query_once(prompt).await.unwrap();
919
920        agent
921            .context
922            .add_message(ChatMessage::new_summary("Summary"))
923            .await;
924
925        let expected_chat_request = chat_request! {
926            system!("My system prompt"),
927            summary!("Summary"),
928            user!("Write another poem");
929            tools = []
930        };
931        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
932
933        agent.query_once("Write another poem").await.unwrap();
934
935        agent
936            .context
937            .add_message(ChatMessage::new_summary("Summary 2"))
938            .await;
939
940        let expected_chat_request = chat_request! {
941            system!("My system prompt"),
942            summary!("Summary 2"),
943            user!("Write a third poem");
944            tools = []
945        };
946        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
947
948        agent.query_once("Write a third poem").await.unwrap();
949    }
950
951    #[test_log::test(tokio::test)]
952    async fn test_agent_hooks() {
953        let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
954        let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
955        let mock_before_completion = MockHook::new("before_completion")
956            .expect_calls(2)
957            .to_owned();
958        let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
959        let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
960        let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
961        let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
962
963        // Once for mock tool and once for stop
964        let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
965        let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
966
967        let prompt = "Write a poem";
968        let mock_llm = MockChatCompletion::new();
969        let mock_tool = MockTool::default();
970
971        let chat_request = chat_request! {
972            user!("Write a poem");
973
974            tools = [mock_tool.clone()]
975        };
976
977        let mock_tool_response = chat_response! {
978            "Roses are red";
979            tool_calls = ["mock_tool"]
980
981        };
982
983        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
984
985        let chat_request = chat_request! {
986            user!("Write a poem"),
987            assistant!("Roses are red", ["mock_tool"]),
988            tool_output!("mock_tool", "Great!");
989
990            tools = [mock_tool.clone()]
991        };
992
993        let stop_response = chat_response! {
994            "Roses are red";
995            tool_calls = ["stop"]
996        };
997
998        mock_llm.expect_complete(chat_request, Ok(stop_response));
999        mock_tool.expect_invoke_ok("Great!".into(), None);
1000
1001        let mut agent = Agent::builder()
1002            .tools([mock_tool])
1003            .llm(&mock_llm)
1004            .no_system_prompt()
1005            .before_all(mock_before_all.hook_fn())
1006            .on_start(mock_on_start_fn.on_start_fn())
1007            .before_completion(mock_before_completion.before_completion_fn())
1008            .before_tool(mock_before_tool.before_tool_fn())
1009            .after_completion(mock_after_completion.after_completion_fn())
1010            .after_tool(mock_after_tool.after_tool_fn())
1011            .after_each(mock_after_each.hook_fn())
1012            .on_new_message(mock_on_message.message_hook_fn())
1013            .on_stop(mock_on_stop.stop_hook_fn())
1014            .build()
1015            .unwrap();
1016
1017        agent.query(prompt).await.unwrap();
1018    }
1019
1020    #[test_log::test(tokio::test)]
1021    async fn test_agent_loop_limit() {
1022        let prompt = "Generate content"; // Example prompt
1023        let mock_llm = MockChatCompletion::new();
1024        let mock_tool = MockTool::new("mock_tool");
1025
1026        let chat_request = chat_request! {
1027            user!(prompt);
1028            tools = [mock_tool.clone()]
1029        };
1030        mock_tool.expect_invoke_ok("Great!".into(), None);
1031
1032        let mock_tool_response = chat_response! {
1033            "Some response";
1034            tool_calls = ["mock_tool"]
1035        };
1036
1037        // Set expectations for the mock LLM responses
1038        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1039
1040        // // Response for terminating the loop
1041        let stop_response = chat_response! {
1042            "Final response";
1043            tool_calls = ["stop"]
1044        };
1045
1046        mock_llm.expect_complete(chat_request, Ok(stop_response));
1047
1048        let mut agent = Agent::builder()
1049            .tools([mock_tool])
1050            .llm(&mock_llm)
1051            .no_system_prompt()
1052            .limit(1) // Setting the loop limit to 1
1053            .build()
1054            .unwrap();
1055
1056        // Run the agent
1057        agent.query(prompt).await.unwrap();
1058
1059        // Assert that the remaining message is still in the queue
1060        let remaining = mock_llm.expectations.lock().unwrap().pop();
1061        assert!(remaining.is_some());
1062
1063        // Assert that the agent is stopped after reaching the loop limit
1064        assert!(agent.is_stopped());
1065    }
1066
1067    #[test_log::test(tokio::test)]
1068    async fn test_tool_retry_mechanism() {
1069        let prompt = "Execute my tool";
1070        let mock_llm = MockChatCompletion::new();
1071        let mock_tool = MockTool::new("retry_tool");
1072
1073        // Configure mock tool to fail twice. First time is fed back to the LLM, second time is an
1074        // error
1075        mock_tool.expect_invoke(
1076            Err(ToolError::WrongArguments(serde_json::Error::custom(
1077                "missing `query`",
1078            ))),
1079            None,
1080        );
1081        mock_tool.expect_invoke(
1082            Err(ToolError::WrongArguments(serde_json::Error::custom(
1083                "missing `query`",
1084            ))),
1085            None,
1086        );
1087
1088        let chat_request = chat_request! {
1089            user!(prompt);
1090            tools = [mock_tool.clone()]
1091        };
1092        let retry_response = chat_response! {
1093            "First failing attempt";
1094            tool_calls = ["retry_tool"]
1095        };
1096        mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1097
1098        let chat_request = chat_request! {
1099            user!(prompt),
1100            assistant!("First failing attempt", ["retry_tool"]),
1101            tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1102
1103            tools = [mock_tool.clone()]
1104        };
1105        let will_fail_response = chat_response! {
1106            "Finished execution";
1107            tool_calls = ["retry_tool"]
1108        };
1109        mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1110
1111        let mut agent = Agent::builder()
1112            .tools([mock_tool])
1113            .llm(&mock_llm)
1114            .no_system_prompt()
1115            .tool_retry_limit(1) // The test relies on a limit of 2 retries.
1116            .build()
1117            .unwrap();
1118
1119        // Run the agent
1120        let result = agent.query(prompt).await;
1121
1122        assert!(result.is_err());
1123        assert!(result.unwrap_err().to_string().contains("missing `query`"));
1124        assert!(agent.is_stopped());
1125    }
1126
1127    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1128    async fn test_streaming() {
1129        let prompt = "Generate content"; // Example prompt
1130        let mock_llm = MockChatCompletion::new();
1131        let on_stream_fn = MockHook::new("on_stream").expect_calls(3).to_owned();
1132
1133        let chat_request = chat_request! {
1134            user!(prompt);
1135
1136            tools = []
1137        };
1138
1139        let response = chat_response! {
1140            "one two three";
1141            tool_calls = ["stop"]
1142        };
1143
1144        // Set expectations for the mock LLM responses
1145        mock_llm.expect_complete(chat_request, Ok(response));
1146
1147        let mut agent = Agent::builder()
1148            .llm(&mock_llm)
1149            .on_stream(on_stream_fn.on_stream_fn())
1150            .no_system_prompt()
1151            .build()
1152            .unwrap();
1153
1154        // Run the agent
1155        agent.query(prompt).await.unwrap();
1156
1157        tracing::debug!("Agent finished running");
1158
1159        // Assert that the agent is stopped after reaching the loop limit
1160        assert!(agent.is_stopped());
1161    }
1162
1163    #[test_log::test(tokio::test)]
1164    async fn test_recovering_agent_existing_history() {
1165        // First, let's run an agent
1166        let prompt = "Write a poem";
1167        let mock_llm = MockChatCompletion::new();
1168        let mock_tool = MockTool::new("mock_tool");
1169
1170        let chat_request = chat_request! {
1171            user!("Write a poem");
1172
1173            tools = [mock_tool.clone()]
1174        };
1175
1176        let mock_tool_response = chat_response! {
1177            "Roses are red";
1178            tool_calls = ["mock_tool"]
1179
1180        };
1181
1182        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1183
1184        let chat_request = chat_request! {
1185            user!("Write a poem"),
1186            assistant!("Roses are red", ["mock_tool"]),
1187            tool_output!("mock_tool", "Great!");
1188
1189            tools = [mock_tool.clone()]
1190        };
1191
1192        let stop_response = chat_response! {
1193            "Roses are red";
1194            tool_calls = ["stop"]
1195        };
1196
1197        mock_llm.expect_complete(chat_request, Ok(stop_response));
1198        mock_tool.expect_invoke_ok("Great!".into(), None);
1199
1200        let mut agent = Agent::builder()
1201            .tools([mock_tool.clone()])
1202            .llm(&mock_llm)
1203            .no_system_prompt()
1204            .build()
1205            .unwrap();
1206
1207        agent.query(prompt).await.unwrap();
1208
1209        // Let's retrieve the history of the agent
1210        let history = agent.history().await;
1211
1212        // Store it as a string somewhere
1213        let serialized = serde_json::to_string(&history).unwrap();
1214
1215        // Retrieve it
1216        let history: Vec<ChatMessage> = serde_json::from_str(&serialized).unwrap();
1217
1218        // Build a context from the history
1219        let context = DefaultContext::default()
1220            .with_message_history(history)
1221            .to_owned();
1222
1223        let expected_chat_request = chat_request! {
1224            user!("Write a poem"),
1225            assistant!("Roses are red", ["mock_tool"]),
1226            tool_output!("mock_tool", "Great!"),
1227            assistant!("Roses are red", ["stop"]),
1228            tool_output!("stop", ToolOutput::Stop),
1229            user!("Try again!");
1230
1231            tools = [mock_tool.clone()]
1232        };
1233
1234        let stop_response = chat_response! {
1235            "Really stopping now";
1236            tool_calls = ["stop"]
1237        };
1238
1239        mock_llm.expect_complete(expected_chat_request, Ok(stop_response));
1240
1241        let mut agent = Agent::builder()
1242            .context(context)
1243            .tools([mock_tool])
1244            .llm(&mock_llm)
1245            .no_system_prompt()
1246            .build()
1247            .unwrap();
1248
1249        agent.query_once("Try again!").await.unwrap();
1250    }
1251}