1#![allow(dead_code)]
2use crate::{
3 default_context::DefaultContext,
4 hooks::{
5 AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
6 Hook, HookTypes, MessageHookFn, OnStartFn,
7 },
8 state,
9 system_prompt::SystemPrompt,
10 tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
11};
12use std::{
13 collections::{HashMap, HashSet},
14 hash::{DefaultHasher, Hash as _, Hasher as _},
15 sync::Arc,
16};
17
18use anyhow::Result;
19use derive_builder::Builder;
20use swiftide_core::{
21 chat_completion::{
22 ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
23 },
24 prompt::Prompt,
25 AgentContext,
26};
27use tracing::{debug, Instrument};
28
29#[derive(Clone, Builder)]
41pub struct Agent {
42 #[builder(default, setter(into))]
44 pub(crate) hooks: Vec<Hook>,
45 #[builder(
47 setter(custom),
48 default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
49 )]
50 pub(crate) context: Arc<dyn AgentContext>,
51 #[builder(default = Agent::default_tools(), setter(custom))]
53 pub(crate) tools: HashSet<Box<dyn Tool>>,
54
55 #[builder(setter(custom))]
57 pub(crate) llm: Box<dyn ChatCompletion>,
58
59 #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
79 pub(crate) system_prompt: Option<Prompt>,
80
81 #[builder(private, default = state::State::default())]
83 pub(crate) state: state::State,
84
85 #[builder(default, setter(strip_option))]
88 pub(crate) limit: Option<usize>,
89
90 #[builder(default = 3)]
102 pub(crate) tool_retry_limit: usize,
103
104 #[builder(private, default)]
107 pub(crate) tool_retries_counter: HashMap<u64, usize>,
108}
109
110impl std::fmt::Debug for Agent {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 f.debug_struct("Agent")
113 .field(
115 "hooks",
116 &self
117 .hooks
118 .iter()
119 .map(std::string::ToString::to_string)
120 .collect::<Vec<_>>(),
121 )
122 .field(
123 "tools",
124 &self
125 .tools
126 .iter()
127 .map(swiftide_core::Tool::name)
128 .collect::<Vec<_>>(),
129 )
130 .field("llm", &"Box<dyn ChatCompletion>")
131 .field("state", &self.state)
132 .finish()
133 }
134}
135
136impl AgentBuilder {
137 pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
139 where
140 Self: Clone,
141 {
142 self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
143 self
144 }
145
146 pub fn no_system_prompt(&mut self) -> &mut Self {
148 self.system_prompt = Some(None);
149
150 self
151 }
152
153 pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
155 let hooks = self.hooks.get_or_insert_with(Vec::new);
156 hooks.push(hook);
157
158 self
159 }
160
161 pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
164 self.add_hook(Hook::BeforeAll(Box::new(hook)))
165 }
166
167 pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
171 self.add_hook(Hook::OnStart(Box::new(hook)))
172 }
173
174 pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
176 self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
177 }
178
179 pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
185 self.add_hook(Hook::AfterTool(Box::new(hook)))
186 }
187
188 pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
190 self.add_hook(Hook::BeforeTool(Box::new(hook)))
191 }
192
193 pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
195 self.add_hook(Hook::AfterCompletion(Box::new(hook)))
196 }
197
198 pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
201 self.add_hook(Hook::AfterEach(Box::new(hook)))
202 }
203
204 pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
207 self.add_hook(Hook::OnNewMessage(Box::new(hook)))
208 }
209
210 pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
212 let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
213
214 self.llm = Some(boxed);
215 self
216 }
217
218 pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
223 where
224 TOOL: Into<Box<dyn Tool>>,
225 {
226 self.tools = Some(
227 tools
228 .into_iter()
229 .map(Into::into)
230 .chain(Agent::default_tools())
231 .collect(),
232 );
233 self
234 }
235}
236
237impl Agent {
238 pub fn builder() -> AgentBuilder {
240 AgentBuilder::default()
241 }
242}
243
244impl Agent {
245 fn default_tools() -> HashSet<Box<dyn Tool>> {
247 HashSet::from([Box::new(Stop::default()) as Box<dyn Tool>])
248 }
249
250 #[tracing::instrument(skip_all, name = "agent.query")]
253 pub async fn query(&mut self, query: impl Into<String> + std::fmt::Debug) -> Result<()> {
254 self.run_agent(Some(query.into()), false).await
255 }
256
257 #[tracing::instrument(skip_all, name = "agent.query_once")]
259 pub async fn query_once(&mut self, query: impl Into<String> + std::fmt::Debug) -> Result<()> {
260 self.run_agent(Some(query.into()), true).await
261 }
262
263 #[tracing::instrument(skip_all, name = "agent.run")]
266 pub async fn run(&mut self) -> Result<()> {
267 self.run_agent(None, false).await
268 }
269
270 #[tracing::instrument(skip_all, name = "agent.run_once")]
273 pub async fn run_once(&mut self) -> Result<()> {
274 self.run_agent(None, true).await
275 }
276
277 pub async fn history(&self) -> Vec<ChatMessage> {
279 self.context.history().await
280 }
281
282 async fn run_agent(&mut self, maybe_query: Option<String>, just_once: bool) -> Result<()> {
283 if self.state.is_running() {
284 anyhow::bail!("Agent is already running");
285 }
286
287 if self.state.is_pending() {
288 if let Some(system_prompt) = &self.system_prompt {
289 self.context
290 .add_messages(vec![ChatMessage::System(system_prompt.render().await?)])
291 .await;
292 }
293 for hook in self.hooks_by_type(HookTypes::BeforeAll) {
294 if let Hook::BeforeAll(hook) = hook {
295 let span = tracing::info_span!(
296 "hook",
297 "otel.name" = format!("hook.{}", HookTypes::AfterTool)
298 );
299 tracing::info!("Calling {} hook", HookTypes::AfterTool);
300 hook(self).instrument(span.or_current()).await?;
301 }
302 }
303 }
304
305 for hook in self.hooks_by_type(HookTypes::OnStart) {
306 if let Hook::OnStart(hook) = hook {
307 let span = tracing::info_span!(
308 "hook",
309 "otel.name" = format!("hook.{}", HookTypes::OnStart)
310 );
311 tracing::info!("Calling {} hook", HookTypes::OnStart);
312 hook(self).instrument(span.or_current()).await?;
313 }
314 }
315
316 self.state = state::State::Running;
317
318 if let Some(query) = maybe_query {
319 self.context.add_message(ChatMessage::User(query)).await;
320 }
321
322 let mut loop_counter = 0;
323
324 while let Some(messages) = self.context.next_completion().await {
325 if let Some(limit) = self.limit {
326 if loop_counter >= limit {
327 tracing::warn!("Agent loop limit reached");
328 break;
329 }
330 }
331 let result = self.run_completions(&messages).await;
332
333 if let Err(err) = result {
334 self.stop();
335 tracing::error!(error = ?err, "Agent stopped with error {err}");
336 return Err(err);
337 }
338
339 if just_once || self.state.is_stopped() {
340 break;
341 }
342 loop_counter += 1;
343 }
344
345 self.stop();
347
348 Ok(())
349 }
350
351 #[tracing::instrument(skip_all, err)]
352 async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<()> {
353 debug!(
354 "Running completion for agent with {} messages",
355 messages.len()
356 );
357
358 let mut chat_completion_request = ChatCompletionRequest::builder()
359 .messages(messages)
360 .tools_spec(
361 self.tools
362 .iter()
363 .map(swiftide_core::Tool::tool_spec)
364 .collect::<HashSet<_>>(),
365 )
366 .build()?;
367
368 for hook in self.hooks_by_type(HookTypes::BeforeCompletion) {
369 if let Hook::BeforeCompletion(hook) = hook {
370 let span = tracing::info_span!(
371 "hook",
372 "otel.name" = format!("hook.{}", HookTypes::BeforeCompletion)
373 );
374 tracing::info!("Calling {} hook", HookTypes::BeforeCompletion);
375 hook(&*self, &mut chat_completion_request)
376 .instrument(span.or_current())
377 .await?;
378 }
379 }
380
381 debug!(
382 "Calling LLM with the following new messages:\n {}",
383 self.context
384 .current_new_messages()
385 .await
386 .iter()
387 .map(ToString::to_string)
388 .collect::<Vec<_>>()
389 .join(",\n")
390 );
391
392 let mut response = self.llm.complete(&chat_completion_request).await?;
393
394 for hook in self.hooks_by_type(HookTypes::AfterCompletion) {
395 if let Hook::AfterCompletion(hook) = hook {
396 let span = tracing::info_span!(
397 "hook",
398 "otel.name" = format!("hook.{}", HookTypes::AfterCompletion)
399 );
400 tracing::info!("Calling {} hook", HookTypes::AfterCompletion);
401 hook(&*self, &mut response)
402 .instrument(span.or_current())
403 .await?;
404 }
405 }
406 self.add_message(ChatMessage::Assistant(
407 response.message,
408 response.tool_calls.clone(),
409 ))
410 .await?;
411
412 if let Some(tool_calls) = response.tool_calls {
413 self.invoke_tools(tool_calls).await?;
414 };
415
416 for hook in self.hooks_by_type(HookTypes::AfterEach) {
417 if let Hook::AfterEach(hook) = hook {
418 let span = tracing::info_span!(
419 "hook",
420 "otel.name" = format!("hook.{}", HookTypes::AfterEach)
421 );
422 tracing::info!("Calling {} hook", HookTypes::AfterEach);
423 hook(&*self).instrument(span.or_current()).await?;
424 }
425 }
426
427 Ok(())
428 }
429
430 async fn invoke_tools(&mut self, tool_calls: Vec<ToolCall>) -> Result<()> {
431 debug!("LLM returned tool calls: {:?}", tool_calls);
432
433 let mut handles = vec![];
434 for tool_call in tool_calls {
435 let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
436 tracing::warn!("Tool {} not found", tool_call.name());
437 continue;
438 };
439 tracing::info!("Calling tool `{}`", tool_call.name());
440
441 let tool_args = tool_call.args().map(String::from);
442 let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
443
444 for hook in self.hooks_by_type(HookTypes::BeforeTool) {
445 if let Hook::BeforeTool(hook) = hook {
446 let span = tracing::info_span!(
447 "hook",
448 "otel.name" = format!("hook.{}", HookTypes::BeforeTool)
449 );
450 tracing::info!("Calling {} hook", HookTypes::BeforeTool);
451 hook(&*self, &tool_call)
452 .instrument(span.or_current())
453 .await?;
454 }
455 }
456
457 let tool_span = tracing::info_span!(
458 "tool",
459 "otel.name" = format!("tool.{}", tool.name().as_ref())
460 );
461
462 let handle = tokio::spawn(async move {
463 let tool_args = ArgPreprocessor::preprocess(tool_args.as_deref());
464 let output = tool.invoke(&*context, tool_args.as_deref()).await.map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
465
466 tracing::debug!(output = output.to_string(), args = ?tool_args, tool_name = tool.name().as_ref(), "Completed tool call");
467
468 Ok(output)
469 }.instrument(tool_span.or_current()));
470
471 handles.push((handle, tool_call));
472 }
473
474 for (handle, tool_call) in handles {
475 let mut output = handle.await?;
476
477 for hook in self.hooks_by_type(HookTypes::AfterTool) {
479 if let Hook::AfterTool(hook) = hook {
480 let span = tracing::info_span!(
481 "hook",
482 "otel.name" = format!("hook.{}", HookTypes::AfterTool)
483 );
484 tracing::info!("Calling {} hook", HookTypes::AfterTool);
485 hook(&*self, &tool_call, &mut output)
486 .instrument(span.or_current())
487 .await?;
488 }
489 }
490
491 if let Err(error) = output {
492 let stop = self.tool_calls_over_limit(&tool_call);
493 if stop {
494 tracing::error!(
495 "Tool call failed, retry limit reached, stopping agent: {err}",
496 err = error
497 );
498 } else {
499 tracing::warn!(
500 error = error.to_string(),
501 tool_call = ?tool_call,
502 "Tool call failed, retrying",
503 );
504 }
505 self.add_message(ChatMessage::ToolOutput(
506 tool_call,
507 ToolOutput::Fail(error.to_string()),
508 ))
509 .await?;
510 if stop {
511 self.stop();
512 return Err(error.into());
513 }
514 continue;
515 }
516
517 let output = output?;
518 self.handle_control_tools(&output);
519 self.add_message(ChatMessage::ToolOutput(tool_call, output))
520 .await?;
521 }
522
523 Ok(())
524 }
525
526 fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
527 self.hooks
528 .iter()
529 .filter(|h| hook_type == (*h).into())
530 .collect()
531 }
532
533 fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
534 self.tools
535 .iter()
536 .find(|tool| tool.name() == tool_name)
537 .cloned()
538 }
539
540 fn handle_control_tools(&mut self, output: &ToolOutput) {
542 if let ToolOutput::Stop = output {
543 tracing::warn!("Stop tool called, stopping agent");
544 self.stop();
545 }
546 }
547
548 fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
549 let mut s = DefaultHasher::new();
550 tool_call.hash(&mut s);
551 let hash = s.finish();
552
553 if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
554 let val = *retries >= self.tool_retry_limit;
555 *retries += 1;
556 val
557 } else {
558 self.tool_retries_counter.insert(hash, 1);
559 false
560 }
561 }
562
563 #[tracing::instrument(skip_all, fields(message = message.to_string()))]
564 async fn add_message(&self, mut message: ChatMessage) -> Result<()> {
565 for hook in self.hooks_by_type(HookTypes::OnNewMessage) {
566 if let Hook::OnNewMessage(hook) = hook {
567 let span = tracing::info_span!(
568 "hook",
569 "otel.name" = format!("hook.{}", HookTypes::OnNewMessage)
570 );
571 if let Err(err) = hook(self, &mut message).instrument(span.or_current()).await {
572 tracing::error!(
573 "Error in {hooktype} hook: {err}",
574 hooktype = HookTypes::OnNewMessage,
575 );
576 }
577 }
578 }
579 self.context.add_message(message).await;
580 Ok(())
581 }
582
583 pub fn stop(&mut self) {
585 self.state = state::State::Stopped;
586 }
587
588 pub fn context(&self) -> &dyn AgentContext {
590 &self.context
591 }
592
593 pub fn is_running(&self) -> bool {
595 self.state.is_running()
596 }
597
598 pub fn is_stopped(&self) -> bool {
600 self.state.is_stopped()
601 }
602
603 pub fn is_pending(&self) -> bool {
605 self.state.is_pending()
606 }
607}
608
609#[cfg(test)]
610mod tests {
611
612 use serde::ser::Error;
613 use swiftide_core::chat_completion::errors::ToolError;
614 use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
615 use swiftide_core::test_utils::MockChatCompletion;
616
617 use super::*;
618 use crate::{
619 assistant, chat_request, chat_response, summary, system, tool_failed, tool_output, user,
620 };
621
622 use crate::test_utils::{MockHook, MockTool};
623
624 #[test_log::test(tokio::test)]
625 async fn test_agent_builder_defaults() {
626 let mock_llm = MockChatCompletion::new();
628
629 let agent = Agent::builder().llm(&mock_llm).build().unwrap();
631
632 assert!(agent.find_tool_by_name("stop").is_some());
636
637 let agent = Agent::builder()
639 .tools([Stop::default(), Stop::default()])
640 .llm(&mock_llm)
641 .build()
642 .unwrap();
643
644 assert_eq!(agent.tools.len(), 1);
645
646 let agent = Agent::builder()
648 .tools([MockTool::new("mock_tool")])
649 .llm(&mock_llm)
650 .build()
651 .unwrap();
652
653 assert_eq!(agent.tools.len(), 2);
654 assert!(agent.find_tool_by_name("mock_tool").is_some());
655 assert!(agent.find_tool_by_name("stop").is_some());
656
657 assert!(agent.context().history().await.is_empty());
658 }
659
660 #[test_log::test(tokio::test)]
661 async fn test_agent_tool_calling_loop() {
662 let prompt = "Write a poem";
663 let mock_llm = MockChatCompletion::new();
664 let mock_tool = MockTool::new("mock_tool");
665
666 let chat_request = chat_request! {
667 user!("Write a poem");
668
669 tools = [mock_tool.clone()]
670 };
671
672 let mock_tool_response = chat_response! {
673 "Roses are red";
674 tool_calls = ["mock_tool"]
675
676 };
677
678 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
679
680 let chat_request = chat_request! {
681 user!("Write a poem"),
682 assistant!("Roses are red", ["mock_tool"]),
683 tool_output!("mock_tool", "Great!");
684
685 tools = [mock_tool.clone()]
686 };
687
688 let stop_response = chat_response! {
689 "Roses are red";
690 tool_calls = ["stop"]
691 };
692
693 mock_llm.expect_complete(chat_request, Ok(stop_response));
694 mock_tool.expect_invoke_ok("Great!".into(), None);
695
696 let mut agent = Agent::builder()
697 .tools([mock_tool])
698 .llm(&mock_llm)
699 .no_system_prompt()
700 .build()
701 .unwrap();
702
703 agent.query(prompt).await.unwrap();
704 }
705
706 #[test_log::test(tokio::test)]
707 async fn test_agent_tool_run_once() {
708 let prompt = "Write a poem";
709 let mock_llm = MockChatCompletion::new();
710 let mock_tool = MockTool::default();
711
712 let chat_request = chat_request! {
713 system!("My system prompt"),
714 user!("Write a poem");
715
716 tools = [mock_tool.clone()]
717 };
718
719 let mock_tool_response = chat_response! {
720 "Roses are red";
721 tool_calls = ["mock_tool"]
722
723 };
724
725 mock_tool.expect_invoke_ok("Great!".into(), None);
726 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
727
728 let mut agent = Agent::builder()
729 .tools([mock_tool])
730 .system_prompt("My system prompt")
731 .llm(&mock_llm)
732 .build()
733 .unwrap();
734
735 agent.query_once(prompt).await.unwrap();
736 }
737
738 #[test_log::test(tokio::test(flavor = "multi_thread"))]
739 async fn test_multiple_tool_calls() {
740 let prompt = "Write a poem";
741 let mock_llm = MockChatCompletion::new();
742 let mock_tool = MockTool::new("mock_tool1");
743 let mock_tool2 = MockTool::new("mock_tool2");
744
745 let chat_request = chat_request! {
746 system!("My system prompt"),
747 user!("Write a poem");
748
749
750
751 tools = [mock_tool.clone(), mock_tool2.clone()]
752 };
753
754 let mock_tool_response = chat_response! {
755 "Roses are red";
756
757 tool_calls = ["mock_tool1", "mock_tool2"]
758
759 };
760
761 dbg!(&chat_request);
762 mock_tool.expect_invoke_ok("Great!".into(), None);
763 mock_tool2.expect_invoke_ok("Great!".into(), None);
764 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
765
766 let chat_request = chat_request! {
767 system!("My system prompt"),
768 user!("Write a poem"),
769 assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
770 tool_output!("mock_tool1", "Great!"),
771 tool_output!("mock_tool2", "Great!");
772
773 tools = [mock_tool.clone(), mock_tool2.clone()]
774 };
775
776 let mock_tool_response = chat_response! {
777 "Ok!";
778
779 tool_calls = ["stop"]
780
781 };
782
783 mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
784
785 let mut agent = Agent::builder()
786 .tools([mock_tool, mock_tool2])
787 .system_prompt("My system prompt")
788 .llm(&mock_llm)
789 .build()
790 .unwrap();
791
792 agent.query(prompt).await.unwrap();
793 }
794
795 #[test_log::test(tokio::test)]
796 async fn test_agent_state_machine() {
797 let prompt = "Write a poem";
798 let mock_llm = MockChatCompletion::new();
799
800 let chat_request = chat_request! {
801 user!("Write a poem");
802 tools = []
803 };
804 let mock_tool_response = chat_response! {
805 "Roses are red";
806 tool_calls = []
807 };
808
809 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
810 let mut agent = Agent::builder()
811 .llm(&mock_llm)
812 .no_system_prompt()
813 .build()
814 .unwrap();
815
816 assert!(agent.state.is_pending());
818 agent.query_once(prompt).await.unwrap();
819
820 assert!(agent.state.is_stopped());
822 }
823
824 #[test_log::test(tokio::test)]
825 async fn test_summary() {
826 let prompt = "Write a poem";
827 let mock_llm = MockChatCompletion::new();
828
829 let mock_tool_response = chat_response! {
830 "Roses are red";
831 tool_calls = []
832
833 };
834
835 let expected_chat_request = chat_request! {
836 system!("My system prompt"),
837 user!("Write a poem");
838
839 tools = []
840 };
841
842 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
843
844 let mut agent = Agent::builder()
845 .system_prompt("My system prompt")
846 .llm(&mock_llm)
847 .build()
848 .unwrap();
849
850 agent.query_once(prompt).await.unwrap();
851
852 agent
853 .context
854 .add_message(ChatMessage::new_summary("Summary"))
855 .await;
856
857 let expected_chat_request = chat_request! {
858 system!("My system prompt"),
859 summary!("Summary"),
860 user!("Write another poem");
861 tools = []
862 };
863 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
864
865 agent.query_once("Write another poem").await.unwrap();
866
867 agent
868 .context
869 .add_message(ChatMessage::new_summary("Summary 2"))
870 .await;
871
872 let expected_chat_request = chat_request! {
873 system!("My system prompt"),
874 summary!("Summary 2"),
875 user!("Write a third poem");
876 tools = []
877 };
878 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
879
880 agent.query_once("Write a third poem").await.unwrap();
881 }
882
883 #[test_log::test(tokio::test)]
884 async fn test_agent_hooks() {
885 let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
886 let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
887 let mock_before_completion = MockHook::new("before_completion")
888 .expect_calls(2)
889 .to_owned();
890 let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
891 let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
892 let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
893
894 let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
896 let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
897
898 let prompt = "Write a poem";
899 let mock_llm = MockChatCompletion::new();
900 let mock_tool = MockTool::default();
901
902 let chat_request = chat_request! {
903 user!("Write a poem");
904
905 tools = [mock_tool.clone()]
906 };
907
908 let mock_tool_response = chat_response! {
909 "Roses are red";
910 tool_calls = ["mock_tool"]
911
912 };
913
914 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
915
916 let chat_request = chat_request! {
917 user!("Write a poem"),
918 assistant!("Roses are red", ["mock_tool"]),
919 tool_output!("mock_tool", "Great!");
920
921 tools = [mock_tool.clone()]
922 };
923
924 let stop_response = chat_response! {
925 "Roses are red";
926 tool_calls = ["stop"]
927 };
928
929 mock_llm.expect_complete(chat_request, Ok(stop_response));
930 mock_tool.expect_invoke_ok("Great!".into(), None);
931
932 let mut agent = Agent::builder()
933 .tools([mock_tool])
934 .llm(&mock_llm)
935 .no_system_prompt()
936 .before_all(mock_before_all.hook_fn())
937 .on_start(mock_on_start_fn.on_start_fn())
938 .before_completion(mock_before_completion.before_completion_fn())
939 .before_tool(mock_before_tool.before_tool_fn())
940 .after_completion(mock_after_completion.after_completion_fn())
941 .after_tool(mock_after_tool.after_tool_fn())
942 .after_each(mock_after_each.hook_fn())
943 .on_new_message(mock_on_message.message_hook_fn())
944 .build()
945 .unwrap();
946
947 agent.query(prompt).await.unwrap();
948 }
949
950 #[test_log::test(tokio::test)]
951 async fn test_agent_loop_limit() {
952 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
954 let mock_tool = MockTool::new("mock_tool");
955
956 let chat_request = chat_request! {
957 user!(prompt);
958 tools = [mock_tool.clone()]
959 };
960 mock_tool.expect_invoke_ok("Great!".into(), None);
961
962 let mock_tool_response = chat_response! {
963 "Some response";
964 tool_calls = ["mock_tool"]
965 };
966
967 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
969
970 let stop_response = chat_response! {
972 "Final response";
973 tool_calls = ["stop"]
974 };
975
976 mock_llm.expect_complete(chat_request, Ok(stop_response));
977
978 let mut agent = Agent::builder()
979 .tools([mock_tool])
980 .llm(&mock_llm)
981 .no_system_prompt()
982 .limit(1) .build()
984 .unwrap();
985
986 agent.query(prompt).await.unwrap();
988
989 let remaining = mock_llm.expectations.lock().unwrap().pop();
991 assert!(remaining.is_some());
992
993 assert!(agent.is_stopped());
995 }
996
997 #[test_log::test(tokio::test)]
998 async fn test_tool_retry_mechanism() {
999 let prompt = "Execute my tool";
1000 let mock_llm = MockChatCompletion::new();
1001 let mock_tool = MockTool::new("retry_tool");
1002
1003 mock_tool.expect_invoke(
1006 Err(ToolError::WrongArguments(serde_json::Error::custom(
1007 "missing `query`",
1008 ))),
1009 None,
1010 );
1011 mock_tool.expect_invoke(
1012 Err(ToolError::WrongArguments(serde_json::Error::custom(
1013 "missing `query`",
1014 ))),
1015 None,
1016 );
1017
1018 let chat_request = chat_request! {
1019 user!(prompt);
1020 tools = [mock_tool.clone()]
1021 };
1022 let retry_response = chat_response! {
1023 "First failing attempt";
1024 tool_calls = ["retry_tool"]
1025 };
1026 mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1027
1028 let chat_request = chat_request! {
1029 user!(prompt),
1030 assistant!("First failing attempt", ["retry_tool"]),
1031 tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1032
1033 tools = [mock_tool.clone()]
1034 };
1035 let will_fail_response = chat_response! {
1036 "Finished execution";
1037 tool_calls = ["retry_tool"]
1038 };
1039 mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1040
1041 let mut agent = Agent::builder()
1042 .tools([mock_tool])
1043 .llm(&mock_llm)
1044 .no_system_prompt()
1045 .tool_retry_limit(1) .build()
1047 .unwrap();
1048
1049 let result = agent.query(prompt).await;
1051
1052 assert!(result.is_err());
1053 assert!(result.unwrap_err().to_string().contains("missing `query`"));
1054 assert!(agent.is_stopped());
1055 }
1056}