swiftide_agents/
agent.rs

1#![allow(dead_code)]
2use crate::{
3    default_context::DefaultContext,
4    errors::AgentError,
5    hooks::{
6        AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7        Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn, OnStreamFn,
8    },
9    invoke_hooks,
10    state::{self, StopReason},
11    system_prompt::SystemPrompt,
12    tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15    collections::{HashMap, HashSet},
16    hash::{DefaultHasher, Hash as _, Hasher as _},
17    sync::Arc,
18};
19
20use derive_builder::Builder;
21use futures_util::stream::StreamExt;
22use swiftide_core::{
23    AgentContext, ToolBox,
24    chat_completion::{
25        ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
26    },
27    prompt::Prompt,
28};
29use tracing::{Instrument, debug};
30
31/// Agents are the main interface for building agentic systems.
32///
33/// Construct agents by calling the builder, setting an llm, configure hooks, tools and other
34/// customizations.
35///
36/// # Important defaults
37///
38/// - The default context is the `DefaultContext`, executing tools locally with the `LocalExecutor`.
39/// - A default `stop` tool is provided for agents to explicitly stop if needed
40/// - The default `SystemPrompt` instructs the agent with chain of thought and some common
41///   safeguards, but is otherwise quite bare. In a lot of cases this can be sufficient.
42#[derive(Clone, Builder)]
43pub struct Agent {
44    /// Hooks are functions that are called at specific points in the agent's lifecycle.
45    #[builder(default, setter(into))]
46    pub(crate) hooks: Vec<Hook>,
47    /// The context in which the agent operates, by default this is the `DefaultContext`.
48    #[builder(
49        setter(custom),
50        default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
51    )]
52    pub(crate) context: Arc<dyn AgentContext>,
53    /// Tools the agent can use
54    #[builder(default = Agent::default_tools(), setter(custom))]
55    pub(crate) tools: HashSet<Box<dyn Tool>>,
56
57    /// Toolboxes are collections of tools that can be added to the agent.
58    ///
59    /// Toolboxes make their tools available to the agent at runtime.
60    #[builder(default)]
61    pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
62
63    /// The language model that the agent uses for completion.
64    #[builder(setter(custom))]
65    pub(crate) llm: Box<dyn ChatCompletion>,
66
67    /// System prompt for the agent when it starts
68    ///
69    /// Some agents profit significantly from a tailored prompt. But it is not always needed.
70    ///
71    /// See [`SystemPrompt`] for an opiniated, customizable system prompt.
72    ///
73    /// Swiftide provides a default system prompt for all agents.
74    ///
75    /// # Example
76    ///
77    /// ```no_run
78    /// # use swiftide_agents::system_prompt::SystemPrompt;
79    /// # use swiftide_agents::Agent;
80    /// Agent::builder()
81    ///     .system_prompt(
82    ///         SystemPrompt::builder().role("You are an expert engineer")
83    ///         .build().unwrap())
84    ///     .build().unwrap();
85    /// ```
86    #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
87    pub(crate) system_prompt: Option<Prompt>,
88
89    /// Initial state of the agent
90    #[builder(private, default = state::State::default())]
91    pub(crate) state: state::State,
92
93    /// Optional limit on the amount of loops the agent can run.
94    /// The counter is reset when the agent is stopped.
95    #[builder(default, setter(strip_option))]
96    pub(crate) limit: Option<usize>,
97
98    /// The maximum amount of times the failed output of a tool will be send
99    /// to an LLM before the agent stops. Defaults to 3.
100    ///
101    /// LLMs sometimes send missing arguments, or a tool might actually fail, but retrying could be
102    /// worth while. If the limit is not reached, the agent will send the formatted error back to
103    /// the LLM.
104    ///
105    /// The limit is hashed based on the tool call name and arguments, so the limit is per tool
106    /// call.
107    ///
108    /// This limit is _not_ reset when the agent is stopped.
109    #[builder(default = 3)]
110    pub(crate) tool_retry_limit: usize,
111
112    /// Enables streaming the chat completion responses for the agent.
113    #[builder(default)]
114    pub(crate) streaming: bool,
115
116    /// Internally tracks the amount of times a tool has been retried. The key is a hash based on
117    /// the name and args of the tool.
118    #[builder(private, default)]
119    pub(crate) tool_retries_counter: HashMap<u64, usize>,
120
121    /// Tools loaded from toolboxes
122    #[builder(private, default)]
123    pub(crate) toolbox_tools: HashSet<Box<dyn Tool>>,
124}
125
126impl std::fmt::Debug for Agent {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        f.debug_struct("Agent")
129            // display hooks as a list of type: number of hooks
130            .field(
131                "hooks",
132                &self
133                    .hooks
134                    .iter()
135                    .map(std::string::ToString::to_string)
136                    .collect::<Vec<_>>(),
137            )
138            .field(
139                "tools",
140                &self
141                    .tools
142                    .iter()
143                    .map(swiftide_core::Tool::name)
144                    .collect::<Vec<_>>(),
145            )
146            .field("llm", &"Box<dyn ChatCompletion>")
147            .field("state", &self.state)
148            .finish()
149    }
150}
151
152impl AgentBuilder {
153    /// The context in which the agent operates, by default this is the `DefaultContext`.
154    pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
155    where
156        Self: Clone,
157    {
158        self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
159        self
160    }
161
162    /// Disable the system prompt.
163    pub fn no_system_prompt(&mut self) -> &mut Self {
164        self.system_prompt = Some(None);
165
166        self
167    }
168
169    /// Add a hook to the agent.
170    pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
171        let hooks = self.hooks.get_or_insert_with(Vec::new);
172        hooks.push(hook);
173
174        self
175    }
176
177    /// Add a hook that runs once, before all completions. Even if the agent is paused and resumed,
178    /// before all will not trigger again.
179    pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
180        self.add_hook(Hook::BeforeAll(Box::new(hook)))
181    }
182
183    /// Add a hook that runs once, when the agent starts. This hook also runs if the agent stopped
184    /// and then starts again. The hook runs after any `before_all` hooks and before the
185    /// `before_completion` hooks.
186    pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
187        self.add_hook(Hook::OnStart(Box::new(hook)))
188    }
189
190    /// Add a hook that runs when the agent receives a streaming response
191    ///
192    /// The response will always include both the current accumulated message and the delta
193    ///
194    /// This will set `self.streaming` to true, there is no need to set it manually for the default
195    /// behaviour.
196    pub fn on_stream(&mut self, hook: impl OnStreamFn + 'static) -> &mut Self {
197        self.streaming = Some(true);
198        self.add_hook(Hook::OnStream(Box::new(hook)))
199    }
200
201    /// Add a hook that runs before each completion.
202    pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
203        self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
204    }
205
206    /// Add a hook that runs after each tool. The `Result<ToolOutput, ToolError>` is provided
207    /// as mut, so the tool output can be fully modified.
208    ///
209    /// The `ToolOutput` also references the original `ToolCall`, allowing you to match at runtime
210    /// what tool to interact with.
211    pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
212        self.add_hook(Hook::AfterTool(Box::new(hook)))
213    }
214
215    /// Add a hook that runs before each tool. Yields an immutable reference to the `ToolCall`.
216    pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
217        self.add_hook(Hook::BeforeTool(Box::new(hook)))
218    }
219
220    /// Add a hook that runs after each completion, before tool invocation and/or new messages.
221    pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
222        self.add_hook(Hook::AfterCompletion(Box::new(hook)))
223    }
224
225    /// Add a hook that runs after each completion, after tool invocations, right before a new loop
226    /// might start
227    pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
228        self.add_hook(Hook::AfterEach(Box::new(hook)))
229    }
230
231    /// Add a hook that runs when a new message is added to the context. Note that each tool adds a
232    /// separate message.
233    pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
234        self.add_hook(Hook::OnNewMessage(Box::new(hook)))
235    }
236
237    pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
238        self.add_hook(Hook::OnStop(Box::new(hook)))
239    }
240
241    /// Set the LLM for the agent. An LLM must implement the `ChatCompletion` trait.
242    pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
243        let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
244
245        self.llm = Some(boxed);
246        self
247    }
248
249    /// Define the available tools for the agent. Tools must implement the `Tool` trait.
250    ///
251    /// See the [tool attribute macro](`swiftide_macros::tool`) and the [tool derive
252    /// macro](`swiftide_macros::Tool`) for easy ways to create (many) tools.
253    pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
254    where
255        TOOL: Into<Box<dyn Tool>>,
256    {
257        self.tools = Some(
258            tools
259                .into_iter()
260                .map(Into::into)
261                .chain(Agent::default_tools())
262                .collect(),
263        );
264        self
265    }
266
267    /// Add a toolbox to the agent. Toolboxes are collections of tools that can be added to the
268    /// to the agent. Available tools are evaluated at runtime, when the agent starts for the first
269    /// time.
270    ///
271    /// Agents can have many toolboxes.
272    pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
273        let toolboxes = self.toolboxes.get_or_insert_with(Vec::new);
274        toolboxes.push(Box::new(toolbox));
275
276        self
277    }
278}
279
280impl Agent {
281    /// Build a new agent
282    pub fn builder() -> AgentBuilder {
283        AgentBuilder::default()
284            .tools(Agent::default_tools())
285            .to_owned()
286    }
287
288    /// Default tools for the agent that it always includes
289    /// Right now this is the `stop` tool, which allows the agent to stop itself.
290    pub fn default_tools() -> HashSet<Box<dyn Tool>> {
291        HashSet::from([Stop::default().boxed()])
292    }
293
294    /// Run the agent with a user message. The agent will loop completions, make tool calls, until
295    /// no new messages are available.
296    ///
297    /// # Errors
298    ///
299    /// Errors if anything goes wrong, see `AgentError` for more details.
300    #[tracing::instrument(skip_all, name = "agent.query")]
301    pub async fn query(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
302        let query = query
303            .into()
304            .render()
305            .map_err(AgentError::FailedToRenderPrompt)?;
306        self.run_agent(Some(query), false).await
307    }
308
309    /// Run the agent with a user message once.
310    ///
311    /// # Errors
312    ///
313    /// Errors if anything goes wrong, see `AgentError` for more details.
314    #[tracing::instrument(skip_all, name = "agent.query_once")]
315    pub async fn query_once(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
316        let query = query
317            .into()
318            .render()
319            .map_err(AgentError::FailedToRenderPrompt)?;
320        self.run_agent(Some(query), true).await
321    }
322
323    /// Run the agent with without user message. The agent will loop completions, make tool calls,
324    /// until no new messages are available.
325    ///
326    /// # Errors
327    ///
328    /// Errors if anything goes wrong, see `AgentError` for more details.
329    #[tracing::instrument(skip_all, name = "agent.run")]
330    pub async fn run(&mut self) -> Result<(), AgentError> {
331        self.run_agent(None, false).await
332    }
333
334    /// Run the agent with without user message. The agent will loop completions, make tool calls,
335    /// until
336    ///
337    /// # Errors
338    ///
339    /// Errors if anything goes wrong, see `AgentError` for more details.
340    #[tracing::instrument(skip_all, name = "agent.run_once")]
341    pub async fn run_once(&mut self) -> Result<(), AgentError> {
342        self.run_agent(None, true).await
343    }
344
345    /// Retrieve the message history of the agent
346    ///
347    /// # Errors
348    ///
349    /// Error if the message history cannot be retrieved, e.g. if the context is not set up or a
350    /// connection fails
351    pub async fn history(&self) -> Result<Vec<ChatMessage>, AgentError> {
352        self.context
353            .history()
354            .await
355            .map_err(AgentError::MessageHistoryError)
356    }
357
358    async fn run_agent(
359        &mut self,
360        maybe_query: Option<String>,
361        just_once: bool,
362    ) -> Result<(), AgentError> {
363        if self.state.is_running() {
364            return Err(AgentError::AlreadyRunning);
365        }
366
367        if self.state.is_pending() {
368            if let Some(system_prompt) = &self.system_prompt {
369                self.context
370                    .add_messages(vec![ChatMessage::System(
371                        system_prompt
372                            .render()
373                            .map_err(AgentError::FailedToRenderSystemPrompt)?,
374                    )])
375                    .await
376                    .map_err(AgentError::MessageHistoryError)?;
377            }
378
379            invoke_hooks!(BeforeAll, self);
380
381            self.load_toolboxes().await?;
382        }
383
384        invoke_hooks!(OnStart, self);
385
386        self.state = state::State::Running;
387
388        if let Some(query) = maybe_query {
389            self.context
390                .add_message(ChatMessage::User(query))
391                .await
392                .map_err(AgentError::MessageHistoryError)?;
393        }
394
395        let mut loop_counter = 0;
396
397        while let Some(messages) = self
398            .context
399            .next_completion()
400            .await
401            .map_err(AgentError::MessageHistoryError)?
402        {
403            if let Some(limit) = self.limit {
404                if loop_counter >= limit {
405                    tracing::warn!("Agent loop limit reached");
406                    break;
407                }
408            }
409
410            // If the last message contains tool calls that have not been completed,
411            // run the tools first
412            if let Some(&ChatMessage::Assistant(.., Some(ref tool_calls))) =
413                maybe_tool_call_without_output(&messages)
414            {
415                tracing::debug!("Uncompleted tool calls found; invoking tools");
416                self.invoke_tools(tool_calls).await?;
417                // Move on to the next tick, so that the
418                continue;
419            }
420
421            let result = self.run_completions(&messages).await;
422
423            if let Err(err) = result {
424                self.stop_with_error(&err).await;
425                tracing::error!(error = ?err, "Agent stopped with error {err}");
426                return Err(err);
427            }
428
429            if just_once || self.state.is_stopped() {
430                break;
431            }
432            loop_counter += 1;
433        }
434
435        // If there are no new messages, ensure we update our state
436        self.stop(StopReason::NoNewMessages).await;
437
438        Ok(())
439    }
440
441    #[tracing::instrument(skip_all, err)]
442    async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<(), AgentError> {
443        debug!(
444            tools = ?self
445                .tools
446                .iter()
447                .map(|t| t.name())
448                .collect::<Vec<_>>()
449                ,
450            "Running completion for agent with {} new messages",
451            messages.len()
452        );
453
454        let mut chat_completion_request = ChatCompletionRequest::builder()
455            .messages(messages)
456            .tools_spec(
457                self.tools
458                    .iter()
459                    .map(swiftide_core::Tool::tool_spec)
460                    .collect::<HashSet<_>>(),
461            )
462            .build()
463            .map_err(AgentError::FailedToBuildRequest)?;
464
465        invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
466
467        debug!(
468            "Calling LLM with the following new messages:\n {}",
469            self.context
470                .current_new_messages()
471                .await
472                .map_err(AgentError::MessageHistoryError)?
473                .iter()
474                .map(ToString::to_string)
475                .collect::<Vec<_>>()
476                .join(",\n")
477        );
478
479        let mut response = if self.streaming {
480            let mut last_response = None;
481            let mut stream = self.llm.complete_stream(&chat_completion_request).await;
482
483            while let Some(response) = stream.next().await {
484                let response = response.map_err(AgentError::CompletionsFailed)?;
485                invoke_hooks!(OnStream, self, &response);
486                last_response = Some(response);
487            }
488            tracing::trace!(?last_response, "Streaming completed");
489            last_response.ok_or(AgentError::EmptyStream)
490        } else {
491            self.llm
492                .complete(&chat_completion_request)
493                .await
494                .map_err(AgentError::CompletionsFailed)
495        }?;
496
497        // The arg preprocessor helps avoid common llm errors.
498        // This must happen as early as possible
499        response
500            .tool_calls
501            .as_deref_mut()
502            .map(ArgPreprocessor::preprocess_tool_calls);
503
504        invoke_hooks!(AfterCompletion, self, &mut response);
505
506        self.add_message(ChatMessage::Assistant(
507            response.message,
508            response.tool_calls.clone(),
509        ))
510        .await?;
511
512        if let Some(tool_calls) = response.tool_calls {
513            self.invoke_tools(&tool_calls).await?;
514        }
515
516        invoke_hooks!(AfterEach, self);
517
518        Ok(())
519    }
520
521    async fn invoke_tools(&mut self, tool_calls: &[ToolCall]) -> Result<(), AgentError> {
522        tracing::debug!("LLM returned tool calls: {:?}", tool_calls);
523
524        let mut handles = vec![];
525        for tool_call in tool_calls {
526            let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
527                tracing::warn!("Tool {} not found", tool_call.name());
528                continue;
529            };
530            tracing::info!("Calling tool `{}`", tool_call.name());
531
532            // let tool_args = tool_call.args().map(String::from);
533            let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
534
535            invoke_hooks!(BeforeTool, self, &tool_call);
536
537            let tool_span = tracing::info_span!(
538                "tool",
539                "otel.name" = format!("tool.{}", tool.name().as_ref())
540            );
541
542            let handle_tool_call = tool_call.clone();
543            let handle = tokio::spawn(async move {
544                    let handle_tool_call = handle_tool_call;
545                    let output = tool.invoke(&*context, &handle_tool_call)
546                        .await
547                        .map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
548
549                    tracing::debug!(output = output.to_string(), args = ?handle_tool_call.args(), tool_name = tool.name().as_ref(), "Completed tool call");
550
551                    Ok(output)
552                }.instrument(tool_span.or_current()));
553
554            handles.push((handle, tool_call));
555        }
556
557        for (handle, tool_call) in handles {
558            let mut output = handle.await.map_err(AgentError::ToolFailedToJoin)?;
559
560            invoke_hooks!(AfterTool, self, &tool_call, &mut output);
561
562            if let Err(error) = output {
563                let stop = self.tool_calls_over_limit(tool_call);
564                if stop {
565                    tracing::error!(
566                        ?error,
567                        "Tool call failed, retry limit reached, stopping agent: {error}",
568                    );
569                } else {
570                    tracing::warn!(
571                        ?error,
572                        tool_call = ?tool_call,
573                        "Tool call failed, retrying",
574                    );
575                }
576                self.add_message(ChatMessage::ToolOutput(
577                    tool_call.clone(),
578                    ToolOutput::Fail(error.to_string()),
579                ))
580                .await?;
581                if stop {
582                    self.stop(StopReason::ToolCallsOverLimit(tool_call.to_owned()))
583                        .await;
584                    return Err(error.into());
585                }
586                continue;
587            }
588
589            let output = output?;
590            self.handle_control_tools(tool_call, &output).await;
591
592            // Feedback required leaves the tool call open
593            //
594            // It assumes a follow up invocation of the agent will have the feedback approved
595            if !output.is_feedback_required() {
596                self.add_message(ChatMessage::ToolOutput(tool_call.to_owned(), output))
597                    .await?;
598            }
599        }
600
601        Ok(())
602    }
603
604    fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
605        self.hooks
606            .iter()
607            .filter(|h| hook_type == (*h).into())
608            .collect()
609    }
610
611    fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
612        self.tools
613            .iter()
614            .find(|tool| tool.name() == tool_name)
615            .cloned()
616    }
617
618    // Handle any tool specific output (e.g. stop)
619    async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
620        match output {
621            ToolOutput::Stop => {
622                tracing::warn!("Stop tool called, stopping agent");
623                self.stop(StopReason::RequestedByTool(tool_call.clone()))
624                    .await;
625            }
626
627            ToolOutput::FeedbackRequired(maybe_payload) => {
628                tracing::warn!("Feedback required, stopping agent");
629                self.stop(StopReason::FeedbackRequired {
630                    tool_call: tool_call.clone(),
631                    payload: maybe_payload.clone(),
632                })
633                .await;
634            }
635            _ => (),
636        }
637    }
638
639    fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
640        let mut s = DefaultHasher::new();
641        tool_call.hash(&mut s);
642        let hash = s.finish();
643
644        if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
645            let val = *retries >= self.tool_retry_limit;
646            *retries += 1;
647            val
648        } else {
649            self.tool_retries_counter.insert(hash, 1);
650            false
651        }
652    }
653
654    /// Add a message to the agent's context
655    ///
656    /// This will trigger a `OnNewMessage` hook if its present.
657    ///
658    /// If you want to add a message without triggering the hook, use the context directly.
659    ///
660    /// # Errors
661    ///
662    /// Errors if the message cannot be added to the context. With the default in memory context
663    /// that is not supposed to happen.
664    #[tracing::instrument(skip_all, fields(message = message.to_string()))]
665    pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
666        invoke_hooks!(OnNewMessage, self, &mut message);
667
668        self.context
669            .add_message(message)
670            .await
671            .map_err(AgentError::MessageHistoryError)?;
672        Ok(())
673    }
674
675    /// Tell the agent to stop. It will finish it's current loop and then stop.
676    pub async fn stop(&mut self, reason: impl Into<StopReason>) {
677        if self.state.is_stopped() {
678            return;
679        }
680        let reason = reason.into();
681        invoke_hooks!(OnStop, self, reason.clone(), None);
682
683        self.state = state::State::Stopped(reason);
684    }
685
686    pub async fn stop_with_error(&mut self, error: &AgentError) {
687        if self.state.is_stopped() {
688            return;
689        }
690        invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
691
692        self.state = state::State::Stopped(StopReason::Error);
693    }
694
695    /// Access the agent's context
696    pub fn context(&self) -> &dyn AgentContext {
697        &self.context
698    }
699
700    /// The agent is still running
701    pub fn is_running(&self) -> bool {
702        self.state.is_running()
703    }
704
705    /// The agent stopped
706    pub fn is_stopped(&self) -> bool {
707        self.state.is_stopped()
708    }
709
710    /// The agent has not (ever) started
711    pub fn is_pending(&self) -> bool {
712        self.state.is_pending()
713    }
714
715    /// Get a list of tools available to the agent
716    pub fn tools(&self) -> &HashSet<Box<dyn Tool>> {
717        &self.tools
718    }
719
720    pub fn state(&self) -> &state::State {
721        &self.state
722    }
723
724    pub fn stop_reason(&self) -> Option<&StopReason> {
725        self.state.stop_reason()
726    }
727
728    async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
729        for toolbox in &self.toolboxes {
730            let tools = toolbox
731                .available_tools()
732                .await
733                .map_err(AgentError::ToolBoxFailedToLoad)?;
734            self.toolbox_tools.extend(tools);
735        }
736
737        self.tools.extend(self.toolbox_tools.clone());
738
739        Ok(())
740    }
741}
742
743/// Reverse searches through messages, if it encounters a tool call before encountering an output,
744/// it will return the chat message with the tool calls, otherwise it returns None
745fn maybe_tool_call_without_output(messages: &[ChatMessage]) -> Option<&ChatMessage> {
746    for message in messages.iter().rev() {
747        if let ChatMessage::ToolOutput(..) = message {
748            return None;
749        }
750
751        if let ChatMessage::Assistant(.., Some(tool_calls)) = message {
752            if !tool_calls.is_empty() {
753                return Some(message);
754            }
755        }
756    }
757
758    None
759}
760
761#[cfg(test)]
762mod tests {
763
764    use serde::ser::Error;
765    use swiftide_core::ToolFeedback;
766    use swiftide_core::chat_completion::errors::ToolError;
767    use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
768    use swiftide_core::test_utils::MockChatCompletion;
769
770    use super::*;
771    use crate::{
772        State, assistant, chat_request, chat_response, summary, system, tool_failed, tool_output,
773        user,
774    };
775
776    use crate::test_utils::{MockHook, MockTool};
777
778    #[test_log::test(tokio::test)]
779    async fn test_agent_builder_defaults() {
780        // Create a prompt
781        let mock_llm = MockChatCompletion::new();
782
783        // Build the agent
784        let agent = Agent::builder().llm(&mock_llm).build().unwrap();
785
786        // Check that the context is the default context
787
788        // Check that the default tools are added
789        assert!(agent.find_tool_by_name("stop").is_some());
790
791        // Check it does not allow duplicates
792        let agent = Agent::builder()
793            .tools([Stop::default(), Stop::default()])
794            .llm(&mock_llm)
795            .build()
796            .unwrap();
797
798        assert_eq!(agent.tools.len(), 1);
799
800        // It should include the default tool if a different tool is provided
801        let agent = Agent::builder()
802            .tools([MockTool::new("mock_tool")])
803            .llm(&mock_llm)
804            .build()
805            .unwrap();
806
807        assert_eq!(agent.tools.len(), 2);
808        assert!(agent.find_tool_by_name("mock_tool").is_some());
809        assert!(agent.find_tool_by_name("stop").is_some());
810
811        assert!(agent.context().history().await.unwrap().is_empty());
812    }
813
814    #[test_log::test(tokio::test)]
815    async fn test_agent_tool_calling_loop() {
816        let prompt = "Write a poem";
817        let mock_llm = MockChatCompletion::new();
818        let mock_tool = MockTool::new("mock_tool");
819
820        let chat_request = chat_request! {
821            user!("Write a poem");
822
823            tools = [mock_tool.clone()]
824        };
825
826        let mock_tool_response = chat_response! {
827            "Roses are red";
828            tool_calls = ["mock_tool"]
829
830        };
831
832        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
833
834        let chat_request = chat_request! {
835            user!("Write a poem"),
836            assistant!("Roses are red", ["mock_tool"]),
837            tool_output!("mock_tool", "Great!");
838
839            tools = [mock_tool.clone()]
840        };
841
842        let stop_response = chat_response! {
843            "Roses are red";
844            tool_calls = ["stop"]
845        };
846
847        mock_llm.expect_complete(chat_request, Ok(stop_response));
848        mock_tool.expect_invoke_ok("Great!".into(), None);
849
850        let mut agent = Agent::builder()
851            .tools([mock_tool])
852            .llm(&mock_llm)
853            .no_system_prompt()
854            .build()
855            .unwrap();
856
857        agent.query(prompt).await.unwrap();
858    }
859
860    #[test_log::test(tokio::test)]
861    async fn test_agent_tool_run_once() {
862        let prompt = "Write a poem";
863        let mock_llm = MockChatCompletion::new();
864        let mock_tool = MockTool::default();
865
866        let chat_request = chat_request! {
867            system!("My system prompt"),
868            user!("Write a poem");
869
870            tools = [mock_tool.clone()]
871        };
872
873        let mock_tool_response = chat_response! {
874            "Roses are red";
875            tool_calls = ["mock_tool"]
876
877        };
878
879        mock_tool.expect_invoke_ok("Great!".into(), None);
880        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
881
882        let mut agent = Agent::builder()
883            .tools([mock_tool])
884            .system_prompt("My system prompt")
885            .llm(&mock_llm)
886            .build()
887            .unwrap();
888
889        agent.query_once(prompt).await.unwrap();
890    }
891
892    #[test_log::test(tokio::test)]
893    async fn test_agent_tool_via_toolbox_run_once() {
894        let prompt = "Write a poem";
895        let mock_llm = MockChatCompletion::new();
896        let mock_tool = MockTool::default();
897
898        let chat_request = chat_request! {
899            system!("My system prompt"),
900            user!("Write a poem");
901
902            tools = [mock_tool.clone()]
903        };
904
905        let mock_tool_response = chat_response! {
906            "Roses are red";
907            tool_calls = ["mock_tool"]
908
909        };
910
911        mock_tool.expect_invoke_ok("Great!".into(), None);
912        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
913
914        let mut agent = Agent::builder()
915            .add_toolbox(vec![mock_tool.boxed()])
916            .system_prompt("My system prompt")
917            .llm(&mock_llm)
918            .build()
919            .unwrap();
920
921        agent.query_once(prompt).await.unwrap();
922    }
923
924    #[test_log::test(tokio::test(flavor = "multi_thread"))]
925    async fn test_multiple_tool_calls() {
926        let prompt = "Write a poem";
927        let mock_llm = MockChatCompletion::new();
928        let mock_tool = MockTool::new("mock_tool1");
929        let mock_tool2 = MockTool::new("mock_tool2");
930
931        let chat_request = chat_request! {
932            system!("My system prompt"),
933            user!("Write a poem");
934
935
936
937            tools = [mock_tool.clone(), mock_tool2.clone()]
938        };
939
940        let mock_tool_response = chat_response! {
941            "Roses are red";
942
943            tool_calls = ["mock_tool1", "mock_tool2"]
944
945        };
946
947        dbg!(&chat_request);
948        mock_tool.expect_invoke_ok("Great!".into(), None);
949        mock_tool2.expect_invoke_ok("Great!".into(), None);
950        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
951
952        let chat_request = chat_request! {
953            system!("My system prompt"),
954            user!("Write a poem"),
955            assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
956            tool_output!("mock_tool1", "Great!"),
957            tool_output!("mock_tool2", "Great!");
958
959            tools = [mock_tool.clone(), mock_tool2.clone()]
960        };
961
962        let mock_tool_response = chat_response! {
963            "Ok!";
964
965            tool_calls = ["stop"]
966
967        };
968
969        mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
970
971        let mut agent = Agent::builder()
972            .tools([mock_tool, mock_tool2])
973            .system_prompt("My system prompt")
974            .llm(&mock_llm)
975            .build()
976            .unwrap();
977
978        agent.query(prompt).await.unwrap();
979    }
980
981    #[test_log::test(tokio::test)]
982    async fn test_agent_state_machine() {
983        let prompt = "Write a poem";
984        let mock_llm = MockChatCompletion::new();
985
986        let chat_request = chat_request! {
987            user!("Write a poem");
988            tools = []
989        };
990        let mock_tool_response = chat_response! {
991            "Roses are red";
992            tool_calls = []
993        };
994
995        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
996        let mut agent = Agent::builder()
997            .llm(&mock_llm)
998            .no_system_prompt()
999            .build()
1000            .unwrap();
1001
1002        // Agent has never run and is pending
1003        assert!(agent.state.is_pending());
1004        agent.query_once(prompt).await.unwrap();
1005
1006        // Agent is stopped, there might be more messages
1007        assert!(agent.state.is_stopped());
1008    }
1009
1010    #[test_log::test(tokio::test)]
1011    async fn test_summary() {
1012        let prompt = "Write a poem";
1013        let mock_llm = MockChatCompletion::new();
1014
1015        let mock_tool_response = chat_response! {
1016            "Roses are red";
1017            tool_calls = []
1018
1019        };
1020
1021        let expected_chat_request = chat_request! {
1022            system!("My system prompt"),
1023            user!("Write a poem");
1024
1025            tools = []
1026        };
1027
1028        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1029
1030        let mut agent = Agent::builder()
1031            .system_prompt("My system prompt")
1032            .llm(&mock_llm)
1033            .build()
1034            .unwrap();
1035
1036        agent.query_once(prompt).await.unwrap();
1037
1038        agent
1039            .context
1040            .add_message(ChatMessage::new_summary("Summary"))
1041            .await
1042            .unwrap();
1043
1044        let expected_chat_request = chat_request! {
1045            system!("My system prompt"),
1046            summary!("Summary"),
1047            user!("Write another poem");
1048            tools = []
1049        };
1050        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1051
1052        agent.query_once("Write another poem").await.unwrap();
1053
1054        agent
1055            .context
1056            .add_message(ChatMessage::new_summary("Summary 2"))
1057            .await
1058            .unwrap();
1059
1060        let expected_chat_request = chat_request! {
1061            system!("My system prompt"),
1062            summary!("Summary 2"),
1063            user!("Write a third poem");
1064            tools = []
1065        };
1066        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
1067
1068        agent.query_once("Write a third poem").await.unwrap();
1069    }
1070
1071    #[test_log::test(tokio::test)]
1072    async fn test_agent_hooks() {
1073        let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
1074        let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
1075        let mock_before_completion = MockHook::new("before_completion")
1076            .expect_calls(2)
1077            .to_owned();
1078        let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
1079        let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
1080        let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
1081        let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
1082
1083        // Once for mock tool and once for stop
1084        let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
1085        let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
1086
1087        let prompt = "Write a poem";
1088        let mock_llm = MockChatCompletion::new();
1089        let mock_tool = MockTool::default();
1090
1091        let chat_request = chat_request! {
1092            user!("Write a poem");
1093
1094            tools = [mock_tool.clone()]
1095        };
1096
1097        let mock_tool_response = chat_response! {
1098            "Roses are red";
1099            tool_calls = ["mock_tool"]
1100
1101        };
1102
1103        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1104
1105        let chat_request = chat_request! {
1106            user!("Write a poem"),
1107            assistant!("Roses are red", ["mock_tool"]),
1108            tool_output!("mock_tool", "Great!");
1109
1110            tools = [mock_tool.clone()]
1111        };
1112
1113        let stop_response = chat_response! {
1114            "Roses are red";
1115            tool_calls = ["stop"]
1116        };
1117
1118        mock_llm.expect_complete(chat_request, Ok(stop_response));
1119        mock_tool.expect_invoke_ok("Great!".into(), None);
1120
1121        let mut agent = Agent::builder()
1122            .tools([mock_tool])
1123            .llm(&mock_llm)
1124            .no_system_prompt()
1125            .before_all(mock_before_all.hook_fn())
1126            .on_start(mock_on_start_fn.on_start_fn())
1127            .before_completion(mock_before_completion.before_completion_fn())
1128            .before_tool(mock_before_tool.before_tool_fn())
1129            .after_completion(mock_after_completion.after_completion_fn())
1130            .after_tool(mock_after_tool.after_tool_fn())
1131            .after_each(mock_after_each.hook_fn())
1132            .on_new_message(mock_on_message.message_hook_fn())
1133            .on_stop(mock_on_stop.stop_hook_fn())
1134            .build()
1135            .unwrap();
1136
1137        agent.query(prompt).await.unwrap();
1138    }
1139
1140    #[test_log::test(tokio::test)]
1141    async fn test_agent_loop_limit() {
1142        let prompt = "Generate content"; // Example prompt
1143        let mock_llm = MockChatCompletion::new();
1144        let mock_tool = MockTool::new("mock_tool");
1145
1146        let chat_request = chat_request! {
1147            user!(prompt);
1148            tools = [mock_tool.clone()]
1149        };
1150        mock_tool.expect_invoke_ok("Great!".into(), None);
1151
1152        let mock_tool_response = chat_response! {
1153            "Some response";
1154            tool_calls = ["mock_tool"]
1155        };
1156
1157        // Set expectations for the mock LLM responses
1158        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1159
1160        // // Response for terminating the loop
1161        let stop_response = chat_response! {
1162            "Final response";
1163            tool_calls = ["stop"]
1164        };
1165
1166        mock_llm.expect_complete(chat_request, Ok(stop_response));
1167
1168        let mut agent = Agent::builder()
1169            .tools([mock_tool])
1170            .llm(&mock_llm)
1171            .no_system_prompt()
1172            .limit(1) // Setting the loop limit to 1
1173            .build()
1174            .unwrap();
1175
1176        // Run the agent
1177        agent.query(prompt).await.unwrap();
1178
1179        // Assert that the remaining message is still in the queue
1180        let remaining = mock_llm.expectations.lock().unwrap().pop();
1181        assert!(remaining.is_some());
1182
1183        // Assert that the agent is stopped after reaching the loop limit
1184        assert!(agent.is_stopped());
1185    }
1186
1187    #[test_log::test(tokio::test)]
1188    async fn test_tool_retry_mechanism() {
1189        let prompt = "Execute my tool";
1190        let mock_llm = MockChatCompletion::new();
1191        let mock_tool = MockTool::new("retry_tool");
1192
1193        // Configure mock tool to fail twice. First time is fed back to the LLM, second time is an
1194        // error
1195        mock_tool.expect_invoke(
1196            Err(ToolError::WrongArguments(serde_json::Error::custom(
1197                "missing `query`",
1198            ))),
1199            None,
1200        );
1201        mock_tool.expect_invoke(
1202            Err(ToolError::WrongArguments(serde_json::Error::custom(
1203                "missing `query`",
1204            ))),
1205            None,
1206        );
1207
1208        let chat_request = chat_request! {
1209            user!(prompt);
1210            tools = [mock_tool.clone()]
1211        };
1212        let retry_response = chat_response! {
1213            "First failing attempt";
1214            tool_calls = ["retry_tool"]
1215        };
1216        mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1217
1218        let chat_request = chat_request! {
1219            user!(prompt),
1220            assistant!("First failing attempt", ["retry_tool"]),
1221            tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1222
1223            tools = [mock_tool.clone()]
1224        };
1225        let will_fail_response = chat_response! {
1226            "Finished execution";
1227            tool_calls = ["retry_tool"]
1228        };
1229        mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1230
1231        let mut agent = Agent::builder()
1232            .tools([mock_tool])
1233            .llm(&mock_llm)
1234            .no_system_prompt()
1235            .tool_retry_limit(1) // The test relies on a limit of 2 retries.
1236            .build()
1237            .unwrap();
1238
1239        // Run the agent
1240        let result = agent.query(prompt).await;
1241
1242        assert!(result.is_err());
1243        assert!(result.unwrap_err().to_string().contains("missing `query`"));
1244        assert!(agent.is_stopped());
1245    }
1246
1247    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1248    async fn test_streaming() {
1249        let prompt = "Generate content"; // Example prompt
1250        let mock_llm = MockChatCompletion::new();
1251        let on_stream_fn = MockHook::new("on_stream").expect_calls(3).to_owned();
1252
1253        let chat_request = chat_request! {
1254            user!(prompt);
1255
1256            tools = []
1257        };
1258
1259        let response = chat_response! {
1260            "one two three";
1261            tool_calls = ["stop"]
1262        };
1263
1264        // Set expectations for the mock LLM responses
1265        mock_llm.expect_complete(chat_request, Ok(response));
1266
1267        let mut agent = Agent::builder()
1268            .llm(&mock_llm)
1269            .on_stream(on_stream_fn.on_stream_fn())
1270            .no_system_prompt()
1271            .build()
1272            .unwrap();
1273
1274        // Run the agent
1275        agent.query(prompt).await.unwrap();
1276
1277        tracing::debug!("Agent finished running");
1278
1279        // Assert that the agent is stopped after reaching the loop limit
1280        assert!(agent.is_stopped());
1281    }
1282
1283    #[test_log::test(tokio::test)]
1284    async fn test_recovering_agent_existing_history() {
1285        // First, let's run an agent
1286        let prompt = "Write a poem";
1287        let mock_llm = MockChatCompletion::new();
1288        let mock_tool = MockTool::new("mock_tool");
1289
1290        let chat_request = chat_request! {
1291            user!("Write a poem");
1292
1293            tools = [mock_tool.clone()]
1294        };
1295
1296        let mock_tool_response = chat_response! {
1297            "Roses are red";
1298            tool_calls = ["mock_tool"]
1299
1300        };
1301
1302        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1303
1304        let chat_request = chat_request! {
1305            user!("Write a poem"),
1306            assistant!("Roses are red", ["mock_tool"]),
1307            tool_output!("mock_tool", "Great!");
1308
1309            tools = [mock_tool.clone()]
1310        };
1311
1312        let stop_response = chat_response! {
1313            "Roses are red";
1314            tool_calls = ["stop"]
1315        };
1316
1317        mock_llm.expect_complete(chat_request, Ok(stop_response));
1318        mock_tool.expect_invoke_ok("Great!".into(), None);
1319
1320        let mut agent = Agent::builder()
1321            .tools([mock_tool.clone()])
1322            .llm(&mock_llm)
1323            .no_system_prompt()
1324            .build()
1325            .unwrap();
1326
1327        agent.query(prompt).await.unwrap();
1328
1329        // Let's retrieve the history of the agent
1330        let history = agent.history().await.unwrap();
1331
1332        // Store it as a string somewhere
1333        let serialized = serde_json::to_string(&history).unwrap();
1334
1335        // Retrieve it
1336        let history: Vec<ChatMessage> = serde_json::from_str(&serialized).unwrap();
1337
1338        // Build a context from the history
1339        let context = DefaultContext::default()
1340            .with_existing_messages(history)
1341            .await
1342            .unwrap()
1343            .to_owned();
1344
1345        let expected_chat_request = chat_request! {
1346            user!("Write a poem"),
1347            assistant!("Roses are red", ["mock_tool"]),
1348            tool_output!("mock_tool", "Great!"),
1349            assistant!("Roses are red", ["stop"]),
1350            tool_output!("stop", ToolOutput::Stop),
1351            user!("Try again!");
1352
1353            tools = [mock_tool.clone()]
1354        };
1355
1356        let stop_response = chat_response! {
1357            "Really stopping now";
1358            tool_calls = ["stop"]
1359        };
1360
1361        mock_llm.expect_complete(expected_chat_request, Ok(stop_response));
1362
1363        let mut agent = Agent::builder()
1364            .context(context)
1365            .tools([mock_tool])
1366            .llm(&mock_llm)
1367            .no_system_prompt()
1368            .build()
1369            .unwrap();
1370
1371        agent.query_once("Try again!").await.unwrap();
1372    }
1373
1374    #[test_log::test(tokio::test)]
1375    async fn test_agent_with_approval_required_tool() {
1376        use super::*;
1377        use crate::tools::control::ApprovalRequired;
1378        use crate::{assistant, chat_request, chat_response, user};
1379        use swiftide_core::chat_completion::ToolCall;
1380
1381        // Step 1: Build a tool that needs approval.
1382        let mock_tool = MockTool::default();
1383        mock_tool.expect_invoke_ok("Great!".into(), None);
1384
1385        let approval_tool = ApprovalRequired(mock_tool.boxed());
1386
1387        // Step 2: Set up the mock LLM.
1388        let mock_llm = MockChatCompletion::new();
1389
1390        let chat_req1 = chat_request! {
1391            user!("Request with approval");
1392            tools = [approval_tool.clone()]
1393        };
1394        let chat_resp1 = chat_response! {
1395            "Completion message";
1396            tool_calls = ["mock_tool"]
1397        };
1398        mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1399
1400        // The response will include the previous request, but no tool output
1401        // from the required tool
1402        let chat_req2 = chat_request! {
1403            user!("Request with approval"),
1404            assistant!("Completion message", ["mock_tool"]),
1405            tool_output!("mock_tool", "Great!");
1406            // Simulate feedback required output
1407            tools = [approval_tool.clone()]
1408        };
1409        let chat_resp2 = chat_response! {
1410            "Post-feedback message";
1411            tool_calls = ["stop"]
1412        };
1413        mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1414
1415        // Step 3: Wire up the agent.
1416        let mut agent = Agent::builder()
1417            .tools([approval_tool])
1418            .llm(&mock_llm)
1419            .no_system_prompt()
1420            .build()
1421            .unwrap();
1422
1423        // Step 4: Run agent to trigger approval.
1424        agent.query_once("Request with approval").await.unwrap();
1425
1426        assert!(matches!(
1427            agent.state,
1428            crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1429        ));
1430
1431        let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1432        else {
1433            panic!("Expected feedback required");
1434        };
1435
1436        // Step 5: Simulate feedback, run again and assert finish.
1437        agent
1438            .context
1439            .feedback_received(&tool_call, &ToolFeedback::approved())
1440            .await
1441            .unwrap();
1442
1443        tracing::debug!("running after approval");
1444        agent.run_once().await.unwrap();
1445        assert!(agent.is_stopped());
1446    }
1447
1448    #[test_log::test(tokio::test)]
1449    async fn test_agent_with_approval_required_tool_denied() {
1450        use super::*;
1451        use crate::tools::control::ApprovalRequired;
1452        use crate::{assistant, chat_request, chat_response, user};
1453        use swiftide_core::chat_completion::ToolCall;
1454
1455        // Step 1: Build a tool that needs approval.
1456        let mock_tool = MockTool::default();
1457
1458        let approval_tool = ApprovalRequired(mock_tool.boxed());
1459
1460        // Step 2: Set up the mock LLM.
1461        let mock_llm = MockChatCompletion::new();
1462
1463        let chat_req1 = chat_request! {
1464            user!("Request with approval");
1465            tools = [approval_tool.clone()]
1466        };
1467        let chat_resp1 = chat_response! {
1468            "Completion message";
1469            tool_calls = ["mock_tool"]
1470        };
1471        mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1472
1473        // The response will include the previous request, but no tool output
1474        // from the required tool
1475        let chat_req2 = chat_request! {
1476            user!("Request with approval"),
1477            assistant!("Completion message", ["mock_tool"]),
1478            tool_output!("mock_tool", "This tool call was refused");
1479            // Simulate feedback required output
1480            tools = [approval_tool.clone()]
1481        };
1482        let chat_resp2 = chat_response! {
1483            "Post-feedback message";
1484            tool_calls = ["stop"]
1485        };
1486        mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1487
1488        // Step 3: Wire up the agent.
1489        let mut agent = Agent::builder()
1490            .tools([approval_tool])
1491            .llm(&mock_llm)
1492            .no_system_prompt()
1493            .build()
1494            .unwrap();
1495
1496        // Step 4: Run agent to trigger approval.
1497        agent.query_once("Request with approval").await.unwrap();
1498
1499        assert!(matches!(
1500            agent.state,
1501            crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1502        ));
1503
1504        let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1505        else {
1506            panic!("Expected feedback required");
1507        };
1508
1509        // Step 5: Simulate feedback, run again and assert finish.
1510        agent
1511            .context
1512            .feedback_received(&tool_call, &ToolFeedback::refused())
1513            .await
1514            .unwrap();
1515
1516        tracing::debug!("running after approval");
1517        agent.run_once().await.unwrap();
1518
1519        let history = agent.context().history().await.unwrap();
1520        history
1521            .iter()
1522            .rfind(|m| {
1523                let ChatMessage::ToolOutput(.., ToolOutput::Text(msg)) = m else {
1524                    return false;
1525                };
1526                msg.contains("refused")
1527            })
1528            .expect("Could not find refusal message");
1529
1530        assert!(agent.is_stopped());
1531    }
1532}