1#![allow(dead_code)]
2use crate::{
3 default_context::DefaultContext,
4 errors::AgentError,
5 hooks::{
6 AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7 Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn,
8 },
9 invoke_hooks,
10 state::{self, StopReason},
11 system_prompt::SystemPrompt,
12 tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15 collections::{HashMap, HashSet},
16 hash::{DefaultHasher, Hash as _, Hasher as _},
17 sync::Arc,
18};
19
20use derive_builder::Builder;
21use swiftide_core::{
22 chat_completion::{
23 ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
24 },
25 prompt::Prompt,
26 AgentContext, ToolBox,
27};
28use tracing::{debug, Instrument};
29
30#[derive(Clone, Builder)]
42pub struct Agent {
43 #[builder(default, setter(into))]
45 pub(crate) hooks: Vec<Hook>,
46 #[builder(
48 setter(custom),
49 default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
50 )]
51 pub(crate) context: Arc<dyn AgentContext>,
52 #[builder(default = Agent::default_tools(), setter(custom))]
54 pub(crate) tools: HashSet<Box<dyn Tool>>,
55
56 #[builder(default)]
60 pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
61
62 #[builder(setter(custom))]
64 pub(crate) llm: Box<dyn ChatCompletion>,
65
66 #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
86 pub(crate) system_prompt: Option<Prompt>,
87
88 #[builder(private, default = state::State::default())]
90 pub(crate) state: state::State,
91
92 #[builder(default, setter(strip_option))]
95 pub(crate) limit: Option<usize>,
96
97 #[builder(default = 3)]
109 pub(crate) tool_retry_limit: usize,
110
111 #[builder(private, default)]
114 pub(crate) tool_retries_counter: HashMap<u64, usize>,
115
116 #[builder(private, default)]
118 pub(crate) toolbox_tools: HashSet<Box<dyn Tool>>,
119}
120
121impl std::fmt::Debug for Agent {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 f.debug_struct("Agent")
124 .field(
126 "hooks",
127 &self
128 .hooks
129 .iter()
130 .map(std::string::ToString::to_string)
131 .collect::<Vec<_>>(),
132 )
133 .field(
134 "tools",
135 &self
136 .tools
137 .iter()
138 .map(swiftide_core::Tool::name)
139 .collect::<Vec<_>>(),
140 )
141 .field("llm", &"Box<dyn ChatCompletion>")
142 .field("state", &self.state)
143 .finish()
144 }
145}
146
147impl AgentBuilder {
148 pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
150 where
151 Self: Clone,
152 {
153 self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
154 self
155 }
156
157 pub fn no_system_prompt(&mut self) -> &mut Self {
159 self.system_prompt = Some(None);
160
161 self
162 }
163
164 pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
166 let hooks = self.hooks.get_or_insert_with(Vec::new);
167 hooks.push(hook);
168
169 self
170 }
171
172 pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
175 self.add_hook(Hook::BeforeAll(Box::new(hook)))
176 }
177
178 pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
182 self.add_hook(Hook::OnStart(Box::new(hook)))
183 }
184
185 pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
187 self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
188 }
189
190 pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
196 self.add_hook(Hook::AfterTool(Box::new(hook)))
197 }
198
199 pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
201 self.add_hook(Hook::BeforeTool(Box::new(hook)))
202 }
203
204 pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
206 self.add_hook(Hook::AfterCompletion(Box::new(hook)))
207 }
208
209 pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
212 self.add_hook(Hook::AfterEach(Box::new(hook)))
213 }
214
215 pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
218 self.add_hook(Hook::OnNewMessage(Box::new(hook)))
219 }
220
221 pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
222 self.add_hook(Hook::OnStop(Box::new(hook)))
223 }
224
225 pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
227 let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
228
229 self.llm = Some(boxed);
230 self
231 }
232
233 pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
238 where
239 TOOL: Into<Box<dyn Tool>>,
240 {
241 self.tools = Some(
242 tools
243 .into_iter()
244 .map(Into::into)
245 .chain(Agent::default_tools())
246 .collect(),
247 );
248 self
249 }
250
251 pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
257 self.toolboxes.get_or_insert_with(Vec::new);
258
259 self.toolboxes.as_mut().unwrap().push(Box::new(toolbox));
260 self
261 }
262}
263
264impl Agent {
265 pub fn builder() -> AgentBuilder {
267 AgentBuilder::default()
268 }
269}
270
271impl Agent {
272 fn default_tools() -> HashSet<Box<dyn Tool>> {
274 HashSet::from([Box::new(Stop::default()) as Box<dyn Tool>])
275 }
276
277 #[tracing::instrument(skip_all, name = "agent.query")]
280 pub async fn query(
281 &mut self,
282 query: impl Into<String> + std::fmt::Debug,
283 ) -> Result<(), AgentError> {
284 self.run_agent(Some(query.into()), false).await
285 }
286
287 #[tracing::instrument(skip_all, name = "agent.query_once")]
289 pub async fn query_once(
290 &mut self,
291 query: impl Into<String> + std::fmt::Debug,
292 ) -> Result<(), AgentError> {
293 self.run_agent(Some(query.into()), true).await
294 }
295
296 #[tracing::instrument(skip_all, name = "agent.run")]
299 pub async fn run(&mut self) -> Result<(), AgentError> {
300 self.run_agent(None, false).await
301 }
302
303 #[tracing::instrument(skip_all, name = "agent.run_once")]
306 pub async fn run_once(&mut self) -> Result<(), AgentError> {
307 self.run_agent(None, true).await
308 }
309
310 pub async fn history(&self) -> Vec<ChatMessage> {
312 self.context.history().await
313 }
314
315 async fn run_agent(
316 &mut self,
317 maybe_query: Option<String>,
318 just_once: bool,
319 ) -> Result<(), AgentError> {
320 if self.state.is_running() {
321 return Err(AgentError::AlreadyRunning);
322 }
323
324 if self.state.is_pending() {
325 if let Some(system_prompt) = &self.system_prompt {
326 self.context
327 .add_messages(vec![ChatMessage::System(
328 system_prompt
329 .render()
330 .map_err(AgentError::FailedToRenderSystemPrompt)?,
331 )])
332 .await;
333 }
334
335 invoke_hooks!(BeforeAll, self);
336
337 self.load_toolboxes().await?;
338 }
339
340 invoke_hooks!(OnStart, self);
341
342 self.state = state::State::Running;
343
344 if let Some(query) = maybe_query {
345 self.context.add_message(ChatMessage::User(query)).await;
346 }
347
348 let mut loop_counter = 0;
349
350 while let Some(messages) = self.context.next_completion().await {
351 if let Some(limit) = self.limit {
352 if loop_counter >= limit {
353 tracing::warn!("Agent loop limit reached");
354 break;
355 }
356 }
357 let result = self.run_completions(&messages).await;
358
359 if let Err(err) = result {
360 self.stop_with_error(&err).await;
361 tracing::error!(error = ?err, "Agent stopped with error {err}");
362 return Err(err);
363 }
364
365 if just_once || self.state.is_stopped() {
366 break;
367 }
368 loop_counter += 1;
369 }
370
371 self.stop(StopReason::NoNewMessages).await;
373
374 Ok(())
375 }
376
377 #[tracing::instrument(skip_all, err)]
378 async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<(), AgentError> {
379 debug!(
380 "Running completion for agent with {} messages",
381 messages.len()
382 );
383
384 let mut chat_completion_request = ChatCompletionRequest::builder()
385 .messages(messages)
386 .tools_spec(
387 self.tools
388 .iter()
389 .map(swiftide_core::Tool::tool_spec)
390 .collect::<HashSet<_>>(),
391 )
392 .build()
393 .map_err(AgentError::FailedToBuildRequest)?;
394
395 invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
396
397 debug!(
398 "Calling LLM with the following new messages:\n {}",
399 self.context
400 .current_new_messages()
401 .await
402 .iter()
403 .map(ToString::to_string)
404 .collect::<Vec<_>>()
405 .join(",\n")
406 );
407
408 let mut response = self
409 .llm
410 .complete(&chat_completion_request)
411 .await
412 .map_err(AgentError::CompletionsFailed)?;
413
414 invoke_hooks!(AfterCompletion, self, &mut response);
415
416 self.add_message(ChatMessage::Assistant(
417 response.message,
418 response.tool_calls.clone(),
419 ))
420 .await?;
421
422 if let Some(tool_calls) = response.tool_calls {
423 self.invoke_tools(tool_calls).await?;
424 }
425
426 invoke_hooks!(AfterEach, self);
427
428 Ok(())
429 }
430
431 async fn invoke_tools(&mut self, tool_calls: Vec<ToolCall>) -> Result<(), AgentError> {
432 debug!("LLM returned tool calls: {:?}", tool_calls);
433
434 let mut handles = vec![];
435 for tool_call in tool_calls {
436 let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
437 tracing::warn!("Tool {} not found", tool_call.name());
438 continue;
439 };
440 tracing::info!("Calling tool `{}`", tool_call.name());
441
442 let tool_args = tool_call.args().map(String::from);
443 let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
444
445 invoke_hooks!(BeforeTool, self, &tool_call);
446
447 let tool_span = tracing::info_span!(
448 "tool",
449 "otel.name" = format!("tool.{}", tool.name().as_ref())
450 );
451
452 let handle = tokio::spawn(async move {
453 let tool_args = ArgPreprocessor::preprocess(tool_args.as_deref());
454 let output = tool.invoke(&*context, tool_args.as_deref()).await.map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
455
456 tracing::debug!(output = output.to_string(), args = ?tool_args, tool_name = tool.name().as_ref(), "Completed tool call");
457
458 Ok(output)
459 }.instrument(tool_span.or_current()));
460
461 handles.push((handle, tool_call));
462 }
463
464 for (handle, tool_call) in handles {
465 let mut output = handle.await.map_err(AgentError::ToolFailedToJoin)?;
466
467 invoke_hooks!(AfterTool, self, &tool_call, &mut output);
468
469 if let Err(error) = output {
470 let stop = self.tool_calls_over_limit(&tool_call);
471 if stop {
472 tracing::error!(
473 ?error,
474 "Tool call failed, retry limit reached, stopping agent: {error}",
475 );
476 } else {
477 tracing::warn!(
478 ?error,
479 tool_call = ?tool_call,
480 "Tool call failed, retrying",
481 );
482 }
483 self.add_message(ChatMessage::ToolOutput(
484 tool_call.clone(),
485 ToolOutput::Fail(error.to_string()),
486 ))
487 .await?;
488 if stop {
489 self.stop(StopReason::ToolCallsOverLimit(tool_call)).await;
490 return Err(error.into());
491 }
492 continue;
493 }
494
495 let output = output?;
496 self.handle_control_tools(&tool_call, &output).await;
497 self.add_message(ChatMessage::ToolOutput(tool_call, output))
498 .await?;
499 }
500
501 Ok(())
502 }
503
504 fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
505 self.hooks
506 .iter()
507 .filter(|h| hook_type == (*h).into())
508 .collect()
509 }
510
511 fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
512 self.tools
513 .iter()
514 .find(|tool| tool.name() == tool_name)
515 .cloned()
516 }
517
518 async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
520 if let ToolOutput::Stop = output {
521 tracing::warn!("Stop tool called, stopping agent");
522 self.stop(StopReason::RequestedByTool(tool_call.clone()))
523 .await;
524 }
525 }
526
527 fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
528 let mut s = DefaultHasher::new();
529 tool_call.hash(&mut s);
530 let hash = s.finish();
531
532 if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
533 let val = *retries >= self.tool_retry_limit;
534 *retries += 1;
535 val
536 } else {
537 self.tool_retries_counter.insert(hash, 1);
538 false
539 }
540 }
541
542 #[tracing::instrument(skip_all, fields(message = message.to_string()))]
548 pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
549 invoke_hooks!(OnNewMessage, self, &mut message);
550
551 self.context.add_message(message).await;
552 Ok(())
553 }
554
555 pub async fn stop(&mut self, reason: impl Into<StopReason>) {
557 if self.state.is_stopped() {
558 return;
559 }
560 let reason = reason.into();
561 invoke_hooks!(OnStop, self, reason.clone(), None);
562
563 self.state = state::State::Stopped(reason);
564 }
565
566 pub async fn stop_with_error(&mut self, error: &AgentError) {
567 if self.state.is_stopped() {
568 return;
569 }
570 invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
571
572 self.state = state::State::Stopped(StopReason::Error);
573 }
574
575 pub fn context(&self) -> &dyn AgentContext {
577 &self.context
578 }
579
580 pub fn is_running(&self) -> bool {
582 self.state.is_running()
583 }
584
585 pub fn is_stopped(&self) -> bool {
587 self.state.is_stopped()
588 }
589
590 pub fn is_pending(&self) -> bool {
592 self.state.is_pending()
593 }
594
595 fn tools(&self) -> &HashSet<Box<dyn Tool>> {
597 &self.tools
598 }
599
600 async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
601 for toolbox in &self.toolboxes {
602 let tools = toolbox
603 .available_tools()
604 .await
605 .map_err(AgentError::ToolBoxFailedToLoad)?;
606 self.toolbox_tools.extend(tools);
607 }
608
609 self.tools.extend(self.toolbox_tools.clone());
610
611 Ok(())
612 }
613}
614
615#[cfg(test)]
616mod tests {
617
618 use serde::ser::Error;
619 use swiftide_core::chat_completion::errors::ToolError;
620 use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
621 use swiftide_core::test_utils::MockChatCompletion;
622
623 use super::*;
624 use crate::{
625 assistant, chat_request, chat_response, summary, system, tool_failed, tool_output, user,
626 };
627
628 use crate::test_utils::{MockHook, MockTool};
629
630 #[test_log::test(tokio::test)]
631 async fn test_agent_builder_defaults() {
632 let mock_llm = MockChatCompletion::new();
634
635 let agent = Agent::builder().llm(&mock_llm).build().unwrap();
637
638 assert!(agent.find_tool_by_name("stop").is_some());
642
643 let agent = Agent::builder()
645 .tools([Stop::default(), Stop::default()])
646 .llm(&mock_llm)
647 .build()
648 .unwrap();
649
650 assert_eq!(agent.tools.len(), 1);
651
652 let agent = Agent::builder()
654 .tools([MockTool::new("mock_tool")])
655 .llm(&mock_llm)
656 .build()
657 .unwrap();
658
659 assert_eq!(agent.tools.len(), 2);
660 assert!(agent.find_tool_by_name("mock_tool").is_some());
661 assert!(agent.find_tool_by_name("stop").is_some());
662
663 assert!(agent.context().history().await.is_empty());
664 }
665
666 #[test_log::test(tokio::test)]
667 async fn test_agent_tool_calling_loop() {
668 let prompt = "Write a poem";
669 let mock_llm = MockChatCompletion::new();
670 let mock_tool = MockTool::new("mock_tool");
671
672 let chat_request = chat_request! {
673 user!("Write a poem");
674
675 tools = [mock_tool.clone()]
676 };
677
678 let mock_tool_response = chat_response! {
679 "Roses are red";
680 tool_calls = ["mock_tool"]
681
682 };
683
684 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
685
686 let chat_request = chat_request! {
687 user!("Write a poem"),
688 assistant!("Roses are red", ["mock_tool"]),
689 tool_output!("mock_tool", "Great!");
690
691 tools = [mock_tool.clone()]
692 };
693
694 let stop_response = chat_response! {
695 "Roses are red";
696 tool_calls = ["stop"]
697 };
698
699 mock_llm.expect_complete(chat_request, Ok(stop_response));
700 mock_tool.expect_invoke_ok("Great!".into(), None);
701
702 let mut agent = Agent::builder()
703 .tools([mock_tool])
704 .llm(&mock_llm)
705 .no_system_prompt()
706 .build()
707 .unwrap();
708
709 agent.query(prompt).await.unwrap();
710 }
711
712 #[test_log::test(tokio::test)]
713 async fn test_agent_tool_run_once() {
714 let prompt = "Write a poem";
715 let mock_llm = MockChatCompletion::new();
716 let mock_tool = MockTool::default();
717
718 let chat_request = chat_request! {
719 system!("My system prompt"),
720 user!("Write a poem");
721
722 tools = [mock_tool.clone()]
723 };
724
725 let mock_tool_response = chat_response! {
726 "Roses are red";
727 tool_calls = ["mock_tool"]
728
729 };
730
731 mock_tool.expect_invoke_ok("Great!".into(), None);
732 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
733
734 let mut agent = Agent::builder()
735 .tools([mock_tool])
736 .system_prompt("My system prompt")
737 .llm(&mock_llm)
738 .build()
739 .unwrap();
740
741 agent.query_once(prompt).await.unwrap();
742 }
743
744 #[test_log::test(tokio::test)]
745 async fn test_agent_tool_via_toolbox_run_once() {
746 let prompt = "Write a poem";
747 let mock_llm = MockChatCompletion::new();
748 let mock_tool = MockTool::default();
749
750 let chat_request = chat_request! {
751 system!("My system prompt"),
752 user!("Write a poem");
753
754 tools = [mock_tool.clone()]
755 };
756
757 let mock_tool_response = chat_response! {
758 "Roses are red";
759 tool_calls = ["mock_tool"]
760
761 };
762
763 mock_tool.expect_invoke_ok("Great!".into(), None);
764 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
765
766 let mut agent = Agent::builder()
767 .add_toolbox(vec![mock_tool.boxed()])
768 .system_prompt("My system prompt")
769 .llm(&mock_llm)
770 .build()
771 .unwrap();
772
773 agent.query_once(prompt).await.unwrap();
774 }
775
776 #[test_log::test(tokio::test(flavor = "multi_thread"))]
777 async fn test_multiple_tool_calls() {
778 let prompt = "Write a poem";
779 let mock_llm = MockChatCompletion::new();
780 let mock_tool = MockTool::new("mock_tool1");
781 let mock_tool2 = MockTool::new("mock_tool2");
782
783 let chat_request = chat_request! {
784 system!("My system prompt"),
785 user!("Write a poem");
786
787
788
789 tools = [mock_tool.clone(), mock_tool2.clone()]
790 };
791
792 let mock_tool_response = chat_response! {
793 "Roses are red";
794
795 tool_calls = ["mock_tool1", "mock_tool2"]
796
797 };
798
799 dbg!(&chat_request);
800 mock_tool.expect_invoke_ok("Great!".into(), None);
801 mock_tool2.expect_invoke_ok("Great!".into(), None);
802 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
803
804 let chat_request = chat_request! {
805 system!("My system prompt"),
806 user!("Write a poem"),
807 assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
808 tool_output!("mock_tool1", "Great!"),
809 tool_output!("mock_tool2", "Great!");
810
811 tools = [mock_tool.clone(), mock_tool2.clone()]
812 };
813
814 let mock_tool_response = chat_response! {
815 "Ok!";
816
817 tool_calls = ["stop"]
818
819 };
820
821 mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
822
823 let mut agent = Agent::builder()
824 .tools([mock_tool, mock_tool2])
825 .system_prompt("My system prompt")
826 .llm(&mock_llm)
827 .build()
828 .unwrap();
829
830 agent.query(prompt).await.unwrap();
831 }
832
833 #[test_log::test(tokio::test)]
834 async fn test_agent_state_machine() {
835 let prompt = "Write a poem";
836 let mock_llm = MockChatCompletion::new();
837
838 let chat_request = chat_request! {
839 user!("Write a poem");
840 tools = []
841 };
842 let mock_tool_response = chat_response! {
843 "Roses are red";
844 tool_calls = []
845 };
846
847 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
848 let mut agent = Agent::builder()
849 .llm(&mock_llm)
850 .no_system_prompt()
851 .build()
852 .unwrap();
853
854 assert!(agent.state.is_pending());
856 agent.query_once(prompt).await.unwrap();
857
858 assert!(agent.state.is_stopped());
860 }
861
862 #[test_log::test(tokio::test)]
863 async fn test_summary() {
864 let prompt = "Write a poem";
865 let mock_llm = MockChatCompletion::new();
866
867 let mock_tool_response = chat_response! {
868 "Roses are red";
869 tool_calls = []
870
871 };
872
873 let expected_chat_request = chat_request! {
874 system!("My system prompt"),
875 user!("Write a poem");
876
877 tools = []
878 };
879
880 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
881
882 let mut agent = Agent::builder()
883 .system_prompt("My system prompt")
884 .llm(&mock_llm)
885 .build()
886 .unwrap();
887
888 agent.query_once(prompt).await.unwrap();
889
890 agent
891 .context
892 .add_message(ChatMessage::new_summary("Summary"))
893 .await;
894
895 let expected_chat_request = chat_request! {
896 system!("My system prompt"),
897 summary!("Summary"),
898 user!("Write another poem");
899 tools = []
900 };
901 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
902
903 agent.query_once("Write another poem").await.unwrap();
904
905 agent
906 .context
907 .add_message(ChatMessage::new_summary("Summary 2"))
908 .await;
909
910 let expected_chat_request = chat_request! {
911 system!("My system prompt"),
912 summary!("Summary 2"),
913 user!("Write a third poem");
914 tools = []
915 };
916 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
917
918 agent.query_once("Write a third poem").await.unwrap();
919 }
920
921 #[test_log::test(tokio::test)]
922 async fn test_agent_hooks() {
923 let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
924 let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
925 let mock_before_completion = MockHook::new("before_completion")
926 .expect_calls(2)
927 .to_owned();
928 let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
929 let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
930 let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
931 let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
932
933 let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
935 let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
936
937 let prompt = "Write a poem";
938 let mock_llm = MockChatCompletion::new();
939 let mock_tool = MockTool::default();
940
941 let chat_request = chat_request! {
942 user!("Write a poem");
943
944 tools = [mock_tool.clone()]
945 };
946
947 let mock_tool_response = chat_response! {
948 "Roses are red";
949 tool_calls = ["mock_tool"]
950
951 };
952
953 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
954
955 let chat_request = chat_request! {
956 user!("Write a poem"),
957 assistant!("Roses are red", ["mock_tool"]),
958 tool_output!("mock_tool", "Great!");
959
960 tools = [mock_tool.clone()]
961 };
962
963 let stop_response = chat_response! {
964 "Roses are red";
965 tool_calls = ["stop"]
966 };
967
968 mock_llm.expect_complete(chat_request, Ok(stop_response));
969 mock_tool.expect_invoke_ok("Great!".into(), None);
970
971 let mut agent = Agent::builder()
972 .tools([mock_tool])
973 .llm(&mock_llm)
974 .no_system_prompt()
975 .before_all(mock_before_all.hook_fn())
976 .on_start(mock_on_start_fn.on_start_fn())
977 .before_completion(mock_before_completion.before_completion_fn())
978 .before_tool(mock_before_tool.before_tool_fn())
979 .after_completion(mock_after_completion.after_completion_fn())
980 .after_tool(mock_after_tool.after_tool_fn())
981 .after_each(mock_after_each.hook_fn())
982 .on_new_message(mock_on_message.message_hook_fn())
983 .on_stop(mock_on_stop.stop_hook_fn())
984 .build()
985 .unwrap();
986
987 agent.query(prompt).await.unwrap();
988 }
989
990 #[test_log::test(tokio::test)]
991 async fn test_agent_loop_limit() {
992 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
994 let mock_tool = MockTool::new("mock_tool");
995
996 let chat_request = chat_request! {
997 user!(prompt);
998 tools = [mock_tool.clone()]
999 };
1000 mock_tool.expect_invoke_ok("Great!".into(), None);
1001
1002 let mock_tool_response = chat_response! {
1003 "Some response";
1004 tool_calls = ["mock_tool"]
1005 };
1006
1007 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1009
1010 let stop_response = chat_response! {
1012 "Final response";
1013 tool_calls = ["stop"]
1014 };
1015
1016 mock_llm.expect_complete(chat_request, Ok(stop_response));
1017
1018 let mut agent = Agent::builder()
1019 .tools([mock_tool])
1020 .llm(&mock_llm)
1021 .no_system_prompt()
1022 .limit(1) .build()
1024 .unwrap();
1025
1026 agent.query(prompt).await.unwrap();
1028
1029 let remaining = mock_llm.expectations.lock().unwrap().pop();
1031 assert!(remaining.is_some());
1032
1033 assert!(agent.is_stopped());
1035 }
1036
1037 #[test_log::test(tokio::test)]
1038 async fn test_tool_retry_mechanism() {
1039 let prompt = "Execute my tool";
1040 let mock_llm = MockChatCompletion::new();
1041 let mock_tool = MockTool::new("retry_tool");
1042
1043 mock_tool.expect_invoke(
1046 Err(ToolError::WrongArguments(serde_json::Error::custom(
1047 "missing `query`",
1048 ))),
1049 None,
1050 );
1051 mock_tool.expect_invoke(
1052 Err(ToolError::WrongArguments(serde_json::Error::custom(
1053 "missing `query`",
1054 ))),
1055 None,
1056 );
1057
1058 let chat_request = chat_request! {
1059 user!(prompt);
1060 tools = [mock_tool.clone()]
1061 };
1062 let retry_response = chat_response! {
1063 "First failing attempt";
1064 tool_calls = ["retry_tool"]
1065 };
1066 mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1067
1068 let chat_request = chat_request! {
1069 user!(prompt),
1070 assistant!("First failing attempt", ["retry_tool"]),
1071 tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1072
1073 tools = [mock_tool.clone()]
1074 };
1075 let will_fail_response = chat_response! {
1076 "Finished execution";
1077 tool_calls = ["retry_tool"]
1078 };
1079 mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1080
1081 let mut agent = Agent::builder()
1082 .tools([mock_tool])
1083 .llm(&mock_llm)
1084 .no_system_prompt()
1085 .tool_retry_limit(1) .build()
1087 .unwrap();
1088
1089 let result = agent.query(prompt).await;
1091
1092 assert!(result.is_err());
1093 assert!(result.unwrap_err().to_string().contains("missing `query`"));
1094 assert!(agent.is_stopped());
1095 }
1096}