swiftide_agents/
agent.rs

1#![allow(dead_code)]
2use crate::{
3    default_context::DefaultContext,
4    errors::AgentError,
5    hooks::{
6        AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7        Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn, OnStreamFn,
8    },
9    invoke_hooks,
10    state::{self, StopReason},
11    system_prompt::SystemPrompt,
12    tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15    collections::{HashMap, HashSet},
16    hash::{DefaultHasher, Hash as _, Hasher as _},
17    sync::Arc,
18};
19
20use derive_builder::Builder;
21use futures_util::stream::StreamExt;
22use swiftide_core::{
23    AgentContext, ToolBox,
24    chat_completion::{
25        ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
26    },
27    prompt::Prompt,
28};
29use tracing::{Instrument, debug};
30
31/// Agents are the main interface for building agentic systems.
32///
33/// Construct agents by calling the builder, setting an llm, configure hooks, tools and other
34/// customizations.
35///
36/// # Important defaults
37///
38/// - The default context is the `DefaultContext`, executing tools locally with the `LocalExecutor`.
39/// - A default `stop` tool is provided for agents to explicitly stop if needed
40/// - The default `SystemPrompt` instructs the agent with chain of thought and some common
41///   safeguards, but is otherwise quite bare. In a lot of cases this can be sufficient.
42///
43///   Agents are *not* cheap to clone. However, if an agent gets cloned, it will operate on the
44///   same context.
45#[derive(Builder)]
46pub struct Agent {
47    /// Hooks are functions that are called at specific points in the agent's lifecycle.
48    #[builder(default, setter(into))]
49    pub(crate) hooks: Vec<Hook>,
50    /// The context in which the agent operates, by default this is the `DefaultContext`.
51    #[builder(
52        setter(custom),
53        default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
54    )]
55    pub(crate) context: Arc<dyn AgentContext>,
56    /// Tools the agent can use
57    #[builder(default = Agent::default_tools(), setter(custom))]
58    pub(crate) tools: HashSet<Box<dyn Tool>>,
59
60    /// Toolboxes are collections of tools that can be added to the agent.
61    ///
62    /// Toolboxes make their tools available to the agent at runtime.
63    #[builder(default)]
64    pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
65
66    /// The language model that the agent uses for completion.
67    #[builder(setter(custom))]
68    pub(crate) llm: Box<dyn ChatCompletion>,
69
70    /// System prompt for the agent when it starts
71    ///
72    /// Some agents profit significantly from a tailored prompt. But it is not always needed.
73    ///
74    /// See [`SystemPrompt`] for an opiniated, customizable system prompt.
75    ///
76    /// Swiftide provides a default system prompt for all agents.
77    ///
78    /// # Example
79    ///
80    /// ```no_run
81    /// # use swiftide_agents::system_prompt::SystemPrompt;
82    /// # use swiftide_agents::Agent;
83    /// Agent::builder()
84    ///     .system_prompt(
85    ///         SystemPrompt::builder().role("You are an expert engineer")
86    ///         .build().unwrap())
87    ///     .build().unwrap();
88    /// ```
89    #[builder(setter(into, strip_option), default = Some(SystemPrompt::default()))]
90    pub(crate) system_prompt: Option<SystemPrompt>,
91
92    /// Initial state of the agent
93    #[builder(private, default = state::State::default())]
94    pub(crate) state: state::State,
95
96    /// Optional limit on the amount of loops the agent can run.
97    /// The counter is reset when the agent is stopped.
98    #[builder(default, setter(strip_option))]
99    pub(crate) limit: Option<usize>,
100
101    /// The maximum amount of times the failed output of a tool will be send
102    /// to an LLM before the agent stops. Defaults to 3.
103    ///
104    /// LLMs sometimes send missing arguments, or a tool might actually fail, but retrying could be
105    /// worth while. If the limit is not reached, the agent will send the formatted error back to
106    /// the LLM.
107    ///
108    /// The limit is hashed based on the tool call name and arguments, so the limit is per tool
109    /// call.
110    ///
111    /// This limit is _not_ reset when the agent is stopped.
112    #[builder(default = 3)]
113    pub(crate) tool_retry_limit: usize,
114
115    /// Enables streaming the chat completion responses for the agent.
116    #[builder(default)]
117    pub(crate) streaming: bool,
118
119    /// When set to true, any tools in `Agent::default_tools` will be omitted. Only works if you
120    /// at at least one tool of your own.
121    #[builder(private, default)]
122    pub(crate) clear_default_tools: bool,
123
124    /// Internally tracks the amount of times a tool has been retried. The key is a hash based on
125    /// the name and args of the tool.
126    #[builder(private, default)]
127    pub(crate) tool_retries_counter: HashMap<u64, usize>,
128
129    /// The name of the agent; optional
130    #[builder(default = "unnamed_agent".into(), setter(into))]
131    pub(crate) name: String,
132}
133
134impl Clone for Agent {
135    fn clone(&self) -> Self {
136        Agent {
137            hooks: self.hooks.clone(),
138            context: Arc::new(self.context.clone()),
139            tools: self.tools.clone(),
140            toolboxes: self.toolboxes.clone(),
141            llm: self.llm.clone(),
142            system_prompt: self.system_prompt.clone(),
143            state: self.state.clone(),
144            limit: self.limit,
145            tool_retry_limit: self.tool_retry_limit,
146            tool_retries_counter: HashMap::new(),
147            streaming: self.streaming,
148            name: self.name.clone(),
149            clear_default_tools: self.clear_default_tools,
150        }
151    }
152}
153
154impl std::fmt::Debug for Agent {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        f.debug_struct("Agent")
157            .field("name", &self.name)
158            // display hooks as a list of type: number of hooks
159            .field(
160                "hooks",
161                &self
162                    .hooks
163                    .iter()
164                    .map(std::string::ToString::to_string)
165                    .collect::<Vec<_>>(),
166            )
167            .field(
168                "tools",
169                &self
170                    .tools
171                    .iter()
172                    .map(swiftide_core::Tool::name)
173                    .collect::<Vec<_>>(),
174            )
175            .field("llm", &"Box<dyn ChatCompletion>")
176            .field("state", &self.state)
177            .finish()
178    }
179}
180
181impl AgentBuilder {
182    /// The context in which the agent operates, by default this is the `DefaultContext`.
183    pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
184    where
185        Self: Clone,
186    {
187        self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
188        self
189    }
190
191    /// Returns a mutable reference to the system prompt, if it is set.
192    pub fn system_prompt_mut(&mut self) -> Option<&mut SystemPrompt> {
193        self.system_prompt.as_mut().and_then(Option::as_mut)
194    }
195
196    /// Disable the system prompt.
197    pub fn no_system_prompt(&mut self) -> &mut Self {
198        self.system_prompt = Some(None);
199
200        self
201    }
202
203    /// Add a hook to the agent.
204    pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
205        let hooks = self.hooks.get_or_insert_with(Vec::new);
206        hooks.push(hook);
207
208        self
209    }
210
211    /// Adds a tool to the agent
212    pub fn add_tool(&mut self, tool: impl Tool + 'static) -> &mut Self {
213        let tools = self.tools.get_or_insert_with(HashSet::new);
214        if let Some(tool) = tools.replace(tool.boxed()) {
215            tracing::debug!("Tool {} already exists, replacing", tool.name());
216        }
217
218        self
219    }
220
221    /// Add a hook that runs once, before all completions. Even if the agent is paused and resumed,
222    /// before all will not trigger again.
223    pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
224        self.add_hook(Hook::BeforeAll(Box::new(hook)))
225    }
226
227    /// Add a hook that runs once, when the agent starts. This hook also runs if the agent stopped
228    /// and then starts again. The hook runs after any `before_all` hooks and before the
229    /// `before_completion` hooks.
230    pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
231        self.add_hook(Hook::OnStart(Box::new(hook)))
232    }
233
234    /// Add a hook that runs when the agent receives a streaming response
235    ///
236    /// The response will always include both the current accumulated message and the delta
237    ///
238    /// This will set `self.streaming` to true, there is no need to set it manually for the default
239    /// behaviour.
240    pub fn on_stream(&mut self, hook: impl OnStreamFn + 'static) -> &mut Self {
241        self.streaming = Some(true);
242        self.add_hook(Hook::OnStream(Box::new(hook)))
243    }
244
245    /// Add a hook that runs before each completion.
246    pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
247        self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
248    }
249
250    /// Add a hook that runs after each tool. The `Result<ToolOutput, ToolError>` is provided
251    /// as mut, so the tool output can be fully modified.
252    ///
253    /// The `ToolOutput` also references the original `ToolCall`, allowing you to match at runtime
254    /// what tool to interact with.
255    pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
256        self.add_hook(Hook::AfterTool(Box::new(hook)))
257    }
258
259    /// Add a hook that runs before each tool. Yields an immutable reference to the `ToolCall`.
260    pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
261        self.add_hook(Hook::BeforeTool(Box::new(hook)))
262    }
263
264    /// Add a hook that runs after each completion, before tool invocation and/or new messages.
265    pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
266        self.add_hook(Hook::AfterCompletion(Box::new(hook)))
267    }
268
269    /// Add a hook that runs after each completion, after tool invocations, right before a new loop
270    /// might start
271    pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
272        self.add_hook(Hook::AfterEach(Box::new(hook)))
273    }
274
275    /// Add a hook that runs when a new message is added to the context. Note that each tool adds a
276    /// separate message.
277    pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
278        self.add_hook(Hook::OnNewMessage(Box::new(hook)))
279    }
280
281    pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
282        self.add_hook(Hook::OnStop(Box::new(hook)))
283    }
284
285    /// Set the LLM for the agent. An LLM must implement the `ChatCompletion` trait.
286    pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
287        let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
288
289        self.llm = Some(boxed);
290        self
291    }
292
293    /// Removes the default `stop` tool from the agent. This allows you to add your own or use
294    /// other methods to stop the agent.
295    ///
296    /// Note that you can also just override the tool if the name of the tool is `stop`.
297    pub fn without_default_stop_tool(&mut self) -> &mut Self {
298        self.clear_default_tools = Some(true);
299        self
300    }
301
302    fn builder_default_tools(&self) -> HashSet<Box<dyn Tool>> {
303        if self.clear_default_tools.is_some_and(|b| b) {
304            HashSet::new()
305        } else {
306            Agent::default_tools()
307        }
308    }
309
310    /// Define the available tools for the agent. Tools must implement the `Tool` trait.
311    ///
312    /// See the [tool attribute macro](`swiftide_macros::tool`) and the [tool derive
313    /// macro](`swiftide_macros::Tool`) for easy ways to create (many) tools.
314    pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
315    where
316        TOOL: Into<Box<dyn Tool>>,
317    {
318        self.tools = Some(
319            self.builder_default_tools()
320                .into_iter()
321                .chain(tools.into_iter().map(Into::into))
322                .collect(),
323        );
324        self
325    }
326
327    /// Add a toolbox to the agent. Toolboxes are collections of tools that can be added to the
328    /// to the agent. Available tools are evaluated at runtime, when the agent starts for the first
329    /// time.
330    ///
331    /// Agents can have many toolboxes.
332    pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
333        let toolboxes = self.toolboxes.get_or_insert_with(Vec::new);
334        toolboxes.push(Box::new(toolbox));
335
336        self
337    }
338}
339
340impl Agent {
341    /// Build a new agent
342    pub fn builder() -> AgentBuilder {
343        AgentBuilder::default()
344            .tools(Agent::default_tools())
345            .to_owned()
346    }
347
348    /// The name of the agent
349    pub fn name(&self) -> &str {
350        &self.name
351    }
352
353    /// Default tools for the agent that it always includes
354    /// Right now this is the `stop` tool, which allows the agent to stop itself.
355    pub fn default_tools() -> HashSet<Box<dyn Tool>> {
356        HashSet::from([Stop::default().boxed()])
357    }
358
359    /// Run the agent with a user message. The agent will loop completions, make tool calls, until
360    /// no new messages are available.
361    ///
362    /// # Errors
363    ///
364    /// Errors if anything goes wrong, see `AgentError` for more details.
365    #[tracing::instrument(skip_all, name = "agent.query", err)]
366    pub async fn query(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
367        let query = query
368            .into()
369            .render()
370            .map_err(AgentError::FailedToRenderPrompt)?;
371        self.run_agent(Some(query), false).await
372    }
373
374    /// Adds a tool to an agent at run time
375    pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
376        if let Some(tool) = self.tools.replace(tool) {
377            tracing::debug!("Tool {} already exists, replacing", tool.name());
378        }
379    }
380
381    /// Modify the tools of the agent at runtime
382    ///
383    /// Note that any mcp tools are added to the agent after the first start, and will only then
384    /// also be available here.
385    pub fn tools_mut(&mut self) -> &mut HashSet<Box<dyn Tool>> {
386        &mut self.tools
387    }
388
389    /// Run the agent with a user message once.
390    ///
391    /// # Errors
392    ///
393    /// Errors if anything goes wrong, see `AgentError` for more details.
394    #[tracing::instrument(skip_all, name = "agent.query_once", err)]
395    pub async fn query_once(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
396        self.run_agent(Some(query), true).await
397    }
398
399    /// Run the agent with without user message. The agent will loop completions, make tool calls,
400    /// until no new messages are available.
401    ///
402    /// # Errors
403    ///
404    /// Errors if anything goes wrong, see `AgentError` for more details.
405    #[tracing::instrument(skip_all, name = "agent.run", err)]
406    pub async fn run(&mut self) -> Result<(), AgentError> {
407        self.run_agent(None::<Prompt>, false).await
408    }
409
410    /// Run the agent with without user message. The agent will loop completions, make tool calls,
411    /// until
412    ///
413    /// # Errors
414    ///
415    /// Errors if anything goes wrong, see `AgentError` for more details.
416    #[tracing::instrument(skip_all, name = "agent.run_once", err)]
417    pub async fn run_once(&mut self) -> Result<(), AgentError> {
418        self.run_agent(None::<Prompt>, true).await
419    }
420
421    /// Retrieve the message history of the agent
422    ///
423    /// # Errors
424    ///
425    /// Error if the message history cannot be retrieved, e.g. if the context is not set up or a
426    /// connection fails
427    pub async fn history(&self) -> Result<Vec<ChatMessage>, AgentError> {
428        self.context
429            .history()
430            .await
431            .map_err(AgentError::MessageHistoryError)
432    }
433
434    pub(crate) async fn run_agent(
435        &mut self,
436        maybe_query: Option<impl Into<Prompt>>,
437        just_once: bool,
438    ) -> Result<(), AgentError> {
439        let maybe_query = maybe_query
440            .map(|q| q.into().render())
441            .transpose()
442            .map_err(AgentError::FailedToRenderPrompt)?;
443        if self.state.is_running() {
444            return Err(AgentError::AlreadyRunning);
445        }
446
447        if self.state.is_pending() {
448            if let Some(system_prompt) = &self.system_prompt {
449                self.context
450                    .add_messages(vec![ChatMessage::System(
451                        system_prompt
452                            .to_prompt()
453                            .render()
454                            .map_err(AgentError::FailedToRenderSystemPrompt)?,
455                    )])
456                    .await
457                    .map_err(AgentError::MessageHistoryError)?;
458            }
459
460            invoke_hooks!(BeforeAll, self);
461
462            self.load_toolboxes().await?;
463        }
464
465        if let Some(query) = maybe_query {
466            if cfg!(feature = "langfuse") {
467                debug!(langfuse.input = query);
468            }
469            self.context
470                .add_message(ChatMessage::User(query))
471                .await
472                .map_err(AgentError::MessageHistoryError)?;
473        }
474
475        invoke_hooks!(OnStart, self);
476
477        self.state = state::State::Running;
478
479        let mut loop_counter = 0;
480
481        while let Some(messages) = self
482            .context
483            .next_completion()
484            .await
485            .map_err(AgentError::MessageHistoryError)?
486        {
487            if let Some(limit) = self.limit
488                && loop_counter >= limit
489            {
490                tracing::warn!("Agent loop limit reached");
491                break;
492            }
493
494            // If the last message contains tool calls that have not been completed,
495            // run the tools first
496            if let Some(&ChatMessage::Assistant(.., Some(ref tool_calls))) =
497                maybe_tool_call_without_output(&messages)
498            {
499                tracing::debug!("Uncompleted tool calls found; invoking tools");
500                self.invoke_tools(tool_calls).await?;
501                // Move on to the next tick, so that the
502                continue;
503            }
504
505            let result = self.step(&messages, loop_counter).await;
506
507            if let Err(err) = result {
508                self.stop_with_error(&err).await;
509                tracing::error!(error = ?err, "Agent stopped with error {err}");
510                return Err(err);
511            }
512
513            if just_once || self.state.is_stopped() {
514                break;
515            }
516            loop_counter += 1;
517        }
518
519        // If there are no new messages, ensure we update our state
520        self.stop(StopReason::NoNewMessages).await;
521
522        Ok(())
523    }
524
525    #[tracing::instrument(skip(self, messages), err, fields(otel.name))]
526    async fn step(
527        &mut self,
528        messages: &[ChatMessage],
529        step_count: usize,
530    ) -> Result<(), AgentError> {
531        tracing::Span::current().record("otel.name", format!("step-{step_count}"));
532
533        debug!(
534            tools = ?self
535                .tools
536                .iter()
537                .map(|t| t.name())
538                .collect::<Vec<_>>()
539                ,
540            "Running completion for agent with {} new messages",
541            messages.len()
542        );
543
544        let mut chat_completion_request = ChatCompletionRequest::builder()
545            .messages(messages)
546            .tools_spec(
547                self.tools
548                    .iter()
549                    .map(swiftide_core::Tool::tool_spec)
550                    .collect::<HashSet<_>>(),
551            )
552            .build()
553            .map_err(AgentError::FailedToBuildRequest)?;
554
555        invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
556
557        debug!(
558            "Calling LLM with the following new messages:\n {}",
559            self.context
560                .current_new_messages()
561                .await
562                .map_err(AgentError::MessageHistoryError)?
563                .iter()
564                .map(ToString::to_string)
565                .collect::<Vec<_>>()
566                .join(",\n")
567        );
568
569        let mut response = if self.streaming {
570            let mut last_response = None;
571            let mut stream = self.llm.complete_stream(&chat_completion_request).await;
572
573            while let Some(response) = stream.next().await {
574                let response = response.map_err(AgentError::CompletionsFailed)?;
575                invoke_hooks!(OnStream, self, &response);
576                last_response = Some(response);
577            }
578            tracing::trace!(?last_response, "Streaming completed");
579            last_response.ok_or(AgentError::EmptyStream)
580        } else {
581            self.llm
582                .complete(&chat_completion_request)
583                .await
584                .map_err(AgentError::CompletionsFailed)
585        }?;
586
587        // The arg preprocessor helps avoid common llm errors.
588        // This must happen as early as possible
589        response
590            .tool_calls
591            .as_deref_mut()
592            .map(ArgPreprocessor::preprocess_tool_calls);
593
594        invoke_hooks!(AfterCompletion, self, &mut response);
595
596        self.add_message(ChatMessage::Assistant(
597            response.message,
598            response.tool_calls.clone(),
599        ))
600        .await?;
601
602        if let Some(tool_calls) = response.tool_calls {
603            self.invoke_tools(&tool_calls).await?;
604        }
605
606        invoke_hooks!(AfterEach, self);
607
608        Ok(())
609    }
610
611    async fn invoke_tools(&mut self, tool_calls: &[ToolCall]) -> Result<(), AgentError> {
612        tracing::debug!("LLM returned tool calls: {:?}", tool_calls);
613
614        let mut handles = vec![];
615        for tool_call in tool_calls {
616            let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
617                tracing::warn!("Tool {} not found", tool_call.name());
618                continue;
619            };
620            tracing::info!("Calling tool `{}`", tool_call.name());
621
622            // let tool_args = tool_call.args().map(String::from);
623            let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
624
625            invoke_hooks!(BeforeTool, self, &tool_call);
626
627            let tool_span = tracing::info_span!(
628                "tool",
629                "otel.name" = format!("tool.{}", tool.name().as_ref()),
630            );
631
632            let handle_tool_call = tool_call.clone();
633            let handle = tokio::spawn(async move {
634                    let handle_tool_call = handle_tool_call;
635                    let output = tool.invoke(&*context, &handle_tool_call)
636                        .await
637                        .map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
638
639                if cfg!(feature = "langfuse") {
640                    tracing::debug!(
641                        langfuse.output = %output,
642                        langfuse.input = handle_tool_call.args(),
643                        tool_name = tool.name().as_ref(),
644                    );
645                } else {
646                    tracing::debug!(output = output.to_string(), args = ?handle_tool_call.args(), tool_name = tool.name().as_ref(), "Completed tool call");
647                }
648
649                    Ok(output)
650                }.instrument(tool_span.or_current()));
651
652            handles.push((handle, tool_call));
653        }
654
655        for (handle, tool_call) in handles {
656            let mut output = handle
657                .await
658                .map_err(|err| AgentError::ToolFailedToJoin(tool_call.name().to_string(), err))?;
659
660            invoke_hooks!(AfterTool, self, &tool_call, &mut output);
661
662            if let Err(error) = output {
663                let stop = self.tool_calls_over_limit(tool_call);
664                if stop {
665                    tracing::error!(
666                        ?error,
667                        "Tool call failed, retry limit reached, stopping agent: {error}",
668                    );
669                } else {
670                    tracing::warn!(
671                        ?error,
672                        tool_call = ?tool_call,
673                        "Tool call failed, retrying",
674                    );
675                }
676                self.add_message(ChatMessage::ToolOutput(
677                    tool_call.clone(),
678                    ToolOutput::fail(error.to_string()),
679                ))
680                .await?;
681                if stop {
682                    self.stop(StopReason::ToolCallsOverLimit(tool_call.to_owned()))
683                        .await;
684                    return Err(error.into());
685                }
686                continue;
687            }
688
689            let output = output?;
690            self.handle_control_tools(tool_call, &output).await;
691
692            // Feedback required leaves the tool call open
693            //
694            // It assumes a follow up invocation of the agent will have the feedback approved
695            if !output.is_feedback_required() {
696                self.add_message(ChatMessage::ToolOutput(tool_call.to_owned(), output))
697                    .await?;
698            }
699        }
700
701        Ok(())
702    }
703
704    fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
705        self.hooks
706            .iter()
707            .filter(|h| hook_type == (*h).into())
708            .collect()
709    }
710
711    fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
712        self.tools
713            .iter()
714            .find(|tool| tool.name() == tool_name)
715            .cloned()
716    }
717
718    // Handle any tool specific output (e.g. stop)
719    async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
720        match output {
721            ToolOutput::Stop(maybe_message) => {
722                tracing::warn!("Stop tool called, stopping agent");
723                self.stop(StopReason::RequestedByTool(
724                    tool_call.clone(),
725                    maybe_message.clone(),
726                ))
727                .await;
728            }
729            ToolOutput::FeedbackRequired(maybe_payload) => {
730                tracing::warn!("Feedback required, stopping agent");
731                self.stop(StopReason::FeedbackRequired {
732                    tool_call: tool_call.clone(),
733                    payload: maybe_payload.clone(),
734                })
735                .await;
736            }
737            ToolOutput::AgentFailed(output) => {
738                tracing::warn!("Agent failed, stopping agent");
739                self.stop(StopReason::AgentFailed(output.clone())).await;
740            }
741            _ => (),
742        }
743    }
744
745    fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
746        let mut s = DefaultHasher::new();
747        tool_call.hash(&mut s);
748        let hash = s.finish();
749
750        if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
751            let val = *retries >= self.tool_retry_limit;
752            *retries += 1;
753            val
754        } else {
755            self.tool_retries_counter.insert(hash, 1);
756            false
757        }
758    }
759
760    /// Add a message to the agent's context
761    ///
762    /// This will trigger a `OnNewMessage` hook if its present.
763    ///
764    /// If you want to add a message without triggering the hook, use the context directly.
765    ///
766    /// # Errors
767    ///
768    /// Errors if the message cannot be added to the context. With the default in memory context
769    /// that is not supposed to happen.
770    #[tracing::instrument(skip_all, fields(message = message.to_string()))]
771    pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
772        invoke_hooks!(OnNewMessage, self, &mut message);
773
774        self.context
775            .add_message(message)
776            .await
777            .map_err(AgentError::MessageHistoryError)?;
778        Ok(())
779    }
780
781    /// Tell the agent to stop. It will finish it's current loop and then stop.
782    pub async fn stop(&mut self, reason: impl Into<StopReason>) {
783        if self.state.is_stopped() {
784            return;
785        }
786
787        let reason = reason.into();
788        invoke_hooks!(OnStop, self, reason.clone(), None);
789
790        if cfg!(feature = "langfuse") {
791            debug!(langfuse.output = serde_json::to_string_pretty(&reason).ok());
792        }
793
794        self.state = state::State::Stopped(reason);
795    }
796
797    pub async fn stop_with_error(&mut self, error: &AgentError) {
798        if self.state.is_stopped() {
799            return;
800        }
801        invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
802
803        self.state = state::State::Stopped(StopReason::Error);
804    }
805
806    /// Access the agent's context
807    pub fn context(&self) -> &dyn AgentContext {
808        &self.context
809    }
810
811    /// The agent is still running
812    pub fn is_running(&self) -> bool {
813        self.state.is_running()
814    }
815
816    /// The agent stopped
817    pub fn is_stopped(&self) -> bool {
818        self.state.is_stopped()
819    }
820
821    /// The agent has not (ever) started
822    pub fn is_pending(&self) -> bool {
823        self.state.is_pending()
824    }
825
826    /// Get a list of tools available to the agent
827    pub fn tools(&self) -> &HashSet<Box<dyn Tool>> {
828        &self.tools
829    }
830
831    pub fn state(&self) -> &state::State {
832        &self.state
833    }
834
835    pub fn stop_reason(&self) -> Option<&StopReason> {
836        self.state.stop_reason()
837    }
838
839    async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
840        for toolbox in &self.toolboxes {
841            let tools = toolbox
842                .available_tools()
843                .await
844                .map_err(AgentError::ToolBoxFailedToLoad)?;
845            self.tools.extend(tools);
846        }
847
848        Ok(())
849    }
850}
851
852/// Reverse searches through messages, if it encounters a tool call before encountering an output,
853/// it will return the chat message with the tool calls, otherwise it returns None
854fn maybe_tool_call_without_output(messages: &[ChatMessage]) -> Option<&ChatMessage> {
855    for message in messages.iter().rev() {
856        if let ChatMessage::ToolOutput(..) = message {
857            return None;
858        }
859
860        if let ChatMessage::Assistant(.., Some(tool_calls)) = message
861            && !tool_calls.is_empty()
862        {
863            return Some(message);
864        }
865    }
866
867    None
868}
869
870#[cfg(test)]
871mod tests {
872
873    use serde::ser::Error;
874    use swiftide_core::ToolFeedback;
875    use swiftide_core::chat_completion::errors::ToolError;
876    use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
877    use swiftide_core::test_utils::MockChatCompletion;
878
879    use super::*;
880    use crate::{
881        State, assistant, chat_request, chat_response, summary, system, tool_failed, tool_output,
882        user,
883    };
884
885    use crate::test_utils::{MockHook, MockTool};
886
887    #[test_log::test(tokio::test)]
888    async fn test_agent_builder_defaults() {
889        // Create a prompt
890        let mock_llm = MockChatCompletion::new();
891
892        // Build the agent
893        let agent = Agent::builder().llm(&mock_llm).build().unwrap();
894
895        // Check that the context is the default context
896
897        // Check that the default tools are added
898        assert!(agent.find_tool_by_name("stop").is_some());
899
900        // Check it does not allow duplicates
901        let agent = Agent::builder()
902            .tools([Stop::default(), Stop::default()])
903            .llm(&mock_llm)
904            .build()
905            .unwrap();
906
907        assert_eq!(agent.tools.len(), 1);
908
909        // It should include the default tool if a different tool is provided
910        let agent = Agent::builder()
911            .tools([MockTool::new("mock_tool")])
912            .llm(&mock_llm)
913            .build()
914            .unwrap();
915
916        assert_eq!(agent.tools.len(), 2);
917        assert!(agent.find_tool_by_name("mock_tool").is_some());
918        assert!(agent.find_tool_by_name("stop").is_some());
919
920        assert!(agent.context().history().await.unwrap().is_empty());
921    }
922
923    #[test_log::test(tokio::test)]
924    async fn test_agent_tool_calling_loop() {
925        let prompt = "Write a poem";
926        let mock_llm = MockChatCompletion::new();
927        let mock_tool = MockTool::new("mock_tool");
928
929        let chat_request = chat_request! {
930            user!("Write a poem");
931
932            tools = [mock_tool.clone()]
933        };
934
935        let mock_tool_response = chat_response! {
936            "Roses are red";
937            tool_calls = ["mock_tool"]
938
939        };
940
941        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
942
943        let chat_request = chat_request! {
944            user!("Write a poem"),
945            assistant!("Roses are red", ["mock_tool"]),
946            tool_output!("mock_tool", "Great!");
947
948            tools = [mock_tool.clone()]
949        };
950
951        let stop_response = chat_response! {
952            "Roses are red";
953            tool_calls = ["stop"]
954        };
955
956        mock_llm.expect_complete(chat_request, Ok(stop_response));
957        mock_tool.expect_invoke_ok("Great!".into(), None);
958
959        let mut agent = Agent::builder()
960            .tools([mock_tool])
961            .llm(&mock_llm)
962            .no_system_prompt()
963            .build()
964            .unwrap();
965
966        agent.query(prompt).await.unwrap();
967    }
968
969    #[test_log::test(tokio::test)]
970    async fn test_agent_tool_run_once() {
971        let prompt = "Write a poem";
972        let mock_llm = MockChatCompletion::new();
973        let mock_tool = MockTool::default();
974
975        let chat_request = chat_request! {
976            system!("My system prompt"),
977            user!("Write a poem");
978
979            tools = [mock_tool.clone()]
980        };
981
982        let mock_tool_response = chat_response! {
983            "Roses are red";
984            tool_calls = ["mock_tool"]
985
986        };
987
988        mock_tool.expect_invoke_ok("Great!".into(), None);
989        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
990
991        let mut agent = Agent::builder()
992            .tools([mock_tool])
993            .system_prompt("My system prompt")
994            .llm(&mock_llm)
995            .build()
996            .unwrap();
997
998        agent.query_once(prompt).await.unwrap();
999    }
1000
1001    #[test_log::test(tokio::test)]
1002    async fn test_agent_tool_via_toolbox_run_once() {
1003        let prompt = "Write a poem";
1004        let mock_llm = MockChatCompletion::new();
1005        let mock_tool = MockTool::default();
1006
1007        let chat_request = chat_request! {
1008            system!("My system prompt"),
1009            user!("Write a poem");
1010
1011            tools = [mock_tool.clone()]
1012        };
1013
1014        let mock_tool_response = chat_response! {
1015            "Roses are red";
1016            tool_calls = ["mock_tool"]
1017
1018        };
1019
1020        mock_tool.expect_invoke_ok("Great!".into(), None);
1021        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1022
1023        let mut agent = Agent::builder()
1024            .add_toolbox(vec![mock_tool.boxed()])
1025            .system_prompt("My system prompt")
1026            .llm(&mock_llm)
1027            .build()
1028            .unwrap();
1029
1030        agent.query_once(prompt).await.unwrap();
1031    }
1032
1033    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1034    async fn test_multiple_tool_calls() {
1035        let prompt = "Write a poem";
1036        let mock_llm = MockChatCompletion::new();
1037        let mock_tool = MockTool::new("mock_tool1");
1038        let mock_tool2 = MockTool::new("mock_tool2");
1039
1040        let chat_request = chat_request! {
1041            system!("My system prompt"),
1042            user!("Write a poem");
1043
1044
1045
1046            tools = [mock_tool.clone(), mock_tool2.clone()]
1047        };
1048
1049        let mock_tool_response = chat_response! {
1050            "Roses are red";
1051
1052            tool_calls = ["mock_tool1", "mock_tool2"]
1053
1054        };
1055
1056        dbg!(&chat_request);
1057        mock_tool.expect_invoke_ok("Great!".into(), None);
1058        mock_tool2.expect_invoke_ok("Great!".into(), None);
1059        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1060
1061        let chat_request = chat_request! {
1062            system!("My system prompt"),
1063            user!("Write a poem"),
1064            assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
1065            tool_output!("mock_tool1", "Great!"),
1066            tool_output!("mock_tool2", "Great!");
1067
1068            tools = [mock_tool.clone(), mock_tool2.clone()]
1069        };
1070
1071        let mock_tool_response = chat_response! {
1072            "Ok!";
1073
1074            tool_calls = ["stop"]
1075
1076        };
1077
1078        mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
1079
1080        let mut agent = Agent::builder()
1081            .tools([mock_tool, mock_tool2])
1082            .system_prompt("My system prompt")
1083            .llm(&mock_llm)
1084            .build()
1085            .unwrap();
1086
1087        agent.query(prompt).await.unwrap();
1088    }
1089
1090    #[test_log::test(tokio::test)]
1091    async fn test_agent_state_machine() {
1092        let prompt = "Write a poem";
1093        let mock_llm = MockChatCompletion::new();
1094
1095        let chat_request = chat_request! {
1096            user!("Write a poem");
1097            tools = []
1098        };
1099        let mock_tool_response = chat_response! {
1100            "Roses are red";
1101            tool_calls = []
1102        };
1103
1104        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1105        let mut agent = Agent::builder()
1106            .llm(&mock_llm)
1107            .no_system_prompt()
1108            .build()
1109            .unwrap();
1110
1111        // Agent has never run and is pending
1112        assert!(agent.state.is_pending());
1113        agent.query_once(prompt).await.unwrap();
1114
1115        // Agent is stopped, there might be more messages
1116        assert!(agent.state.is_stopped());
1117    }
1118
1119    #[test_log::test(tokio::test)]
1120    async fn test_summary() {
1121        let prompt = "Write a poem";
1122        let mock_llm = MockChatCompletion::new();
1123
1124        let mock_tool_response = chat_response! {
1125            "Roses are red";
1126            tool_calls = []
1127
1128        };
1129
1130        let expected_chat_request = chat_request! {
1131            system!("My system prompt"),
1132            user!("Write a poem");
1133
1134            tools = []
1135        };
1136
1137        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1138
1139        let mut agent = Agent::builder()
1140            .system_prompt("My system prompt")
1141            .llm(&mock_llm)
1142            .build()
1143            .unwrap();
1144
1145        agent.query_once(prompt).await.unwrap();
1146
1147        agent
1148            .context
1149            .add_message(ChatMessage::new_summary("Summary"))
1150            .await
1151            .unwrap();
1152
1153        let expected_chat_request = chat_request! {
1154            system!("My system prompt"),
1155            summary!("Summary"),
1156            user!("Write another poem");
1157            tools = []
1158        };
1159        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1160
1161        agent.query_once("Write another poem").await.unwrap();
1162
1163        agent
1164            .context
1165            .add_message(ChatMessage::new_summary("Summary 2"))
1166            .await
1167            .unwrap();
1168
1169        let expected_chat_request = chat_request! {
1170            system!("My system prompt"),
1171            summary!("Summary 2"),
1172            user!("Write a third poem");
1173            tools = []
1174        };
1175        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
1176
1177        agent.query_once("Write a third poem").await.unwrap();
1178    }
1179
1180    #[test_log::test(tokio::test)]
1181    async fn test_agent_hooks() {
1182        let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
1183        let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
1184        let mock_before_completion = MockHook::new("before_completion")
1185            .expect_calls(2)
1186            .to_owned();
1187        let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
1188        let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
1189        let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
1190        let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
1191
1192        // Once for mock tool and once for stop
1193        let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
1194        let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
1195
1196        let prompt = "Write a poem";
1197        let mock_llm = MockChatCompletion::new();
1198        let mock_tool = MockTool::default();
1199
1200        let chat_request = chat_request! {
1201            user!("Write a poem");
1202
1203            tools = [mock_tool.clone()]
1204        };
1205
1206        let mock_tool_response = chat_response! {
1207            "Roses are red";
1208            tool_calls = ["mock_tool"]
1209
1210        };
1211
1212        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1213
1214        let chat_request = chat_request! {
1215            user!("Write a poem"),
1216            assistant!("Roses are red", ["mock_tool"]),
1217            tool_output!("mock_tool", "Great!");
1218
1219            tools = [mock_tool.clone()]
1220        };
1221
1222        let stop_response = chat_response! {
1223            "Roses are red";
1224            tool_calls = ["stop"]
1225        };
1226
1227        mock_llm.expect_complete(chat_request, Ok(stop_response));
1228        mock_tool.expect_invoke_ok("Great!".into(), None);
1229
1230        let mut agent = Agent::builder()
1231            .tools([mock_tool])
1232            .llm(&mock_llm)
1233            .no_system_prompt()
1234            .before_all(mock_before_all.hook_fn())
1235            .on_start(mock_on_start_fn.on_start_fn())
1236            .before_completion(mock_before_completion.before_completion_fn())
1237            .before_tool(mock_before_tool.before_tool_fn())
1238            .after_completion(mock_after_completion.after_completion_fn())
1239            .after_tool(mock_after_tool.after_tool_fn())
1240            .after_each(mock_after_each.hook_fn())
1241            .on_new_message(mock_on_message.message_hook_fn())
1242            .on_stop(mock_on_stop.stop_hook_fn())
1243            .build()
1244            .unwrap();
1245
1246        agent.query(prompt).await.unwrap();
1247    }
1248
1249    #[test_log::test(tokio::test)]
1250    async fn test_agent_loop_limit() {
1251        let prompt = "Generate content"; // Example prompt
1252        let mock_llm = MockChatCompletion::new();
1253        let mock_tool = MockTool::new("mock_tool");
1254
1255        let chat_request = chat_request! {
1256            user!(prompt);
1257            tools = [mock_tool.clone()]
1258        };
1259        mock_tool.expect_invoke_ok("Great!".into(), None);
1260
1261        let mock_tool_response = chat_response! {
1262            "Some response";
1263            tool_calls = ["mock_tool"]
1264        };
1265
1266        // Set expectations for the mock LLM responses
1267        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1268
1269        // // Response for terminating the loop
1270        let stop_response = chat_response! {
1271            "Final response";
1272            tool_calls = ["stop"]
1273        };
1274
1275        mock_llm.expect_complete(chat_request, Ok(stop_response));
1276
1277        let mut agent = Agent::builder()
1278            .tools([mock_tool])
1279            .llm(&mock_llm)
1280            .no_system_prompt()
1281            .limit(1) // Setting the loop limit to 1
1282            .build()
1283            .unwrap();
1284
1285        // Run the agent
1286        agent.query(prompt).await.unwrap();
1287
1288        // Assert that the remaining message is still in the queue
1289        let remaining = mock_llm.expectations.lock().unwrap().pop();
1290        assert!(remaining.is_some());
1291
1292        // Assert that the agent is stopped after reaching the loop limit
1293        assert!(agent.is_stopped());
1294    }
1295
1296    #[test_log::test(tokio::test)]
1297    async fn test_tool_retry_mechanism() {
1298        let prompt = "Execute my tool";
1299        let mock_llm = MockChatCompletion::new();
1300        let mock_tool = MockTool::new("retry_tool");
1301
1302        // Configure mock tool to fail twice. First time is fed back to the LLM, second time is an
1303        // error
1304        mock_tool.expect_invoke(
1305            Err(ToolError::WrongArguments(serde_json::Error::custom(
1306                "missing `query`",
1307            ))),
1308            None,
1309        );
1310        mock_tool.expect_invoke(
1311            Err(ToolError::WrongArguments(serde_json::Error::custom(
1312                "missing `query`",
1313            ))),
1314            None,
1315        );
1316
1317        let chat_request = chat_request! {
1318            user!(prompt);
1319            tools = [mock_tool.clone()]
1320        };
1321        let retry_response = chat_response! {
1322            "First failing attempt";
1323            tool_calls = ["retry_tool"]
1324        };
1325        mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1326
1327        let chat_request = chat_request! {
1328            user!(prompt),
1329            assistant!("First failing attempt", ["retry_tool"]),
1330            tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1331
1332            tools = [mock_tool.clone()]
1333        };
1334        let will_fail_response = chat_response! {
1335            "Finished execution";
1336            tool_calls = ["retry_tool"]
1337        };
1338        mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1339
1340        let mut agent = Agent::builder()
1341            .tools([mock_tool])
1342            .llm(&mock_llm)
1343            .no_system_prompt()
1344            .tool_retry_limit(1) // The test relies on a limit of 2 retries.
1345            .build()
1346            .unwrap();
1347
1348        // Run the agent
1349        let result = agent.query(prompt).await;
1350
1351        assert!(result.is_err());
1352        assert!(result.unwrap_err().to_string().contains("missing `query`"));
1353        assert!(agent.is_stopped());
1354    }
1355
1356    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1357    async fn test_streaming() {
1358        let prompt = "Generate content"; // Example prompt
1359        let mock_llm = MockChatCompletion::new();
1360        let on_stream_fn = MockHook::new("on_stream").expect_calls(3).to_owned();
1361
1362        let chat_request = chat_request! {
1363            user!(prompt);
1364
1365            tools = []
1366        };
1367
1368        let response = chat_response! {
1369            "one two three";
1370            tool_calls = ["stop"]
1371        };
1372
1373        // Set expectations for the mock LLM responses
1374        mock_llm.expect_complete(chat_request, Ok(response));
1375
1376        let mut agent = Agent::builder()
1377            .llm(&mock_llm)
1378            .on_stream(on_stream_fn.on_stream_fn())
1379            .no_system_prompt()
1380            .build()
1381            .unwrap();
1382
1383        // Run the agent
1384        agent.query(prompt).await.unwrap();
1385
1386        tracing::debug!("Agent finished running");
1387
1388        // Assert that the agent is stopped after reaching the loop limit
1389        assert!(agent.is_stopped());
1390    }
1391
1392    #[test_log::test(tokio::test)]
1393    async fn test_recovering_agent_existing_history() {
1394        // First, let's run an agent
1395        let prompt = "Write a poem";
1396        let mock_llm = MockChatCompletion::new();
1397        let mock_tool = MockTool::new("mock_tool");
1398
1399        let chat_request = chat_request! {
1400            user!("Write a poem");
1401
1402            tools = [mock_tool.clone()]
1403        };
1404
1405        let mock_tool_response = chat_response! {
1406            "Roses are red";
1407            tool_calls = ["mock_tool"]
1408
1409        };
1410
1411        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1412
1413        let chat_request = chat_request! {
1414            user!("Write a poem"),
1415            assistant!("Roses are red", ["mock_tool"]),
1416            tool_output!("mock_tool", "Great!");
1417
1418            tools = [mock_tool.clone()]
1419        };
1420
1421        let stop_response = chat_response! {
1422            "Roses are red";
1423            tool_calls = ["stop"]
1424        };
1425
1426        mock_llm.expect_complete(chat_request, Ok(stop_response));
1427        mock_tool.expect_invoke_ok("Great!".into(), None);
1428
1429        let mut agent = Agent::builder()
1430            .tools([mock_tool.clone()])
1431            .llm(&mock_llm)
1432            .no_system_prompt()
1433            .build()
1434            .unwrap();
1435
1436        agent.query(prompt).await.unwrap();
1437
1438        // Let's retrieve the history of the agent
1439        let history = agent.history().await.unwrap();
1440
1441        // Store it as a string somewhere
1442        let serialized = serde_json::to_string(&history).unwrap();
1443
1444        // Retrieve it
1445        let history: Vec<ChatMessage> = serde_json::from_str(&serialized).unwrap();
1446
1447        // Build a context from the history
1448        let context = DefaultContext::default()
1449            .with_existing_messages(history)
1450            .await
1451            .unwrap()
1452            .to_owned();
1453
1454        let stop_output = ToolOutput::stop();
1455        let expected_chat_request = chat_request! {
1456            user!("Write a poem"),
1457            assistant!("Roses are red", ["mock_tool"]),
1458            tool_output!("mock_tool", "Great!"),
1459            assistant!("Roses are red", ["stop"]),
1460            tool_output!("stop", stop_output),
1461            user!("Try again!");
1462
1463            tools = [mock_tool.clone()]
1464        };
1465
1466        let stop_response = chat_response! {
1467            "Really stopping now";
1468            tool_calls = ["stop"]
1469        };
1470
1471        mock_llm.expect_complete(expected_chat_request, Ok(stop_response));
1472
1473        let mut agent = Agent::builder()
1474            .context(context)
1475            .tools([mock_tool])
1476            .llm(&mock_llm)
1477            .no_system_prompt()
1478            .build()
1479            .unwrap();
1480
1481        agent.query_once("Try again!").await.unwrap();
1482    }
1483
1484    #[test_log::test(tokio::test)]
1485    async fn test_agent_with_approval_required_tool() {
1486        use super::*;
1487        use crate::tools::control::ApprovalRequired;
1488        use crate::{assistant, chat_request, chat_response, user};
1489        use swiftide_core::chat_completion::ToolCall;
1490
1491        // Step 1: Build a tool that needs approval.
1492        let mock_tool = MockTool::default();
1493        mock_tool.expect_invoke_ok("Great!".into(), None);
1494
1495        let approval_tool = ApprovalRequired(mock_tool.boxed());
1496
1497        // Step 2: Set up the mock LLM.
1498        let mock_llm = MockChatCompletion::new();
1499
1500        let chat_req1 = chat_request! {
1501            user!("Request with approval");
1502            tools = [approval_tool.clone()]
1503        };
1504        let chat_resp1 = chat_response! {
1505            "Completion message";
1506            tool_calls = ["mock_tool"]
1507        };
1508        mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1509
1510        // The response will include the previous request, but no tool output
1511        // from the required tool
1512        let chat_req2 = chat_request! {
1513            user!("Request with approval"),
1514            assistant!("Completion message", ["mock_tool"]),
1515            tool_output!("mock_tool", "Great!");
1516            // Simulate feedback required output
1517            tools = [approval_tool.clone()]
1518        };
1519        let chat_resp2 = chat_response! {
1520            "Post-feedback message";
1521            tool_calls = ["stop"]
1522        };
1523        mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1524
1525        // Step 3: Wire up the agent.
1526        let mut agent = Agent::builder()
1527            .tools([approval_tool])
1528            .llm(&mock_llm)
1529            .no_system_prompt()
1530            .build()
1531            .unwrap();
1532
1533        // Step 4: Run agent to trigger approval.
1534        agent.query_once("Request with approval").await.unwrap();
1535
1536        assert!(matches!(
1537            agent.state,
1538            crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1539        ));
1540
1541        let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1542        else {
1543            panic!("Expected feedback required");
1544        };
1545
1546        // Step 5: Simulate feedback, run again and assert finish.
1547        agent
1548            .context
1549            .feedback_received(&tool_call, &ToolFeedback::approved())
1550            .await
1551            .unwrap();
1552
1553        tracing::debug!("running after approval");
1554        agent.run_once().await.unwrap();
1555        assert!(agent.is_stopped());
1556    }
1557
1558    #[test_log::test(tokio::test)]
1559    async fn test_agent_with_approval_required_tool_denied() {
1560        use super::*;
1561        use crate::tools::control::ApprovalRequired;
1562        use crate::{assistant, chat_request, chat_response, user};
1563        use swiftide_core::chat_completion::ToolCall;
1564
1565        // Step 1: Build a tool that needs approval.
1566        let mock_tool = MockTool::default();
1567
1568        let approval_tool = ApprovalRequired(mock_tool.boxed());
1569
1570        // Step 2: Set up the mock LLM.
1571        let mock_llm = MockChatCompletion::new();
1572
1573        let chat_req1 = chat_request! {
1574            user!("Request with approval");
1575            tools = [approval_tool.clone()]
1576        };
1577        let chat_resp1 = chat_response! {
1578            "Completion message";
1579            tool_calls = ["mock_tool"]
1580        };
1581        mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1582
1583        // The response will include the previous request, but no tool output
1584        // from the required tool
1585        let chat_req2 = chat_request! {
1586            user!("Request with approval"),
1587            assistant!("Completion message", ["mock_tool"]),
1588            tool_output!("mock_tool", "This tool call was refused");
1589            // Simulate feedback required output
1590            tools = [approval_tool.clone()]
1591        };
1592        let chat_resp2 = chat_response! {
1593            "Post-feedback message";
1594            tool_calls = ["stop"]
1595        };
1596        mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1597
1598        // Step 3: Wire up the agent.
1599        let mut agent = Agent::builder()
1600            .tools([approval_tool])
1601            .llm(&mock_llm)
1602            .no_system_prompt()
1603            .build()
1604            .unwrap();
1605
1606        // Step 4: Run agent to trigger approval.
1607        agent.query_once("Request with approval").await.unwrap();
1608
1609        assert!(matches!(
1610            agent.state,
1611            crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1612        ));
1613
1614        let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1615        else {
1616            panic!("Expected feedback required");
1617        };
1618
1619        // Step 5: Simulate feedback, run again and assert finish.
1620        agent
1621            .context
1622            .feedback_received(&tool_call, &ToolFeedback::refused())
1623            .await
1624            .unwrap();
1625
1626        tracing::debug!("running after approval");
1627        agent.run_once().await.unwrap();
1628
1629        let history = agent.context().history().await.unwrap();
1630        history
1631            .iter()
1632            .rfind(|m| {
1633                let ChatMessage::ToolOutput(.., ToolOutput::Text(msg)) = m else {
1634                    return false;
1635                };
1636                msg.contains("refused")
1637            })
1638            .expect("Could not find refusal message");
1639
1640        assert!(agent.is_stopped());
1641    }
1642
1643    #[test_log::test(tokio::test)]
1644    async fn test_removing_default_stop_tool() {
1645        let mock_llm = MockChatCompletion::new();
1646        let mock_tool = MockTool::new("mock_tool");
1647
1648        // Build agent with without_default_stop_tool
1649        let agent = Agent::builder()
1650            .without_default_stop_tool()
1651            .tools([mock_tool.clone()])
1652            .llm(&mock_llm)
1653            .no_system_prompt()
1654            .build()
1655            .unwrap();
1656
1657        // Check that "stop" tool is NOT included
1658        assert!(agent.find_tool_by_name("stop").is_none());
1659        // Check that our provided tool is still present
1660        assert!(agent.find_tool_by_name("mock_tool").is_some());
1661    }
1662}