1#![allow(dead_code)]
2use crate::{
3 default_context::DefaultContext,
4 hooks::{
5 AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
6 Hook, HookTypes, MessageHookFn, OnStartFn,
7 },
8 state,
9 system_prompt::SystemPrompt,
10 tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
11};
12use std::{
13 collections::{HashMap, HashSet},
14 hash::{DefaultHasher, Hash as _, Hasher as _},
15 sync::Arc,
16};
17
18use anyhow::Result;
19use derive_builder::Builder;
20use swiftide_core::{
21 chat_completion::{
22 ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
23 },
24 prompt::Prompt,
25 AgentContext,
26};
27use tracing::{debug, Instrument};
28
29#[derive(Clone, Builder)]
41pub struct Agent {
42 #[builder(default, setter(into))]
44 pub(crate) hooks: Vec<Hook>,
45 #[builder(
47 setter(custom),
48 default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
49 )]
50 pub(crate) context: Arc<dyn AgentContext>,
51 #[builder(default = Agent::default_tools(), setter(custom))]
53 pub(crate) tools: HashSet<Box<dyn Tool>>,
54
55 #[builder(setter(custom))]
57 pub(crate) llm: Box<dyn ChatCompletion>,
58
59 #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
79 pub(crate) system_prompt: Option<Prompt>,
80
81 #[builder(private, default = state::State::default())]
83 pub(crate) state: state::State,
84
85 #[builder(default, setter(strip_option))]
88 pub(crate) limit: Option<usize>,
89
90 #[builder(default = 3)]
102 pub(crate) tool_retry_limit: usize,
103
104 #[builder(private, default)]
107 pub(crate) tool_retries_counter: HashMap<u64, usize>,
108}
109
110impl std::fmt::Debug for Agent {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 f.debug_struct("Agent")
113 .field(
115 "hooks",
116 &self
117 .hooks
118 .iter()
119 .map(std::string::ToString::to_string)
120 .collect::<Vec<_>>(),
121 )
122 .field(
123 "tools",
124 &self
125 .tools
126 .iter()
127 .map(swiftide_core::Tool::name)
128 .collect::<Vec<_>>(),
129 )
130 .field("llm", &"Box<dyn ChatCompletion>")
131 .field("state", &self.state)
132 .finish()
133 }
134}
135
136impl AgentBuilder {
137 pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
139 where
140 Self: Clone,
141 {
142 self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
143 self
144 }
145
146 pub fn no_system_prompt(&mut self) -> &mut Self {
148 self.system_prompt = Some(None);
149
150 self
151 }
152
153 pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
155 let hooks = self.hooks.get_or_insert_with(Vec::new);
156 hooks.push(hook);
157
158 self
159 }
160
161 pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
164 self.add_hook(Hook::BeforeAll(Box::new(hook)))
165 }
166
167 pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
171 self.add_hook(Hook::OnStart(Box::new(hook)))
172 }
173
174 pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
176 self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
177 }
178
179 pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
185 self.add_hook(Hook::AfterTool(Box::new(hook)))
186 }
187
188 pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
190 self.add_hook(Hook::BeforeTool(Box::new(hook)))
191 }
192
193 pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
195 self.add_hook(Hook::AfterCompletion(Box::new(hook)))
196 }
197
198 pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
201 self.add_hook(Hook::AfterEach(Box::new(hook)))
202 }
203
204 pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
207 self.add_hook(Hook::OnNewMessage(Box::new(hook)))
208 }
209
210 pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
212 let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
213
214 self.llm = Some(boxed);
215 self
216 }
217
218 pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
223 where
224 TOOL: Into<Box<dyn Tool>>,
225 {
226 self.tools = Some(
227 tools
228 .into_iter()
229 .map(Into::into)
230 .chain(Agent::default_tools())
231 .collect(),
232 );
233 self
234 }
235}
236
237impl Agent {
238 pub fn builder() -> AgentBuilder {
240 AgentBuilder::default()
241 }
242}
243
244impl Agent {
245 fn default_tools() -> HashSet<Box<dyn Tool>> {
247 HashSet::from([Box::new(Stop::default()) as Box<dyn Tool>])
248 }
249
250 #[tracing::instrument(skip_all, name = "agent.query")]
253 pub async fn query(&mut self, query: impl Into<String> + std::fmt::Debug) -> Result<()> {
254 self.run_agent(Some(query.into()), false).await
255 }
256
257 #[tracing::instrument(skip_all, name = "agent.query_once")]
259 pub async fn query_once(&mut self, query: impl Into<String> + std::fmt::Debug) -> Result<()> {
260 self.run_agent(Some(query.into()), true).await
261 }
262
263 #[tracing::instrument(skip_all, name = "agent.run")]
266 pub async fn run(&mut self) -> Result<()> {
267 self.run_agent(None, false).await
268 }
269
270 #[tracing::instrument(skip_all, name = "agent.run_once")]
273 pub async fn run_once(&mut self) -> Result<()> {
274 self.run_agent(None, true).await
275 }
276
277 pub async fn history(&self) -> Vec<ChatMessage> {
279 self.context.history().await
280 }
281
282 async fn run_agent(&mut self, maybe_query: Option<String>, just_once: bool) -> Result<()> {
283 if self.state.is_running() {
284 anyhow::bail!("Agent is already running");
285 }
286
287 if self.state.is_pending() {
288 if let Some(system_prompt) = &self.system_prompt {
289 self.context
290 .add_messages(vec![ChatMessage::System(system_prompt.render().await?)])
291 .await;
292 }
293 for hook in self.hooks_by_type(HookTypes::BeforeAll) {
294 if let Hook::BeforeAll(hook) = hook {
295 let span = tracing::info_span!(
296 "hook",
297 "otel.name" = format!("hook.{}", HookTypes::AfterTool)
298 );
299 tracing::info!("Calling {} hook", HookTypes::AfterTool);
300 hook(self).instrument(span.or_current()).await?;
301 }
302 }
303 }
304
305 for hook in self.hooks_by_type(HookTypes::OnStart) {
306 if let Hook::OnStart(hook) = hook {
307 let span = tracing::info_span!(
308 "hook",
309 "otel.name" = format!("hook.{}", HookTypes::OnStart)
310 );
311 tracing::info!("Calling {} hook", HookTypes::OnStart);
312 hook(self).instrument(span.or_current()).await?;
313 }
314 }
315
316 self.state = state::State::Running;
317
318 if let Some(query) = maybe_query {
319 self.context.add_message(ChatMessage::User(query)).await;
320 }
321
322 let mut loop_counter = 0;
323
324 while let Some(messages) = self.context.next_completion().await {
325 if let Some(limit) = self.limit {
326 if loop_counter >= limit {
327 tracing::warn!("Agent loop limit reached");
328 break;
329 }
330 }
331 let result = self.run_completions(&messages).await;
332
333 if let Err(err) = result {
334 self.stop();
335 tracing::error!(error = ?err, "Agent stopped with error {err}");
336 return Err(err);
337 }
338
339 if just_once || self.state.is_stopped() {
340 break;
341 }
342 loop_counter += 1;
343 }
344
345 self.stop();
347
348 Ok(())
349 }
350
351 #[tracing::instrument(skip_all, err)]
352 async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<()> {
353 debug!(
354 "Running completion for agent with {} messages",
355 messages.len()
356 );
357
358 let mut chat_completion_request = ChatCompletionRequest::builder()
359 .messages(messages)
360 .tools_spec(
361 self.tools
362 .iter()
363 .map(swiftide_core::Tool::tool_spec)
364 .collect::<HashSet<_>>(),
365 )
366 .build()?;
367
368 for hook in self.hooks_by_type(HookTypes::BeforeCompletion) {
369 if let Hook::BeforeCompletion(hook) = hook {
370 let span = tracing::info_span!(
371 "hook",
372 "otel.name" = format!("hook.{}", HookTypes::BeforeCompletion)
373 );
374 tracing::info!("Calling {} hook", HookTypes::BeforeCompletion);
375 hook(&*self, &mut chat_completion_request)
376 .instrument(span.or_current())
377 .await?;
378 }
379 }
380
381 debug!(
382 "Calling LLM with the following new messages:\n {}",
383 self.context
384 .current_new_messages()
385 .await
386 .iter()
387 .map(ToString::to_string)
388 .collect::<Vec<_>>()
389 .join(",\n")
390 );
391
392 let mut response = self.llm.complete(&chat_completion_request).await?;
393
394 for hook in self.hooks_by_type(HookTypes::AfterCompletion) {
395 if let Hook::AfterCompletion(hook) = hook {
396 let span = tracing::info_span!(
397 "hook",
398 "otel.name" = format!("hook.{}", HookTypes::AfterCompletion)
399 );
400 tracing::info!("Calling {} hook", HookTypes::AfterCompletion);
401 hook(&*self, &mut response)
402 .instrument(span.or_current())
403 .await?;
404 }
405 }
406 self.add_message(ChatMessage::Assistant(
407 response.message,
408 response.tool_calls.clone(),
409 ))
410 .await?;
411
412 if let Some(tool_calls) = response.tool_calls {
413 self.invoke_tools(tool_calls).await?;
414 }
415
416 for hook in self.hooks_by_type(HookTypes::AfterEach) {
417 if let Hook::AfterEach(hook) = hook {
418 let span = tracing::info_span!(
419 "hook",
420 "otel.name" = format!("hook.{}", HookTypes::AfterEach)
421 );
422 tracing::info!("Calling {} hook", HookTypes::AfterEach);
423 hook(&*self).instrument(span.or_current()).await?;
424 }
425 }
426
427 Ok(())
428 }
429
430 async fn invoke_tools(&mut self, tool_calls: Vec<ToolCall>) -> Result<()> {
431 debug!("LLM returned tool calls: {:?}", tool_calls);
432
433 let mut handles = vec![];
434 for tool_call in tool_calls {
435 let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
436 tracing::warn!("Tool {} not found", tool_call.name());
437 continue;
438 };
439 tracing::info!("Calling tool `{}`", tool_call.name());
440
441 let tool_args = tool_call.args().map(String::from);
442 let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
443
444 for hook in self.hooks_by_type(HookTypes::BeforeTool) {
445 if let Hook::BeforeTool(hook) = hook {
446 let span = tracing::info_span!(
447 "hook",
448 "otel.name" = format!("hook.{}", HookTypes::BeforeTool)
449 );
450 tracing::info!("Calling {} hook", HookTypes::BeforeTool);
451 hook(&*self, &tool_call)
452 .instrument(span.or_current())
453 .await?;
454 }
455 }
456
457 let tool_span = tracing::info_span!(
458 "tool",
459 "otel.name" = format!("tool.{}", tool.name().as_ref())
460 );
461
462 let handle = tokio::spawn(async move {
463 let tool_args = ArgPreprocessor::preprocess(tool_args.as_deref());
464 let output = tool.invoke(&*context, tool_args.as_deref()).await.map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
465
466 tracing::debug!(output = output.to_string(), args = ?tool_args, tool_name = tool.name().as_ref(), "Completed tool call");
467
468 Ok(output)
469 }.instrument(tool_span.or_current()));
470
471 handles.push((handle, tool_call));
472 }
473
474 for (handle, tool_call) in handles {
475 let mut output = handle.await?;
476
477 for hook in self.hooks_by_type(HookTypes::AfterTool) {
479 if let Hook::AfterTool(hook) = hook {
480 let span = tracing::info_span!(
481 "hook",
482 "otel.name" = format!("hook.{}", HookTypes::AfterTool)
483 );
484 tracing::info!("Calling {} hook", HookTypes::AfterTool);
485 hook(&*self, &tool_call, &mut output)
486 .instrument(span.or_current())
487 .await?;
488 }
489 }
490
491 if let Err(error) = output {
492 let stop = self.tool_calls_over_limit(&tool_call);
493 if stop {
494 tracing::error!(
495 "Tool call failed, retry limit reached, stopping agent: {err}",
496 err = error
497 );
498 } else {
499 tracing::warn!(
500 error = error.to_string(),
501 tool_call = ?tool_call,
502 "Tool call failed, retrying",
503 );
504 }
505 self.add_message(ChatMessage::ToolOutput(
506 tool_call,
507 ToolOutput::Fail(error.to_string()),
508 ))
509 .await?;
510 if stop {
511 self.stop();
512 return Err(error.into());
513 }
514 continue;
515 }
516
517 let output = output?;
518 self.handle_control_tools(&output);
519 self.add_message(ChatMessage::ToolOutput(tool_call, output))
520 .await?;
521 }
522
523 Ok(())
524 }
525
526 fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
527 self.hooks
528 .iter()
529 .filter(|h| hook_type == (*h).into())
530 .collect()
531 }
532
533 fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
534 self.tools
535 .iter()
536 .find(|tool| tool.name() == tool_name)
537 .cloned()
538 }
539
540 fn handle_control_tools(&mut self, output: &ToolOutput) {
542 if let ToolOutput::Stop = output {
543 tracing::warn!("Stop tool called, stopping agent");
544 self.stop();
545 }
546 }
547
548 fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
549 let mut s = DefaultHasher::new();
550 tool_call.hash(&mut s);
551 let hash = s.finish();
552
553 if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
554 let val = *retries >= self.tool_retry_limit;
555 *retries += 1;
556 val
557 } else {
558 self.tool_retries_counter.insert(hash, 1);
559 false
560 }
561 }
562
563 #[tracing::instrument(skip_all, fields(message = message.to_string()))]
569 pub async fn add_message(&self, mut message: ChatMessage) -> Result<()> {
570 for hook in self.hooks_by_type(HookTypes::OnNewMessage) {
571 if let Hook::OnNewMessage(hook) = hook {
572 let span = tracing::info_span!(
573 "hook",
574 "otel.name" = format!("hook.{}", HookTypes::OnNewMessage)
575 );
576 if let Err(err) = hook(self, &mut message).instrument(span.or_current()).await {
577 tracing::error!(
578 "Error in {hooktype} hook: {err}",
579 hooktype = HookTypes::OnNewMessage,
580 );
581 }
582 }
583 }
584 self.context.add_message(message).await;
585 Ok(())
586 }
587
588 pub fn stop(&mut self) {
590 self.state = state::State::Stopped;
591 }
592
593 pub fn context(&self) -> &dyn AgentContext {
595 &self.context
596 }
597
598 pub fn is_running(&self) -> bool {
600 self.state.is_running()
601 }
602
603 pub fn is_stopped(&self) -> bool {
605 self.state.is_stopped()
606 }
607
608 pub fn is_pending(&self) -> bool {
610 self.state.is_pending()
611 }
612}
613
614#[cfg(test)]
615mod tests {
616
617 use serde::ser::Error;
618 use swiftide_core::chat_completion::errors::ToolError;
619 use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
620 use swiftide_core::test_utils::MockChatCompletion;
621
622 use super::*;
623 use crate::{
624 assistant, chat_request, chat_response, summary, system, tool_failed, tool_output, user,
625 };
626
627 use crate::test_utils::{MockHook, MockTool};
628
629 #[test_log::test(tokio::test)]
630 async fn test_agent_builder_defaults() {
631 let mock_llm = MockChatCompletion::new();
633
634 let agent = Agent::builder().llm(&mock_llm).build().unwrap();
636
637 assert!(agent.find_tool_by_name("stop").is_some());
641
642 let agent = Agent::builder()
644 .tools([Stop::default(), Stop::default()])
645 .llm(&mock_llm)
646 .build()
647 .unwrap();
648
649 assert_eq!(agent.tools.len(), 1);
650
651 let agent = Agent::builder()
653 .tools([MockTool::new("mock_tool")])
654 .llm(&mock_llm)
655 .build()
656 .unwrap();
657
658 assert_eq!(agent.tools.len(), 2);
659 assert!(agent.find_tool_by_name("mock_tool").is_some());
660 assert!(agent.find_tool_by_name("stop").is_some());
661
662 assert!(agent.context().history().await.is_empty());
663 }
664
665 #[test_log::test(tokio::test)]
666 async fn test_agent_tool_calling_loop() {
667 let prompt = "Write a poem";
668 let mock_llm = MockChatCompletion::new();
669 let mock_tool = MockTool::new("mock_tool");
670
671 let chat_request = chat_request! {
672 user!("Write a poem");
673
674 tools = [mock_tool.clone()]
675 };
676
677 let mock_tool_response = chat_response! {
678 "Roses are red";
679 tool_calls = ["mock_tool"]
680
681 };
682
683 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
684
685 let chat_request = chat_request! {
686 user!("Write a poem"),
687 assistant!("Roses are red", ["mock_tool"]),
688 tool_output!("mock_tool", "Great!");
689
690 tools = [mock_tool.clone()]
691 };
692
693 let stop_response = chat_response! {
694 "Roses are red";
695 tool_calls = ["stop"]
696 };
697
698 mock_llm.expect_complete(chat_request, Ok(stop_response));
699 mock_tool.expect_invoke_ok("Great!".into(), None);
700
701 let mut agent = Agent::builder()
702 .tools([mock_tool])
703 .llm(&mock_llm)
704 .no_system_prompt()
705 .build()
706 .unwrap();
707
708 agent.query(prompt).await.unwrap();
709 }
710
711 #[test_log::test(tokio::test)]
712 async fn test_agent_tool_run_once() {
713 let prompt = "Write a poem";
714 let mock_llm = MockChatCompletion::new();
715 let mock_tool = MockTool::default();
716
717 let chat_request = chat_request! {
718 system!("My system prompt"),
719 user!("Write a poem");
720
721 tools = [mock_tool.clone()]
722 };
723
724 let mock_tool_response = chat_response! {
725 "Roses are red";
726 tool_calls = ["mock_tool"]
727
728 };
729
730 mock_tool.expect_invoke_ok("Great!".into(), None);
731 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
732
733 let mut agent = Agent::builder()
734 .tools([mock_tool])
735 .system_prompt("My system prompt")
736 .llm(&mock_llm)
737 .build()
738 .unwrap();
739
740 agent.query_once(prompt).await.unwrap();
741 }
742
743 #[test_log::test(tokio::test(flavor = "multi_thread"))]
744 async fn test_multiple_tool_calls() {
745 let prompt = "Write a poem";
746 let mock_llm = MockChatCompletion::new();
747 let mock_tool = MockTool::new("mock_tool1");
748 let mock_tool2 = MockTool::new("mock_tool2");
749
750 let chat_request = chat_request! {
751 system!("My system prompt"),
752 user!("Write a poem");
753
754
755
756 tools = [mock_tool.clone(), mock_tool2.clone()]
757 };
758
759 let mock_tool_response = chat_response! {
760 "Roses are red";
761
762 tool_calls = ["mock_tool1", "mock_tool2"]
763
764 };
765
766 dbg!(&chat_request);
767 mock_tool.expect_invoke_ok("Great!".into(), None);
768 mock_tool2.expect_invoke_ok("Great!".into(), None);
769 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
770
771 let chat_request = chat_request! {
772 system!("My system prompt"),
773 user!("Write a poem"),
774 assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
775 tool_output!("mock_tool1", "Great!"),
776 tool_output!("mock_tool2", "Great!");
777
778 tools = [mock_tool.clone(), mock_tool2.clone()]
779 };
780
781 let mock_tool_response = chat_response! {
782 "Ok!";
783
784 tool_calls = ["stop"]
785
786 };
787
788 mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
789
790 let mut agent = Agent::builder()
791 .tools([mock_tool, mock_tool2])
792 .system_prompt("My system prompt")
793 .llm(&mock_llm)
794 .build()
795 .unwrap();
796
797 agent.query(prompt).await.unwrap();
798 }
799
800 #[test_log::test(tokio::test)]
801 async fn test_agent_state_machine() {
802 let prompt = "Write a poem";
803 let mock_llm = MockChatCompletion::new();
804
805 let chat_request = chat_request! {
806 user!("Write a poem");
807 tools = []
808 };
809 let mock_tool_response = chat_response! {
810 "Roses are red";
811 tool_calls = []
812 };
813
814 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
815 let mut agent = Agent::builder()
816 .llm(&mock_llm)
817 .no_system_prompt()
818 .build()
819 .unwrap();
820
821 assert!(agent.state.is_pending());
823 agent.query_once(prompt).await.unwrap();
824
825 assert!(agent.state.is_stopped());
827 }
828
829 #[test_log::test(tokio::test)]
830 async fn test_summary() {
831 let prompt = "Write a poem";
832 let mock_llm = MockChatCompletion::new();
833
834 let mock_tool_response = chat_response! {
835 "Roses are red";
836 tool_calls = []
837
838 };
839
840 let expected_chat_request = chat_request! {
841 system!("My system prompt"),
842 user!("Write a poem");
843
844 tools = []
845 };
846
847 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
848
849 let mut agent = Agent::builder()
850 .system_prompt("My system prompt")
851 .llm(&mock_llm)
852 .build()
853 .unwrap();
854
855 agent.query_once(prompt).await.unwrap();
856
857 agent
858 .context
859 .add_message(ChatMessage::new_summary("Summary"))
860 .await;
861
862 let expected_chat_request = chat_request! {
863 system!("My system prompt"),
864 summary!("Summary"),
865 user!("Write another poem");
866 tools = []
867 };
868 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
869
870 agent.query_once("Write another poem").await.unwrap();
871
872 agent
873 .context
874 .add_message(ChatMessage::new_summary("Summary 2"))
875 .await;
876
877 let expected_chat_request = chat_request! {
878 system!("My system prompt"),
879 summary!("Summary 2"),
880 user!("Write a third poem");
881 tools = []
882 };
883 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
884
885 agent.query_once("Write a third poem").await.unwrap();
886 }
887
888 #[test_log::test(tokio::test)]
889 async fn test_agent_hooks() {
890 let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
891 let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
892 let mock_before_completion = MockHook::new("before_completion")
893 .expect_calls(2)
894 .to_owned();
895 let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
896 let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
897 let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
898
899 let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
901 let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
902
903 let prompt = "Write a poem";
904 let mock_llm = MockChatCompletion::new();
905 let mock_tool = MockTool::default();
906
907 let chat_request = chat_request! {
908 user!("Write a poem");
909
910 tools = [mock_tool.clone()]
911 };
912
913 let mock_tool_response = chat_response! {
914 "Roses are red";
915 tool_calls = ["mock_tool"]
916
917 };
918
919 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
920
921 let chat_request = chat_request! {
922 user!("Write a poem"),
923 assistant!("Roses are red", ["mock_tool"]),
924 tool_output!("mock_tool", "Great!");
925
926 tools = [mock_tool.clone()]
927 };
928
929 let stop_response = chat_response! {
930 "Roses are red";
931 tool_calls = ["stop"]
932 };
933
934 mock_llm.expect_complete(chat_request, Ok(stop_response));
935 mock_tool.expect_invoke_ok("Great!".into(), None);
936
937 let mut agent = Agent::builder()
938 .tools([mock_tool])
939 .llm(&mock_llm)
940 .no_system_prompt()
941 .before_all(mock_before_all.hook_fn())
942 .on_start(mock_on_start_fn.on_start_fn())
943 .before_completion(mock_before_completion.before_completion_fn())
944 .before_tool(mock_before_tool.before_tool_fn())
945 .after_completion(mock_after_completion.after_completion_fn())
946 .after_tool(mock_after_tool.after_tool_fn())
947 .after_each(mock_after_each.hook_fn())
948 .on_new_message(mock_on_message.message_hook_fn())
949 .build()
950 .unwrap();
951
952 agent.query(prompt).await.unwrap();
953 }
954
955 #[test_log::test(tokio::test)]
956 async fn test_agent_loop_limit() {
957 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
959 let mock_tool = MockTool::new("mock_tool");
960
961 let chat_request = chat_request! {
962 user!(prompt);
963 tools = [mock_tool.clone()]
964 };
965 mock_tool.expect_invoke_ok("Great!".into(), None);
966
967 let mock_tool_response = chat_response! {
968 "Some response";
969 tool_calls = ["mock_tool"]
970 };
971
972 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
974
975 let stop_response = chat_response! {
977 "Final response";
978 tool_calls = ["stop"]
979 };
980
981 mock_llm.expect_complete(chat_request, Ok(stop_response));
982
983 let mut agent = Agent::builder()
984 .tools([mock_tool])
985 .llm(&mock_llm)
986 .no_system_prompt()
987 .limit(1) .build()
989 .unwrap();
990
991 agent.query(prompt).await.unwrap();
993
994 let remaining = mock_llm.expectations.lock().unwrap().pop();
996 assert!(remaining.is_some());
997
998 assert!(agent.is_stopped());
1000 }
1001
1002 #[test_log::test(tokio::test)]
1003 async fn test_tool_retry_mechanism() {
1004 let prompt = "Execute my tool";
1005 let mock_llm = MockChatCompletion::new();
1006 let mock_tool = MockTool::new("retry_tool");
1007
1008 mock_tool.expect_invoke(
1011 Err(ToolError::WrongArguments(serde_json::Error::custom(
1012 "missing `query`",
1013 ))),
1014 None,
1015 );
1016 mock_tool.expect_invoke(
1017 Err(ToolError::WrongArguments(serde_json::Error::custom(
1018 "missing `query`",
1019 ))),
1020 None,
1021 );
1022
1023 let chat_request = chat_request! {
1024 user!(prompt);
1025 tools = [mock_tool.clone()]
1026 };
1027 let retry_response = chat_response! {
1028 "First failing attempt";
1029 tool_calls = ["retry_tool"]
1030 };
1031 mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1032
1033 let chat_request = chat_request! {
1034 user!(prompt),
1035 assistant!("First failing attempt", ["retry_tool"]),
1036 tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1037
1038 tools = [mock_tool.clone()]
1039 };
1040 let will_fail_response = chat_response! {
1041 "Finished execution";
1042 tool_calls = ["retry_tool"]
1043 };
1044 mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1045
1046 let mut agent = Agent::builder()
1047 .tools([mock_tool])
1048 .llm(&mock_llm)
1049 .no_system_prompt()
1050 .tool_retry_limit(1) .build()
1052 .unwrap();
1053
1054 let result = agent.query(prompt).await;
1056
1057 assert!(result.is_err());
1058 assert!(result.unwrap_err().to_string().contains("missing `query`"));
1059 assert!(agent.is_stopped());
1060 }
1061}