swiftide_agents/
agent.rs

1#![allow(dead_code)]
2use crate::{
3    default_context::DefaultContext,
4    errors::AgentError,
5    hooks::{
6        AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7        Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn, OnStreamFn,
8    },
9    invoke_hooks,
10    state::{self, StopReason},
11    system_prompt::SystemPrompt,
12    tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15    collections::{HashMap, HashSet},
16    hash::{DefaultHasher, Hash as _, Hasher as _},
17    sync::Arc,
18};
19
20use derive_builder::Builder;
21use futures_util::stream::StreamExt;
22use swiftide_core::{
23    AgentContext, ToolBox,
24    chat_completion::{
25        ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
26    },
27    prompt::Prompt,
28};
29use tracing::{Instrument, debug};
30
31/// Agents are the main interface for building agentic systems.
32///
33/// Construct agents by calling the builder, setting an llm, configure hooks, tools and other
34/// customizations.
35///
36/// # Important defaults
37///
38/// - The default context is the `DefaultContext`, executing tools locally with the `LocalExecutor`.
39/// - A default `stop` tool is provided for agents to explicitly stop if needed
40/// - The default `SystemPrompt` instructs the agent with chain of thought and some common
41///   safeguards, but is otherwise quite bare. In a lot of cases this can be sufficient.
42#[derive(Clone, Builder)]
43pub struct Agent {
44    /// Hooks are functions that are called at specific points in the agent's lifecycle.
45    #[builder(default, setter(into))]
46    pub(crate) hooks: Vec<Hook>,
47    /// The context in which the agent operates, by default this is the `DefaultContext`.
48    #[builder(
49        setter(custom),
50        default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
51    )]
52    pub(crate) context: Arc<dyn AgentContext>,
53    /// Tools the agent can use
54    #[builder(default = Agent::default_tools(), setter(custom))]
55    pub(crate) tools: HashSet<Box<dyn Tool>>,
56
57    /// Toolboxes are collections of tools that can be added to the agent.
58    ///
59    /// Toolboxes make their tools available to the agent at runtime.
60    #[builder(default)]
61    pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
62
63    /// The language model that the agent uses for completion.
64    #[builder(setter(custom))]
65    pub(crate) llm: Box<dyn ChatCompletion>,
66
67    /// System prompt for the agent when it starts
68    ///
69    /// Some agents profit significantly from a tailored prompt. But it is not always needed.
70    ///
71    /// See [`SystemPrompt`] for an opiniated, customizable system prompt.
72    ///
73    /// Swiftide provides a default system prompt for all agents.
74    ///
75    /// # Example
76    ///
77    /// ```no_run
78    /// # use swiftide_agents::system_prompt::SystemPrompt;
79    /// # use swiftide_agents::Agent;
80    /// Agent::builder()
81    ///     .system_prompt(
82    ///         SystemPrompt::builder().role("You are an expert engineer")
83    ///         .build().unwrap())
84    ///     .build().unwrap();
85    /// ```
86    #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
87    pub(crate) system_prompt: Option<Prompt>,
88
89    /// Initial state of the agent
90    #[builder(private, default = state::State::default())]
91    pub(crate) state: state::State,
92
93    /// Optional limit on the amount of loops the agent can run.
94    /// The counter is reset when the agent is stopped.
95    #[builder(default, setter(strip_option))]
96    pub(crate) limit: Option<usize>,
97
98    /// The maximum amount of times the failed output of a tool will be send
99    /// to an LLM before the agent stops. Defaults to 3.
100    ///
101    /// LLMs sometimes send missing arguments, or a tool might actually fail, but retrying could be
102    /// worth while. If the limit is not reached, the agent will send the formatted error back to
103    /// the LLM.
104    ///
105    /// The limit is hashed based on the tool call name and arguments, so the limit is per tool
106    /// call.
107    ///
108    /// This limit is _not_ reset when the agent is stopped.
109    #[builder(default = 3)]
110    pub(crate) tool_retry_limit: usize,
111
112    /// Enables streaming the chat completion responses for the agent.
113    #[builder(default)]
114    pub(crate) streaming: bool,
115
116    /// Internally tracks the amount of times a tool has been retried. The key is a hash based on
117    /// the name and args of the tool.
118    #[builder(private, default)]
119    pub(crate) tool_retries_counter: HashMap<u64, usize>,
120
121    /// Tools loaded from toolboxes
122    #[builder(private, default)]
123    pub(crate) toolbox_tools: HashSet<Box<dyn Tool>>,
124}
125
126impl std::fmt::Debug for Agent {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        f.debug_struct("Agent")
129            // display hooks as a list of type: number of hooks
130            .field(
131                "hooks",
132                &self
133                    .hooks
134                    .iter()
135                    .map(std::string::ToString::to_string)
136                    .collect::<Vec<_>>(),
137            )
138            .field(
139                "tools",
140                &self
141                    .tools
142                    .iter()
143                    .map(swiftide_core::Tool::name)
144                    .collect::<Vec<_>>(),
145            )
146            .field("llm", &"Box<dyn ChatCompletion>")
147            .field("state", &self.state)
148            .finish()
149    }
150}
151
152impl AgentBuilder {
153    /// The context in which the agent operates, by default this is the `DefaultContext`.
154    pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
155    where
156        Self: Clone,
157    {
158        self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
159        self
160    }
161
162    /// Disable the system prompt.
163    pub fn no_system_prompt(&mut self) -> &mut Self {
164        self.system_prompt = Some(None);
165
166        self
167    }
168
169    /// Add a hook to the agent.
170    pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
171        let hooks = self.hooks.get_or_insert_with(Vec::new);
172        hooks.push(hook);
173
174        self
175    }
176
177    /// Add a hook that runs once, before all completions. Even if the agent is paused and resumed,
178    /// before all will not trigger again.
179    pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
180        self.add_hook(Hook::BeforeAll(Box::new(hook)))
181    }
182
183    /// Add a hook that runs once, when the agent starts. This hook also runs if the agent stopped
184    /// and then starts again. The hook runs after any `before_all` hooks and before the
185    /// `before_completion` hooks.
186    pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
187        self.add_hook(Hook::OnStart(Box::new(hook)))
188    }
189
190    /// Add a hook that runs when the agent receives a streaming response
191    ///
192    /// The response will always include both the current accumulated message and the delta
193    ///
194    /// This will set `self.streaming` to true, there is no need to set it manually for the default
195    /// behaviour.
196    pub fn on_stream(&mut self, hook: impl OnStreamFn + 'static) -> &mut Self {
197        self.streaming = Some(true);
198        self.add_hook(Hook::OnStream(Box::new(hook)))
199    }
200
201    /// Add a hook that runs before each completion.
202    pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
203        self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
204    }
205
206    /// Add a hook that runs after each tool. The `Result<ToolOutput, ToolError>` is provided
207    /// as mut, so the tool output can be fully modified.
208    ///
209    /// The `ToolOutput` also references the original `ToolCall`, allowing you to match at runtime
210    /// what tool to interact with.
211    pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
212        self.add_hook(Hook::AfterTool(Box::new(hook)))
213    }
214
215    /// Add a hook that runs before each tool. Yields an immutable reference to the `ToolCall`.
216    pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
217        self.add_hook(Hook::BeforeTool(Box::new(hook)))
218    }
219
220    /// Add a hook that runs after each completion, before tool invocation and/or new messages.
221    pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
222        self.add_hook(Hook::AfterCompletion(Box::new(hook)))
223    }
224
225    /// Add a hook that runs after each completion, after tool invocations, right before a new loop
226    /// might start
227    pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
228        self.add_hook(Hook::AfterEach(Box::new(hook)))
229    }
230
231    /// Add a hook that runs when a new message is added to the context. Note that each tool adds a
232    /// separate message.
233    pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
234        self.add_hook(Hook::OnNewMessage(Box::new(hook)))
235    }
236
237    pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
238        self.add_hook(Hook::OnStop(Box::new(hook)))
239    }
240
241    /// Set the LLM for the agent. An LLM must implement the `ChatCompletion` trait.
242    pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
243        let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
244
245        self.llm = Some(boxed);
246        self
247    }
248
249    /// Define the available tools for the agent. Tools must implement the `Tool` trait.
250    ///
251    /// See the [tool attribute macro](`swiftide_macros::tool`) and the [tool derive
252    /// macro](`swiftide_macros::Tool`) for easy ways to create (many) tools.
253    pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
254    where
255        TOOL: Into<Box<dyn Tool>>,
256    {
257        self.tools = Some(
258            tools
259                .into_iter()
260                .map(Into::into)
261                .chain(Agent::default_tools())
262                .collect(),
263        );
264        self
265    }
266
267    /// Add a toolbox to the agent. Toolboxes are collections of tools that can be added to the
268    /// to the agent. Available tools are evaluated at runtime, when the agent starts for the first
269    /// time.
270    ///
271    /// Agents can have many toolboxes.
272    pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
273        let toolboxes = self.toolboxes.get_or_insert_with(Vec::new);
274        toolboxes.push(Box::new(toolbox));
275
276        self
277    }
278}
279
280impl Agent {
281    /// Build a new agent
282    pub fn builder() -> AgentBuilder {
283        AgentBuilder::default()
284            .tools(Agent::default_tools())
285            .to_owned()
286    }
287
288    /// Default tools for the agent that it always includes
289    /// Right now this is the `stop` tool, which allows the agent to stop itself.
290    pub fn default_tools() -> HashSet<Box<dyn Tool>> {
291        HashSet::from([Stop::default().boxed()])
292    }
293
294    /// Run the agent with a user message. The agent will loop completions, make tool calls, until
295    /// no new messages are available.
296    #[tracing::instrument(skip_all, name = "agent.query")]
297    pub async fn query(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
298        let query = query
299            .into()
300            .render()
301            .map_err(AgentError::FailedToRenderPrompt)?;
302        self.run_agent(Some(query), false).await
303    }
304
305    /// Run the agent with a user message once.
306    #[tracing::instrument(skip_all, name = "agent.query_once")]
307    pub async fn query_once(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
308        let query = query
309            .into()
310            .render()
311            .map_err(AgentError::FailedToRenderPrompt)?;
312        self.run_agent(Some(query), true).await
313    }
314
315    /// Run the agent with without user message. The agent will loop completions, make tool calls,
316    /// until no new messages are available.
317    #[tracing::instrument(skip_all, name = "agent.run")]
318    pub async fn run(&mut self) -> Result<(), AgentError> {
319        self.run_agent(None, false).await
320    }
321
322    /// Run the agent with without user message. The agent will loop completions, make tool calls,
323    /// until
324    #[tracing::instrument(skip_all, name = "agent.run_once")]
325    pub async fn run_once(&mut self) -> Result<(), AgentError> {
326        self.run_agent(None, true).await
327    }
328
329    /// Retrieve the message history of the agent
330    ///
331    /// # Errors
332    ///
333    /// Error if the message history cannot be retrieved, e.g. if the context is not set up or a
334    /// connection fails
335    pub async fn history(&self) -> Result<Vec<ChatMessage>, AgentError> {
336        self.context
337            .history()
338            .await
339            .map_err(AgentError::MessageHistoryError)
340    }
341
342    async fn run_agent(
343        &mut self,
344        maybe_query: Option<String>,
345        just_once: bool,
346    ) -> Result<(), AgentError> {
347        if self.state.is_running() {
348            return Err(AgentError::AlreadyRunning);
349        }
350
351        if self.state.is_pending() {
352            if let Some(system_prompt) = &self.system_prompt {
353                self.context
354                    .add_messages(vec![ChatMessage::System(
355                        system_prompt
356                            .render()
357                            .map_err(AgentError::FailedToRenderSystemPrompt)?,
358                    )])
359                    .await
360                    .map_err(AgentError::MessageHistoryError)?;
361            }
362
363            invoke_hooks!(BeforeAll, self);
364
365            self.load_toolboxes().await?;
366        }
367
368        invoke_hooks!(OnStart, self);
369
370        self.state = state::State::Running;
371
372        if let Some(query) = maybe_query {
373            self.context
374                .add_message(ChatMessage::User(query))
375                .await
376                .map_err(AgentError::MessageHistoryError)?;
377        }
378
379        let mut loop_counter = 0;
380
381        while let Some(messages) = self
382            .context
383            .next_completion()
384            .await
385            .map_err(AgentError::MessageHistoryError)?
386        {
387            if let Some(limit) = self.limit {
388                if loop_counter >= limit {
389                    tracing::warn!("Agent loop limit reached");
390                    break;
391                }
392            }
393
394            // If the last message contains tool calls that have not been completed,
395            // run the tools first
396            if let Some(&ChatMessage::Assistant(.., Some(ref tool_calls))) =
397                maybe_tool_call_without_output(&messages)
398            {
399                tracing::debug!("Uncompleted tool calls found; invoking tools");
400                self.invoke_tools(tool_calls).await?;
401                // Move on to the next tick, so that the
402                continue;
403            }
404
405            let result = self.run_completions(&messages).await;
406
407            if let Err(err) = result {
408                self.stop_with_error(&err).await;
409                tracing::error!(error = ?err, "Agent stopped with error {err}");
410                return Err(err);
411            }
412
413            if just_once || self.state.is_stopped() {
414                break;
415            }
416            loop_counter += 1;
417        }
418
419        // If there are no new messages, ensure we update our state
420        self.stop(StopReason::NoNewMessages).await;
421
422        Ok(())
423    }
424
425    #[tracing::instrument(skip_all, err)]
426    async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<(), AgentError> {
427        debug!(
428            tools = ?self
429                .tools
430                .iter()
431                .map(|t| t.name())
432                .collect::<Vec<_>>()
433                ,
434            "Running completion for agent with {} new messages",
435            messages.len()
436        );
437
438        let mut chat_completion_request = ChatCompletionRequest::builder()
439            .messages(messages)
440            .tools_spec(
441                self.tools
442                    .iter()
443                    .map(swiftide_core::Tool::tool_spec)
444                    .collect::<HashSet<_>>(),
445            )
446            .build()
447            .map_err(AgentError::FailedToBuildRequest)?;
448
449        invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
450
451        debug!(
452            "Calling LLM with the following new messages:\n {}",
453            self.context
454                .current_new_messages()
455                .await
456                .map_err(AgentError::MessageHistoryError)?
457                .iter()
458                .map(ToString::to_string)
459                .collect::<Vec<_>>()
460                .join(",\n")
461        );
462
463        let mut response = if self.streaming {
464            let mut last_response = None;
465            let mut stream = self.llm.complete_stream(&chat_completion_request).await;
466
467            while let Some(response) = stream.next().await {
468                let response = response.map_err(AgentError::CompletionsFailed)?;
469                invoke_hooks!(OnStream, self, &response);
470                last_response = Some(response);
471            }
472            tracing::trace!(?last_response, "Streaming completed");
473            last_response.ok_or(AgentError::EmptyStream)
474        } else {
475            self.llm
476                .complete(&chat_completion_request)
477                .await
478                .map_err(AgentError::CompletionsFailed)
479        }?;
480
481        // The arg preprocessor helps avoid common llm errors.
482        // This must happen as early as possible
483        response
484            .tool_calls
485            .as_deref_mut()
486            .map(ArgPreprocessor::preprocess_tool_calls);
487
488        invoke_hooks!(AfterCompletion, self, &mut response);
489
490        self.add_message(ChatMessage::Assistant(
491            response.message,
492            response.tool_calls.clone(),
493        ))
494        .await?;
495
496        if let Some(tool_calls) = response.tool_calls {
497            self.invoke_tools(&tool_calls).await?;
498        }
499
500        invoke_hooks!(AfterEach, self);
501
502        Ok(())
503    }
504
505    async fn invoke_tools(&mut self, tool_calls: &[ToolCall]) -> Result<(), AgentError> {
506        tracing::debug!("LLM returned tool calls: {:?}", tool_calls);
507
508        let mut handles = vec![];
509        for tool_call in tool_calls {
510            let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
511                tracing::warn!("Tool {} not found", tool_call.name());
512                continue;
513            };
514            tracing::info!("Calling tool `{}`", tool_call.name());
515
516            // let tool_args = tool_call.args().map(String::from);
517            let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
518
519            invoke_hooks!(BeforeTool, self, &tool_call);
520
521            let tool_span = tracing::info_span!(
522                "tool",
523                "otel.name" = format!("tool.{}", tool.name().as_ref())
524            );
525
526            let handle_tool_call = tool_call.clone();
527            let handle = tokio::spawn(async move {
528                    let handle_tool_call = handle_tool_call;
529                    let output = tool.invoke(&*context, &handle_tool_call)
530                        .await
531                        .map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
532
533                    tracing::debug!(output = output.to_string(), args = ?handle_tool_call.args(), tool_name = tool.name().as_ref(), "Completed tool call");
534
535                    Ok(output)
536                }.instrument(tool_span.or_current()));
537
538            handles.push((handle, tool_call));
539        }
540
541        for (handle, tool_call) in handles {
542            let mut output = handle.await.map_err(AgentError::ToolFailedToJoin)?;
543
544            invoke_hooks!(AfterTool, self, &tool_call, &mut output);
545
546            if let Err(error) = output {
547                let stop = self.tool_calls_over_limit(tool_call);
548                if stop {
549                    tracing::error!(
550                        ?error,
551                        "Tool call failed, retry limit reached, stopping agent: {error}",
552                    );
553                } else {
554                    tracing::warn!(
555                        ?error,
556                        tool_call = ?tool_call,
557                        "Tool call failed, retrying",
558                    );
559                }
560                self.add_message(ChatMessage::ToolOutput(
561                    tool_call.clone(),
562                    ToolOutput::Fail(error.to_string()),
563                ))
564                .await?;
565                if stop {
566                    self.stop(StopReason::ToolCallsOverLimit(tool_call.to_owned()))
567                        .await;
568                    return Err(error.into());
569                }
570                continue;
571            }
572
573            let output = output?;
574            self.handle_control_tools(tool_call, &output).await;
575
576            // Feedback required leaves the tool call open
577            //
578            // It assumes a follow up invocation of the agent will have the feedback approved
579            if !output.is_feedback_required() {
580                self.add_message(ChatMessage::ToolOutput(tool_call.to_owned(), output))
581                    .await?;
582            }
583        }
584
585        Ok(())
586    }
587
588    fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
589        self.hooks
590            .iter()
591            .filter(|h| hook_type == (*h).into())
592            .collect()
593    }
594
595    fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
596        self.tools
597            .iter()
598            .find(|tool| tool.name() == tool_name)
599            .cloned()
600    }
601
602    // Handle any tool specific output (e.g. stop)
603    async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
604        match output {
605            ToolOutput::Stop => {
606                tracing::warn!("Stop tool called, stopping agent");
607                self.stop(StopReason::RequestedByTool(tool_call.clone()))
608                    .await;
609            }
610
611            ToolOutput::FeedbackRequired(maybe_payload) => {
612                tracing::warn!("Feedback required, stopping agent");
613                self.stop(StopReason::FeedbackRequired {
614                    tool_call: tool_call.clone(),
615                    payload: maybe_payload.clone(),
616                })
617                .await;
618            }
619            _ => (),
620        }
621    }
622
623    fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
624        let mut s = DefaultHasher::new();
625        tool_call.hash(&mut s);
626        let hash = s.finish();
627
628        if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
629            let val = *retries >= self.tool_retry_limit;
630            *retries += 1;
631            val
632        } else {
633            self.tool_retries_counter.insert(hash, 1);
634            false
635        }
636    }
637
638    /// Add a message to the agent's context
639    ///
640    /// This will trigger a `OnNewMessage` hook if its present.
641    ///
642    /// If you want to add a message without triggering the hook, use the context directly.
643    #[tracing::instrument(skip_all, fields(message = message.to_string()))]
644    pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
645        invoke_hooks!(OnNewMessage, self, &mut message);
646
647        self.context
648            .add_message(message)
649            .await
650            .map_err(AgentError::MessageHistoryError)?;
651        Ok(())
652    }
653
654    /// Tell the agent to stop. It will finish it's current loop and then stop.
655    pub async fn stop(&mut self, reason: impl Into<StopReason>) {
656        if self.state.is_stopped() {
657            return;
658        }
659        let reason = reason.into();
660        invoke_hooks!(OnStop, self, reason.clone(), None);
661
662        self.state = state::State::Stopped(reason);
663    }
664
665    pub async fn stop_with_error(&mut self, error: &AgentError) {
666        if self.state.is_stopped() {
667            return;
668        }
669        invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
670
671        self.state = state::State::Stopped(StopReason::Error);
672    }
673
674    /// Access the agent's context
675    pub fn context(&self) -> &dyn AgentContext {
676        &self.context
677    }
678
679    /// The agent is still running
680    pub fn is_running(&self) -> bool {
681        self.state.is_running()
682    }
683
684    /// The agent stopped
685    pub fn is_stopped(&self) -> bool {
686        self.state.is_stopped()
687    }
688
689    /// The agent has not (ever) started
690    pub fn is_pending(&self) -> bool {
691        self.state.is_pending()
692    }
693
694    /// Get a list of tools available to the agent
695    pub fn tools(&self) -> &HashSet<Box<dyn Tool>> {
696        &self.tools
697    }
698
699    pub fn state(&self) -> &state::State {
700        &self.state
701    }
702
703    pub fn stop_reason(&self) -> Option<&StopReason> {
704        self.state.stop_reason()
705    }
706
707    async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
708        for toolbox in &self.toolboxes {
709            let tools = toolbox
710                .available_tools()
711                .await
712                .map_err(AgentError::ToolBoxFailedToLoad)?;
713            self.toolbox_tools.extend(tools);
714        }
715
716        self.tools.extend(self.toolbox_tools.clone());
717
718        Ok(())
719    }
720}
721
722/// Reverse searches through messages, if it encounters a tool call before encountering an output,
723/// it will return the chat message with the tool calls, otherwise it returns None
724fn maybe_tool_call_without_output(messages: &[ChatMessage]) -> Option<&ChatMessage> {
725    for message in messages.iter().rev() {
726        if let ChatMessage::ToolOutput(..) = message {
727            return None;
728        }
729
730        if let ChatMessage::Assistant(.., Some(tool_calls)) = message {
731            if !tool_calls.is_empty() {
732                return Some(message);
733            }
734        }
735    }
736
737    None
738}
739
740#[cfg(test)]
741mod tests {
742
743    use serde::ser::Error;
744    use swiftide_core::ToolFeedback;
745    use swiftide_core::chat_completion::errors::ToolError;
746    use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
747    use swiftide_core::test_utils::MockChatCompletion;
748
749    use super::*;
750    use crate::{
751        State, assistant, chat_request, chat_response, summary, system, tool_failed, tool_output,
752        user,
753    };
754
755    use crate::test_utils::{MockHook, MockTool};
756
757    #[test_log::test(tokio::test)]
758    async fn test_agent_builder_defaults() {
759        // Create a prompt
760        let mock_llm = MockChatCompletion::new();
761
762        // Build the agent
763        let agent = Agent::builder().llm(&mock_llm).build().unwrap();
764
765        // Check that the context is the default context
766
767        // Check that the default tools are added
768        assert!(agent.find_tool_by_name("stop").is_some());
769
770        // Check it does not allow duplicates
771        let agent = Agent::builder()
772            .tools([Stop::default(), Stop::default()])
773            .llm(&mock_llm)
774            .build()
775            .unwrap();
776
777        assert_eq!(agent.tools.len(), 1);
778
779        // It should include the default tool if a different tool is provided
780        let agent = Agent::builder()
781            .tools([MockTool::new("mock_tool")])
782            .llm(&mock_llm)
783            .build()
784            .unwrap();
785
786        assert_eq!(agent.tools.len(), 2);
787        assert!(agent.find_tool_by_name("mock_tool").is_some());
788        assert!(agent.find_tool_by_name("stop").is_some());
789
790        assert!(agent.context().history().await.unwrap().is_empty());
791    }
792
793    #[test_log::test(tokio::test)]
794    async fn test_agent_tool_calling_loop() {
795        let prompt = "Write a poem";
796        let mock_llm = MockChatCompletion::new();
797        let mock_tool = MockTool::new("mock_tool");
798
799        let chat_request = chat_request! {
800            user!("Write a poem");
801
802            tools = [mock_tool.clone()]
803        };
804
805        let mock_tool_response = chat_response! {
806            "Roses are red";
807            tool_calls = ["mock_tool"]
808
809        };
810
811        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
812
813        let chat_request = chat_request! {
814            user!("Write a poem"),
815            assistant!("Roses are red", ["mock_tool"]),
816            tool_output!("mock_tool", "Great!");
817
818            tools = [mock_tool.clone()]
819        };
820
821        let stop_response = chat_response! {
822            "Roses are red";
823            tool_calls = ["stop"]
824        };
825
826        mock_llm.expect_complete(chat_request, Ok(stop_response));
827        mock_tool.expect_invoke_ok("Great!".into(), None);
828
829        let mut agent = Agent::builder()
830            .tools([mock_tool])
831            .llm(&mock_llm)
832            .no_system_prompt()
833            .build()
834            .unwrap();
835
836        agent.query(prompt).await.unwrap();
837    }
838
839    #[test_log::test(tokio::test)]
840    async fn test_agent_tool_run_once() {
841        let prompt = "Write a poem";
842        let mock_llm = MockChatCompletion::new();
843        let mock_tool = MockTool::default();
844
845        let chat_request = chat_request! {
846            system!("My system prompt"),
847            user!("Write a poem");
848
849            tools = [mock_tool.clone()]
850        };
851
852        let mock_tool_response = chat_response! {
853            "Roses are red";
854            tool_calls = ["mock_tool"]
855
856        };
857
858        mock_tool.expect_invoke_ok("Great!".into(), None);
859        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
860
861        let mut agent = Agent::builder()
862            .tools([mock_tool])
863            .system_prompt("My system prompt")
864            .llm(&mock_llm)
865            .build()
866            .unwrap();
867
868        agent.query_once(prompt).await.unwrap();
869    }
870
871    #[test_log::test(tokio::test)]
872    async fn test_agent_tool_via_toolbox_run_once() {
873        let prompt = "Write a poem";
874        let mock_llm = MockChatCompletion::new();
875        let mock_tool = MockTool::default();
876
877        let chat_request = chat_request! {
878            system!("My system prompt"),
879            user!("Write a poem");
880
881            tools = [mock_tool.clone()]
882        };
883
884        let mock_tool_response = chat_response! {
885            "Roses are red";
886            tool_calls = ["mock_tool"]
887
888        };
889
890        mock_tool.expect_invoke_ok("Great!".into(), None);
891        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
892
893        let mut agent = Agent::builder()
894            .add_toolbox(vec![mock_tool.boxed()])
895            .system_prompt("My system prompt")
896            .llm(&mock_llm)
897            .build()
898            .unwrap();
899
900        agent.query_once(prompt).await.unwrap();
901    }
902
903    #[test_log::test(tokio::test(flavor = "multi_thread"))]
904    async fn test_multiple_tool_calls() {
905        let prompt = "Write a poem";
906        let mock_llm = MockChatCompletion::new();
907        let mock_tool = MockTool::new("mock_tool1");
908        let mock_tool2 = MockTool::new("mock_tool2");
909
910        let chat_request = chat_request! {
911            system!("My system prompt"),
912            user!("Write a poem");
913
914
915
916            tools = [mock_tool.clone(), mock_tool2.clone()]
917        };
918
919        let mock_tool_response = chat_response! {
920            "Roses are red";
921
922            tool_calls = ["mock_tool1", "mock_tool2"]
923
924        };
925
926        dbg!(&chat_request);
927        mock_tool.expect_invoke_ok("Great!".into(), None);
928        mock_tool2.expect_invoke_ok("Great!".into(), None);
929        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
930
931        let chat_request = chat_request! {
932            system!("My system prompt"),
933            user!("Write a poem"),
934            assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
935            tool_output!("mock_tool1", "Great!"),
936            tool_output!("mock_tool2", "Great!");
937
938            tools = [mock_tool.clone(), mock_tool2.clone()]
939        };
940
941        let mock_tool_response = chat_response! {
942            "Ok!";
943
944            tool_calls = ["stop"]
945
946        };
947
948        mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
949
950        let mut agent = Agent::builder()
951            .tools([mock_tool, mock_tool2])
952            .system_prompt("My system prompt")
953            .llm(&mock_llm)
954            .build()
955            .unwrap();
956
957        agent.query(prompt).await.unwrap();
958    }
959
960    #[test_log::test(tokio::test)]
961    async fn test_agent_state_machine() {
962        let prompt = "Write a poem";
963        let mock_llm = MockChatCompletion::new();
964
965        let chat_request = chat_request! {
966            user!("Write a poem");
967            tools = []
968        };
969        let mock_tool_response = chat_response! {
970            "Roses are red";
971            tool_calls = []
972        };
973
974        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
975        let mut agent = Agent::builder()
976            .llm(&mock_llm)
977            .no_system_prompt()
978            .build()
979            .unwrap();
980
981        // Agent has never run and is pending
982        assert!(agent.state.is_pending());
983        agent.query_once(prompt).await.unwrap();
984
985        // Agent is stopped, there might be more messages
986        assert!(agent.state.is_stopped());
987    }
988
989    #[test_log::test(tokio::test)]
990    async fn test_summary() {
991        let prompt = "Write a poem";
992        let mock_llm = MockChatCompletion::new();
993
994        let mock_tool_response = chat_response! {
995            "Roses are red";
996            tool_calls = []
997
998        };
999
1000        let expected_chat_request = chat_request! {
1001            system!("My system prompt"),
1002            user!("Write a poem");
1003
1004            tools = []
1005        };
1006
1007        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1008
1009        let mut agent = Agent::builder()
1010            .system_prompt("My system prompt")
1011            .llm(&mock_llm)
1012            .build()
1013            .unwrap();
1014
1015        agent.query_once(prompt).await.unwrap();
1016
1017        agent
1018            .context
1019            .add_message(ChatMessage::new_summary("Summary"))
1020            .await
1021            .unwrap();
1022
1023        let expected_chat_request = chat_request! {
1024            system!("My system prompt"),
1025            summary!("Summary"),
1026            user!("Write another poem");
1027            tools = []
1028        };
1029        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1030
1031        agent.query_once("Write another poem").await.unwrap();
1032
1033        agent
1034            .context
1035            .add_message(ChatMessage::new_summary("Summary 2"))
1036            .await
1037            .unwrap();
1038
1039        let expected_chat_request = chat_request! {
1040            system!("My system prompt"),
1041            summary!("Summary 2"),
1042            user!("Write a third poem");
1043            tools = []
1044        };
1045        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
1046
1047        agent.query_once("Write a third poem").await.unwrap();
1048    }
1049
1050    #[test_log::test(tokio::test)]
1051    async fn test_agent_hooks() {
1052        let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
1053        let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
1054        let mock_before_completion = MockHook::new("before_completion")
1055            .expect_calls(2)
1056            .to_owned();
1057        let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
1058        let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
1059        let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
1060        let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
1061
1062        // Once for mock tool and once for stop
1063        let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
1064        let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
1065
1066        let prompt = "Write a poem";
1067        let mock_llm = MockChatCompletion::new();
1068        let mock_tool = MockTool::default();
1069
1070        let chat_request = chat_request! {
1071            user!("Write a poem");
1072
1073            tools = [mock_tool.clone()]
1074        };
1075
1076        let mock_tool_response = chat_response! {
1077            "Roses are red";
1078            tool_calls = ["mock_tool"]
1079
1080        };
1081
1082        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1083
1084        let chat_request = chat_request! {
1085            user!("Write a poem"),
1086            assistant!("Roses are red", ["mock_tool"]),
1087            tool_output!("mock_tool", "Great!");
1088
1089            tools = [mock_tool.clone()]
1090        };
1091
1092        let stop_response = chat_response! {
1093            "Roses are red";
1094            tool_calls = ["stop"]
1095        };
1096
1097        mock_llm.expect_complete(chat_request, Ok(stop_response));
1098        mock_tool.expect_invoke_ok("Great!".into(), None);
1099
1100        let mut agent = Agent::builder()
1101            .tools([mock_tool])
1102            .llm(&mock_llm)
1103            .no_system_prompt()
1104            .before_all(mock_before_all.hook_fn())
1105            .on_start(mock_on_start_fn.on_start_fn())
1106            .before_completion(mock_before_completion.before_completion_fn())
1107            .before_tool(mock_before_tool.before_tool_fn())
1108            .after_completion(mock_after_completion.after_completion_fn())
1109            .after_tool(mock_after_tool.after_tool_fn())
1110            .after_each(mock_after_each.hook_fn())
1111            .on_new_message(mock_on_message.message_hook_fn())
1112            .on_stop(mock_on_stop.stop_hook_fn())
1113            .build()
1114            .unwrap();
1115
1116        agent.query(prompt).await.unwrap();
1117    }
1118
1119    #[test_log::test(tokio::test)]
1120    async fn test_agent_loop_limit() {
1121        let prompt = "Generate content"; // Example prompt
1122        let mock_llm = MockChatCompletion::new();
1123        let mock_tool = MockTool::new("mock_tool");
1124
1125        let chat_request = chat_request! {
1126            user!(prompt);
1127            tools = [mock_tool.clone()]
1128        };
1129        mock_tool.expect_invoke_ok("Great!".into(), None);
1130
1131        let mock_tool_response = chat_response! {
1132            "Some response";
1133            tool_calls = ["mock_tool"]
1134        };
1135
1136        // Set expectations for the mock LLM responses
1137        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1138
1139        // // Response for terminating the loop
1140        let stop_response = chat_response! {
1141            "Final response";
1142            tool_calls = ["stop"]
1143        };
1144
1145        mock_llm.expect_complete(chat_request, Ok(stop_response));
1146
1147        let mut agent = Agent::builder()
1148            .tools([mock_tool])
1149            .llm(&mock_llm)
1150            .no_system_prompt()
1151            .limit(1) // Setting the loop limit to 1
1152            .build()
1153            .unwrap();
1154
1155        // Run the agent
1156        agent.query(prompt).await.unwrap();
1157
1158        // Assert that the remaining message is still in the queue
1159        let remaining = mock_llm.expectations.lock().unwrap().pop();
1160        assert!(remaining.is_some());
1161
1162        // Assert that the agent is stopped after reaching the loop limit
1163        assert!(agent.is_stopped());
1164    }
1165
1166    #[test_log::test(tokio::test)]
1167    async fn test_tool_retry_mechanism() {
1168        let prompt = "Execute my tool";
1169        let mock_llm = MockChatCompletion::new();
1170        let mock_tool = MockTool::new("retry_tool");
1171
1172        // Configure mock tool to fail twice. First time is fed back to the LLM, second time is an
1173        // error
1174        mock_tool.expect_invoke(
1175            Err(ToolError::WrongArguments(serde_json::Error::custom(
1176                "missing `query`",
1177            ))),
1178            None,
1179        );
1180        mock_tool.expect_invoke(
1181            Err(ToolError::WrongArguments(serde_json::Error::custom(
1182                "missing `query`",
1183            ))),
1184            None,
1185        );
1186
1187        let chat_request = chat_request! {
1188            user!(prompt);
1189            tools = [mock_tool.clone()]
1190        };
1191        let retry_response = chat_response! {
1192            "First failing attempt";
1193            tool_calls = ["retry_tool"]
1194        };
1195        mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1196
1197        let chat_request = chat_request! {
1198            user!(prompt),
1199            assistant!("First failing attempt", ["retry_tool"]),
1200            tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1201
1202            tools = [mock_tool.clone()]
1203        };
1204        let will_fail_response = chat_response! {
1205            "Finished execution";
1206            tool_calls = ["retry_tool"]
1207        };
1208        mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1209
1210        let mut agent = Agent::builder()
1211            .tools([mock_tool])
1212            .llm(&mock_llm)
1213            .no_system_prompt()
1214            .tool_retry_limit(1) // The test relies on a limit of 2 retries.
1215            .build()
1216            .unwrap();
1217
1218        // Run the agent
1219        let result = agent.query(prompt).await;
1220
1221        assert!(result.is_err());
1222        assert!(result.unwrap_err().to_string().contains("missing `query`"));
1223        assert!(agent.is_stopped());
1224    }
1225
1226    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1227    async fn test_streaming() {
1228        let prompt = "Generate content"; // Example prompt
1229        let mock_llm = MockChatCompletion::new();
1230        let on_stream_fn = MockHook::new("on_stream").expect_calls(3).to_owned();
1231
1232        let chat_request = chat_request! {
1233            user!(prompt);
1234
1235            tools = []
1236        };
1237
1238        let response = chat_response! {
1239            "one two three";
1240            tool_calls = ["stop"]
1241        };
1242
1243        // Set expectations for the mock LLM responses
1244        mock_llm.expect_complete(chat_request, Ok(response));
1245
1246        let mut agent = Agent::builder()
1247            .llm(&mock_llm)
1248            .on_stream(on_stream_fn.on_stream_fn())
1249            .no_system_prompt()
1250            .build()
1251            .unwrap();
1252
1253        // Run the agent
1254        agent.query(prompt).await.unwrap();
1255
1256        tracing::debug!("Agent finished running");
1257
1258        // Assert that the agent is stopped after reaching the loop limit
1259        assert!(agent.is_stopped());
1260    }
1261
1262    #[test_log::test(tokio::test)]
1263    async fn test_recovering_agent_existing_history() {
1264        // First, let's run an agent
1265        let prompt = "Write a poem";
1266        let mock_llm = MockChatCompletion::new();
1267        let mock_tool = MockTool::new("mock_tool");
1268
1269        let chat_request = chat_request! {
1270            user!("Write a poem");
1271
1272            tools = [mock_tool.clone()]
1273        };
1274
1275        let mock_tool_response = chat_response! {
1276            "Roses are red";
1277            tool_calls = ["mock_tool"]
1278
1279        };
1280
1281        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1282
1283        let chat_request = chat_request! {
1284            user!("Write a poem"),
1285            assistant!("Roses are red", ["mock_tool"]),
1286            tool_output!("mock_tool", "Great!");
1287
1288            tools = [mock_tool.clone()]
1289        };
1290
1291        let stop_response = chat_response! {
1292            "Roses are red";
1293            tool_calls = ["stop"]
1294        };
1295
1296        mock_llm.expect_complete(chat_request, Ok(stop_response));
1297        mock_tool.expect_invoke_ok("Great!".into(), None);
1298
1299        let mut agent = Agent::builder()
1300            .tools([mock_tool.clone()])
1301            .llm(&mock_llm)
1302            .no_system_prompt()
1303            .build()
1304            .unwrap();
1305
1306        agent.query(prompt).await.unwrap();
1307
1308        // Let's retrieve the history of the agent
1309        let history = agent.history().await.unwrap();
1310
1311        // Store it as a string somewhere
1312        let serialized = serde_json::to_string(&history).unwrap();
1313
1314        // Retrieve it
1315        let history: Vec<ChatMessage> = serde_json::from_str(&serialized).unwrap();
1316
1317        // Build a context from the history
1318        let context = DefaultContext::default()
1319            .with_existing_messages(history)
1320            .await
1321            .unwrap()
1322            .to_owned();
1323
1324        let expected_chat_request = chat_request! {
1325            user!("Write a poem"),
1326            assistant!("Roses are red", ["mock_tool"]),
1327            tool_output!("mock_tool", "Great!"),
1328            assistant!("Roses are red", ["stop"]),
1329            tool_output!("stop", ToolOutput::Stop),
1330            user!("Try again!");
1331
1332            tools = [mock_tool.clone()]
1333        };
1334
1335        let stop_response = chat_response! {
1336            "Really stopping now";
1337            tool_calls = ["stop"]
1338        };
1339
1340        mock_llm.expect_complete(expected_chat_request, Ok(stop_response));
1341
1342        let mut agent = Agent::builder()
1343            .context(context)
1344            .tools([mock_tool])
1345            .llm(&mock_llm)
1346            .no_system_prompt()
1347            .build()
1348            .unwrap();
1349
1350        agent.query_once("Try again!").await.unwrap();
1351    }
1352
1353    #[test_log::test(tokio::test)]
1354    async fn test_agent_with_approval_required_tool() {
1355        use super::*;
1356        use crate::tools::control::ApprovalRequired;
1357        use crate::{assistant, chat_request, chat_response, user};
1358        use swiftide_core::chat_completion::ToolCall;
1359
1360        // Step 1: Build a tool that needs approval.
1361        let mock_tool = MockTool::default();
1362        mock_tool.expect_invoke_ok("Great!".into(), None);
1363
1364        let approval_tool = ApprovalRequired(mock_tool.boxed());
1365
1366        // Step 2: Set up the mock LLM.
1367        let mock_llm = MockChatCompletion::new();
1368
1369        let chat_req1 = chat_request! {
1370            user!("Request with approval");
1371            tools = [approval_tool.clone()]
1372        };
1373        let chat_resp1 = chat_response! {
1374            "Completion message";
1375            tool_calls = ["mock_tool"]
1376        };
1377        mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1378
1379        // The response will include the previous request, but no tool output
1380        // from the required tool
1381        let chat_req2 = chat_request! {
1382            user!("Request with approval"),
1383            assistant!("Completion message", ["mock_tool"]),
1384            tool_output!("mock_tool", "Great!");
1385            // Simulate feedback required output
1386            tools = [approval_tool.clone()]
1387        };
1388        let chat_resp2 = chat_response! {
1389            "Post-feedback message";
1390            tool_calls = ["stop"]
1391        };
1392        mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1393
1394        // Step 3: Wire up the agent.
1395        let mut agent = Agent::builder()
1396            .tools([approval_tool])
1397            .llm(&mock_llm)
1398            .no_system_prompt()
1399            .build()
1400            .unwrap();
1401
1402        // Step 4: Run agent to trigger approval.
1403        agent.query_once("Request with approval").await.unwrap();
1404
1405        assert!(matches!(
1406            agent.state,
1407            crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1408        ));
1409
1410        let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1411        else {
1412            panic!("Expected feedback required");
1413        };
1414
1415        // Step 5: Simulate feedback, run again and assert finish.
1416        agent
1417            .context
1418            .feedback_received(&tool_call, &ToolFeedback::approved())
1419            .await
1420            .unwrap();
1421
1422        tracing::debug!("running after approval");
1423        agent.run_once().await.unwrap();
1424        assert!(agent.is_stopped());
1425    }
1426
1427    #[test_log::test(tokio::test)]
1428    async fn test_agent_with_approval_required_tool_denied() {
1429        use super::*;
1430        use crate::tools::control::ApprovalRequired;
1431        use crate::{assistant, chat_request, chat_response, user};
1432        use swiftide_core::chat_completion::ToolCall;
1433
1434        // Step 1: Build a tool that needs approval.
1435        let mock_tool = MockTool::default();
1436
1437        let approval_tool = ApprovalRequired(mock_tool.boxed());
1438
1439        // Step 2: Set up the mock LLM.
1440        let mock_llm = MockChatCompletion::new();
1441
1442        let chat_req1 = chat_request! {
1443            user!("Request with approval");
1444            tools = [approval_tool.clone()]
1445        };
1446        let chat_resp1 = chat_response! {
1447            "Completion message";
1448            tool_calls = ["mock_tool"]
1449        };
1450        mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1451
1452        // The response will include the previous request, but no tool output
1453        // from the required tool
1454        let chat_req2 = chat_request! {
1455            user!("Request with approval"),
1456            assistant!("Completion message", ["mock_tool"]),
1457            tool_output!("mock_tool", "This tool call was refused");
1458            // Simulate feedback required output
1459            tools = [approval_tool.clone()]
1460        };
1461        let chat_resp2 = chat_response! {
1462            "Post-feedback message";
1463            tool_calls = ["stop"]
1464        };
1465        mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1466
1467        // Step 3: Wire up the agent.
1468        let mut agent = Agent::builder()
1469            .tools([approval_tool])
1470            .llm(&mock_llm)
1471            .no_system_prompt()
1472            .build()
1473            .unwrap();
1474
1475        // Step 4: Run agent to trigger approval.
1476        agent.query_once("Request with approval").await.unwrap();
1477
1478        assert!(matches!(
1479            agent.state,
1480            crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1481        ));
1482
1483        let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1484        else {
1485            panic!("Expected feedback required");
1486        };
1487
1488        // Step 5: Simulate feedback, run again and assert finish.
1489        agent
1490            .context
1491            .feedback_received(&tool_call, &ToolFeedback::refused())
1492            .await
1493            .unwrap();
1494
1495        tracing::debug!("running after approval");
1496        agent.run_once().await.unwrap();
1497
1498        let history = agent.context().history().await.unwrap();
1499        history
1500            .iter()
1501            .rfind(|m| {
1502                let ChatMessage::ToolOutput(.., ToolOutput::Text(msg)) = m else {
1503                    return false;
1504                };
1505                msg.contains("refused")
1506            })
1507            .expect("Could not find refusal message");
1508
1509        assert!(agent.is_stopped());
1510    }
1511}