1#![allow(dead_code)]
2use crate::{
3 default_context::DefaultContext,
4 hooks::{
5 AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
6 Hook, HookTypes, MessageHookFn, OnStartFn,
7 },
8 state,
9 system_prompt::SystemPrompt,
10 tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
11};
12use std::{
13 collections::{HashMap, HashSet},
14 hash::{DefaultHasher, Hash as _, Hasher as _},
15 sync::Arc,
16};
17
18use anyhow::Result;
19use derive_builder::Builder;
20use swiftide_core::{
21 chat_completion::{
22 ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
23 },
24 prompt::Prompt,
25 AgentContext,
26};
27use tracing::{debug, Instrument};
28
29#[derive(Clone, Builder)]
40pub struct Agent {
41 #[builder(default, setter(into))]
43 pub(crate) hooks: Vec<Hook>,
44 #[builder(
46 setter(custom),
47 default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
48 )]
49 pub(crate) context: Arc<dyn AgentContext>,
50 #[builder(default = Agent::default_tools(), setter(custom))]
52 pub(crate) tools: HashSet<Box<dyn Tool>>,
53
54 #[builder(setter(custom))]
56 pub(crate) llm: Box<dyn ChatCompletion>,
57
58 #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
78 pub(crate) system_prompt: Option<Prompt>,
79
80 #[builder(private, default = state::State::default())]
82 pub(crate) state: state::State,
83
84 #[builder(default, setter(strip_option))]
87 pub(crate) limit: Option<usize>,
88
89 #[builder(default = 3)]
100 pub(crate) tool_retry_limit: usize,
101
102 #[builder(private, default)]
105 pub(crate) tool_retries_counter: HashMap<u64, usize>,
106}
107
108impl std::fmt::Debug for Agent {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 f.debug_struct("Agent")
111 .field(
113 "hooks",
114 &self
115 .hooks
116 .iter()
117 .map(std::string::ToString::to_string)
118 .collect::<Vec<_>>(),
119 )
120 .field(
121 "tools",
122 &self
123 .tools
124 .iter()
125 .map(swiftide_core::Tool::name)
126 .collect::<Vec<_>>(),
127 )
128 .field("llm", &"Box<dyn ChatCompletion>")
129 .field("state", &self.state)
130 .finish()
131 }
132}
133
134impl AgentBuilder {
135 pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
137 where
138 Self: Clone,
139 {
140 self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
141 self
142 }
143
144 pub fn no_system_prompt(&mut self) -> &mut Self {
146 self.system_prompt = Some(None);
147
148 self
149 }
150
151 pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
153 let hooks = self.hooks.get_or_insert_with(Vec::new);
154 hooks.push(hook);
155
156 self
157 }
158
159 pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
162 self.add_hook(Hook::BeforeAll(Box::new(hook)))
163 }
164
165 pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
169 self.add_hook(Hook::OnStart(Box::new(hook)))
170 }
171
172 pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
174 self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
175 }
176
177 pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
183 self.add_hook(Hook::AfterTool(Box::new(hook)))
184 }
185
186 pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
188 self.add_hook(Hook::BeforeTool(Box::new(hook)))
189 }
190
191 pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
193 self.add_hook(Hook::AfterCompletion(Box::new(hook)))
194 }
195
196 pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
199 self.add_hook(Hook::AfterEach(Box::new(hook)))
200 }
201
202 pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
205 self.add_hook(Hook::OnNewMessage(Box::new(hook)))
206 }
207
208 pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
210 let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
211
212 self.llm = Some(boxed);
213 self
214 }
215
216 pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
221 where
222 TOOL: Into<Box<dyn Tool>>,
223 {
224 self.tools = Some(
225 tools
226 .into_iter()
227 .map(Into::into)
228 .chain(Agent::default_tools())
229 .collect(),
230 );
231 self
232 }
233}
234
235impl Agent {
236 pub fn builder() -> AgentBuilder {
238 AgentBuilder::default()
239 }
240}
241
242impl Agent {
243 fn default_tools() -> HashSet<Box<dyn Tool>> {
245 HashSet::from([Box::new(Stop::default()) as Box<dyn Tool>])
246 }
247
248 #[tracing::instrument(skip_all, name = "agent.query")]
251 pub async fn query(&mut self, query: impl Into<String> + std::fmt::Debug) -> Result<()> {
252 self.run_agent(Some(query.into()), false).await
253 }
254
255 #[tracing::instrument(skip_all, name = "agent.query_once")]
257 pub async fn query_once(&mut self, query: impl Into<String> + std::fmt::Debug) -> Result<()> {
258 self.run_agent(Some(query.into()), true).await
259 }
260
261 #[tracing::instrument(skip_all, name = "agent.run")]
264 pub async fn run(&mut self) -> Result<()> {
265 self.run_agent(None, false).await
266 }
267
268 #[tracing::instrument(skip_all, name = "agent.run_once")]
270 pub async fn run_once(&mut self) -> Result<()> {
271 self.run_agent(None, true).await
272 }
273
274 pub async fn history(&self) -> Vec<ChatMessage> {
276 self.context.history().await
277 }
278
279 async fn run_agent(&mut self, maybe_query: Option<String>, just_once: bool) -> Result<()> {
280 if self.state.is_running() {
281 anyhow::bail!("Agent is already running");
282 }
283
284 if self.state.is_pending() {
285 if let Some(system_prompt) = &self.system_prompt {
286 self.context
287 .add_messages(vec![ChatMessage::System(system_prompt.render().await?)])
288 .await;
289 }
290 for hook in self.hooks_by_type(HookTypes::BeforeAll) {
291 if let Hook::BeforeAll(hook) = hook {
292 let span = tracing::info_span!(
293 "hook",
294 "otel.name" = format!("hook.{}", HookTypes::AfterTool)
295 );
296 tracing::info!("Calling {} hook", HookTypes::AfterTool);
297 hook(self).instrument(span.or_current()).await?;
298 }
299 }
300 }
301
302 for hook in self.hooks_by_type(HookTypes::OnStart) {
303 if let Hook::OnStart(hook) = hook {
304 let span = tracing::info_span!(
305 "hook",
306 "otel.name" = format!("hook.{}", HookTypes::OnStart)
307 );
308 tracing::info!("Calling {} hook", HookTypes::OnStart);
309 hook(self).instrument(span.or_current()).await?;
310 }
311 }
312
313 self.state = state::State::Running;
314
315 if let Some(query) = maybe_query {
316 self.context.add_message(ChatMessage::User(query)).await;
317 }
318
319 let mut loop_counter = 0;
320
321 while let Some(messages) = self.context.next_completion().await {
322 if let Some(limit) = self.limit {
323 if loop_counter >= limit {
324 tracing::warn!("Agent loop limit reached");
325 break;
326 }
327 }
328 let result = self.run_completions(&messages).await;
329
330 if let Err(err) = result {
331 self.stop();
332 tracing::error!(error = ?err, "Agent stopped with error {err}");
333 return Err(err);
334 }
335
336 if just_once || self.state.is_stopped() {
337 break;
338 }
339 loop_counter += 1;
340 }
341
342 self.stop();
344
345 Ok(())
346 }
347
348 #[tracing::instrument(skip_all, err)]
349 async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<()> {
350 debug!(
351 "Running completion for agent with {} messages",
352 messages.len()
353 );
354
355 let mut chat_completion_request = ChatCompletionRequest::builder()
356 .messages(messages)
357 .tools_spec(
358 self.tools
359 .iter()
360 .map(swiftide_core::Tool::tool_spec)
361 .collect::<HashSet<_>>(),
362 )
363 .build()?;
364
365 for hook in self.hooks_by_type(HookTypes::BeforeCompletion) {
366 if let Hook::BeforeCompletion(hook) = hook {
367 let span = tracing::info_span!(
368 "hook",
369 "otel.name" = format!("hook.{}", HookTypes::BeforeCompletion)
370 );
371 tracing::info!("Calling {} hook", HookTypes::BeforeCompletion);
372 hook(&*self, &mut chat_completion_request)
373 .instrument(span.or_current())
374 .await?;
375 }
376 }
377
378 debug!(
379 "Calling LLM with the following new messages:\n {}",
380 self.context
381 .current_new_messages()
382 .await
383 .iter()
384 .map(ToString::to_string)
385 .collect::<Vec<_>>()
386 .join(",\n")
387 );
388
389 let mut response = self.llm.complete(&chat_completion_request).await?;
390
391 for hook in self.hooks_by_type(HookTypes::AfterCompletion) {
392 if let Hook::AfterCompletion(hook) = hook {
393 let span = tracing::info_span!(
394 "hook",
395 "otel.name" = format!("hook.{}", HookTypes::AfterCompletion)
396 );
397 tracing::info!("Calling {} hook", HookTypes::AfterCompletion);
398 hook(&*self, &mut response)
399 .instrument(span.or_current())
400 .await?;
401 }
402 }
403 self.add_message(ChatMessage::Assistant(
404 response.message,
405 response.tool_calls.clone(),
406 ))
407 .await?;
408
409 if let Some(tool_calls) = response.tool_calls {
410 self.invoke_tools(tool_calls).await?;
411 };
412
413 for hook in self.hooks_by_type(HookTypes::AfterEach) {
414 if let Hook::AfterEach(hook) = hook {
415 let span = tracing::info_span!(
416 "hook",
417 "otel.name" = format!("hook.{}", HookTypes::AfterEach)
418 );
419 tracing::info!("Calling {} hook", HookTypes::AfterEach);
420 hook(&*self).instrument(span.or_current()).await?;
421 }
422 }
423
424 Ok(())
425 }
426
427 async fn invoke_tools(&mut self, tool_calls: Vec<ToolCall>) -> Result<()> {
428 debug!("LLM returned tool calls: {:?}", tool_calls);
429
430 let mut handles = vec![];
431 for tool_call in tool_calls {
432 let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
433 tracing::warn!("Tool {} not found", tool_call.name());
434 continue;
435 };
436 tracing::info!("Calling tool `{}`", tool_call.name());
437
438 let tool_args = tool_call.args().map(String::from);
439 let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
440
441 for hook in self.hooks_by_type(HookTypes::BeforeTool) {
442 if let Hook::BeforeTool(hook) = hook {
443 let span = tracing::info_span!(
444 "hook",
445 "otel.name" = format!("hook.{}", HookTypes::BeforeTool)
446 );
447 tracing::info!("Calling {} hook", HookTypes::BeforeTool);
448 hook(&*self, &tool_call)
449 .instrument(span.or_current())
450 .await?;
451 }
452 }
453
454 let tool_span = tracing::info_span!(
455 "tool",
456 "otel.name" = format!("tool.{}", tool.name().as_ref())
457 );
458
459 let handle = tokio::spawn(async move {
460 let tool_args = ArgPreprocessor::preprocess(tool_args.as_deref());
461 let output = tool.invoke(&*context, tool_args.as_deref()).await.map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
462
463 tracing::debug!(output = output.to_string(), args = ?tool_args, tool_name = tool.name().as_ref(), "Completed tool call");
464
465 Ok(output)
466 }.instrument(tool_span.or_current()));
467
468 handles.push((handle, tool_call));
469 }
470
471 for (handle, tool_call) in handles {
472 let mut output = handle.await?;
473
474 for hook in self.hooks_by_type(HookTypes::AfterTool) {
476 if let Hook::AfterTool(hook) = hook {
477 let span = tracing::info_span!(
478 "hook",
479 "otel.name" = format!("hook.{}", HookTypes::AfterTool)
480 );
481 tracing::info!("Calling {} hook", HookTypes::AfterTool);
482 hook(&*self, &tool_call, &mut output)
483 .instrument(span.or_current())
484 .await?;
485 }
486 }
487
488 if let Err(error) = output {
489 if self.tool_calls_over_limit(&tool_call) {
490 tracing::error!(
491 "Tool call failed, retry limit reached, stopping agent: {err}",
492 err = error
493 );
494 self.stop();
495 return Err(error.into());
496 }
497 tracing::warn!(
498 error = error.to_string(),
499 tool_call = ?tool_call,
500 "Tool call failed, retrying",
501 );
502 self.add_message(ChatMessage::ToolOutput(
503 tool_call,
504 ToolOutput::Fail(error.to_string()),
505 ))
506 .await?;
507 continue;
508 }
509
510 let output = output?;
511 self.handle_control_tools(&output);
512 self.add_message(ChatMessage::ToolOutput(tool_call, output))
513 .await?;
514 }
515
516 Ok(())
517 }
518
519 fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
520 self.hooks
521 .iter()
522 .filter(|h| hook_type == (*h).into())
523 .collect()
524 }
525
526 fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
527 self.tools
528 .iter()
529 .find(|tool| tool.name() == tool_name)
530 .cloned()
531 }
532
533 fn handle_control_tools(&mut self, output: &ToolOutput) {
535 if let ToolOutput::Stop = output {
536 tracing::warn!("Stop tool called, stopping agent");
537 self.stop();
538 }
539 }
540
541 fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
542 let mut s = DefaultHasher::new();
543 tool_call.hash(&mut s);
544 let hash = s.finish();
545
546 if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
547 let val = *retries >= self.tool_retry_limit;
548 *retries += 1;
549 val
550 } else {
551 self.tool_retries_counter.insert(hash, 1);
552 false
553 }
554 }
555
556 #[tracing::instrument(skip_all, fields(message = message.to_string()))]
557 async fn add_message(&self, mut message: ChatMessage) -> Result<()> {
558 for hook in self.hooks_by_type(HookTypes::OnNewMessage) {
559 if let Hook::OnNewMessage(hook) = hook {
560 let span = tracing::info_span!(
561 "hook",
562 "otel.name" = format!("hook.{}", HookTypes::OnNewMessage)
563 );
564 if let Err(err) = hook(self, &mut message).instrument(span.or_current()).await {
565 tracing::error!(
566 "Error in {hooktype} hook: {err}",
567 hooktype = HookTypes::OnNewMessage,
568 );
569 }
570 }
571 }
572 self.context.add_message(message).await;
573 Ok(())
574 }
575
576 pub fn stop(&mut self) {
578 self.state = state::State::Stopped;
579 }
580
581 pub fn context(&self) -> &dyn AgentContext {
583 &self.context
584 }
585
586 pub fn is_running(&self) -> bool {
588 self.state.is_running()
589 }
590
591 pub fn is_stopped(&self) -> bool {
593 self.state.is_stopped()
594 }
595
596 pub fn is_pending(&self) -> bool {
598 self.state.is_pending()
599 }
600}
601
602#[cfg(test)]
603mod tests {
604
605 use serde::ser::Error;
606 use swiftide_core::chat_completion::errors::ToolError;
607 use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
608 use swiftide_core::test_utils::MockChatCompletion;
609
610 use super::*;
611 use crate::{
612 assistant, chat_request, chat_response, summary, system, tool_failed, tool_output, user,
613 };
614
615 use crate::test_utils::{MockHook, MockTool};
616
617 #[test_log::test(tokio::test)]
618 async fn test_agent_builder_defaults() {
619 let mock_llm = MockChatCompletion::new();
621
622 let agent = Agent::builder().llm(&mock_llm).build().unwrap();
624
625 assert!(agent.find_tool_by_name("stop").is_some());
629
630 let agent = Agent::builder()
632 .tools([Stop::default(), Stop::default()])
633 .llm(&mock_llm)
634 .build()
635 .unwrap();
636
637 assert_eq!(agent.tools.len(), 1);
638
639 let agent = Agent::builder()
641 .tools([MockTool::new("mock_tool")])
642 .llm(&mock_llm)
643 .build()
644 .unwrap();
645
646 assert_eq!(agent.tools.len(), 2);
647 assert!(agent.find_tool_by_name("mock_tool").is_some());
648 assert!(agent.find_tool_by_name("stop").is_some());
649
650 assert!(agent.context().history().await.is_empty());
651 }
652
653 #[test_log::test(tokio::test)]
654 async fn test_agent_tool_calling_loop() {
655 let prompt = "Write a poem";
656 let mock_llm = MockChatCompletion::new();
657 let mock_tool = MockTool::new("mock_tool");
658
659 let chat_request = chat_request! {
660 user!("Write a poem");
661
662 tools = [mock_tool.clone()]
663 };
664
665 let mock_tool_response = chat_response! {
666 "Roses are red";
667 tool_calls = ["mock_tool"]
668
669 };
670
671 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
672
673 let chat_request = chat_request! {
674 user!("Write a poem"),
675 assistant!("Roses are red", ["mock_tool"]),
676 tool_output!("mock_tool", "Great!");
677
678 tools = [mock_tool.clone()]
679 };
680
681 let stop_response = chat_response! {
682 "Roses are red";
683 tool_calls = ["stop"]
684 };
685
686 mock_llm.expect_complete(chat_request, Ok(stop_response));
687 mock_tool.expect_invoke_ok("Great!".into(), None);
688
689 let mut agent = Agent::builder()
690 .tools([mock_tool])
691 .llm(&mock_llm)
692 .no_system_prompt()
693 .build()
694 .unwrap();
695
696 agent.query(prompt).await.unwrap();
697 }
698
699 #[test_log::test(tokio::test)]
700 async fn test_agent_tool_run_once() {
701 let prompt = "Write a poem";
702 let mock_llm = MockChatCompletion::new();
703 let mock_tool = MockTool::default();
704
705 let chat_request = chat_request! {
706 system!("My system prompt"),
707 user!("Write a poem");
708
709 tools = [mock_tool.clone()]
710 };
711
712 let mock_tool_response = chat_response! {
713 "Roses are red";
714 tool_calls = ["mock_tool"]
715
716 };
717
718 mock_tool.expect_invoke_ok("Great!".into(), None);
719 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
720
721 let mut agent = Agent::builder()
722 .tools([mock_tool])
723 .system_prompt("My system prompt")
724 .llm(&mock_llm)
725 .build()
726 .unwrap();
727
728 agent.query_once(prompt).await.unwrap();
729 }
730
731 #[test_log::test(tokio::test(flavor = "multi_thread"))]
732 async fn test_multiple_tool_calls() {
733 let prompt = "Write a poem";
734 let mock_llm = MockChatCompletion::new();
735 let mock_tool = MockTool::new("mock_tool1");
736 let mock_tool2 = MockTool::new("mock_tool2");
737
738 let chat_request = chat_request! {
739 system!("My system prompt"),
740 user!("Write a poem");
741
742
743
744 tools = [mock_tool.clone(), mock_tool2.clone()]
745 };
746
747 let mock_tool_response = chat_response! {
748 "Roses are red";
749
750 tool_calls = ["mock_tool1", "mock_tool2"]
751
752 };
753
754 dbg!(&chat_request);
755 mock_tool.expect_invoke_ok("Great!".into(), None);
756 mock_tool2.expect_invoke_ok("Great!".into(), None);
757 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
758
759 let chat_request = chat_request! {
760 system!("My system prompt"),
761 user!("Write a poem"),
762 assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
763 tool_output!("mock_tool1", "Great!"),
764 tool_output!("mock_tool2", "Great!");
765
766 tools = [mock_tool.clone(), mock_tool2.clone()]
767 };
768
769 let mock_tool_response = chat_response! {
770 "Ok!";
771
772 tool_calls = ["stop"]
773
774 };
775
776 mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
777
778 let mut agent = Agent::builder()
779 .tools([mock_tool, mock_tool2])
780 .system_prompt("My system prompt")
781 .llm(&mock_llm)
782 .build()
783 .unwrap();
784
785 agent.query(prompt).await.unwrap();
786 }
787
788 #[test_log::test(tokio::test)]
789 async fn test_agent_state_machine() {
790 let prompt = "Write a poem";
791 let mock_llm = MockChatCompletion::new();
792
793 let chat_request = chat_request! {
794 user!("Write a poem");
795 tools = []
796 };
797 let mock_tool_response = chat_response! {
798 "Roses are red";
799 tool_calls = []
800 };
801
802 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
803 let mut agent = Agent::builder()
804 .llm(&mock_llm)
805 .no_system_prompt()
806 .build()
807 .unwrap();
808
809 assert!(agent.state.is_pending());
811 agent.query_once(prompt).await.unwrap();
812
813 assert!(agent.state.is_stopped());
815 }
816
817 #[test_log::test(tokio::test)]
818 async fn test_summary() {
819 let prompt = "Write a poem";
820 let mock_llm = MockChatCompletion::new();
821
822 let mock_tool_response = chat_response! {
823 "Roses are red";
824 tool_calls = []
825
826 };
827
828 let expected_chat_request = chat_request! {
829 system!("My system prompt"),
830 user!("Write a poem");
831
832 tools = []
833 };
834
835 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
836
837 let mut agent = Agent::builder()
838 .system_prompt("My system prompt")
839 .llm(&mock_llm)
840 .build()
841 .unwrap();
842
843 agent.query_once(prompt).await.unwrap();
844
845 agent
846 .context
847 .add_message(ChatMessage::new_summary("Summary"))
848 .await;
849
850 let expected_chat_request = chat_request! {
851 system!("My system prompt"),
852 summary!("Summary"),
853 user!("Write another poem");
854 tools = []
855 };
856 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
857
858 agent.query_once("Write another poem").await.unwrap();
859
860 agent
861 .context
862 .add_message(ChatMessage::new_summary("Summary 2"))
863 .await;
864
865 let expected_chat_request = chat_request! {
866 system!("My system prompt"),
867 summary!("Summary 2"),
868 user!("Write a third poem");
869 tools = []
870 };
871 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
872
873 agent.query_once("Write a third poem").await.unwrap();
874 }
875
876 #[test_log::test(tokio::test)]
877 async fn test_agent_hooks() {
878 let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
879 let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
880 let mock_before_completion = MockHook::new("before_completion")
881 .expect_calls(2)
882 .to_owned();
883 let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
884 let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
885 let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
886
887 let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
889 let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
890
891 let prompt = "Write a poem";
892 let mock_llm = MockChatCompletion::new();
893 let mock_tool = MockTool::default();
894
895 let chat_request = chat_request! {
896 user!("Write a poem");
897
898 tools = [mock_tool.clone()]
899 };
900
901 let mock_tool_response = chat_response! {
902 "Roses are red";
903 tool_calls = ["mock_tool"]
904
905 };
906
907 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
908
909 let chat_request = chat_request! {
910 user!("Write a poem"),
911 assistant!("Roses are red", ["mock_tool"]),
912 tool_output!("mock_tool", "Great!");
913
914 tools = [mock_tool.clone()]
915 };
916
917 let stop_response = chat_response! {
918 "Roses are red";
919 tool_calls = ["stop"]
920 };
921
922 mock_llm.expect_complete(chat_request, Ok(stop_response));
923 mock_tool.expect_invoke_ok("Great!".into(), None);
924
925 let mut agent = Agent::builder()
926 .tools([mock_tool])
927 .llm(&mock_llm)
928 .no_system_prompt()
929 .before_all(mock_before_all.hook_fn())
930 .on_start(mock_on_start_fn.on_start_fn())
931 .before_completion(mock_before_completion.before_completion_fn())
932 .before_tool(mock_before_tool.before_tool_fn())
933 .after_completion(mock_after_completion.after_completion_fn())
934 .after_tool(mock_after_tool.after_tool_fn())
935 .after_each(mock_after_each.hook_fn())
936 .on_new_message(mock_on_message.message_hook_fn())
937 .build()
938 .unwrap();
939
940 agent.query(prompt).await.unwrap();
941 }
942
943 #[test_log::test(tokio::test)]
944 async fn test_agent_loop_limit() {
945 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
947 let mock_tool = MockTool::new("mock_tool");
948
949 let chat_request = chat_request! {
950 user!(prompt);
951 tools = [mock_tool.clone()]
952 };
953 mock_tool.expect_invoke_ok("Great!".into(), None);
954
955 let mock_tool_response = chat_response! {
956 "Some response";
957 tool_calls = ["mock_tool"]
958 };
959
960 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
962
963 let stop_response = chat_response! {
965 "Final response";
966 tool_calls = ["stop"]
967 };
968
969 mock_llm.expect_complete(chat_request, Ok(stop_response));
970
971 let mut agent = Agent::builder()
972 .tools([mock_tool])
973 .llm(&mock_llm)
974 .no_system_prompt()
975 .limit(1) .build()
977 .unwrap();
978
979 agent.query(prompt).await.unwrap();
981
982 let remaining = mock_llm.expectations.lock().unwrap().pop();
984 assert!(remaining.is_some());
985
986 assert!(agent.is_stopped());
988 }
989
990 #[test_log::test(tokio::test)]
991 async fn test_tool_retry_mechanism() {
992 let prompt = "Execute my tool";
993 let mock_llm = MockChatCompletion::new();
994 let mock_tool = MockTool::new("retry_tool");
995
996 mock_tool.expect_invoke(
999 Err(ToolError::WrongArguments(serde_json::Error::custom(
1000 "missing `query`",
1001 ))),
1002 None,
1003 );
1004 mock_tool.expect_invoke(
1005 Err(ToolError::WrongArguments(serde_json::Error::custom(
1006 "missing `query`",
1007 ))),
1008 None,
1009 );
1010
1011 let chat_request = chat_request! {
1012 user!(prompt);
1013 tools = [mock_tool.clone()]
1014 };
1015 let retry_response = chat_response! {
1016 "First failing attempt";
1017 tool_calls = ["retry_tool"]
1018 };
1019 mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1020
1021 let chat_request = chat_request! {
1022 user!(prompt),
1023 assistant!("First failing attempt", ["retry_tool"]),
1024 tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1025
1026 tools = [mock_tool.clone()]
1027 };
1028 let will_fail_response = chat_response! {
1029 "Finished execution";
1030 tool_calls = ["retry_tool"]
1031 };
1032 mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1033
1034 let mut agent = Agent::builder()
1035 .tools([mock_tool])
1036 .llm(&mock_llm)
1037 .no_system_prompt()
1038 .tool_retry_limit(1) .build()
1040 .unwrap();
1041
1042 let result = agent.query(prompt).await;
1044
1045 assert!(result.is_err());
1046 assert!(result.unwrap_err().to_string().contains("missing `query`"));
1047 assert!(agent.is_stopped());
1048 }
1049}