1#![allow(dead_code)]
2use crate::{
3 default_context::DefaultContext,
4 hooks::{
5 AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
6 Hook, HookTypes, MessageHookFn, OnStartFn,
7 },
8 state,
9 system_prompt::SystemPrompt,
10 tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
11};
12use std::{
13 collections::{HashMap, HashSet},
14 hash::{DefaultHasher, Hash as _, Hasher as _},
15 sync::Arc,
16};
17
18use anyhow::Result;
19use derive_builder::Builder;
20use swiftide_core::{
21 chat_completion::{
22 ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
23 },
24 prompt::Prompt,
25 AgentContext,
26};
27use tracing::{debug, Instrument};
28
29#[derive(Clone, Builder)]
40pub struct Agent {
41 #[builder(default, setter(into))]
43 pub(crate) hooks: Vec<Hook>,
44 #[builder(
46 setter(custom),
47 default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
48 )]
49 pub(crate) context: Arc<dyn AgentContext>,
50 #[builder(default = Agent::default_tools(), setter(custom))]
52 pub(crate) tools: HashSet<Box<dyn Tool>>,
53
54 #[builder(setter(custom))]
56 pub(crate) llm: Box<dyn ChatCompletion>,
57
58 #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
78 pub(crate) system_prompt: Option<Prompt>,
79
80 #[builder(private, default = state::State::default())]
82 pub(crate) state: state::State,
83
84 #[builder(default, setter(strip_option))]
87 pub(crate) limit: Option<usize>,
88
89 #[builder(default = 3)]
100 pub(crate) tool_retry_limit: usize,
101
102 #[builder(private, default)]
105 pub(crate) tool_retries_counter: HashMap<u64, usize>,
106}
107
108impl std::fmt::Debug for Agent {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 f.debug_struct("Agent")
111 .field(
113 "hooks",
114 &self
115 .hooks
116 .iter()
117 .map(std::string::ToString::to_string)
118 .collect::<Vec<_>>(),
119 )
120 .field(
121 "tools",
122 &self
123 .tools
124 .iter()
125 .map(swiftide_core::Tool::name)
126 .collect::<Vec<_>>(),
127 )
128 .field("llm", &"Box<dyn ChatCompletion>")
129 .field("state", &self.state)
130 .finish()
131 }
132}
133
134impl AgentBuilder {
135 pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
137 where
138 Self: Clone,
139 {
140 self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
141 self
142 }
143
144 pub fn no_system_prompt(&mut self) -> &mut Self {
146 self.system_prompt = Some(None);
147
148 self
149 }
150
151 pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
153 let hooks = self.hooks.get_or_insert_with(Vec::new);
154 hooks.push(hook);
155
156 self
157 }
158
159 pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
162 self.add_hook(Hook::BeforeAll(Box::new(hook)))
163 }
164
165 pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
169 self.add_hook(Hook::OnStart(Box::new(hook)))
170 }
171
172 pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
174 self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
175 }
176
177 pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
183 self.add_hook(Hook::AfterTool(Box::new(hook)))
184 }
185
186 pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
188 self.add_hook(Hook::BeforeTool(Box::new(hook)))
189 }
190
191 pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
193 self.add_hook(Hook::AfterCompletion(Box::new(hook)))
194 }
195
196 pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
199 self.add_hook(Hook::AfterEach(Box::new(hook)))
200 }
201
202 pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
205 self.add_hook(Hook::OnNewMessage(Box::new(hook)))
206 }
207
208 pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
210 let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
211
212 self.llm = Some(boxed);
213 self
214 }
215
216 pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
221 where
222 TOOL: Into<Box<dyn Tool>>,
223 {
224 self.tools = Some(
225 tools
226 .into_iter()
227 .map(Into::into)
228 .chain(Agent::default_tools())
229 .collect(),
230 );
231 self
232 }
233}
234
235impl Agent {
236 pub fn builder() -> AgentBuilder {
238 AgentBuilder::default()
239 }
240}
241
242impl Agent {
243 fn default_tools() -> HashSet<Box<dyn Tool>> {
245 HashSet::from([Box::new(Stop::default()) as Box<dyn Tool>])
246 }
247
248 #[tracing::instrument(skip_all, name = "agent.query")]
251 pub async fn query(&mut self, query: impl Into<String> + std::fmt::Debug) -> Result<()> {
252 self.run_agent(Some(query.into()), false).await
253 }
254
255 #[tracing::instrument(skip_all, name = "agent.query_once")]
257 pub async fn query_once(&mut self, query: impl Into<String> + std::fmt::Debug) -> Result<()> {
258 self.run_agent(Some(query.into()), true).await
259 }
260
261 #[tracing::instrument(skip_all, name = "agent.run")]
264 pub async fn run(&mut self) -> Result<()> {
265 self.run_agent(None, false).await
266 }
267
268 #[tracing::instrument(skip_all, name = "agent.run_once")]
270 pub async fn run_once(&mut self) -> Result<()> {
271 self.run_agent(None, true).await
272 }
273
274 pub async fn history(&self) -> Vec<ChatMessage> {
276 self.context.history().await
277 }
278
279 async fn run_agent(&mut self, maybe_query: Option<String>, just_once: bool) -> Result<()> {
280 if self.state.is_running() {
281 anyhow::bail!("Agent is already running");
282 }
283
284 if self.state.is_pending() {
285 if let Some(system_prompt) = &self.system_prompt {
286 self.context
287 .add_messages(vec![ChatMessage::System(system_prompt.render().await?)])
288 .await;
289 }
290 for hook in self.hooks_by_type(HookTypes::BeforeAll) {
291 if let Hook::BeforeAll(hook) = hook {
292 let span = tracing::info_span!(
293 "hook",
294 "otel.name" = format!("hook.{}", HookTypes::AfterTool)
295 );
296 tracing::info!("Calling {} hook", HookTypes::AfterTool);
297 hook(self).instrument(span.or_current()).await?;
298 }
299 }
300 }
301
302 for hook in self.hooks_by_type(HookTypes::OnStart) {
303 if let Hook::OnStart(hook) = hook {
304 let span = tracing::info_span!(
305 "hook",
306 "otel.name" = format!("hook.{}", HookTypes::OnStart)
307 );
308 tracing::info!("Calling {} hook", HookTypes::OnStart);
309 hook(self).instrument(span.or_current()).await?;
310 }
311 }
312
313 self.state = state::State::Running;
314
315 if let Some(query) = maybe_query {
316 self.context.add_message(ChatMessage::User(query)).await;
317 }
318
319 let mut loop_counter = 0;
320
321 while let Some(messages) = self.context.next_completion().await {
322 if let Some(limit) = self.limit {
323 if loop_counter >= limit {
324 tracing::warn!("Agent loop limit reached");
325 break;
326 }
327 }
328 let result = self.run_completions(&messages).await;
329
330 if let Err(err) = result {
331 self.stop();
332 tracing::error!(error = ?err, "Agent stopped with error {err}");
333 return Err(err);
334 }
335
336 if just_once || self.state.is_stopped() {
337 break;
338 }
339 loop_counter += 1;
340 }
341
342 self.stop();
344
345 Ok(())
346 }
347
348 #[tracing::instrument(skip_all, err)]
349 async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<()> {
350 debug!(
351 "Running completion for agent with {} messages",
352 messages.len()
353 );
354
355 let mut chat_completion_request = ChatCompletionRequest::builder()
356 .messages(messages)
357 .tools_spec(
358 self.tools
359 .iter()
360 .map(swiftide_core::Tool::tool_spec)
361 .collect::<HashSet<_>>(),
362 )
363 .build()?;
364
365 for hook in self.hooks_by_type(HookTypes::BeforeCompletion) {
366 if let Hook::BeforeCompletion(hook) = hook {
367 let span = tracing::info_span!(
368 "hook",
369 "otel.name" = format!("hook.{}", HookTypes::BeforeCompletion)
370 );
371 tracing::info!("Calling {} hook", HookTypes::BeforeCompletion);
372 hook(&*self, &mut chat_completion_request)
373 .instrument(span.or_current())
374 .await?;
375 }
376 }
377
378 debug!(
379 "Calling LLM with the following new messages:\n {}",
380 self.context
381 .current_new_messages()
382 .await
383 .iter()
384 .map(ToString::to_string)
385 .collect::<Vec<_>>()
386 .join(",\n")
387 );
388
389 let mut response = self.llm.complete(&chat_completion_request).await?;
390
391 for hook in self.hooks_by_type(HookTypes::AfterCompletion) {
392 if let Hook::AfterCompletion(hook) = hook {
393 let span = tracing::info_span!(
394 "hook",
395 "otel.name" = format!("hook.{}", HookTypes::AfterCompletion)
396 );
397 tracing::info!("Calling {} hook", HookTypes::AfterCompletion);
398 hook(&*self, &mut response)
399 .instrument(span.or_current())
400 .await?;
401 }
402 }
403 self.add_message(ChatMessage::Assistant(
404 response.message,
405 response.tool_calls.clone(),
406 ))
407 .await?;
408
409 if let Some(tool_calls) = response.tool_calls {
410 self.invoke_tools(tool_calls).await?;
411 };
412
413 for hook in self.hooks_by_type(HookTypes::AfterEach) {
414 if let Hook::AfterEach(hook) = hook {
415 let span = tracing::info_span!(
416 "hook",
417 "otel.name" = format!("hook.{}", HookTypes::AfterEach)
418 );
419 tracing::info!("Calling {} hook", HookTypes::AfterEach);
420 hook(&*self).instrument(span.or_current()).await?;
421 }
422 }
423
424 Ok(())
425 }
426
427 async fn invoke_tools(&mut self, tool_calls: Vec<ToolCall>) -> Result<()> {
428 debug!("LLM returned tool calls: {:?}", tool_calls);
429
430 let mut handles = vec![];
431 for tool_call in tool_calls {
432 let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
433 tracing::warn!("Tool {} not found", tool_call.name());
434 continue;
435 };
436 tracing::info!("Calling tool `{}`", tool_call.name());
437
438 let tool_args = tool_call.args().map(String::from);
439 let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
440
441 for hook in self.hooks_by_type(HookTypes::BeforeTool) {
442 if let Hook::BeforeTool(hook) = hook {
443 let span = tracing::info_span!(
444 "hook",
445 "otel.name" = format!("hook.{}", HookTypes::BeforeTool)
446 );
447 tracing::info!("Calling {} hook", HookTypes::BeforeTool);
448 hook(&*self, &tool_call)
449 .instrument(span.or_current())
450 .await?;
451 }
452 }
453
454 let tool_span =
455 tracing::info_span!("tool", "otel.name" = format!("tool.{}", tool.name()));
456
457 let handle = tokio::spawn(async move {
458 let tool_args = ArgPreprocessor::preprocess(tool_args.as_deref());
459 let output = tool.invoke(&*context, tool_args.as_deref()).await.map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
460
461 tracing::debug!(output = output.to_string(), args = ?tool_args, tool_name = tool.name(), "Completed tool call");
462
463 Ok(output)
464 }.instrument(tool_span.or_current()));
465
466 handles.push((handle, tool_call));
467 }
468
469 for (handle, tool_call) in handles {
470 let mut output = handle.await?;
471
472 for hook in self.hooks_by_type(HookTypes::AfterTool) {
474 if let Hook::AfterTool(hook) = hook {
475 let span = tracing::info_span!(
476 "hook",
477 "otel.name" = format!("hook.{}", HookTypes::AfterTool)
478 );
479 tracing::info!("Calling {} hook", HookTypes::AfterTool);
480 hook(&*self, &tool_call, &mut output)
481 .instrument(span.or_current())
482 .await?;
483 }
484 }
485
486 if let Err(error) = output {
487 if self.tool_calls_over_limit(&tool_call) {
488 tracing::error!(
489 "Tool call failed, retry limit reached, stopping agent: {err}",
490 err = error
491 );
492 self.stop();
493 return Err(error.into());
494 }
495 tracing::warn!(
496 error = error.to_string(),
497 tool_call = ?tool_call,
498 "Tool call failed, retrying",
499 );
500 self.add_message(ChatMessage::ToolOutput(
501 tool_call,
502 ToolOutput::Fail(error.to_string()),
503 ))
504 .await?;
505 continue;
506 }
507
508 let output = output?;
509 self.handle_control_tools(&output);
510 self.add_message(ChatMessage::ToolOutput(tool_call, output))
511 .await?;
512 }
513
514 Ok(())
515 }
516
517 fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
518 self.hooks
519 .iter()
520 .filter(|h| hook_type == (*h).into())
521 .collect()
522 }
523
524 fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
525 self.tools
526 .iter()
527 .find(|tool| tool.name() == tool_name)
528 .cloned()
529 }
530
531 fn handle_control_tools(&mut self, output: &ToolOutput) {
533 if let ToolOutput::Stop = output {
534 tracing::warn!("Stop tool called, stopping agent");
535 self.stop();
536 }
537 }
538
539 fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
540 let mut s = DefaultHasher::new();
541 tool_call.hash(&mut s);
542 let hash = s.finish();
543
544 if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
545 let val = *retries >= self.tool_retry_limit;
546 *retries += 1;
547 val
548 } else {
549 self.tool_retries_counter.insert(hash, 1);
550 false
551 }
552 }
553
554 #[tracing::instrument(skip_all, fields(message = message.to_string()))]
555 async fn add_message(&self, mut message: ChatMessage) -> Result<()> {
556 for hook in self.hooks_by_type(HookTypes::OnNewMessage) {
557 if let Hook::OnNewMessage(hook) = hook {
558 let span = tracing::info_span!(
559 "hook",
560 "otel.name" = format!("hook.{}", HookTypes::OnNewMessage)
561 );
562 if let Err(err) = hook(self, &mut message).instrument(span.or_current()).await {
563 tracing::error!(
564 "Error in {hooktype} hook: {err}",
565 hooktype = HookTypes::OnNewMessage,
566 );
567 }
568 }
569 }
570 self.context.add_message(message).await;
571 Ok(())
572 }
573
574 pub fn stop(&mut self) {
576 self.state = state::State::Stopped;
577 }
578
579 pub fn context(&self) -> &dyn AgentContext {
581 &self.context
582 }
583
584 pub fn is_running(&self) -> bool {
586 self.state.is_running()
587 }
588
589 pub fn is_stopped(&self) -> bool {
591 self.state.is_stopped()
592 }
593
594 pub fn is_pending(&self) -> bool {
596 self.state.is_pending()
597 }
598}
599
600#[cfg(test)]
601mod tests {
602
603 use serde::ser::Error;
604 use swiftide_core::chat_completion::errors::ToolError;
605 use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
606 use swiftide_core::test_utils::MockChatCompletion;
607
608 use super::*;
609 use crate::{
610 assistant, chat_request, chat_response, summary, system, tool_failed, tool_output, user,
611 };
612
613 use crate::test_utils::{MockHook, MockTool};
614
615 #[test_log::test(tokio::test)]
616 async fn test_agent_builder_defaults() {
617 let mock_llm = MockChatCompletion::new();
619
620 let agent = Agent::builder().llm(&mock_llm).build().unwrap();
622
623 assert!(agent.find_tool_by_name("stop").is_some());
627
628 let agent = Agent::builder()
630 .tools([Stop::default(), Stop::default()])
631 .llm(&mock_llm)
632 .build()
633 .unwrap();
634
635 assert_eq!(agent.tools.len(), 1);
636
637 let agent = Agent::builder()
639 .tools([MockTool::new("mock_tool")])
640 .llm(&mock_llm)
641 .build()
642 .unwrap();
643
644 assert_eq!(agent.tools.len(), 2);
645 assert!(agent.find_tool_by_name("mock_tool").is_some());
646 assert!(agent.find_tool_by_name("stop").is_some());
647
648 assert!(agent.context().history().await.is_empty());
649 }
650
651 #[test_log::test(tokio::test)]
652 async fn test_agent_tool_calling_loop() {
653 let prompt = "Write a poem";
654 let mock_llm = MockChatCompletion::new();
655 let mock_tool = MockTool::new("mock_tool");
656
657 let chat_request = chat_request! {
658 user!("Write a poem");
659
660 tools = [mock_tool.clone()]
661 };
662
663 let mock_tool_response = chat_response! {
664 "Roses are red";
665 tool_calls = ["mock_tool"]
666
667 };
668
669 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
670
671 let chat_request = chat_request! {
672 user!("Write a poem"),
673 assistant!("Roses are red", ["mock_tool"]),
674 tool_output!("mock_tool", "Great!");
675
676 tools = [mock_tool.clone()]
677 };
678
679 let stop_response = chat_response! {
680 "Roses are red";
681 tool_calls = ["stop"]
682 };
683
684 mock_llm.expect_complete(chat_request, Ok(stop_response));
685 mock_tool.expect_invoke_ok("Great!".into(), None);
686
687 let mut agent = Agent::builder()
688 .tools([mock_tool])
689 .llm(&mock_llm)
690 .no_system_prompt()
691 .build()
692 .unwrap();
693
694 agent.query(prompt).await.unwrap();
695 }
696
697 #[test_log::test(tokio::test)]
698 async fn test_agent_tool_run_once() {
699 let prompt = "Write a poem";
700 let mock_llm = MockChatCompletion::new();
701 let mock_tool = MockTool::default();
702
703 let chat_request = chat_request! {
704 system!("My system prompt"),
705 user!("Write a poem");
706
707 tools = [mock_tool.clone()]
708 };
709
710 let mock_tool_response = chat_response! {
711 "Roses are red";
712 tool_calls = ["mock_tool"]
713
714 };
715
716 mock_tool.expect_invoke_ok("Great!".into(), None);
717 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
718
719 let mut agent = Agent::builder()
720 .tools([mock_tool])
721 .system_prompt("My system prompt")
722 .llm(&mock_llm)
723 .build()
724 .unwrap();
725
726 agent.query_once(prompt).await.unwrap();
727 }
728
729 #[test_log::test(tokio::test(flavor = "multi_thread"))]
730 async fn test_multiple_tool_calls() {
731 let prompt = "Write a poem";
732 let mock_llm = MockChatCompletion::new();
733 let mock_tool = MockTool::new("mock_tool1");
734 let mock_tool2 = MockTool::new("mock_tool2");
735
736 let chat_request = chat_request! {
737 system!("My system prompt"),
738 user!("Write a poem");
739
740
741
742 tools = [mock_tool.clone(), mock_tool2.clone()]
743 };
744
745 let mock_tool_response = chat_response! {
746 "Roses are red";
747
748 tool_calls = ["mock_tool1", "mock_tool2"]
749
750 };
751
752 dbg!(&chat_request);
753 mock_tool.expect_invoke_ok("Great!".into(), None);
754 mock_tool2.expect_invoke_ok("Great!".into(), None);
755 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
756
757 let chat_request = chat_request! {
758 system!("My system prompt"),
759 user!("Write a poem"),
760 assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
761 tool_output!("mock_tool1", "Great!"),
762 tool_output!("mock_tool2", "Great!");
763
764 tools = [mock_tool.clone(), mock_tool2.clone()]
765 };
766
767 let mock_tool_response = chat_response! {
768 "Ok!";
769
770 tool_calls = ["stop"]
771
772 };
773
774 mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
775
776 let mut agent = Agent::builder()
777 .tools([mock_tool, mock_tool2])
778 .system_prompt("My system prompt")
779 .llm(&mock_llm)
780 .build()
781 .unwrap();
782
783 agent.query(prompt).await.unwrap();
784 }
785
786 #[test_log::test(tokio::test)]
787 async fn test_agent_state_machine() {
788 let prompt = "Write a poem";
789 let mock_llm = MockChatCompletion::new();
790
791 let chat_request = chat_request! {
792 user!("Write a poem");
793 tools = []
794 };
795 let mock_tool_response = chat_response! {
796 "Roses are red";
797 tool_calls = []
798 };
799
800 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
801 let mut agent = Agent::builder()
802 .llm(&mock_llm)
803 .no_system_prompt()
804 .build()
805 .unwrap();
806
807 assert!(agent.state.is_pending());
809 agent.query_once(prompt).await.unwrap();
810
811 assert!(agent.state.is_stopped());
813 }
814
815 #[test_log::test(tokio::test)]
816 async fn test_summary() {
817 let prompt = "Write a poem";
818 let mock_llm = MockChatCompletion::new();
819
820 let mock_tool_response = chat_response! {
821 "Roses are red";
822 tool_calls = []
823
824 };
825
826 let expected_chat_request = chat_request! {
827 system!("My system prompt"),
828 user!("Write a poem");
829
830 tools = []
831 };
832
833 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
834
835 let mut agent = Agent::builder()
836 .system_prompt("My system prompt")
837 .llm(&mock_llm)
838 .build()
839 .unwrap();
840
841 agent.query_once(prompt).await.unwrap();
842
843 agent
844 .context
845 .add_message(ChatMessage::new_summary("Summary"))
846 .await;
847
848 let expected_chat_request = chat_request! {
849 system!("My system prompt"),
850 summary!("Summary"),
851 user!("Write another poem");
852 tools = []
853 };
854 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
855
856 agent.query_once("Write another poem").await.unwrap();
857
858 agent
859 .context
860 .add_message(ChatMessage::new_summary("Summary 2"))
861 .await;
862
863 let expected_chat_request = chat_request! {
864 system!("My system prompt"),
865 summary!("Summary 2"),
866 user!("Write a third poem");
867 tools = []
868 };
869 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
870
871 agent.query_once("Write a third poem").await.unwrap();
872 }
873
874 #[test_log::test(tokio::test)]
875 async fn test_agent_hooks() {
876 let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
877 let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
878 let mock_before_completion = MockHook::new("before_completion")
879 .expect_calls(2)
880 .to_owned();
881 let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
882 let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
883 let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
884
885 let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
887 let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
888
889 let prompt = "Write a poem";
890 let mock_llm = MockChatCompletion::new();
891 let mock_tool = MockTool::default();
892
893 let chat_request = chat_request! {
894 user!("Write a poem");
895
896 tools = [mock_tool.clone()]
897 };
898
899 let mock_tool_response = chat_response! {
900 "Roses are red";
901 tool_calls = ["mock_tool"]
902
903 };
904
905 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
906
907 let chat_request = chat_request! {
908 user!("Write a poem"),
909 assistant!("Roses are red", ["mock_tool"]),
910 tool_output!("mock_tool", "Great!");
911
912 tools = [mock_tool.clone()]
913 };
914
915 let stop_response = chat_response! {
916 "Roses are red";
917 tool_calls = ["stop"]
918 };
919
920 mock_llm.expect_complete(chat_request, Ok(stop_response));
921 mock_tool.expect_invoke_ok("Great!".into(), None);
922
923 let mut agent = Agent::builder()
924 .tools([mock_tool])
925 .llm(&mock_llm)
926 .no_system_prompt()
927 .before_all(mock_before_all.hook_fn())
928 .on_start(mock_on_start_fn.on_start_fn())
929 .before_completion(mock_before_completion.before_completion_fn())
930 .before_tool(mock_before_tool.before_tool_fn())
931 .after_completion(mock_after_completion.after_completion_fn())
932 .after_tool(mock_after_tool.after_tool_fn())
933 .after_each(mock_after_each.hook_fn())
934 .on_new_message(mock_on_message.message_hook_fn())
935 .build()
936 .unwrap();
937
938 agent.query(prompt).await.unwrap();
939 }
940
941 #[test_log::test(tokio::test)]
942 async fn test_agent_loop_limit() {
943 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
945 let mock_tool = MockTool::new("mock_tool");
946
947 let chat_request = chat_request! {
948 user!(prompt);
949 tools = [mock_tool.clone()]
950 };
951 mock_tool.expect_invoke_ok("Great!".into(), None);
952
953 let mock_tool_response = chat_response! {
954 "Some response";
955 tool_calls = ["mock_tool"]
956 };
957
958 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
960
961 let stop_response = chat_response! {
963 "Final response";
964 tool_calls = ["stop"]
965 };
966
967 mock_llm.expect_complete(chat_request, Ok(stop_response));
968
969 let mut agent = Agent::builder()
970 .tools([mock_tool])
971 .llm(&mock_llm)
972 .no_system_prompt()
973 .limit(1) .build()
975 .unwrap();
976
977 agent.query(prompt).await.unwrap();
979
980 let remaining = mock_llm.expectations.lock().unwrap().pop();
982 assert!(remaining.is_some());
983
984 assert!(agent.is_stopped());
986 }
987
988 #[test_log::test(tokio::test)]
989 async fn test_tool_retry_mechanism() {
990 let prompt = "Execute my tool";
991 let mock_llm = MockChatCompletion::new();
992 let mock_tool = MockTool::new("retry_tool");
993
994 mock_tool.expect_invoke(
997 Err(ToolError::WrongArguments(serde_json::Error::custom(
998 "missing `query`",
999 ))),
1000 None,
1001 );
1002 mock_tool.expect_invoke(
1003 Err(ToolError::WrongArguments(serde_json::Error::custom(
1004 "missing `query`",
1005 ))),
1006 None,
1007 );
1008
1009 let chat_request = chat_request! {
1010 user!(prompt);
1011 tools = [mock_tool.clone()]
1012 };
1013 let retry_response = chat_response! {
1014 "First failing attempt";
1015 tool_calls = ["retry_tool"]
1016 };
1017 mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1018
1019 let chat_request = chat_request! {
1020 user!(prompt),
1021 assistant!("First failing attempt", ["retry_tool"]),
1022 tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1023
1024 tools = [mock_tool.clone()]
1025 };
1026 let will_fail_response = chat_response! {
1027 "Finished execution";
1028 tool_calls = ["retry_tool"]
1029 };
1030 mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1031
1032 let mut agent = Agent::builder()
1033 .tools([mock_tool])
1034 .llm(&mock_llm)
1035 .no_system_prompt()
1036 .tool_retry_limit(1) .build()
1038 .unwrap();
1039
1040 let result = agent.query(prompt).await;
1042
1043 assert!(result.is_err());
1044 assert!(result.unwrap_err().to_string().contains("missing `query`"));
1045 assert!(agent.is_stopped());
1046 }
1047}