1#![allow(dead_code)]
2use crate::{
3 default_context::DefaultContext,
4 hooks::{
5 AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
6 Hook, HookTypes, MessageHookFn, OnStartFn,
7 },
8 state,
9 system_prompt::SystemPrompt,
10 tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
11};
12use std::{
13 collections::{HashMap, HashSet},
14 hash::{DefaultHasher, Hash as _, Hasher as _},
15 sync::Arc,
16};
17
18use anyhow::Result;
19use derive_builder::Builder;
20use swiftide_core::{
21 chat_completion::{
22 ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
23 },
24 prompt::Prompt,
25 AgentContext,
26};
27use tracing::{debug, Instrument};
28
29#[derive(Clone, Builder)]
40pub struct Agent {
41 #[builder(default, setter(into))]
43 pub(crate) hooks: Vec<Hook>,
44 #[builder(
46 setter(custom),
47 default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
48 )]
49 pub(crate) context: Arc<dyn AgentContext>,
50 #[builder(default = Agent::default_tools(), setter(custom))]
52 pub(crate) tools: HashSet<Box<dyn Tool>>,
53
54 #[builder(setter(custom))]
56 pub(crate) llm: Box<dyn ChatCompletion>,
57
58 #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
78 pub(crate) system_prompt: Option<Prompt>,
79
80 #[builder(private, default = state::State::default())]
82 pub(crate) state: state::State,
83
84 #[builder(default, setter(strip_option))]
87 pub(crate) limit: Option<usize>,
88
89 #[builder(default = 3)]
100 pub(crate) tool_retry_limit: usize,
101
102 #[builder(private, default)]
105 pub(crate) tool_retries_counter: HashMap<u64, usize>,
106}
107
108impl std::fmt::Debug for Agent {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 f.debug_struct("Agent")
111 .field(
113 "hooks",
114 &self
115 .hooks
116 .iter()
117 .map(std::string::ToString::to_string)
118 .collect::<Vec<_>>(),
119 )
120 .field(
121 "tools",
122 &self
123 .tools
124 .iter()
125 .map(swiftide_core::Tool::name)
126 .collect::<Vec<_>>(),
127 )
128 .field("llm", &"Box<dyn ChatCompletion>")
129 .field("state", &self.state)
130 .finish()
131 }
132}
133
134impl AgentBuilder {
135 pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
137 where
138 Self: Clone,
139 {
140 self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
141 self
142 }
143
144 pub fn no_system_prompt(&mut self) -> &mut Self {
146 self.system_prompt = Some(None);
147
148 self
149 }
150
151 pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
153 let hooks = self.hooks.get_or_insert_with(Vec::new);
154 hooks.push(hook);
155
156 self
157 }
158
159 pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
162 self.add_hook(Hook::BeforeAll(Box::new(hook)))
163 }
164
165 pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
169 self.add_hook(Hook::OnStart(Box::new(hook)))
170 }
171
172 pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
174 self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
175 }
176
177 pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
183 self.add_hook(Hook::AfterTool(Box::new(hook)))
184 }
185
186 pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
188 self.add_hook(Hook::BeforeTool(Box::new(hook)))
189 }
190
191 pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
193 self.add_hook(Hook::AfterCompletion(Box::new(hook)))
194 }
195
196 pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
199 self.add_hook(Hook::AfterEach(Box::new(hook)))
200 }
201
202 pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
205 self.add_hook(Hook::OnNewMessage(Box::new(hook)))
206 }
207
208 pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
210 let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
211
212 self.llm = Some(boxed);
213 self
214 }
215
216 pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
221 where
222 TOOL: Into<Box<dyn Tool>>,
223 {
224 self.tools = Some(
225 tools
226 .into_iter()
227 .map(Into::into)
228 .chain(Agent::default_tools())
229 .collect(),
230 );
231 self
232 }
233}
234
235impl Agent {
236 pub fn builder() -> AgentBuilder {
238 AgentBuilder::default()
239 }
240}
241
242impl Agent {
243 fn default_tools() -> HashSet<Box<dyn Tool>> {
245 HashSet::from([Box::new(Stop::default()) as Box<dyn Tool>])
246 }
247
248 #[tracing::instrument(skip_all, name = "agent.query")]
251 pub async fn query(&mut self, query: impl Into<String> + std::fmt::Debug) -> Result<()> {
252 self.run_agent(Some(query.into()), false).await
253 }
254
255 #[tracing::instrument(skip_all, name = "agent.query_once")]
257 pub async fn query_once(&mut self, query: impl Into<String> + std::fmt::Debug) -> Result<()> {
258 self.run_agent(Some(query.into()), true).await
259 }
260
261 #[tracing::instrument(skip_all, name = "agent.run")]
264 pub async fn run(&mut self) -> Result<()> {
265 self.run_agent(None, false).await
266 }
267
268 #[tracing::instrument(skip_all, name = "agent.run_once")]
270 pub async fn run_once(&mut self) -> Result<()> {
271 self.run_agent(None, true).await
272 }
273
274 pub async fn history(&self) -> Vec<ChatMessage> {
276 self.context.history().await
277 }
278
279 async fn run_agent(&mut self, maybe_query: Option<String>, just_once: bool) -> Result<()> {
280 if self.state.is_running() {
281 anyhow::bail!("Agent is already running");
282 }
283
284 if self.state.is_pending() {
285 if let Some(system_prompt) = &self.system_prompt {
286 self.context
287 .add_messages(vec![ChatMessage::System(system_prompt.render().await?)])
288 .await;
289 }
290 for hook in self.hooks_by_type(HookTypes::BeforeAll) {
291 if let Hook::BeforeAll(hook) = hook {
292 let span = tracing::info_span!(
293 "hook",
294 "otel.name" = format!("hook.{}", HookTypes::AfterTool)
295 );
296 tracing::info!("Calling {} hook", HookTypes::AfterTool);
297 hook(self).instrument(span.or_current()).await?;
298 }
299 }
300 }
301
302 for hook in self.hooks_by_type(HookTypes::OnStart) {
303 if let Hook::OnStart(hook) = hook {
304 let span = tracing::info_span!(
305 "hook",
306 "otel.name" = format!("hook.{}", HookTypes::OnStart)
307 );
308 tracing::info!("Calling {} hook", HookTypes::OnStart);
309 hook(self).instrument(span.or_current()).await?;
310 }
311 }
312
313 self.state = state::State::Running;
314
315 if let Some(query) = maybe_query {
316 self.context.add_message(ChatMessage::User(query)).await;
317 }
318
319 let mut loop_counter = 0;
320
321 while let Some(messages) = self.context.next_completion().await {
322 if let Some(limit) = self.limit {
323 if loop_counter >= limit {
324 tracing::warn!("Agent loop limit reached");
325 break;
326 }
327 }
328 let result = self.run_completions(&messages).await;
329
330 if let Err(err) = result {
331 self.stop();
332 tracing::error!(error = ?err, "Agent stopped with error {err}");
333 return Err(err);
334 }
335
336 if just_once || self.state.is_stopped() {
337 break;
338 }
339 loop_counter += 1;
340 }
341
342 self.stop();
344
345 Ok(())
346 }
347
348 #[tracing::instrument(skip_all, err)]
349 async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<()> {
350 debug!(
351 "Running completion for agent with {} messages",
352 messages.len()
353 );
354
355 let mut chat_completion_request = ChatCompletionRequest::builder()
356 .messages(messages)
357 .tools_spec(
358 self.tools
359 .iter()
360 .map(swiftide_core::Tool::tool_spec)
361 .collect::<HashSet<_>>(),
362 )
363 .build()?;
364
365 for hook in self.hooks_by_type(HookTypes::BeforeCompletion) {
366 if let Hook::BeforeCompletion(hook) = hook {
367 let span = tracing::info_span!(
368 "hook",
369 "otel.name" = format!("hook.{}", HookTypes::BeforeCompletion)
370 );
371 tracing::info!("Calling {} hook", HookTypes::BeforeCompletion);
372 hook(&*self, &mut chat_completion_request)
373 .instrument(span.or_current())
374 .await?;
375 }
376 }
377
378 debug!(
379 "Calling LLM with the following new messages:\n {}",
380 self.context
381 .current_new_messages()
382 .await
383 .iter()
384 .map(ToString::to_string)
385 .collect::<Vec<_>>()
386 .join(",\n")
387 );
388
389 let mut response = self.llm.complete(&chat_completion_request).await?;
390
391 for hook in self.hooks_by_type(HookTypes::AfterCompletion) {
392 if let Hook::AfterCompletion(hook) = hook {
393 let span = tracing::info_span!(
394 "hook",
395 "otel.name" = format!("hook.{}", HookTypes::AfterCompletion)
396 );
397 tracing::info!("Calling {} hook", HookTypes::AfterCompletion);
398 hook(&*self, &mut response)
399 .instrument(span.or_current())
400 .await?;
401 }
402 }
403 self.add_message(ChatMessage::Assistant(
404 response.message,
405 response.tool_calls.clone(),
406 ))
407 .await?;
408
409 if let Some(tool_calls) = response.tool_calls {
410 self.invoke_tools(tool_calls).await?;
411 };
412
413 for hook in self.hooks_by_type(HookTypes::AfterEach) {
414 if let Hook::AfterEach(hook) = hook {
415 let span = tracing::info_span!(
416 "hook",
417 "otel.name" = format!("hook.{}", HookTypes::AfterEach)
418 );
419 tracing::info!("Calling {} hook", HookTypes::AfterEach);
420 hook(&*self).instrument(span.or_current()).await?;
421 }
422 }
423
424 Ok(())
425 }
426
427 async fn invoke_tools(&mut self, tool_calls: Vec<ToolCall>) -> Result<()> {
428 debug!("LLM returned tool calls: {:?}", tool_calls);
429
430 let mut handles = vec![];
431 for tool_call in tool_calls {
432 let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
433 tracing::warn!("Tool {} not found", tool_call.name());
434 continue;
435 };
436 tracing::info!("Calling tool `{}`", tool_call.name());
437
438 let tool_args = tool_call.args().map(String::from);
439 let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
440
441 for hook in self.hooks_by_type(HookTypes::BeforeTool) {
442 if let Hook::BeforeTool(hook) = hook {
443 let span = tracing::info_span!(
444 "hook",
445 "otel.name" = format!("hook.{}", HookTypes::BeforeTool)
446 );
447 tracing::info!("Calling {} hook", HookTypes::BeforeTool);
448 hook(&*self, &tool_call)
449 .instrument(span.or_current())
450 .await?;
451 }
452 }
453
454 let tool_span = tracing::info_span!(
455 "tool",
456 "otel.name" = format!("tool.{}", tool.name().as_ref())
457 );
458
459 let handle = tokio::spawn(async move {
460 let tool_args = ArgPreprocessor::preprocess(tool_args.as_deref());
461 let output = tool.invoke(&*context, tool_args.as_deref()).await.map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
462
463 tracing::debug!(output = output.to_string(), args = ?tool_args, tool_name = tool.name().as_ref(), "Completed tool call");
464
465 Ok(output)
466 }.instrument(tool_span.or_current()));
467
468 handles.push((handle, tool_call));
469 }
470
471 for (handle, tool_call) in handles {
472 let mut output = handle.await?;
473
474 for hook in self.hooks_by_type(HookTypes::AfterTool) {
476 if let Hook::AfterTool(hook) = hook {
477 let span = tracing::info_span!(
478 "hook",
479 "otel.name" = format!("hook.{}", HookTypes::AfterTool)
480 );
481 tracing::info!("Calling {} hook", HookTypes::AfterTool);
482 hook(&*self, &tool_call, &mut output)
483 .instrument(span.or_current())
484 .await?;
485 }
486 }
487
488 if let Err(error) = output {
489 let stop = self.tool_calls_over_limit(&tool_call);
490 if stop {
491 tracing::error!(
492 "Tool call failed, retry limit reached, stopping agent: {err}",
493 err = error
494 );
495 } else {
496 tracing::warn!(
497 error = error.to_string(),
498 tool_call = ?tool_call,
499 "Tool call failed, retrying",
500 );
501 }
502 self.add_message(ChatMessage::ToolOutput(
503 tool_call,
504 ToolOutput::Fail(error.to_string()),
505 ))
506 .await?;
507 if stop {
508 self.stop();
509 return Err(error.into());
510 }
511 continue;
512 }
513
514 let output = output?;
515 self.handle_control_tools(&output);
516 self.add_message(ChatMessage::ToolOutput(tool_call, output))
517 .await?;
518 }
519
520 Ok(())
521 }
522
523 fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
524 self.hooks
525 .iter()
526 .filter(|h| hook_type == (*h).into())
527 .collect()
528 }
529
530 fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
531 self.tools
532 .iter()
533 .find(|tool| tool.name() == tool_name)
534 .cloned()
535 }
536
537 fn handle_control_tools(&mut self, output: &ToolOutput) {
539 if let ToolOutput::Stop = output {
540 tracing::warn!("Stop tool called, stopping agent");
541 self.stop();
542 }
543 }
544
545 fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
546 let mut s = DefaultHasher::new();
547 tool_call.hash(&mut s);
548 let hash = s.finish();
549
550 if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
551 let val = *retries >= self.tool_retry_limit;
552 *retries += 1;
553 val
554 } else {
555 self.tool_retries_counter.insert(hash, 1);
556 false
557 }
558 }
559
560 #[tracing::instrument(skip_all, fields(message = message.to_string()))]
561 async fn add_message(&self, mut message: ChatMessage) -> Result<()> {
562 for hook in self.hooks_by_type(HookTypes::OnNewMessage) {
563 if let Hook::OnNewMessage(hook) = hook {
564 let span = tracing::info_span!(
565 "hook",
566 "otel.name" = format!("hook.{}", HookTypes::OnNewMessage)
567 );
568 if let Err(err) = hook(self, &mut message).instrument(span.or_current()).await {
569 tracing::error!(
570 "Error in {hooktype} hook: {err}",
571 hooktype = HookTypes::OnNewMessage,
572 );
573 }
574 }
575 }
576 self.context.add_message(message).await;
577 Ok(())
578 }
579
580 pub fn stop(&mut self) {
582 self.state = state::State::Stopped;
583 }
584
585 pub fn context(&self) -> &dyn AgentContext {
587 &self.context
588 }
589
590 pub fn is_running(&self) -> bool {
592 self.state.is_running()
593 }
594
595 pub fn is_stopped(&self) -> bool {
597 self.state.is_stopped()
598 }
599
600 pub fn is_pending(&self) -> bool {
602 self.state.is_pending()
603 }
604}
605
606#[cfg(test)]
607mod tests {
608
609 use serde::ser::Error;
610 use swiftide_core::chat_completion::errors::ToolError;
611 use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
612 use swiftide_core::test_utils::MockChatCompletion;
613
614 use super::*;
615 use crate::{
616 assistant, chat_request, chat_response, summary, system, tool_failed, tool_output, user,
617 };
618
619 use crate::test_utils::{MockHook, MockTool};
620
621 #[test_log::test(tokio::test)]
622 async fn test_agent_builder_defaults() {
623 let mock_llm = MockChatCompletion::new();
625
626 let agent = Agent::builder().llm(&mock_llm).build().unwrap();
628
629 assert!(agent.find_tool_by_name("stop").is_some());
633
634 let agent = Agent::builder()
636 .tools([Stop::default(), Stop::default()])
637 .llm(&mock_llm)
638 .build()
639 .unwrap();
640
641 assert_eq!(agent.tools.len(), 1);
642
643 let agent = Agent::builder()
645 .tools([MockTool::new("mock_tool")])
646 .llm(&mock_llm)
647 .build()
648 .unwrap();
649
650 assert_eq!(agent.tools.len(), 2);
651 assert!(agent.find_tool_by_name("mock_tool").is_some());
652 assert!(agent.find_tool_by_name("stop").is_some());
653
654 assert!(agent.context().history().await.is_empty());
655 }
656
657 #[test_log::test(tokio::test)]
658 async fn test_agent_tool_calling_loop() {
659 let prompt = "Write a poem";
660 let mock_llm = MockChatCompletion::new();
661 let mock_tool = MockTool::new("mock_tool");
662
663 let chat_request = chat_request! {
664 user!("Write a poem");
665
666 tools = [mock_tool.clone()]
667 };
668
669 let mock_tool_response = chat_response! {
670 "Roses are red";
671 tool_calls = ["mock_tool"]
672
673 };
674
675 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
676
677 let chat_request = chat_request! {
678 user!("Write a poem"),
679 assistant!("Roses are red", ["mock_tool"]),
680 tool_output!("mock_tool", "Great!");
681
682 tools = [mock_tool.clone()]
683 };
684
685 let stop_response = chat_response! {
686 "Roses are red";
687 tool_calls = ["stop"]
688 };
689
690 mock_llm.expect_complete(chat_request, Ok(stop_response));
691 mock_tool.expect_invoke_ok("Great!".into(), None);
692
693 let mut agent = Agent::builder()
694 .tools([mock_tool])
695 .llm(&mock_llm)
696 .no_system_prompt()
697 .build()
698 .unwrap();
699
700 agent.query(prompt).await.unwrap();
701 }
702
703 #[test_log::test(tokio::test)]
704 async fn test_agent_tool_run_once() {
705 let prompt = "Write a poem";
706 let mock_llm = MockChatCompletion::new();
707 let mock_tool = MockTool::default();
708
709 let chat_request = chat_request! {
710 system!("My system prompt"),
711 user!("Write a poem");
712
713 tools = [mock_tool.clone()]
714 };
715
716 let mock_tool_response = chat_response! {
717 "Roses are red";
718 tool_calls = ["mock_tool"]
719
720 };
721
722 mock_tool.expect_invoke_ok("Great!".into(), None);
723 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
724
725 let mut agent = Agent::builder()
726 .tools([mock_tool])
727 .system_prompt("My system prompt")
728 .llm(&mock_llm)
729 .build()
730 .unwrap();
731
732 agent.query_once(prompt).await.unwrap();
733 }
734
735 #[test_log::test(tokio::test(flavor = "multi_thread"))]
736 async fn test_multiple_tool_calls() {
737 let prompt = "Write a poem";
738 let mock_llm = MockChatCompletion::new();
739 let mock_tool = MockTool::new("mock_tool1");
740 let mock_tool2 = MockTool::new("mock_tool2");
741
742 let chat_request = chat_request! {
743 system!("My system prompt"),
744 user!("Write a poem");
745
746
747
748 tools = [mock_tool.clone(), mock_tool2.clone()]
749 };
750
751 let mock_tool_response = chat_response! {
752 "Roses are red";
753
754 tool_calls = ["mock_tool1", "mock_tool2"]
755
756 };
757
758 dbg!(&chat_request);
759 mock_tool.expect_invoke_ok("Great!".into(), None);
760 mock_tool2.expect_invoke_ok("Great!".into(), None);
761 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
762
763 let chat_request = chat_request! {
764 system!("My system prompt"),
765 user!("Write a poem"),
766 assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
767 tool_output!("mock_tool1", "Great!"),
768 tool_output!("mock_tool2", "Great!");
769
770 tools = [mock_tool.clone(), mock_tool2.clone()]
771 };
772
773 let mock_tool_response = chat_response! {
774 "Ok!";
775
776 tool_calls = ["stop"]
777
778 };
779
780 mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
781
782 let mut agent = Agent::builder()
783 .tools([mock_tool, mock_tool2])
784 .system_prompt("My system prompt")
785 .llm(&mock_llm)
786 .build()
787 .unwrap();
788
789 agent.query(prompt).await.unwrap();
790 }
791
792 #[test_log::test(tokio::test)]
793 async fn test_agent_state_machine() {
794 let prompt = "Write a poem";
795 let mock_llm = MockChatCompletion::new();
796
797 let chat_request = chat_request! {
798 user!("Write a poem");
799 tools = []
800 };
801 let mock_tool_response = chat_response! {
802 "Roses are red";
803 tool_calls = []
804 };
805
806 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
807 let mut agent = Agent::builder()
808 .llm(&mock_llm)
809 .no_system_prompt()
810 .build()
811 .unwrap();
812
813 assert!(agent.state.is_pending());
815 agent.query_once(prompt).await.unwrap();
816
817 assert!(agent.state.is_stopped());
819 }
820
821 #[test_log::test(tokio::test)]
822 async fn test_summary() {
823 let prompt = "Write a poem";
824 let mock_llm = MockChatCompletion::new();
825
826 let mock_tool_response = chat_response! {
827 "Roses are red";
828 tool_calls = []
829
830 };
831
832 let expected_chat_request = chat_request! {
833 system!("My system prompt"),
834 user!("Write a poem");
835
836 tools = []
837 };
838
839 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
840
841 let mut agent = Agent::builder()
842 .system_prompt("My system prompt")
843 .llm(&mock_llm)
844 .build()
845 .unwrap();
846
847 agent.query_once(prompt).await.unwrap();
848
849 agent
850 .context
851 .add_message(ChatMessage::new_summary("Summary"))
852 .await;
853
854 let expected_chat_request = chat_request! {
855 system!("My system prompt"),
856 summary!("Summary"),
857 user!("Write another poem");
858 tools = []
859 };
860 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
861
862 agent.query_once("Write another poem").await.unwrap();
863
864 agent
865 .context
866 .add_message(ChatMessage::new_summary("Summary 2"))
867 .await;
868
869 let expected_chat_request = chat_request! {
870 system!("My system prompt"),
871 summary!("Summary 2"),
872 user!("Write a third poem");
873 tools = []
874 };
875 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
876
877 agent.query_once("Write a third poem").await.unwrap();
878 }
879
880 #[test_log::test(tokio::test)]
881 async fn test_agent_hooks() {
882 let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
883 let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
884 let mock_before_completion = MockHook::new("before_completion")
885 .expect_calls(2)
886 .to_owned();
887 let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
888 let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
889 let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
890
891 let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
893 let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
894
895 let prompt = "Write a poem";
896 let mock_llm = MockChatCompletion::new();
897 let mock_tool = MockTool::default();
898
899 let chat_request = chat_request! {
900 user!("Write a poem");
901
902 tools = [mock_tool.clone()]
903 };
904
905 let mock_tool_response = chat_response! {
906 "Roses are red";
907 tool_calls = ["mock_tool"]
908
909 };
910
911 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
912
913 let chat_request = chat_request! {
914 user!("Write a poem"),
915 assistant!("Roses are red", ["mock_tool"]),
916 tool_output!("mock_tool", "Great!");
917
918 tools = [mock_tool.clone()]
919 };
920
921 let stop_response = chat_response! {
922 "Roses are red";
923 tool_calls = ["stop"]
924 };
925
926 mock_llm.expect_complete(chat_request, Ok(stop_response));
927 mock_tool.expect_invoke_ok("Great!".into(), None);
928
929 let mut agent = Agent::builder()
930 .tools([mock_tool])
931 .llm(&mock_llm)
932 .no_system_prompt()
933 .before_all(mock_before_all.hook_fn())
934 .on_start(mock_on_start_fn.on_start_fn())
935 .before_completion(mock_before_completion.before_completion_fn())
936 .before_tool(mock_before_tool.before_tool_fn())
937 .after_completion(mock_after_completion.after_completion_fn())
938 .after_tool(mock_after_tool.after_tool_fn())
939 .after_each(mock_after_each.hook_fn())
940 .on_new_message(mock_on_message.message_hook_fn())
941 .build()
942 .unwrap();
943
944 agent.query(prompt).await.unwrap();
945 }
946
947 #[test_log::test(tokio::test)]
948 async fn test_agent_loop_limit() {
949 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
951 let mock_tool = MockTool::new("mock_tool");
952
953 let chat_request = chat_request! {
954 user!(prompt);
955 tools = [mock_tool.clone()]
956 };
957 mock_tool.expect_invoke_ok("Great!".into(), None);
958
959 let mock_tool_response = chat_response! {
960 "Some response";
961 tool_calls = ["mock_tool"]
962 };
963
964 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
966
967 let stop_response = chat_response! {
969 "Final response";
970 tool_calls = ["stop"]
971 };
972
973 mock_llm.expect_complete(chat_request, Ok(stop_response));
974
975 let mut agent = Agent::builder()
976 .tools([mock_tool])
977 .llm(&mock_llm)
978 .no_system_prompt()
979 .limit(1) .build()
981 .unwrap();
982
983 agent.query(prompt).await.unwrap();
985
986 let remaining = mock_llm.expectations.lock().unwrap().pop();
988 assert!(remaining.is_some());
989
990 assert!(agent.is_stopped());
992 }
993
994 #[test_log::test(tokio::test)]
995 async fn test_tool_retry_mechanism() {
996 let prompt = "Execute my tool";
997 let mock_llm = MockChatCompletion::new();
998 let mock_tool = MockTool::new("retry_tool");
999
1000 mock_tool.expect_invoke(
1003 Err(ToolError::WrongArguments(serde_json::Error::custom(
1004 "missing `query`",
1005 ))),
1006 None,
1007 );
1008 mock_tool.expect_invoke(
1009 Err(ToolError::WrongArguments(serde_json::Error::custom(
1010 "missing `query`",
1011 ))),
1012 None,
1013 );
1014
1015 let chat_request = chat_request! {
1016 user!(prompt);
1017 tools = [mock_tool.clone()]
1018 };
1019 let retry_response = chat_response! {
1020 "First failing attempt";
1021 tool_calls = ["retry_tool"]
1022 };
1023 mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1024
1025 let chat_request = chat_request! {
1026 user!(prompt),
1027 assistant!("First failing attempt", ["retry_tool"]),
1028 tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1029
1030 tools = [mock_tool.clone()]
1031 };
1032 let will_fail_response = chat_response! {
1033 "Finished execution";
1034 tool_calls = ["retry_tool"]
1035 };
1036 mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1037
1038 let mut agent = Agent::builder()
1039 .tools([mock_tool])
1040 .llm(&mock_llm)
1041 .no_system_prompt()
1042 .tool_retry_limit(1) .build()
1044 .unwrap();
1045
1046 let result = agent.query(prompt).await;
1048
1049 assert!(result.is_err());
1050 assert!(result.unwrap_err().to_string().contains("missing `query`"));
1051 assert!(agent.is_stopped());
1052 }
1053}