swiftide_agents/
agent.rs

1#![allow(dead_code)]
2use crate::{
3    default_context::DefaultContext,
4    errors::AgentError,
5    hooks::{
6        AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7        Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn, OnStreamFn,
8    },
9    invoke_hooks,
10    state::{self, StopReason},
11    system_prompt::SystemPrompt,
12    tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15    collections::{HashMap, HashSet},
16    hash::{DefaultHasher, Hash as _, Hasher as _},
17    sync::Arc,
18};
19
20use derive_builder::Builder;
21use futures_util::stream::StreamExt;
22use swiftide_core::{
23    AgentContext, ToolBox,
24    chat_completion::{
25        ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
26    },
27    prompt::Prompt,
28};
29use tracing::{Instrument, debug};
30
31/// Agents are the main interface for building agentic systems.
32///
33/// Construct agents by calling the builder, setting an llm, configure hooks, tools and other
34/// customizations.
35///
36/// # Important defaults
37///
38/// - The default context is the `DefaultContext`, executing tools locally with the `LocalExecutor`.
39/// - A default `stop` tool is provided for agents to explicitly stop if needed
40/// - The default `SystemPrompt` instructs the agent with chain of thought and some common
41///   safeguards, but is otherwise quite bare. In a lot of cases this can be sufficient.
42///
43///   Agents are *not* cheap to clone. However, if an agent gets cloned, it will operate on the
44///   same context.
45#[derive(Builder)]
46pub struct Agent {
47    /// Hooks are functions that are called at specific points in the agent's lifecycle.
48    #[builder(default, setter(into))]
49    pub(crate) hooks: Vec<Hook>,
50    /// The context in which the agent operates, by default this is the `DefaultContext`.
51    #[builder(
52        setter(custom),
53        default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
54    )]
55    pub(crate) context: Arc<dyn AgentContext>,
56    /// Tools the agent can use
57    #[builder(default = Agent::default_tools(), setter(custom))]
58    pub(crate) tools: HashSet<Box<dyn Tool>>,
59
60    /// Toolboxes are collections of tools that can be added to the agent.
61    ///
62    /// Toolboxes make their tools available to the agent at runtime.
63    #[builder(default)]
64    pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
65
66    /// The language model that the agent uses for completion.
67    #[builder(setter(custom))]
68    pub(crate) llm: Box<dyn ChatCompletion>,
69
70    /// System prompt for the agent when it starts
71    ///
72    /// Some agents profit significantly from a tailored prompt. But it is not always needed.
73    ///
74    /// See [`SystemPrompt`] for an opiniated, customizable system prompt.
75    ///
76    /// Swiftide provides a default system prompt for all agents.
77    ///
78    /// Alternatively you can also provide a `Prompt` directly, or disable the system prompt.
79    ///
80    /// # Example
81    ///
82    /// ```no_run
83    /// # use swiftide_agents::system_prompt::SystemPrompt;
84    /// # use swiftide_agents::Agent;
85    /// Agent::builder()
86    ///     .system_prompt(
87    ///         SystemPrompt::builder().role("You are an expert engineer")
88    ///         .build().unwrap())
89    ///     .build().unwrap();
90    /// ```
91    #[builder(setter(into, strip_option), default = Some(SystemPrompt::default()))]
92    pub(crate) system_prompt: Option<SystemPrompt>,
93
94    /// Initial state of the agent
95    #[builder(private, default = state::State::default())]
96    pub(crate) state: state::State,
97
98    /// Optional limit on the amount of loops the agent can run.
99    /// The counter is reset when the agent is stopped.
100    #[builder(default, setter(strip_option))]
101    pub(crate) limit: Option<usize>,
102
103    /// The maximum amount of times the failed output of a tool will be send
104    /// to an LLM before the agent stops. Defaults to 3.
105    ///
106    /// LLMs sometimes send missing arguments, or a tool might actually fail, but retrying could be
107    /// worth while. If the limit is not reached, the agent will send the formatted error back to
108    /// the LLM.
109    ///
110    /// The limit is hashed based on the tool call name and arguments, so the limit is per tool
111    /// call.
112    ///
113    /// This limit is _not_ reset when the agent is stopped.
114    #[builder(default = 3)]
115    pub(crate) tool_retry_limit: usize,
116
117    /// Enables streaming the chat completion responses for the agent.
118    #[builder(default)]
119    pub(crate) streaming: bool,
120
121    /// When set to true, any tools in `Agent::default_tools` will be omitted. Only works if you
122    /// at at least one tool of your own.
123    #[builder(private, default)]
124    pub(crate) clear_default_tools: bool,
125
126    /// Internally tracks the amount of times a tool has been retried. The key is a hash based on
127    /// the name and args of the tool.
128    #[builder(private, default)]
129    pub(crate) tool_retries_counter: HashMap<u64, usize>,
130
131    /// The name of the agent; optional
132    #[builder(default = "unnamed_agent".into(), setter(into))]
133    pub(crate) name: String,
134}
135
136impl Clone for Agent {
137    fn clone(&self) -> Self {
138        Agent {
139            hooks: self.hooks.clone(),
140            context: Arc::new(self.context.clone()),
141            tools: self.tools.clone(),
142            toolboxes: self.toolboxes.clone(),
143            llm: self.llm.clone(),
144            system_prompt: self.system_prompt.clone(),
145            state: self.state.clone(),
146            limit: self.limit,
147            tool_retry_limit: self.tool_retry_limit,
148            tool_retries_counter: HashMap::new(),
149            streaming: self.streaming,
150            name: self.name.clone(),
151            clear_default_tools: self.clear_default_tools,
152        }
153    }
154}
155
156impl std::fmt::Debug for Agent {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        f.debug_struct("Agent")
159            .field("name", &self.name)
160            // display hooks as a list of type: number of hooks
161            .field(
162                "hooks",
163                &self
164                    .hooks
165                    .iter()
166                    .map(std::string::ToString::to_string)
167                    .collect::<Vec<_>>(),
168            )
169            .field(
170                "tools",
171                &self
172                    .tools
173                    .iter()
174                    .map(swiftide_core::Tool::name)
175                    .collect::<Vec<_>>(),
176            )
177            .field("llm", &"Box<dyn ChatCompletion>")
178            .field("state", &self.state)
179            .finish()
180    }
181}
182
183impl AgentBuilder {
184    /// The context in which the agent operates, by default this is the `DefaultContext`.
185    pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
186    where
187        Self: Clone,
188    {
189        self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
190        self
191    }
192
193    /// Returns a mutable reference to the system prompt, if it is set.
194    pub fn system_prompt_mut(&mut self) -> Option<&mut SystemPrompt> {
195        self.system_prompt.as_mut().and_then(Option::as_mut)
196    }
197
198    /// Disable the system prompt.
199    pub fn no_system_prompt(&mut self) -> &mut Self {
200        self.system_prompt = Some(None);
201
202        self
203    }
204
205    /// Add a hook to the agent.
206    pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
207        let hooks = self.hooks.get_or_insert_with(Vec::new);
208        hooks.push(hook);
209
210        self
211    }
212
213    /// Adds a tool to the agent
214    pub fn add_tool(&mut self, tool: impl Tool + 'static) -> &mut Self {
215        let tools = self.tools.get_or_insert_with(HashSet::new);
216        if let Some(tool) = tools.replace(tool.boxed()) {
217            tracing::debug!("Tool {} already exists, replacing", tool.name());
218        }
219
220        self
221    }
222
223    /// Add a hook that runs once, before all completions. Even if the agent is paused and resumed,
224    /// before all will not trigger again.
225    pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
226        self.add_hook(Hook::BeforeAll(Box::new(hook)))
227    }
228
229    /// Add a hook that runs once, when the agent starts. This hook also runs if the agent stopped
230    /// and then starts again. The hook runs after any `before_all` hooks and before the
231    /// `before_completion` hooks.
232    pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
233        self.add_hook(Hook::OnStart(Box::new(hook)))
234    }
235
236    /// Add a hook that runs when the agent receives a streaming response
237    ///
238    /// The response will always include both the current accumulated message and the delta
239    ///
240    /// This will set `self.streaming` to true, there is no need to set it manually for the default
241    /// behaviour.
242    pub fn on_stream(&mut self, hook: impl OnStreamFn + 'static) -> &mut Self {
243        self.streaming = Some(true);
244        self.add_hook(Hook::OnStream(Box::new(hook)))
245    }
246
247    /// Add a hook that runs before each completion.
248    pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
249        self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
250    }
251
252    /// Add a hook that runs after each tool. The `Result<ToolOutput, ToolError>` is provided
253    /// as mut, so the tool output can be fully modified.
254    ///
255    /// The `ToolOutput` also references the original `ToolCall`, allowing you to match at runtime
256    /// what tool to interact with.
257    pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
258        self.add_hook(Hook::AfterTool(Box::new(hook)))
259    }
260
261    /// Add a hook that runs before each tool. Yields an immutable reference to the `ToolCall`.
262    pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
263        self.add_hook(Hook::BeforeTool(Box::new(hook)))
264    }
265
266    /// Add a hook that runs after each completion, before tool invocation and/or new messages.
267    pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
268        self.add_hook(Hook::AfterCompletion(Box::new(hook)))
269    }
270
271    /// Add a hook that runs after each completion, after tool invocations, right before a new loop
272    /// might start
273    pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
274        self.add_hook(Hook::AfterEach(Box::new(hook)))
275    }
276
277    /// Add a hook that runs when a new message is added to the context. Note that each tool adds a
278    /// separate message.
279    pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
280        self.add_hook(Hook::OnNewMessage(Box::new(hook)))
281    }
282
283    pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
284        self.add_hook(Hook::OnStop(Box::new(hook)))
285    }
286
287    /// Set the LLM for the agent. An LLM must implement the `ChatCompletion` trait.
288    pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
289        let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
290
291        self.llm = Some(boxed);
292        self
293    }
294
295    /// Removes the default `stop` tool from the agent. This allows you to add your own or use
296    /// other methods to stop the agent.
297    ///
298    /// Note that you can also just override the tool if the name of the tool is `stop`.
299    pub fn without_default_stop_tool(&mut self) -> &mut Self {
300        self.clear_default_tools = Some(true);
301        self
302    }
303
304    fn builder_default_tools(&self) -> HashSet<Box<dyn Tool>> {
305        if self.clear_default_tools.is_some_and(|b| b) {
306            HashSet::new()
307        } else {
308            Agent::default_tools()
309        }
310    }
311
312    /// Define the available tools for the agent. Tools must implement the `Tool` trait.
313    ///
314    /// See the [tool attribute macro](`swiftide_macros::tool`) and the [tool derive
315    /// macro](`swiftide_macros::Tool`) for easy ways to create (many) tools.
316    pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
317    where
318        TOOL: Into<Box<dyn Tool>>,
319    {
320        self.tools = Some(
321            self.builder_default_tools()
322                .into_iter()
323                .chain(tools.into_iter().map(Into::into))
324                .collect(),
325        );
326        self
327    }
328
329    /// Add a toolbox to the agent. Toolboxes are collections of tools that can be added to the
330    /// to the agent. Available tools are evaluated at runtime, when the agent starts for the first
331    /// time.
332    ///
333    /// Agents can have many toolboxes.
334    pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
335        let toolboxes = self.toolboxes.get_or_insert_with(Vec::new);
336        toolboxes.push(Box::new(toolbox));
337
338        self
339    }
340}
341
342impl Agent {
343    /// Build a new agent
344    pub fn builder() -> AgentBuilder {
345        AgentBuilder::default()
346            .tools(Agent::default_tools())
347            .to_owned()
348    }
349
350    /// The name of the agent
351    pub fn name(&self) -> &str {
352        &self.name
353    }
354
355    /// Default tools for the agent that it always includes
356    /// Right now this is the `stop` tool, which allows the agent to stop itself.
357    pub fn default_tools() -> HashSet<Box<dyn Tool>> {
358        HashSet::from([Stop::default().boxed()])
359    }
360
361    /// Run the agent with a user message. The agent will loop completions, make tool calls, until
362    /// no new messages are available.
363    ///
364    /// # Errors
365    ///
366    /// Errors if anything goes wrong, see `AgentError` for more details.
367    #[tracing::instrument(skip_all, name = "agent.query", err)]
368    pub async fn query(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
369        let query = query
370            .into()
371            .render()
372            .map_err(AgentError::FailedToRenderPrompt)?;
373        self.run_agent(Some(query), false).await
374    }
375
376    /// Adds a tool to an agent at run time
377    pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
378        if let Some(tool) = self.tools.replace(tool) {
379            tracing::debug!("Tool {} already exists, replacing", tool.name());
380        }
381    }
382
383    /// Modify the tools of the agent at runtime
384    ///
385    /// Note that any mcp tools are added to the agent after the first start, and will only then
386    /// also be available here.
387    pub fn tools_mut(&mut self) -> &mut HashSet<Box<dyn Tool>> {
388        &mut self.tools
389    }
390
391    /// Run the agent with a user message once.
392    ///
393    /// # Errors
394    ///
395    /// Errors if anything goes wrong, see `AgentError` for more details.
396    #[tracing::instrument(skip_all, name = "agent.query_once", err)]
397    pub async fn query_once(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
398        self.run_agent(Some(query), true).await
399    }
400
401    /// Run the agent with without user message. The agent will loop completions, make tool calls,
402    /// until no new messages are available.
403    ///
404    /// # Errors
405    ///
406    /// Errors if anything goes wrong, see `AgentError` for more details.
407    #[tracing::instrument(skip_all, name = "agent.run", err)]
408    pub async fn run(&mut self) -> Result<(), AgentError> {
409        self.run_agent(None::<Prompt>, false).await
410    }
411
412    /// Run the agent with without user message. The agent will loop completions, make tool calls,
413    /// until
414    ///
415    /// # Errors
416    ///
417    /// Errors if anything goes wrong, see `AgentError` for more details.
418    #[tracing::instrument(skip_all, name = "agent.run_once", err)]
419    pub async fn run_once(&mut self) -> Result<(), AgentError> {
420        self.run_agent(None::<Prompt>, true).await
421    }
422
423    /// Retrieve the message history of the agent
424    ///
425    /// # Errors
426    ///
427    /// Error if the message history cannot be retrieved, e.g. if the context is not set up or a
428    /// connection fails
429    pub async fn history(&self) -> Result<Vec<ChatMessage>, AgentError> {
430        self.context
431            .history()
432            .await
433            .map_err(AgentError::MessageHistoryError)
434    }
435
436    pub(crate) async fn run_agent(
437        &mut self,
438        maybe_query: Option<impl Into<Prompt>>,
439        just_once: bool,
440    ) -> Result<(), AgentError> {
441        let maybe_query = maybe_query
442            .map(|q| q.into().render())
443            .transpose()
444            .map_err(AgentError::FailedToRenderPrompt)?;
445        if self.state.is_running() {
446            return Err(AgentError::AlreadyRunning);
447        }
448
449        if self.state.is_pending() {
450            if let Some(system_prompt) = &self.system_prompt {
451                self.context
452                    .add_messages(vec![ChatMessage::System(
453                        system_prompt
454                            .to_prompt()
455                            .render()
456                            .map_err(AgentError::FailedToRenderSystemPrompt)?,
457                    )])
458                    .await
459                    .map_err(AgentError::MessageHistoryError)?;
460            }
461
462            invoke_hooks!(BeforeAll, self);
463
464            self.load_toolboxes().await?;
465        }
466
467        if let Some(query) = maybe_query {
468            if cfg!(feature = "langfuse") {
469                debug!(langfuse.input = query);
470            }
471            self.context
472                .add_message(ChatMessage::User(query))
473                .await
474                .map_err(AgentError::MessageHistoryError)?;
475        }
476
477        invoke_hooks!(OnStart, self);
478
479        self.state = state::State::Running;
480
481        let mut loop_counter = 0;
482
483        while let Some(messages) = self
484            .context
485            .next_completion()
486            .await
487            .map_err(AgentError::MessageHistoryError)?
488        {
489            if let Some(limit) = self.limit
490                && loop_counter >= limit
491            {
492                tracing::warn!("Agent loop limit reached");
493                break;
494            }
495
496            // If the last message contains tool calls that have not been completed,
497            // run the tools first
498            if let Some(&ChatMessage::Assistant(.., Some(ref tool_calls))) =
499                maybe_tool_call_without_output(&messages)
500            {
501                tracing::debug!("Uncompleted tool calls found; invoking tools");
502                self.invoke_tools(tool_calls).await?;
503                // Move on to the next tick, so that the
504                continue;
505            }
506
507            let result = self.step(&messages, loop_counter).await;
508
509            if let Err(err) = result {
510                self.stop_with_error(&err).await;
511                tracing::error!(error = ?err, "Agent stopped with error {err}");
512                return Err(err);
513            }
514
515            if just_once || self.state.is_stopped() {
516                break;
517            }
518            loop_counter += 1;
519        }
520
521        // If there are no new messages, ensure we update our state
522        self.stop(StopReason::NoNewMessages).await;
523
524        Ok(())
525    }
526
527    #[tracing::instrument(skip(self, messages), err, fields(otel.name))]
528    async fn step(
529        &mut self,
530        messages: &[ChatMessage],
531        step_count: usize,
532    ) -> Result<(), AgentError> {
533        tracing::Span::current().record("otel.name", format!("step-{step_count}"));
534
535        debug!(
536            tools = ?self
537                .tools
538                .iter()
539                .map(|t| t.name())
540                .collect::<Vec<_>>()
541                ,
542            "Running completion for agent with {} new messages",
543            messages.len()
544        );
545
546        let mut chat_completion_request = ChatCompletionRequest::builder()
547            .messages(messages.to_vec())
548            .tool_specs(self.tools.iter().map(swiftide_core::Tool::tool_spec))
549            .build()
550            .map_err(AgentError::FailedToBuildRequest)?;
551
552        invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
553
554        debug!(
555            "Calling LLM with the following new messages:\n {}",
556            self.context
557                .current_new_messages()
558                .await
559                .map_err(AgentError::MessageHistoryError)?
560                .iter()
561                .map(ToString::to_string)
562                .collect::<Vec<_>>()
563                .join(",\n")
564        );
565
566        let mut response = if self.streaming {
567            let mut last_response = None;
568            let mut stream = self.llm.complete_stream(&chat_completion_request).await;
569
570            while let Some(response) = stream.next().await {
571                let response = response.map_err(AgentError::CompletionsFailed)?;
572                invoke_hooks!(OnStream, self, &response);
573                last_response = Some(response);
574            }
575            tracing::trace!(?last_response, "Streaming completed");
576            last_response.ok_or(AgentError::EmptyStream)
577        } else {
578            self.llm
579                .complete(&chat_completion_request)
580                .await
581                .map_err(AgentError::CompletionsFailed)
582        }?;
583
584        // The arg preprocessor helps avoid common llm errors.
585        // This must happen as early as possible
586        response
587            .tool_calls
588            .as_deref_mut()
589            .map(ArgPreprocessor::preprocess_tool_calls);
590
591        invoke_hooks!(AfterCompletion, self, &mut response);
592
593        self.add_message(ChatMessage::Assistant(
594            response.message,
595            response.tool_calls.clone(),
596        ))
597        .await?;
598
599        if let Some(tool_calls) = response.tool_calls {
600            self.invoke_tools(&tool_calls).await?;
601        }
602
603        invoke_hooks!(AfterEach, self);
604
605        Ok(())
606    }
607
608    async fn invoke_tools(&mut self, tool_calls: &[ToolCall]) -> Result<(), AgentError> {
609        tracing::debug!("LLM returned tool calls: {:?}", tool_calls);
610
611        let mut handles = vec![];
612        for tool_call in tool_calls {
613            let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
614                tracing::warn!("Tool {} not found", tool_call.name());
615                continue;
616            };
617            tracing::info!("Calling tool `{}`", tool_call.name());
618
619            // let tool_args = tool_call.args().map(String::from);
620            let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
621
622            invoke_hooks!(BeforeTool, self, &tool_call);
623
624            let tool_span = tracing::info_span!(
625                "tool",
626                "otel.name" = format!("tool.{}", tool.name().as_ref()),
627            );
628
629            let handle_tool_call = tool_call.clone();
630            let handle = tokio::spawn(async move {
631                    let handle_tool_call = handle_tool_call;
632                    let output = tool.invoke(&*context, &handle_tool_call)
633                        .await?;
634
635                if cfg!(feature = "langfuse") {
636                    tracing::debug!(
637                        langfuse.output = %output,
638                        langfuse.input = handle_tool_call.args(),
639                        tool_name = tool.name().as_ref(),
640                    );
641                } else {
642                    tracing::debug!(output = output.to_string(), args = ?handle_tool_call.args(), tool_name = tool.name().as_ref(), "Completed tool call");
643                }
644
645                    Ok(output)
646                }.instrument(tool_span.or_current()));
647
648            handles.push((handle, tool_call));
649        }
650
651        for (handle, tool_call) in handles {
652            let mut output = handle
653                .await
654                .map_err(|err| AgentError::ToolFailedToJoin(tool_call.name().to_string(), err))?;
655
656            invoke_hooks!(AfterTool, self, &tool_call, &mut output);
657
658            if let Err(error) = output {
659                let stop = self.tool_calls_over_limit(tool_call);
660                if stop {
661                    tracing::error!(
662                        ?error,
663                        "Tool call failed, retry limit reached, stopping agent: {error}",
664                    );
665                } else {
666                    tracing::warn!(
667                        ?error,
668                        tool_call = ?tool_call,
669                        "Tool call failed, retrying",
670                    );
671                }
672                self.add_message(ChatMessage::ToolOutput(
673                    tool_call.clone(),
674                    ToolOutput::fail(error.to_string()),
675                ))
676                .await?;
677                if stop {
678                    self.stop(StopReason::ToolCallsOverLimit(tool_call.to_owned()))
679                        .await;
680                    return Err(error.into());
681                }
682                continue;
683            }
684
685            let output = output?;
686            self.handle_control_tools(tool_call, &output).await;
687
688            // Feedback required leaves the tool call open
689            //
690            // It assumes a follow up invocation of the agent will have the feedback approved
691            if !output.is_feedback_required() {
692                self.add_message(ChatMessage::ToolOutput(tool_call.to_owned(), output))
693                    .await?;
694            }
695        }
696
697        Ok(())
698    }
699
700    fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
701        self.hooks
702            .iter()
703            .filter(|h| hook_type == (*h).into())
704            .collect()
705    }
706
707    fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
708        self.tools
709            .iter()
710            .find(|tool| tool.name() == tool_name)
711            .cloned()
712    }
713
714    // Handle any tool specific output (e.g. stop)
715    async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
716        match output {
717            ToolOutput::Stop(maybe_message) => {
718                tracing::warn!("Stop tool called, stopping agent");
719                self.stop(StopReason::RequestedByTool(
720                    tool_call.clone(),
721                    maybe_message.clone(),
722                ))
723                .await;
724            }
725            ToolOutput::FeedbackRequired(maybe_payload) => {
726                tracing::warn!("Feedback required, stopping agent");
727                self.stop(StopReason::FeedbackRequired {
728                    tool_call: tool_call.clone(),
729                    payload: maybe_payload.clone(),
730                })
731                .await;
732            }
733            ToolOutput::AgentFailed(output) => {
734                tracing::warn!("Agent failed, stopping agent");
735                self.stop(StopReason::AgentFailed(output.clone())).await;
736            }
737            _ => (),
738        }
739    }
740
741    /// Retrieve the system prompt, if it is set.
742    pub fn system_prompt(&self) -> Option<&SystemPrompt> {
743        self.system_prompt.as_ref()
744    }
745
746    /// Retrieve a mutable reference to the system prompt, if it is set.
747    ///
748    /// Note that the system prompt is rendered only once, when the agent starts for the first time
749    pub fn system_prompt_mut(&mut self) -> Option<&mut SystemPrompt> {
750        self.system_prompt.as_mut()
751    }
752
753    fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
754        let mut s = DefaultHasher::new();
755        tool_call.hash(&mut s);
756        let hash = s.finish();
757
758        if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
759            let val = *retries >= self.tool_retry_limit;
760            *retries += 1;
761            val
762        } else {
763            self.tool_retries_counter.insert(hash, 1);
764            false
765        }
766    }
767
768    /// Add a message to the agent's context
769    ///
770    /// This will trigger a `OnNewMessage` hook if its present.
771    ///
772    /// If you want to add a message without triggering the hook, use the context directly.
773    ///
774    /// # Errors
775    ///
776    /// Errors if the message cannot be added to the context. With the default in memory context
777    /// that is not supposed to happen.
778    #[tracing::instrument(skip_all, fields(message = message.to_string()))]
779    pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
780        invoke_hooks!(OnNewMessage, self, &mut message);
781
782        self.context
783            .add_message(message)
784            .await
785            .map_err(AgentError::MessageHistoryError)?;
786        Ok(())
787    }
788
789    /// Tell the agent to stop. It will finish it's current loop and then stop.
790    pub async fn stop(&mut self, reason: impl Into<StopReason>) {
791        if self.state.is_stopped() {
792            return;
793        }
794
795        let reason = reason.into();
796        invoke_hooks!(OnStop, self, reason.clone(), None);
797
798        if cfg!(feature = "langfuse") {
799            debug!(langfuse.output = serde_json::to_string_pretty(&reason).ok());
800        }
801
802        self.state = state::State::Stopped(reason);
803    }
804
805    pub async fn stop_with_error(&mut self, error: &AgentError) {
806        if self.state.is_stopped() {
807            return;
808        }
809        invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
810
811        self.state = state::State::Stopped(StopReason::Error);
812    }
813
814    /// Access the agent's context
815    pub fn context(&self) -> &dyn AgentContext {
816        &self.context
817    }
818
819    /// The agent is still running
820    pub fn is_running(&self) -> bool {
821        self.state.is_running()
822    }
823
824    /// The agent stopped
825    pub fn is_stopped(&self) -> bool {
826        self.state.is_stopped()
827    }
828
829    /// The agent has not (ever) started
830    pub fn is_pending(&self) -> bool {
831        self.state.is_pending()
832    }
833
834    /// Get a list of tools available to the agent
835    pub fn tools(&self) -> &HashSet<Box<dyn Tool>> {
836        &self.tools
837    }
838
839    pub fn state(&self) -> &state::State {
840        &self.state
841    }
842
843    pub fn stop_reason(&self) -> Option<&StopReason> {
844        self.state.stop_reason()
845    }
846
847    async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
848        for toolbox in &self.toolboxes {
849            let tools = toolbox
850                .available_tools()
851                .await
852                .map_err(AgentError::ToolBoxFailedToLoad)?;
853            self.tools.extend(tools);
854        }
855
856        Ok(())
857    }
858}
859
860/// Reverse searches through messages, if it encounters a tool call before encountering an output,
861/// it will return the chat message with the tool calls, otherwise it returns None
862fn maybe_tool_call_without_output(messages: &[ChatMessage]) -> Option<&ChatMessage> {
863    for message in messages.iter().rev() {
864        if let ChatMessage::ToolOutput(..) = message {
865            return None;
866        }
867
868        if let ChatMessage::Assistant(.., Some(tool_calls)) = message
869            && !tool_calls.is_empty()
870        {
871            return Some(message);
872        }
873    }
874
875    None
876}
877
878#[cfg(test)]
879mod tests {
880
881    use serde::ser::Error;
882    use swiftide_core::ToolFeedback;
883    use swiftide_core::chat_completion::errors::ToolError;
884    use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
885    use swiftide_core::test_utils::MockChatCompletion;
886
887    use super::*;
888    use crate::{
889        State, assistant, chat_request, chat_response, summary, system, tool_failed, tool_output,
890        user,
891    };
892
893    use crate::test_utils::{MockHook, MockTool};
894
895    #[test_log::test(tokio::test)]
896    async fn test_agent_builder_defaults() {
897        // Create a prompt
898        let mock_llm = MockChatCompletion::new();
899
900        // Build the agent
901        let agent = Agent::builder().llm(&mock_llm).build().unwrap();
902
903        // Check that the context is the default context
904
905        // Check that the default tools are added
906        assert!(agent.find_tool_by_name("stop").is_some());
907
908        // Check it does not allow duplicates
909        let agent = Agent::builder()
910            .tools([Stop::default(), Stop::default()])
911            .llm(&mock_llm)
912            .build()
913            .unwrap();
914
915        assert_eq!(agent.tools.len(), 1);
916
917        // It should include the default tool if a different tool is provided
918        let agent = Agent::builder()
919            .tools([MockTool::new("mock_tool")])
920            .llm(&mock_llm)
921            .build()
922            .unwrap();
923
924        assert_eq!(agent.tools.len(), 2);
925        assert!(agent.find_tool_by_name("mock_tool").is_some());
926        assert!(agent.find_tool_by_name("stop").is_some());
927
928        assert!(agent.context().history().await.unwrap().is_empty());
929    }
930
931    #[test_log::test(tokio::test)]
932    async fn test_agent_tool_calling_loop() {
933        let prompt = "Write a poem";
934        let mock_llm = MockChatCompletion::new();
935        let mock_tool = MockTool::new("mock_tool");
936
937        let chat_request = chat_request! {
938            user!("Write a poem");
939
940            tools = [mock_tool.clone()]
941        };
942
943        let mock_tool_response = chat_response! {
944            "Roses are red";
945            tool_calls = ["mock_tool"]
946
947        };
948
949        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
950
951        let chat_request = chat_request! {
952            user!("Write a poem"),
953            assistant!("Roses are red", ["mock_tool"]),
954            tool_output!("mock_tool", "Great!");
955
956            tools = [mock_tool.clone()]
957        };
958
959        let stop_response = chat_response! {
960            "Roses are red";
961            tool_calls = ["stop"]
962        };
963
964        mock_llm.expect_complete(chat_request, Ok(stop_response));
965        mock_tool.expect_invoke_ok("Great!".into(), None);
966
967        let mut agent = Agent::builder()
968            .tools([mock_tool])
969            .llm(&mock_llm)
970            .no_system_prompt()
971            .build()
972            .unwrap();
973
974        agent.query(prompt).await.unwrap();
975    }
976
977    #[test_log::test(tokio::test)]
978    async fn test_agent_tool_run_once() {
979        let prompt = "Write a poem";
980        let mock_llm = MockChatCompletion::new();
981        let mock_tool = MockTool::default();
982
983        let chat_request = chat_request! {
984            system!("My system prompt"),
985            user!("Write a poem");
986
987            tools = [mock_tool.clone()]
988        };
989
990        let mock_tool_response = chat_response! {
991            "Roses are red";
992            tool_calls = ["mock_tool"]
993
994        };
995
996        mock_tool.expect_invoke_ok("Great!".into(), None);
997        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
998
999        let mut agent = Agent::builder()
1000            .tools([mock_tool])
1001            .system_prompt("My system prompt")
1002            .llm(&mock_llm)
1003            .build()
1004            .unwrap();
1005
1006        agent.query_once(prompt).await.unwrap();
1007    }
1008
1009    #[test_log::test(tokio::test)]
1010    async fn test_agent_tool_via_toolbox_run_once() {
1011        let prompt = "Write a poem";
1012        let mock_llm = MockChatCompletion::new();
1013        let mock_tool = MockTool::default();
1014
1015        let chat_request = chat_request! {
1016            system!("My system prompt"),
1017            user!("Write a poem");
1018
1019            tools = [mock_tool.clone()]
1020        };
1021
1022        let mock_tool_response = chat_response! {
1023            "Roses are red";
1024            tool_calls = ["mock_tool"]
1025
1026        };
1027
1028        mock_tool.expect_invoke_ok("Great!".into(), None);
1029        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1030
1031        let mut agent = Agent::builder()
1032            .add_toolbox(vec![mock_tool.boxed()])
1033            .system_prompt("My system prompt")
1034            .llm(&mock_llm)
1035            .build()
1036            .unwrap();
1037
1038        agent.query_once(prompt).await.unwrap();
1039    }
1040
1041    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1042    async fn test_multiple_tool_calls() {
1043        let prompt = "Write a poem";
1044        let mock_llm = MockChatCompletion::new();
1045        let mock_tool = MockTool::new("mock_tool1");
1046        let mock_tool2 = MockTool::new("mock_tool2");
1047
1048        let chat_request = chat_request! {
1049            system!("My system prompt"),
1050            user!("Write a poem");
1051
1052
1053
1054            tools = [mock_tool.clone(), mock_tool2.clone()]
1055        };
1056
1057        let mock_tool_response = chat_response! {
1058            "Roses are red";
1059
1060            tool_calls = ["mock_tool1", "mock_tool2"]
1061
1062        };
1063
1064        dbg!(&chat_request);
1065        mock_tool.expect_invoke_ok("Great!".into(), None);
1066        mock_tool2.expect_invoke_ok("Great!".into(), None);
1067        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1068
1069        let chat_request = chat_request! {
1070            system!("My system prompt"),
1071            user!("Write a poem"),
1072            assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
1073            tool_output!("mock_tool1", "Great!"),
1074            tool_output!("mock_tool2", "Great!");
1075
1076            tools = [mock_tool.clone(), mock_tool2.clone()]
1077        };
1078
1079        let mock_tool_response = chat_response! {
1080            "Ok!";
1081
1082            tool_calls = ["stop"]
1083
1084        };
1085
1086        mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
1087
1088        let mut agent = Agent::builder()
1089            .tools([mock_tool, mock_tool2])
1090            .system_prompt("My system prompt")
1091            .llm(&mock_llm)
1092            .build()
1093            .unwrap();
1094
1095        agent.query(prompt).await.unwrap();
1096    }
1097
1098    #[test_log::test(tokio::test)]
1099    async fn test_agent_state_machine() {
1100        let prompt = "Write a poem";
1101        let mock_llm = MockChatCompletion::new();
1102
1103        let chat_request = chat_request! {
1104            user!("Write a poem");
1105            tools = []
1106        };
1107        let mock_tool_response = chat_response! {
1108            "Roses are red";
1109            tool_calls = []
1110        };
1111
1112        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1113        let mut agent = Agent::builder()
1114            .llm(&mock_llm)
1115            .no_system_prompt()
1116            .build()
1117            .unwrap();
1118
1119        // Agent has never run and is pending
1120        assert!(agent.state.is_pending());
1121        agent.query_once(prompt).await.unwrap();
1122
1123        // Agent is stopped, there might be more messages
1124        assert!(agent.state.is_stopped());
1125    }
1126
1127    #[test_log::test(tokio::test)]
1128    async fn test_summary() {
1129        let prompt = "Write a poem";
1130        let mock_llm = MockChatCompletion::new();
1131
1132        let mock_tool_response = chat_response! {
1133            "Roses are red";
1134            tool_calls = []
1135
1136        };
1137
1138        let expected_chat_request = chat_request! {
1139            system!("My system prompt"),
1140            user!("Write a poem");
1141
1142            tools = []
1143        };
1144
1145        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1146
1147        let mut agent = Agent::builder()
1148            .system_prompt("My system prompt")
1149            .llm(&mock_llm)
1150            .build()
1151            .unwrap();
1152
1153        agent.query_once(prompt).await.unwrap();
1154
1155        agent
1156            .context
1157            .add_message(ChatMessage::new_summary("Summary"))
1158            .await
1159            .unwrap();
1160
1161        let expected_chat_request = chat_request! {
1162            system!("My system prompt"),
1163            summary!("Summary"),
1164            user!("Write another poem");
1165            tools = []
1166        };
1167        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1168
1169        agent.query_once("Write another poem").await.unwrap();
1170
1171        agent
1172            .context
1173            .add_message(ChatMessage::new_summary("Summary 2"))
1174            .await
1175            .unwrap();
1176
1177        let expected_chat_request = chat_request! {
1178            system!("My system prompt"),
1179            summary!("Summary 2"),
1180            user!("Write a third poem");
1181            tools = []
1182        };
1183        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
1184
1185        agent.query_once("Write a third poem").await.unwrap();
1186    }
1187
1188    #[test_log::test(tokio::test)]
1189    async fn test_agent_hooks() {
1190        let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
1191        let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
1192        let mock_before_completion = MockHook::new("before_completion")
1193            .expect_calls(2)
1194            .to_owned();
1195        let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
1196        let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
1197        let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
1198        let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
1199
1200        // Once for mock tool and once for stop
1201        let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
1202        let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
1203
1204        let prompt = "Write a poem";
1205        let mock_llm = MockChatCompletion::new();
1206        let mock_tool = MockTool::default();
1207
1208        let chat_request = chat_request! {
1209            user!("Write a poem");
1210
1211            tools = [mock_tool.clone()]
1212        };
1213
1214        let mock_tool_response = chat_response! {
1215            "Roses are red";
1216            tool_calls = ["mock_tool"]
1217
1218        };
1219
1220        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1221
1222        let chat_request = chat_request! {
1223            user!("Write a poem"),
1224            assistant!("Roses are red", ["mock_tool"]),
1225            tool_output!("mock_tool", "Great!");
1226
1227            tools = [mock_tool.clone()]
1228        };
1229
1230        let stop_response = chat_response! {
1231            "Roses are red";
1232            tool_calls = ["stop"]
1233        };
1234
1235        mock_llm.expect_complete(chat_request, Ok(stop_response));
1236        mock_tool.expect_invoke_ok("Great!".into(), None);
1237
1238        let mut agent = Agent::builder()
1239            .tools([mock_tool])
1240            .llm(&mock_llm)
1241            .no_system_prompt()
1242            .before_all(mock_before_all.hook_fn())
1243            .on_start(mock_on_start_fn.on_start_fn())
1244            .before_completion(mock_before_completion.before_completion_fn())
1245            .before_tool(mock_before_tool.before_tool_fn())
1246            .after_completion(mock_after_completion.after_completion_fn())
1247            .after_tool(mock_after_tool.after_tool_fn())
1248            .after_each(mock_after_each.hook_fn())
1249            .on_new_message(mock_on_message.message_hook_fn())
1250            .on_stop(mock_on_stop.stop_hook_fn())
1251            .build()
1252            .unwrap();
1253
1254        agent.query(prompt).await.unwrap();
1255    }
1256
1257    #[test_log::test(tokio::test)]
1258    async fn test_agent_loop_limit() {
1259        let prompt = "Generate content"; // Example prompt
1260        let mock_llm = MockChatCompletion::new();
1261        let mock_tool = MockTool::new("mock_tool");
1262
1263        let chat_request = chat_request! {
1264            user!(prompt);
1265            tools = [mock_tool.clone()]
1266        };
1267        mock_tool.expect_invoke_ok("Great!".into(), None);
1268
1269        let mock_tool_response = chat_response! {
1270            "Some response";
1271            tool_calls = ["mock_tool"]
1272        };
1273
1274        // Set expectations for the mock LLM responses
1275        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1276
1277        // // Response for terminating the loop
1278        let stop_response = chat_response! {
1279            "Final response";
1280            tool_calls = ["stop"]
1281        };
1282
1283        mock_llm.expect_complete(chat_request, Ok(stop_response));
1284
1285        let mut agent = Agent::builder()
1286            .tools([mock_tool])
1287            .llm(&mock_llm)
1288            .no_system_prompt()
1289            .limit(1) // Setting the loop limit to 1
1290            .build()
1291            .unwrap();
1292
1293        // Run the agent
1294        agent.query(prompt).await.unwrap();
1295
1296        // Assert that the remaining message is still in the queue
1297        let remaining = mock_llm.expectations.lock().unwrap().pop();
1298        assert!(remaining.is_some());
1299
1300        // Assert that the agent is stopped after reaching the loop limit
1301        assert!(agent.is_stopped());
1302    }
1303
1304    #[test_log::test(tokio::test)]
1305    async fn test_tool_retry_mechanism() {
1306        let prompt = "Execute my tool";
1307        let mock_llm = MockChatCompletion::new();
1308        let mock_tool = MockTool::new("retry_tool");
1309
1310        // Configure mock tool to fail twice. First time is fed back to the LLM, second time is an
1311        // error
1312        mock_tool.expect_invoke(
1313            Err(ToolError::WrongArguments(serde_json::Error::custom(
1314                "missing `query`",
1315            ))),
1316            None,
1317        );
1318        mock_tool.expect_invoke(
1319            Err(ToolError::WrongArguments(serde_json::Error::custom(
1320                "missing `query`",
1321            ))),
1322            None,
1323        );
1324
1325        let chat_request = chat_request! {
1326            user!(prompt);
1327            tools = [mock_tool.clone()]
1328        };
1329        let retry_response = chat_response! {
1330            "First failing attempt";
1331            tool_calls = ["retry_tool"]
1332        };
1333        mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1334
1335        let chat_request = chat_request! {
1336            user!(prompt),
1337            assistant!("First failing attempt", ["retry_tool"]),
1338            tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1339
1340            tools = [mock_tool.clone()]
1341        };
1342        let will_fail_response = chat_response! {
1343            "Finished execution";
1344            tool_calls = ["retry_tool"]
1345        };
1346        mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1347
1348        let mut agent = Agent::builder()
1349            .tools([mock_tool])
1350            .llm(&mock_llm)
1351            .no_system_prompt()
1352            .tool_retry_limit(1) // The test relies on a limit of 2 retries.
1353            .build()
1354            .unwrap();
1355
1356        // Run the agent
1357        let result = agent.query(prompt).await;
1358
1359        assert!(result.is_err());
1360        assert!(result.unwrap_err().to_string().contains("missing `query`"));
1361        assert!(agent.is_stopped());
1362    }
1363
1364    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1365    async fn test_streaming() {
1366        let prompt = "Generate content"; // Example prompt
1367        let mock_llm = MockChatCompletion::new();
1368        let on_stream_fn = MockHook::new("on_stream").expect_calls(3).to_owned();
1369
1370        let chat_request = chat_request! {
1371            user!(prompt);
1372
1373            tools = []
1374        };
1375
1376        let response = chat_response! {
1377            "one two three";
1378            tool_calls = ["stop"]
1379        };
1380
1381        // Set expectations for the mock LLM responses
1382        mock_llm.expect_complete(chat_request, Ok(response));
1383
1384        let mut agent = Agent::builder()
1385            .llm(&mock_llm)
1386            .on_stream(on_stream_fn.on_stream_fn())
1387            .no_system_prompt()
1388            .build()
1389            .unwrap();
1390
1391        // Run the agent
1392        agent.query(prompt).await.unwrap();
1393
1394        tracing::debug!("Agent finished running");
1395
1396        // Assert that the agent is stopped after reaching the loop limit
1397        assert!(agent.is_stopped());
1398    }
1399
1400    #[test_log::test(tokio::test)]
1401    async fn test_recovering_agent_existing_history() {
1402        // First, let's run an agent
1403        let prompt = "Write a poem";
1404        let mock_llm = MockChatCompletion::new();
1405        let mock_tool = MockTool::new("mock_tool");
1406
1407        let chat_request = chat_request! {
1408            user!("Write a poem");
1409
1410            tools = [mock_tool.clone()]
1411        };
1412
1413        let mock_tool_response = chat_response! {
1414            "Roses are red";
1415            tool_calls = ["mock_tool"]
1416
1417        };
1418
1419        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1420
1421        let chat_request = chat_request! {
1422            user!("Write a poem"),
1423            assistant!("Roses are red", ["mock_tool"]),
1424            tool_output!("mock_tool", "Great!");
1425
1426            tools = [mock_tool.clone()]
1427        };
1428
1429        let stop_response = chat_response! {
1430            "Roses are red";
1431            tool_calls = ["stop"]
1432        };
1433
1434        mock_llm.expect_complete(chat_request, Ok(stop_response));
1435        mock_tool.expect_invoke_ok("Great!".into(), None);
1436
1437        let mut agent = Agent::builder()
1438            .tools([mock_tool.clone()])
1439            .llm(&mock_llm)
1440            .no_system_prompt()
1441            .build()
1442            .unwrap();
1443
1444        agent.query(prompt).await.unwrap();
1445
1446        // Let's retrieve the history of the agent
1447        let history = agent.history().await.unwrap();
1448
1449        // Store it as a string somewhere
1450        let serialized = serde_json::to_string(&history).unwrap();
1451
1452        // Retrieve it
1453        let history: Vec<ChatMessage> = serde_json::from_str(&serialized).unwrap();
1454
1455        // Build a context from the history
1456        let context = DefaultContext::default()
1457            .with_existing_messages(history)
1458            .await
1459            .unwrap()
1460            .to_owned();
1461
1462        let stop_output = ToolOutput::stop();
1463        let expected_chat_request = chat_request! {
1464            user!("Write a poem"),
1465            assistant!("Roses are red", ["mock_tool"]),
1466            tool_output!("mock_tool", "Great!"),
1467            assistant!("Roses are red", ["stop"]),
1468            tool_output!("stop", stop_output),
1469            user!("Try again!");
1470
1471            tools = [mock_tool.clone()]
1472        };
1473
1474        let stop_response = chat_response! {
1475            "Really stopping now";
1476            tool_calls = ["stop"]
1477        };
1478
1479        mock_llm.expect_complete(expected_chat_request, Ok(stop_response));
1480
1481        let mut agent = Agent::builder()
1482            .context(context)
1483            .tools([mock_tool])
1484            .llm(&mock_llm)
1485            .no_system_prompt()
1486            .build()
1487            .unwrap();
1488
1489        agent.query_once("Try again!").await.unwrap();
1490    }
1491
1492    #[test_log::test(tokio::test)]
1493    async fn test_agent_with_approval_required_tool() {
1494        use super::*;
1495        use crate::tools::control::ApprovalRequired;
1496        use crate::{assistant, chat_request, chat_response, user};
1497        use swiftide_core::chat_completion::ToolCall;
1498
1499        // Step 1: Build a tool that needs approval.
1500        let mock_tool = MockTool::default();
1501        mock_tool.expect_invoke_ok("Great!".into(), None);
1502
1503        let approval_tool = ApprovalRequired(mock_tool.boxed());
1504
1505        // Step 2: Set up the mock LLM.
1506        let mock_llm = MockChatCompletion::new();
1507
1508        let chat_req1 = chat_request! {
1509            user!("Request with approval");
1510            tools = [approval_tool.clone()]
1511        };
1512        let chat_resp1 = chat_response! {
1513            "Completion message";
1514            tool_calls = ["mock_tool"]
1515        };
1516        mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1517
1518        // The response will include the previous request, but no tool output
1519        // from the required tool
1520        let chat_req2 = chat_request! {
1521            user!("Request with approval"),
1522            assistant!("Completion message", ["mock_tool"]),
1523            tool_output!("mock_tool", "Great!");
1524            // Simulate feedback required output
1525            tools = [approval_tool.clone()]
1526        };
1527        let chat_resp2 = chat_response! {
1528            "Post-feedback message";
1529            tool_calls = ["stop"]
1530        };
1531        mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1532
1533        // Step 3: Wire up the agent.
1534        let mut agent = Agent::builder()
1535            .tools([approval_tool])
1536            .llm(&mock_llm)
1537            .no_system_prompt()
1538            .build()
1539            .unwrap();
1540
1541        // Step 4: Run agent to trigger approval.
1542        agent.query_once("Request with approval").await.unwrap();
1543
1544        assert!(matches!(
1545            agent.state,
1546            crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1547        ));
1548
1549        let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1550        else {
1551            panic!("Expected feedback required");
1552        };
1553
1554        // Step 5: Simulate feedback, run again and assert finish.
1555        agent
1556            .context
1557            .feedback_received(&tool_call, &ToolFeedback::approved())
1558            .await
1559            .unwrap();
1560
1561        tracing::debug!("running after approval");
1562        agent.run_once().await.unwrap();
1563        assert!(agent.is_stopped());
1564    }
1565
1566    #[test_log::test(tokio::test)]
1567    async fn test_agent_with_approval_required_tool_denied() {
1568        use super::*;
1569        use crate::tools::control::ApprovalRequired;
1570        use crate::{assistant, chat_request, chat_response, user};
1571        use swiftide_core::chat_completion::ToolCall;
1572
1573        // Step 1: Build a tool that needs approval.
1574        let mock_tool = MockTool::default();
1575
1576        let approval_tool = ApprovalRequired(mock_tool.boxed());
1577
1578        // Step 2: Set up the mock LLM.
1579        let mock_llm = MockChatCompletion::new();
1580
1581        let chat_req1 = chat_request! {
1582            user!("Request with approval");
1583            tools = [approval_tool.clone()]
1584        };
1585        let chat_resp1 = chat_response! {
1586            "Completion message";
1587            tool_calls = ["mock_tool"]
1588        };
1589        mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1590
1591        // The response will include the previous request, but no tool output
1592        // from the required tool
1593        let chat_req2 = chat_request! {
1594            user!("Request with approval"),
1595            assistant!("Completion message", ["mock_tool"]),
1596            tool_output!("mock_tool", "This tool call was refused");
1597            // Simulate feedback required output
1598            tools = [approval_tool.clone()]
1599        };
1600        let chat_resp2 = chat_response! {
1601            "Post-feedback message";
1602            tool_calls = ["stop"]
1603        };
1604        mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1605
1606        // Step 3: Wire up the agent.
1607        let mut agent = Agent::builder()
1608            .tools([approval_tool])
1609            .llm(&mock_llm)
1610            .no_system_prompt()
1611            .build()
1612            .unwrap();
1613
1614        // Step 4: Run agent to trigger approval.
1615        agent.query_once("Request with approval").await.unwrap();
1616
1617        assert!(matches!(
1618            agent.state,
1619            crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1620        ));
1621
1622        let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1623        else {
1624            panic!("Expected feedback required");
1625        };
1626
1627        // Step 5: Simulate feedback, run again and assert finish.
1628        agent
1629            .context
1630            .feedback_received(&tool_call, &ToolFeedback::refused())
1631            .await
1632            .unwrap();
1633
1634        tracing::debug!("running after approval");
1635        agent.run_once().await.unwrap();
1636
1637        let history = agent.context().history().await.unwrap();
1638        history
1639            .iter()
1640            .rfind(|m| {
1641                let ChatMessage::ToolOutput(.., ToolOutput::Text(msg)) = m else {
1642                    return false;
1643                };
1644                msg.contains("refused")
1645            })
1646            .expect("Could not find refusal message");
1647
1648        assert!(agent.is_stopped());
1649    }
1650
1651    #[test_log::test(tokio::test)]
1652    async fn test_removing_default_stop_tool() {
1653        let mock_llm = MockChatCompletion::new();
1654        let mock_tool = MockTool::new("mock_tool");
1655
1656        // Build agent with without_default_stop_tool
1657        let agent = Agent::builder()
1658            .without_default_stop_tool()
1659            .tools([mock_tool.clone()])
1660            .llm(&mock_llm)
1661            .no_system_prompt()
1662            .build()
1663            .unwrap();
1664
1665        // Check that "stop" tool is NOT included
1666        assert!(agent.find_tool_by_name("stop").is_none());
1667        // Check that our provided tool is still present
1668        assert!(agent.find_tool_by_name("mock_tool").is_some());
1669    }
1670}