swiftide_agents/
agent.rs

1#![allow(dead_code)]
2use crate::{
3    default_context::DefaultContext,
4    errors::AgentError,
5    hooks::{
6        AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7        Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn, OnStreamFn,
8    },
9    invoke_hooks,
10    state::{self, StopReason},
11    system_prompt::SystemPrompt,
12    tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15    collections::{HashMap, HashSet},
16    hash::{DefaultHasher, Hash as _, Hasher as _},
17    sync::Arc,
18};
19
20use derive_builder::Builder;
21use futures_util::stream::StreamExt;
22use swiftide_core::{
23    AgentContext, ToolBox,
24    chat_completion::{
25        ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
26    },
27    prompt::Prompt,
28};
29use tracing::{Instrument, debug};
30
31/// Agents are the main interface for building agentic systems.
32///
33/// Construct agents by calling the builder, setting an llm, configure hooks, tools and other
34/// customizations.
35///
36/// # Important defaults
37///
38/// - The default context is the `DefaultContext`, executing tools locally with the `LocalExecutor`.
39/// - A default `stop` tool is provided for agents to explicitly stop if needed
40/// - The default `SystemPrompt` instructs the agent with chain of thought and some common
41///   safeguards, but is otherwise quite bare. In a lot of cases this can be sufficient.
42///
43///   Agents are *not* cheap to clone. However, if an agent gets cloned, it will operate on the
44///   same context.
45#[derive(Builder)]
46pub struct Agent {
47    /// Hooks are functions that are called at specific points in the agent's lifecycle.
48    #[builder(default, setter(into))]
49    pub(crate) hooks: Vec<Hook>,
50    /// The context in which the agent operates, by default this is the `DefaultContext`.
51    #[builder(
52        setter(custom),
53        default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
54    )]
55    pub(crate) context: Arc<dyn AgentContext>,
56    /// Tools the agent can use
57    #[builder(default = Agent::default_tools(), setter(custom))]
58    pub(crate) tools: HashSet<Box<dyn Tool>>,
59
60    /// Toolboxes are collections of tools that can be added to the agent.
61    ///
62    /// Toolboxes make their tools available to the agent at runtime.
63    #[builder(default)]
64    pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
65
66    /// The language model that the agent uses for completion.
67    #[builder(setter(custom))]
68    pub(crate) llm: Box<dyn ChatCompletion>,
69
70    /// System prompt for the agent when it starts
71    ///
72    /// Some agents profit significantly from a tailored prompt. But it is not always needed.
73    ///
74    /// See [`SystemPrompt`] for an opiniated, customizable system prompt.
75    ///
76    /// Swiftide provides a default system prompt for all agents.
77    ///
78    /// # Example
79    ///
80    /// ```no_run
81    /// # use swiftide_agents::system_prompt::SystemPrompt;
82    /// # use swiftide_agents::Agent;
83    /// Agent::builder()
84    ///     .system_prompt(
85    ///         SystemPrompt::builder().role("You are an expert engineer")
86    ///         .build().unwrap())
87    ///     .build().unwrap();
88    /// ```
89    #[builder(setter(into, strip_option), default = Some(SystemPrompt::default()))]
90    pub(crate) system_prompt: Option<SystemPrompt>,
91
92    /// Initial state of the agent
93    #[builder(private, default = state::State::default())]
94    pub(crate) state: state::State,
95
96    /// Optional limit on the amount of loops the agent can run.
97    /// The counter is reset when the agent is stopped.
98    #[builder(default, setter(strip_option))]
99    pub(crate) limit: Option<usize>,
100
101    /// The maximum amount of times the failed output of a tool will be send
102    /// to an LLM before the agent stops. Defaults to 3.
103    ///
104    /// LLMs sometimes send missing arguments, or a tool might actually fail, but retrying could be
105    /// worth while. If the limit is not reached, the agent will send the formatted error back to
106    /// the LLM.
107    ///
108    /// The limit is hashed based on the tool call name and arguments, so the limit is per tool
109    /// call.
110    ///
111    /// This limit is _not_ reset when the agent is stopped.
112    #[builder(default = 3)]
113    pub(crate) tool_retry_limit: usize,
114
115    /// Enables streaming the chat completion responses for the agent.
116    #[builder(default)]
117    pub(crate) streaming: bool,
118
119    /// When set to true, any tools in `Agent::default_tools` will be omitted. Only works if you
120    /// at at least one tool of your own.
121    #[builder(private, default)]
122    pub(crate) clear_default_tools: bool,
123
124    /// Internally tracks the amount of times a tool has been retried. The key is a hash based on
125    /// the name and args of the tool.
126    #[builder(private, default)]
127    pub(crate) tool_retries_counter: HashMap<u64, usize>,
128
129    /// The name of the agent; optional
130    #[builder(default = "unnamed_agent".into(), setter(into))]
131    pub(crate) name: String,
132}
133
134impl Clone for Agent {
135    fn clone(&self) -> Self {
136        Agent {
137            hooks: self.hooks.clone(),
138            context: Arc::new(self.context.clone()),
139            tools: self.tools.clone(),
140            toolboxes: self.toolboxes.clone(),
141            llm: self.llm.clone(),
142            system_prompt: self.system_prompt.clone(),
143            state: self.state.clone(),
144            limit: self.limit,
145            tool_retry_limit: self.tool_retry_limit,
146            tool_retries_counter: HashMap::new(),
147            streaming: self.streaming,
148            name: self.name.clone(),
149            clear_default_tools: self.clear_default_tools,
150        }
151    }
152}
153
154impl std::fmt::Debug for Agent {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        f.debug_struct("Agent")
157            .field("name", &self.name)
158            // display hooks as a list of type: number of hooks
159            .field(
160                "hooks",
161                &self
162                    .hooks
163                    .iter()
164                    .map(std::string::ToString::to_string)
165                    .collect::<Vec<_>>(),
166            )
167            .field(
168                "tools",
169                &self
170                    .tools
171                    .iter()
172                    .map(swiftide_core::Tool::name)
173                    .collect::<Vec<_>>(),
174            )
175            .field("llm", &"Box<dyn ChatCompletion>")
176            .field("state", &self.state)
177            .finish()
178    }
179}
180
181impl AgentBuilder {
182    /// The context in which the agent operates, by default this is the `DefaultContext`.
183    pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
184    where
185        Self: Clone,
186    {
187        self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
188        self
189    }
190
191    /// Returns a mutable reference to the system prompt, if it is set.
192    pub fn system_prompt_mut(&mut self) -> Option<&mut SystemPrompt> {
193        self.system_prompt.as_mut().and_then(Option::as_mut)
194    }
195
196    /// Disable the system prompt.
197    pub fn no_system_prompt(&mut self) -> &mut Self {
198        self.system_prompt = Some(None);
199
200        self
201    }
202
203    /// Add a hook to the agent.
204    pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
205        let hooks = self.hooks.get_or_insert_with(Vec::new);
206        hooks.push(hook);
207
208        self
209    }
210
211    /// Adds a tool to the agent
212    pub fn add_tool(&mut self, tool: impl Tool + 'static) -> &mut Self {
213        self.tools = Some(
214            self.tools
215                .take()
216                .unwrap_or_default()
217                .into_iter()
218                .chain([Box::new(tool) as Box<dyn Tool>])
219                .collect(),
220        );
221        self
222    }
223
224    /// Add a hook that runs once, before all completions. Even if the agent is paused and resumed,
225    /// before all will not trigger again.
226    pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
227        self.add_hook(Hook::BeforeAll(Box::new(hook)))
228    }
229
230    /// Add a hook that runs once, when the agent starts. This hook also runs if the agent stopped
231    /// and then starts again. The hook runs after any `before_all` hooks and before the
232    /// `before_completion` hooks.
233    pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
234        self.add_hook(Hook::OnStart(Box::new(hook)))
235    }
236
237    /// Add a hook that runs when the agent receives a streaming response
238    ///
239    /// The response will always include both the current accumulated message and the delta
240    ///
241    /// This will set `self.streaming` to true, there is no need to set it manually for the default
242    /// behaviour.
243    pub fn on_stream(&mut self, hook: impl OnStreamFn + 'static) -> &mut Self {
244        self.streaming = Some(true);
245        self.add_hook(Hook::OnStream(Box::new(hook)))
246    }
247
248    /// Add a hook that runs before each completion.
249    pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
250        self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
251    }
252
253    /// Add a hook that runs after each tool. The `Result<ToolOutput, ToolError>` is provided
254    /// as mut, so the tool output can be fully modified.
255    ///
256    /// The `ToolOutput` also references the original `ToolCall`, allowing you to match at runtime
257    /// what tool to interact with.
258    pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
259        self.add_hook(Hook::AfterTool(Box::new(hook)))
260    }
261
262    /// Add a hook that runs before each tool. Yields an immutable reference to the `ToolCall`.
263    pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
264        self.add_hook(Hook::BeforeTool(Box::new(hook)))
265    }
266
267    /// Add a hook that runs after each completion, before tool invocation and/or new messages.
268    pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
269        self.add_hook(Hook::AfterCompletion(Box::new(hook)))
270    }
271
272    /// Add a hook that runs after each completion, after tool invocations, right before a new loop
273    /// might start
274    pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
275        self.add_hook(Hook::AfterEach(Box::new(hook)))
276    }
277
278    /// Add a hook that runs when a new message is added to the context. Note that each tool adds a
279    /// separate message.
280    pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
281        self.add_hook(Hook::OnNewMessage(Box::new(hook)))
282    }
283
284    pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
285        self.add_hook(Hook::OnStop(Box::new(hook)))
286    }
287
288    /// Set the LLM for the agent. An LLM must implement the `ChatCompletion` trait.
289    pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
290        let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
291
292        self.llm = Some(boxed);
293        self
294    }
295
296    fn builder_default_tools(&self) -> HashSet<Box<dyn Tool>> {
297        if self.clear_default_tools.is_some_and(|b| b) {
298            HashSet::new()
299        } else {
300            Agent::default_tools()
301        }
302    }
303
304    /// Define the available tools for the agent. Tools must implement the `Tool` trait.
305    ///
306    /// See the [tool attribute macro](`swiftide_macros::tool`) and the [tool derive
307    /// macro](`swiftide_macros::Tool`) for easy ways to create (many) tools.
308    pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
309    where
310        TOOL: Into<Box<dyn Tool>>,
311    {
312        self.tools = Some(
313            tools
314                .into_iter()
315                .map(Into::into)
316                .chain(self.builder_default_tools())
317                .collect(),
318        );
319        self
320    }
321
322    /// Add a toolbox to the agent. Toolboxes are collections of tools that can be added to the
323    /// to the agent. Available tools are evaluated at runtime, when the agent starts for the first
324    /// time.
325    ///
326    /// Agents can have many toolboxes.
327    pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
328        let toolboxes = self.toolboxes.get_or_insert_with(Vec::new);
329        toolboxes.push(Box::new(toolbox));
330
331        self
332    }
333}
334
335impl Agent {
336    /// Build a new agent
337    pub fn builder() -> AgentBuilder {
338        AgentBuilder::default()
339            .tools(Agent::default_tools())
340            .to_owned()
341    }
342
343    /// The name of the agent
344    pub fn name(&self) -> &str {
345        &self.name
346    }
347
348    /// Default tools for the agent that it always includes
349    /// Right now this is the `stop` tool, which allows the agent to stop itself.
350    pub fn default_tools() -> HashSet<Box<dyn Tool>> {
351        HashSet::from([Stop::default().boxed()])
352    }
353
354    /// Run the agent with a user message. The agent will loop completions, make tool calls, until
355    /// no new messages are available.
356    ///
357    /// # Errors
358    ///
359    /// Errors if anything goes wrong, see `AgentError` for more details.
360    #[tracing::instrument(skip_all, name = "agent.query")]
361    pub async fn query(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
362        let query = query
363            .into()
364            .render()
365            .map_err(AgentError::FailedToRenderPrompt)?;
366        self.run_agent(Some(query), false).await
367    }
368
369    /// Adds a tool to an agent at run time
370    pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
371        self.tools.insert(tool);
372    }
373
374    /// Modify the tools of the agent at runtime
375    ///
376    /// Note that any mcp tools are added to the agent after the first start, and will only then
377    /// also be available here.
378    pub fn tools_mut(&mut self) -> &mut HashSet<Box<dyn Tool>> {
379        &mut self.tools
380    }
381
382    /// Run the agent with a user message once.
383    ///
384    /// # Errors
385    ///
386    /// Errors if anything goes wrong, see `AgentError` for more details.
387    #[tracing::instrument(skip_all, name = "agent.query_once")]
388    pub async fn query_once(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
389        self.run_agent(Some(query), true).await
390    }
391
392    /// Run the agent with without user message. The agent will loop completions, make tool calls,
393    /// until no new messages are available.
394    ///
395    /// # Errors
396    ///
397    /// Errors if anything goes wrong, see `AgentError` for more details.
398    #[tracing::instrument(skip_all, name = "agent.run")]
399    pub async fn run(&mut self) -> Result<(), AgentError> {
400        self.run_agent(None::<Prompt>, false).await
401    }
402
403    /// Run the agent with without user message. The agent will loop completions, make tool calls,
404    /// until
405    ///
406    /// # Errors
407    ///
408    /// Errors if anything goes wrong, see `AgentError` for more details.
409    #[tracing::instrument(skip_all, name = "agent.run_once")]
410    pub async fn run_once(&mut self) -> Result<(), AgentError> {
411        self.run_agent(None::<Prompt>, true).await
412    }
413
414    /// Retrieve the message history of the agent
415    ///
416    /// # Errors
417    ///
418    /// Error if the message history cannot be retrieved, e.g. if the context is not set up or a
419    /// connection fails
420    pub async fn history(&self) -> Result<Vec<ChatMessage>, AgentError> {
421        self.context
422            .history()
423            .await
424            .map_err(AgentError::MessageHistoryError)
425    }
426
427    pub(crate) async fn run_agent(
428        &mut self,
429        maybe_query: Option<impl Into<Prompt>>,
430        just_once: bool,
431    ) -> Result<(), AgentError> {
432        let maybe_query = maybe_query
433            .map(|q| q.into().render())
434            .transpose()
435            .map_err(AgentError::FailedToRenderPrompt)?;
436        if self.state.is_running() {
437            return Err(AgentError::AlreadyRunning);
438        }
439
440        if self.state.is_pending() {
441            if let Some(system_prompt) = &self.system_prompt {
442                self.context
443                    .add_messages(vec![ChatMessage::System(
444                        system_prompt
445                            .to_prompt()
446                            .render()
447                            .map_err(AgentError::FailedToRenderSystemPrompt)?,
448                    )])
449                    .await
450                    .map_err(AgentError::MessageHistoryError)?;
451            }
452
453            invoke_hooks!(BeforeAll, self);
454
455            self.load_toolboxes().await?;
456        }
457
458        if let Some(query) = maybe_query {
459            self.context
460                .add_message(ChatMessage::User(query))
461                .await
462                .map_err(AgentError::MessageHistoryError)?;
463        }
464
465        invoke_hooks!(OnStart, self);
466
467        self.state = state::State::Running;
468
469        let mut loop_counter = 0;
470
471        while let Some(messages) = self
472            .context
473            .next_completion()
474            .await
475            .map_err(AgentError::MessageHistoryError)?
476        {
477            if let Some(limit) = self.limit
478                && loop_counter >= limit
479            {
480                tracing::warn!("Agent loop limit reached");
481                break;
482            }
483
484            // If the last message contains tool calls that have not been completed,
485            // run the tools first
486            if let Some(&ChatMessage::Assistant(.., Some(ref tool_calls))) =
487                maybe_tool_call_without_output(&messages)
488            {
489                tracing::debug!("Uncompleted tool calls found; invoking tools");
490                self.invoke_tools(tool_calls).await?;
491                // Move on to the next tick, so that the
492                continue;
493            }
494
495            let result = self.run_completions(&messages).await;
496
497            if let Err(err) = result {
498                self.stop_with_error(&err).await;
499                tracing::error!(error = ?err, "Agent stopped with error {err}");
500                return Err(err);
501            }
502
503            if just_once || self.state.is_stopped() {
504                break;
505            }
506            loop_counter += 1;
507        }
508
509        // If there are no new messages, ensure we update our state
510        self.stop(StopReason::NoNewMessages).await;
511
512        Ok(())
513    }
514
515    #[tracing::instrument(skip_all, err)]
516    async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<(), AgentError> {
517        debug!(
518            tools = ?self
519                .tools
520                .iter()
521                .map(|t| t.name())
522                .collect::<Vec<_>>()
523                ,
524            "Running completion for agent with {} new messages",
525            messages.len()
526        );
527
528        let mut chat_completion_request = ChatCompletionRequest::builder()
529            .messages(messages)
530            .tools_spec(
531                self.tools
532                    .iter()
533                    .map(swiftide_core::Tool::tool_spec)
534                    .collect::<HashSet<_>>(),
535            )
536            .build()
537            .map_err(AgentError::FailedToBuildRequest)?;
538
539        invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
540
541        debug!(
542            "Calling LLM with the following new messages:\n {}",
543            self.context
544                .current_new_messages()
545                .await
546                .map_err(AgentError::MessageHistoryError)?
547                .iter()
548                .map(ToString::to_string)
549                .collect::<Vec<_>>()
550                .join(",\n")
551        );
552
553        let mut response = if self.streaming {
554            let mut last_response = None;
555            let mut stream = self.llm.complete_stream(&chat_completion_request).await;
556
557            while let Some(response) = stream.next().await {
558                let response = response.map_err(AgentError::CompletionsFailed)?;
559                invoke_hooks!(OnStream, self, &response);
560                last_response = Some(response);
561            }
562            tracing::trace!(?last_response, "Streaming completed");
563            last_response.ok_or(AgentError::EmptyStream)
564        } else {
565            self.llm
566                .complete(&chat_completion_request)
567                .await
568                .map_err(AgentError::CompletionsFailed)
569        }?;
570
571        // The arg preprocessor helps avoid common llm errors.
572        // This must happen as early as possible
573        response
574            .tool_calls
575            .as_deref_mut()
576            .map(ArgPreprocessor::preprocess_tool_calls);
577
578        invoke_hooks!(AfterCompletion, self, &mut response);
579
580        self.add_message(ChatMessage::Assistant(
581            response.message,
582            response.tool_calls.clone(),
583        ))
584        .await?;
585
586        if let Some(tool_calls) = response.tool_calls {
587            self.invoke_tools(&tool_calls).await?;
588        }
589
590        invoke_hooks!(AfterEach, self);
591
592        Ok(())
593    }
594
595    async fn invoke_tools(&mut self, tool_calls: &[ToolCall]) -> Result<(), AgentError> {
596        tracing::debug!("LLM returned tool calls: {:?}", tool_calls);
597
598        let mut handles = vec![];
599        for tool_call in tool_calls {
600            let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
601                tracing::warn!("Tool {} not found", tool_call.name());
602                continue;
603            };
604            tracing::info!("Calling tool `{}`", tool_call.name());
605
606            // let tool_args = tool_call.args().map(String::from);
607            let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
608
609            invoke_hooks!(BeforeTool, self, &tool_call);
610
611            let tool_span = tracing::info_span!(
612                "tool",
613                "otel.name" = format!("tool.{}", tool.name().as_ref())
614            );
615
616            let handle_tool_call = tool_call.clone();
617            let handle = tokio::spawn(async move {
618                    let handle_tool_call = handle_tool_call;
619                    let output = tool.invoke(&*context, &handle_tool_call)
620                        .await
621                        .map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
622
623                    tracing::debug!(output = output.to_string(), args = ?handle_tool_call.args(), tool_name = tool.name().as_ref(), "Completed tool call");
624
625                    Ok(output)
626                }.instrument(tool_span.or_current()));
627
628            handles.push((handle, tool_call));
629        }
630
631        for (handle, tool_call) in handles {
632            let mut output = handle
633                .await
634                .map_err(|err| AgentError::ToolFailedToJoin(tool_call.name().to_string(), err))?;
635
636            invoke_hooks!(AfterTool, self, &tool_call, &mut output);
637
638            if let Err(error) = output {
639                let stop = self.tool_calls_over_limit(tool_call);
640                if stop {
641                    tracing::error!(
642                        ?error,
643                        "Tool call failed, retry limit reached, stopping agent: {error}",
644                    );
645                } else {
646                    tracing::warn!(
647                        ?error,
648                        tool_call = ?tool_call,
649                        "Tool call failed, retrying",
650                    );
651                }
652                self.add_message(ChatMessage::ToolOutput(
653                    tool_call.clone(),
654                    ToolOutput::fail(error.to_string()),
655                ))
656                .await?;
657                if stop {
658                    self.stop(StopReason::ToolCallsOverLimit(tool_call.to_owned()))
659                        .await;
660                    return Err(error.into());
661                }
662                continue;
663            }
664
665            let output = output?;
666            self.handle_control_tools(tool_call, &output).await;
667
668            // Feedback required leaves the tool call open
669            //
670            // It assumes a follow up invocation of the agent will have the feedback approved
671            if !output.is_feedback_required() {
672                self.add_message(ChatMessage::ToolOutput(tool_call.to_owned(), output))
673                    .await?;
674            }
675        }
676
677        Ok(())
678    }
679
680    fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
681        self.hooks
682            .iter()
683            .filter(|h| hook_type == (*h).into())
684            .collect()
685    }
686
687    fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
688        self.tools
689            .iter()
690            .find(|tool| tool.name() == tool_name)
691            .cloned()
692    }
693
694    // Handle any tool specific output (e.g. stop)
695    async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
696        match output {
697            ToolOutput::Stop(maybe_message) => {
698                tracing::warn!("Stop tool called, stopping agent");
699                self.stop(StopReason::RequestedByTool(
700                    tool_call.clone(),
701                    maybe_message.clone(),
702                ))
703                .await;
704            }
705            ToolOutput::FeedbackRequired(maybe_payload) => {
706                tracing::warn!("Feedback required, stopping agent");
707                self.stop(StopReason::FeedbackRequired {
708                    tool_call: tool_call.clone(),
709                    payload: maybe_payload.clone(),
710                })
711                .await;
712            }
713            ToolOutput::AgentFailed(output) => {
714                tracing::warn!("Agent failed, stopping agent");
715                self.stop(StopReason::AgentFailed(output.clone())).await;
716            }
717            _ => (),
718        }
719    }
720
721    fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
722        let mut s = DefaultHasher::new();
723        tool_call.hash(&mut s);
724        let hash = s.finish();
725
726        if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
727            let val = *retries >= self.tool_retry_limit;
728            *retries += 1;
729            val
730        } else {
731            self.tool_retries_counter.insert(hash, 1);
732            false
733        }
734    }
735
736    /// Add a message to the agent's context
737    ///
738    /// This will trigger a `OnNewMessage` hook if its present.
739    ///
740    /// If you want to add a message without triggering the hook, use the context directly.
741    ///
742    /// # Errors
743    ///
744    /// Errors if the message cannot be added to the context. With the default in memory context
745    /// that is not supposed to happen.
746    #[tracing::instrument(skip_all, fields(message = message.to_string()))]
747    pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
748        invoke_hooks!(OnNewMessage, self, &mut message);
749
750        self.context
751            .add_message(message)
752            .await
753            .map_err(AgentError::MessageHistoryError)?;
754        Ok(())
755    }
756
757    /// Tell the agent to stop. It will finish it's current loop and then stop.
758    pub async fn stop(&mut self, reason: impl Into<StopReason>) {
759        if self.state.is_stopped() {
760            return;
761        }
762        let reason = reason.into();
763        invoke_hooks!(OnStop, self, reason.clone(), None);
764
765        self.state = state::State::Stopped(reason);
766    }
767
768    pub async fn stop_with_error(&mut self, error: &AgentError) {
769        if self.state.is_stopped() {
770            return;
771        }
772        invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
773
774        self.state = state::State::Stopped(StopReason::Error);
775    }
776
777    /// Access the agent's context
778    pub fn context(&self) -> &dyn AgentContext {
779        &self.context
780    }
781
782    /// The agent is still running
783    pub fn is_running(&self) -> bool {
784        self.state.is_running()
785    }
786
787    /// The agent stopped
788    pub fn is_stopped(&self) -> bool {
789        self.state.is_stopped()
790    }
791
792    /// The agent has not (ever) started
793    pub fn is_pending(&self) -> bool {
794        self.state.is_pending()
795    }
796
797    /// Get a list of tools available to the agent
798    pub fn tools(&self) -> &HashSet<Box<dyn Tool>> {
799        &self.tools
800    }
801
802    pub fn state(&self) -> &state::State {
803        &self.state
804    }
805
806    pub fn stop_reason(&self) -> Option<&StopReason> {
807        self.state.stop_reason()
808    }
809
810    async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
811        for toolbox in &self.toolboxes {
812            let tools = toolbox
813                .available_tools()
814                .await
815                .map_err(AgentError::ToolBoxFailedToLoad)?;
816            self.tools.extend(tools);
817        }
818
819        Ok(())
820    }
821}
822
823/// Reverse searches through messages, if it encounters a tool call before encountering an output,
824/// it will return the chat message with the tool calls, otherwise it returns None
825fn maybe_tool_call_without_output(messages: &[ChatMessage]) -> Option<&ChatMessage> {
826    for message in messages.iter().rev() {
827        if let ChatMessage::ToolOutput(..) = message {
828            return None;
829        }
830
831        if let ChatMessage::Assistant(.., Some(tool_calls)) = message
832            && !tool_calls.is_empty()
833        {
834            return Some(message);
835        }
836    }
837
838    None
839}
840
841#[cfg(test)]
842mod tests {
843
844    use serde::ser::Error;
845    use swiftide_core::ToolFeedback;
846    use swiftide_core::chat_completion::errors::ToolError;
847    use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
848    use swiftide_core::test_utils::MockChatCompletion;
849
850    use super::*;
851    use crate::{
852        State, assistant, chat_request, chat_response, summary, system, tool_failed, tool_output,
853        user,
854    };
855
856    use crate::test_utils::{MockHook, MockTool};
857
858    #[test_log::test(tokio::test)]
859    async fn test_agent_builder_defaults() {
860        // Create a prompt
861        let mock_llm = MockChatCompletion::new();
862
863        // Build the agent
864        let agent = Agent::builder().llm(&mock_llm).build().unwrap();
865
866        // Check that the context is the default context
867
868        // Check that the default tools are added
869        assert!(agent.find_tool_by_name("stop").is_some());
870
871        // Check it does not allow duplicates
872        let agent = Agent::builder()
873            .tools([Stop::default(), Stop::default()])
874            .llm(&mock_llm)
875            .build()
876            .unwrap();
877
878        assert_eq!(agent.tools.len(), 1);
879
880        // It should include the default tool if a different tool is provided
881        let agent = Agent::builder()
882            .tools([MockTool::new("mock_tool")])
883            .llm(&mock_llm)
884            .build()
885            .unwrap();
886
887        assert_eq!(agent.tools.len(), 2);
888        assert!(agent.find_tool_by_name("mock_tool").is_some());
889        assert!(agent.find_tool_by_name("stop").is_some());
890
891        assert!(agent.context().history().await.unwrap().is_empty());
892    }
893
894    #[test_log::test(tokio::test)]
895    async fn test_agent_tool_calling_loop() {
896        let prompt = "Write a poem";
897        let mock_llm = MockChatCompletion::new();
898        let mock_tool = MockTool::new("mock_tool");
899
900        let chat_request = chat_request! {
901            user!("Write a poem");
902
903            tools = [mock_tool.clone()]
904        };
905
906        let mock_tool_response = chat_response! {
907            "Roses are red";
908            tool_calls = ["mock_tool"]
909
910        };
911
912        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
913
914        let chat_request = chat_request! {
915            user!("Write a poem"),
916            assistant!("Roses are red", ["mock_tool"]),
917            tool_output!("mock_tool", "Great!");
918
919            tools = [mock_tool.clone()]
920        };
921
922        let stop_response = chat_response! {
923            "Roses are red";
924            tool_calls = ["stop"]
925        };
926
927        mock_llm.expect_complete(chat_request, Ok(stop_response));
928        mock_tool.expect_invoke_ok("Great!".into(), None);
929
930        let mut agent = Agent::builder()
931            .tools([mock_tool])
932            .llm(&mock_llm)
933            .no_system_prompt()
934            .build()
935            .unwrap();
936
937        agent.query(prompt).await.unwrap();
938    }
939
940    #[test_log::test(tokio::test)]
941    async fn test_agent_tool_run_once() {
942        let prompt = "Write a poem";
943        let mock_llm = MockChatCompletion::new();
944        let mock_tool = MockTool::default();
945
946        let chat_request = chat_request! {
947            system!("My system prompt"),
948            user!("Write a poem");
949
950            tools = [mock_tool.clone()]
951        };
952
953        let mock_tool_response = chat_response! {
954            "Roses are red";
955            tool_calls = ["mock_tool"]
956
957        };
958
959        mock_tool.expect_invoke_ok("Great!".into(), None);
960        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
961
962        let mut agent = Agent::builder()
963            .tools([mock_tool])
964            .system_prompt("My system prompt")
965            .llm(&mock_llm)
966            .build()
967            .unwrap();
968
969        agent.query_once(prompt).await.unwrap();
970    }
971
972    #[test_log::test(tokio::test)]
973    async fn test_agent_tool_via_toolbox_run_once() {
974        let prompt = "Write a poem";
975        let mock_llm = MockChatCompletion::new();
976        let mock_tool = MockTool::default();
977
978        let chat_request = chat_request! {
979            system!("My system prompt"),
980            user!("Write a poem");
981
982            tools = [mock_tool.clone()]
983        };
984
985        let mock_tool_response = chat_response! {
986            "Roses are red";
987            tool_calls = ["mock_tool"]
988
989        };
990
991        mock_tool.expect_invoke_ok("Great!".into(), None);
992        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
993
994        let mut agent = Agent::builder()
995            .add_toolbox(vec![mock_tool.boxed()])
996            .system_prompt("My system prompt")
997            .llm(&mock_llm)
998            .build()
999            .unwrap();
1000
1001        agent.query_once(prompt).await.unwrap();
1002    }
1003
1004    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1005    async fn test_multiple_tool_calls() {
1006        let prompt = "Write a poem";
1007        let mock_llm = MockChatCompletion::new();
1008        let mock_tool = MockTool::new("mock_tool1");
1009        let mock_tool2 = MockTool::new("mock_tool2");
1010
1011        let chat_request = chat_request! {
1012            system!("My system prompt"),
1013            user!("Write a poem");
1014
1015
1016
1017            tools = [mock_tool.clone(), mock_tool2.clone()]
1018        };
1019
1020        let mock_tool_response = chat_response! {
1021            "Roses are red";
1022
1023            tool_calls = ["mock_tool1", "mock_tool2"]
1024
1025        };
1026
1027        dbg!(&chat_request);
1028        mock_tool.expect_invoke_ok("Great!".into(), None);
1029        mock_tool2.expect_invoke_ok("Great!".into(), None);
1030        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1031
1032        let chat_request = chat_request! {
1033            system!("My system prompt"),
1034            user!("Write a poem"),
1035            assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
1036            tool_output!("mock_tool1", "Great!"),
1037            tool_output!("mock_tool2", "Great!");
1038
1039            tools = [mock_tool.clone(), mock_tool2.clone()]
1040        };
1041
1042        let mock_tool_response = chat_response! {
1043            "Ok!";
1044
1045            tool_calls = ["stop"]
1046
1047        };
1048
1049        mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
1050
1051        let mut agent = Agent::builder()
1052            .tools([mock_tool, mock_tool2])
1053            .system_prompt("My system prompt")
1054            .llm(&mock_llm)
1055            .build()
1056            .unwrap();
1057
1058        agent.query(prompt).await.unwrap();
1059    }
1060
1061    #[test_log::test(tokio::test)]
1062    async fn test_agent_state_machine() {
1063        let prompt = "Write a poem";
1064        let mock_llm = MockChatCompletion::new();
1065
1066        let chat_request = chat_request! {
1067            user!("Write a poem");
1068            tools = []
1069        };
1070        let mock_tool_response = chat_response! {
1071            "Roses are red";
1072            tool_calls = []
1073        };
1074
1075        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1076        let mut agent = Agent::builder()
1077            .llm(&mock_llm)
1078            .no_system_prompt()
1079            .build()
1080            .unwrap();
1081
1082        // Agent has never run and is pending
1083        assert!(agent.state.is_pending());
1084        agent.query_once(prompt).await.unwrap();
1085
1086        // Agent is stopped, there might be more messages
1087        assert!(agent.state.is_stopped());
1088    }
1089
1090    #[test_log::test(tokio::test)]
1091    async fn test_summary() {
1092        let prompt = "Write a poem";
1093        let mock_llm = MockChatCompletion::new();
1094
1095        let mock_tool_response = chat_response! {
1096            "Roses are red";
1097            tool_calls = []
1098
1099        };
1100
1101        let expected_chat_request = chat_request! {
1102            system!("My system prompt"),
1103            user!("Write a poem");
1104
1105            tools = []
1106        };
1107
1108        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1109
1110        let mut agent = Agent::builder()
1111            .system_prompt("My system prompt")
1112            .llm(&mock_llm)
1113            .build()
1114            .unwrap();
1115
1116        agent.query_once(prompt).await.unwrap();
1117
1118        agent
1119            .context
1120            .add_message(ChatMessage::new_summary("Summary"))
1121            .await
1122            .unwrap();
1123
1124        let expected_chat_request = chat_request! {
1125            system!("My system prompt"),
1126            summary!("Summary"),
1127            user!("Write another poem");
1128            tools = []
1129        };
1130        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1131
1132        agent.query_once("Write another poem").await.unwrap();
1133
1134        agent
1135            .context
1136            .add_message(ChatMessage::new_summary("Summary 2"))
1137            .await
1138            .unwrap();
1139
1140        let expected_chat_request = chat_request! {
1141            system!("My system prompt"),
1142            summary!("Summary 2"),
1143            user!("Write a third poem");
1144            tools = []
1145        };
1146        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
1147
1148        agent.query_once("Write a third poem").await.unwrap();
1149    }
1150
1151    #[test_log::test(tokio::test)]
1152    async fn test_agent_hooks() {
1153        let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
1154        let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
1155        let mock_before_completion = MockHook::new("before_completion")
1156            .expect_calls(2)
1157            .to_owned();
1158        let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
1159        let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
1160        let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
1161        let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
1162
1163        // Once for mock tool and once for stop
1164        let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
1165        let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
1166
1167        let prompt = "Write a poem";
1168        let mock_llm = MockChatCompletion::new();
1169        let mock_tool = MockTool::default();
1170
1171        let chat_request = chat_request! {
1172            user!("Write a poem");
1173
1174            tools = [mock_tool.clone()]
1175        };
1176
1177        let mock_tool_response = chat_response! {
1178            "Roses are red";
1179            tool_calls = ["mock_tool"]
1180
1181        };
1182
1183        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1184
1185        let chat_request = chat_request! {
1186            user!("Write a poem"),
1187            assistant!("Roses are red", ["mock_tool"]),
1188            tool_output!("mock_tool", "Great!");
1189
1190            tools = [mock_tool.clone()]
1191        };
1192
1193        let stop_response = chat_response! {
1194            "Roses are red";
1195            tool_calls = ["stop"]
1196        };
1197
1198        mock_llm.expect_complete(chat_request, Ok(stop_response));
1199        mock_tool.expect_invoke_ok("Great!".into(), None);
1200
1201        let mut agent = Agent::builder()
1202            .tools([mock_tool])
1203            .llm(&mock_llm)
1204            .no_system_prompt()
1205            .before_all(mock_before_all.hook_fn())
1206            .on_start(mock_on_start_fn.on_start_fn())
1207            .before_completion(mock_before_completion.before_completion_fn())
1208            .before_tool(mock_before_tool.before_tool_fn())
1209            .after_completion(mock_after_completion.after_completion_fn())
1210            .after_tool(mock_after_tool.after_tool_fn())
1211            .after_each(mock_after_each.hook_fn())
1212            .on_new_message(mock_on_message.message_hook_fn())
1213            .on_stop(mock_on_stop.stop_hook_fn())
1214            .build()
1215            .unwrap();
1216
1217        agent.query(prompt).await.unwrap();
1218    }
1219
1220    #[test_log::test(tokio::test)]
1221    async fn test_agent_loop_limit() {
1222        let prompt = "Generate content"; // Example prompt
1223        let mock_llm = MockChatCompletion::new();
1224        let mock_tool = MockTool::new("mock_tool");
1225
1226        let chat_request = chat_request! {
1227            user!(prompt);
1228            tools = [mock_tool.clone()]
1229        };
1230        mock_tool.expect_invoke_ok("Great!".into(), None);
1231
1232        let mock_tool_response = chat_response! {
1233            "Some response";
1234            tool_calls = ["mock_tool"]
1235        };
1236
1237        // Set expectations for the mock LLM responses
1238        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1239
1240        // // Response for terminating the loop
1241        let stop_response = chat_response! {
1242            "Final response";
1243            tool_calls = ["stop"]
1244        };
1245
1246        mock_llm.expect_complete(chat_request, Ok(stop_response));
1247
1248        let mut agent = Agent::builder()
1249            .tools([mock_tool])
1250            .llm(&mock_llm)
1251            .no_system_prompt()
1252            .limit(1) // Setting the loop limit to 1
1253            .build()
1254            .unwrap();
1255
1256        // Run the agent
1257        agent.query(prompt).await.unwrap();
1258
1259        // Assert that the remaining message is still in the queue
1260        let remaining = mock_llm.expectations.lock().unwrap().pop();
1261        assert!(remaining.is_some());
1262
1263        // Assert that the agent is stopped after reaching the loop limit
1264        assert!(agent.is_stopped());
1265    }
1266
1267    #[test_log::test(tokio::test)]
1268    async fn test_tool_retry_mechanism() {
1269        let prompt = "Execute my tool";
1270        let mock_llm = MockChatCompletion::new();
1271        let mock_tool = MockTool::new("retry_tool");
1272
1273        // Configure mock tool to fail twice. First time is fed back to the LLM, second time is an
1274        // error
1275        mock_tool.expect_invoke(
1276            Err(ToolError::WrongArguments(serde_json::Error::custom(
1277                "missing `query`",
1278            ))),
1279            None,
1280        );
1281        mock_tool.expect_invoke(
1282            Err(ToolError::WrongArguments(serde_json::Error::custom(
1283                "missing `query`",
1284            ))),
1285            None,
1286        );
1287
1288        let chat_request = chat_request! {
1289            user!(prompt);
1290            tools = [mock_tool.clone()]
1291        };
1292        let retry_response = chat_response! {
1293            "First failing attempt";
1294            tool_calls = ["retry_tool"]
1295        };
1296        mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1297
1298        let chat_request = chat_request! {
1299            user!(prompt),
1300            assistant!("First failing attempt", ["retry_tool"]),
1301            tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1302
1303            tools = [mock_tool.clone()]
1304        };
1305        let will_fail_response = chat_response! {
1306            "Finished execution";
1307            tool_calls = ["retry_tool"]
1308        };
1309        mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1310
1311        let mut agent = Agent::builder()
1312            .tools([mock_tool])
1313            .llm(&mock_llm)
1314            .no_system_prompt()
1315            .tool_retry_limit(1) // The test relies on a limit of 2 retries.
1316            .build()
1317            .unwrap();
1318
1319        // Run the agent
1320        let result = agent.query(prompt).await;
1321
1322        assert!(result.is_err());
1323        assert!(result.unwrap_err().to_string().contains("missing `query`"));
1324        assert!(agent.is_stopped());
1325    }
1326
1327    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1328    async fn test_streaming() {
1329        let prompt = "Generate content"; // Example prompt
1330        let mock_llm = MockChatCompletion::new();
1331        let on_stream_fn = MockHook::new("on_stream").expect_calls(3).to_owned();
1332
1333        let chat_request = chat_request! {
1334            user!(prompt);
1335
1336            tools = []
1337        };
1338
1339        let response = chat_response! {
1340            "one two three";
1341            tool_calls = ["stop"]
1342        };
1343
1344        // Set expectations for the mock LLM responses
1345        mock_llm.expect_complete(chat_request, Ok(response));
1346
1347        let mut agent = Agent::builder()
1348            .llm(&mock_llm)
1349            .on_stream(on_stream_fn.on_stream_fn())
1350            .no_system_prompt()
1351            .build()
1352            .unwrap();
1353
1354        // Run the agent
1355        agent.query(prompt).await.unwrap();
1356
1357        tracing::debug!("Agent finished running");
1358
1359        // Assert that the agent is stopped after reaching the loop limit
1360        assert!(agent.is_stopped());
1361    }
1362
1363    #[test_log::test(tokio::test)]
1364    async fn test_recovering_agent_existing_history() {
1365        // First, let's run an agent
1366        let prompt = "Write a poem";
1367        let mock_llm = MockChatCompletion::new();
1368        let mock_tool = MockTool::new("mock_tool");
1369
1370        let chat_request = chat_request! {
1371            user!("Write a poem");
1372
1373            tools = [mock_tool.clone()]
1374        };
1375
1376        let mock_tool_response = chat_response! {
1377            "Roses are red";
1378            tool_calls = ["mock_tool"]
1379
1380        };
1381
1382        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1383
1384        let chat_request = chat_request! {
1385            user!("Write a poem"),
1386            assistant!("Roses are red", ["mock_tool"]),
1387            tool_output!("mock_tool", "Great!");
1388
1389            tools = [mock_tool.clone()]
1390        };
1391
1392        let stop_response = chat_response! {
1393            "Roses are red";
1394            tool_calls = ["stop"]
1395        };
1396
1397        mock_llm.expect_complete(chat_request, Ok(stop_response));
1398        mock_tool.expect_invoke_ok("Great!".into(), None);
1399
1400        let mut agent = Agent::builder()
1401            .tools([mock_tool.clone()])
1402            .llm(&mock_llm)
1403            .no_system_prompt()
1404            .build()
1405            .unwrap();
1406
1407        agent.query(prompt).await.unwrap();
1408
1409        // Let's retrieve the history of the agent
1410        let history = agent.history().await.unwrap();
1411
1412        // Store it as a string somewhere
1413        let serialized = serde_json::to_string(&history).unwrap();
1414
1415        // Retrieve it
1416        let history: Vec<ChatMessage> = serde_json::from_str(&serialized).unwrap();
1417
1418        // Build a context from the history
1419        let context = DefaultContext::default()
1420            .with_existing_messages(history)
1421            .await
1422            .unwrap()
1423            .to_owned();
1424
1425        let stop_output = ToolOutput::stop();
1426        let expected_chat_request = chat_request! {
1427            user!("Write a poem"),
1428            assistant!("Roses are red", ["mock_tool"]),
1429            tool_output!("mock_tool", "Great!"),
1430            assistant!("Roses are red", ["stop"]),
1431            tool_output!("stop", stop_output),
1432            user!("Try again!");
1433
1434            tools = [mock_tool.clone()]
1435        };
1436
1437        let stop_response = chat_response! {
1438            "Really stopping now";
1439            tool_calls = ["stop"]
1440        };
1441
1442        mock_llm.expect_complete(expected_chat_request, Ok(stop_response));
1443
1444        let mut agent = Agent::builder()
1445            .context(context)
1446            .tools([mock_tool])
1447            .llm(&mock_llm)
1448            .no_system_prompt()
1449            .build()
1450            .unwrap();
1451
1452        agent.query_once("Try again!").await.unwrap();
1453    }
1454
1455    #[test_log::test(tokio::test)]
1456    async fn test_agent_with_approval_required_tool() {
1457        use super::*;
1458        use crate::tools::control::ApprovalRequired;
1459        use crate::{assistant, chat_request, chat_response, user};
1460        use swiftide_core::chat_completion::ToolCall;
1461
1462        // Step 1: Build a tool that needs approval.
1463        let mock_tool = MockTool::default();
1464        mock_tool.expect_invoke_ok("Great!".into(), None);
1465
1466        let approval_tool = ApprovalRequired(mock_tool.boxed());
1467
1468        // Step 2: Set up the mock LLM.
1469        let mock_llm = MockChatCompletion::new();
1470
1471        let chat_req1 = chat_request! {
1472            user!("Request with approval");
1473            tools = [approval_tool.clone()]
1474        };
1475        let chat_resp1 = chat_response! {
1476            "Completion message";
1477            tool_calls = ["mock_tool"]
1478        };
1479        mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1480
1481        // The response will include the previous request, but no tool output
1482        // from the required tool
1483        let chat_req2 = chat_request! {
1484            user!("Request with approval"),
1485            assistant!("Completion message", ["mock_tool"]),
1486            tool_output!("mock_tool", "Great!");
1487            // Simulate feedback required output
1488            tools = [approval_tool.clone()]
1489        };
1490        let chat_resp2 = chat_response! {
1491            "Post-feedback message";
1492            tool_calls = ["stop"]
1493        };
1494        mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1495
1496        // Step 3: Wire up the agent.
1497        let mut agent = Agent::builder()
1498            .tools([approval_tool])
1499            .llm(&mock_llm)
1500            .no_system_prompt()
1501            .build()
1502            .unwrap();
1503
1504        // Step 4: Run agent to trigger approval.
1505        agent.query_once("Request with approval").await.unwrap();
1506
1507        assert!(matches!(
1508            agent.state,
1509            crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1510        ));
1511
1512        let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1513        else {
1514            panic!("Expected feedback required");
1515        };
1516
1517        // Step 5: Simulate feedback, run again and assert finish.
1518        agent
1519            .context
1520            .feedback_received(&tool_call, &ToolFeedback::approved())
1521            .await
1522            .unwrap();
1523
1524        tracing::debug!("running after approval");
1525        agent.run_once().await.unwrap();
1526        assert!(agent.is_stopped());
1527    }
1528
1529    #[test_log::test(tokio::test)]
1530    async fn test_agent_with_approval_required_tool_denied() {
1531        use super::*;
1532        use crate::tools::control::ApprovalRequired;
1533        use crate::{assistant, chat_request, chat_response, user};
1534        use swiftide_core::chat_completion::ToolCall;
1535
1536        // Step 1: Build a tool that needs approval.
1537        let mock_tool = MockTool::default();
1538
1539        let approval_tool = ApprovalRequired(mock_tool.boxed());
1540
1541        // Step 2: Set up the mock LLM.
1542        let mock_llm = MockChatCompletion::new();
1543
1544        let chat_req1 = chat_request! {
1545            user!("Request with approval");
1546            tools = [approval_tool.clone()]
1547        };
1548        let chat_resp1 = chat_response! {
1549            "Completion message";
1550            tool_calls = ["mock_tool"]
1551        };
1552        mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1553
1554        // The response will include the previous request, but no tool output
1555        // from the required tool
1556        let chat_req2 = chat_request! {
1557            user!("Request with approval"),
1558            assistant!("Completion message", ["mock_tool"]),
1559            tool_output!("mock_tool", "This tool call was refused");
1560            // Simulate feedback required output
1561            tools = [approval_tool.clone()]
1562        };
1563        let chat_resp2 = chat_response! {
1564            "Post-feedback message";
1565            tool_calls = ["stop"]
1566        };
1567        mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1568
1569        // Step 3: Wire up the agent.
1570        let mut agent = Agent::builder()
1571            .tools([approval_tool])
1572            .llm(&mock_llm)
1573            .no_system_prompt()
1574            .build()
1575            .unwrap();
1576
1577        // Step 4: Run agent to trigger approval.
1578        agent.query_once("Request with approval").await.unwrap();
1579
1580        assert!(matches!(
1581            agent.state,
1582            crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1583        ));
1584
1585        let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1586        else {
1587            panic!("Expected feedback required");
1588        };
1589
1590        // Step 5: Simulate feedback, run again and assert finish.
1591        agent
1592            .context
1593            .feedback_received(&tool_call, &ToolFeedback::refused())
1594            .await
1595            .unwrap();
1596
1597        tracing::debug!("running after approval");
1598        agent.run_once().await.unwrap();
1599
1600        let history = agent.context().history().await.unwrap();
1601        history
1602            .iter()
1603            .rfind(|m| {
1604                let ChatMessage::ToolOutput(.., ToolOutput::Text(msg)) = m else {
1605                    return false;
1606                };
1607                msg.contains("refused")
1608            })
1609            .expect("Could not find refusal message");
1610
1611        assert!(agent.is_stopped());
1612    }
1613}