1#![allow(dead_code)]
2use crate::{
3 default_context::DefaultContext,
4 errors::AgentError,
5 hooks::{
6 AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7 Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn, OnStreamFn,
8 },
9 invoke_hooks,
10 state::{self, StopReason},
11 system_prompt::SystemPrompt,
12 tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15 collections::{HashMap, HashSet},
16 hash::{DefaultHasher, Hash as _, Hasher as _},
17 sync::Arc,
18};
19
20use derive_builder::Builder;
21use futures_util::stream::StreamExt;
22use swiftide_core::{
23 AgentContext, ToolBox,
24 chat_completion::{
25 ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
26 },
27 prompt::Prompt,
28};
29use tracing::{Instrument, debug};
30
31#[derive(Clone, Builder)]
43pub struct Agent {
44 #[builder(default, setter(into))]
46 pub(crate) hooks: Vec<Hook>,
47 #[builder(
49 setter(custom),
50 default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
51 )]
52 pub(crate) context: Arc<dyn AgentContext>,
53 #[builder(default = Agent::default_tools(), setter(custom))]
55 pub(crate) tools: HashSet<Box<dyn Tool>>,
56
57 #[builder(default)]
61 pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
62
63 #[builder(setter(custom))]
65 pub(crate) llm: Box<dyn ChatCompletion>,
66
67 #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
87 pub(crate) system_prompt: Option<Prompt>,
88
89 #[builder(private, default = state::State::default())]
91 pub(crate) state: state::State,
92
93 #[builder(default, setter(strip_option))]
96 pub(crate) limit: Option<usize>,
97
98 #[builder(default = 3)]
110 pub(crate) tool_retry_limit: usize,
111
112 #[builder(default)]
114 pub(crate) streaming: bool,
115
116 #[builder(private, default)]
119 pub(crate) tool_retries_counter: HashMap<u64, usize>,
120
121 #[builder(private, default)]
123 pub(crate) toolbox_tools: HashSet<Box<dyn Tool>>,
124}
125
126impl std::fmt::Debug for Agent {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 f.debug_struct("Agent")
129 .field(
131 "hooks",
132 &self
133 .hooks
134 .iter()
135 .map(std::string::ToString::to_string)
136 .collect::<Vec<_>>(),
137 )
138 .field(
139 "tools",
140 &self
141 .tools
142 .iter()
143 .map(swiftide_core::Tool::name)
144 .collect::<Vec<_>>(),
145 )
146 .field("llm", &"Box<dyn ChatCompletion>")
147 .field("state", &self.state)
148 .finish()
149 }
150}
151
152impl AgentBuilder {
153 pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
155 where
156 Self: Clone,
157 {
158 self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
159 self
160 }
161
162 pub fn no_system_prompt(&mut self) -> &mut Self {
164 self.system_prompt = Some(None);
165
166 self
167 }
168
169 pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
171 let hooks = self.hooks.get_or_insert_with(Vec::new);
172 hooks.push(hook);
173
174 self
175 }
176
177 pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
180 self.add_hook(Hook::BeforeAll(Box::new(hook)))
181 }
182
183 pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
187 self.add_hook(Hook::OnStart(Box::new(hook)))
188 }
189
190 pub fn on_stream(&mut self, hook: impl OnStreamFn + 'static) -> &mut Self {
197 self.streaming = Some(true);
198 self.add_hook(Hook::OnStream(Box::new(hook)))
199 }
200
201 pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
203 self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
204 }
205
206 pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
212 self.add_hook(Hook::AfterTool(Box::new(hook)))
213 }
214
215 pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
217 self.add_hook(Hook::BeforeTool(Box::new(hook)))
218 }
219
220 pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
222 self.add_hook(Hook::AfterCompletion(Box::new(hook)))
223 }
224
225 pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
228 self.add_hook(Hook::AfterEach(Box::new(hook)))
229 }
230
231 pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
234 self.add_hook(Hook::OnNewMessage(Box::new(hook)))
235 }
236
237 pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
238 self.add_hook(Hook::OnStop(Box::new(hook)))
239 }
240
241 pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
243 let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
244
245 self.llm = Some(boxed);
246 self
247 }
248
249 pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
254 where
255 TOOL: Into<Box<dyn Tool>>,
256 {
257 self.tools = Some(
258 tools
259 .into_iter()
260 .map(Into::into)
261 .chain(Agent::default_tools())
262 .collect(),
263 );
264 self
265 }
266
267 pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
273 let toolboxes = self.toolboxes.get_or_insert_with(Vec::new);
274 toolboxes.push(Box::new(toolbox));
275
276 self
277 }
278}
279
280impl Agent {
281 pub fn builder() -> AgentBuilder {
283 AgentBuilder::default()
284 }
285}
286
287impl Agent {
288 fn default_tools() -> HashSet<Box<dyn Tool>> {
290 HashSet::from([Box::new(Stop::default()) as Box<dyn Tool>])
291 }
292
293 #[tracing::instrument(skip_all, name = "agent.query")]
296 pub async fn query(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
297 let query = query
298 .into()
299 .render()
300 .map_err(AgentError::FailedToRenderPrompt)?;
301 self.run_agent(Some(query), false).await
302 }
303
304 #[tracing::instrument(skip_all, name = "agent.query_once")]
306 pub async fn query_once(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
307 let query = query
308 .into()
309 .render()
310 .map_err(AgentError::FailedToRenderPrompt)?;
311 self.run_agent(Some(query), true).await
312 }
313
314 #[tracing::instrument(skip_all, name = "agent.run")]
317 pub async fn run(&mut self) -> Result<(), AgentError> {
318 self.run_agent(None, false).await
319 }
320
321 #[tracing::instrument(skip_all, name = "agent.run_once")]
324 pub async fn run_once(&mut self) -> Result<(), AgentError> {
325 self.run_agent(None, true).await
326 }
327
328 pub async fn history(&self) -> Vec<ChatMessage> {
330 self.context.history().await
331 }
332
333 async fn run_agent(
334 &mut self,
335 maybe_query: Option<String>,
336 just_once: bool,
337 ) -> Result<(), AgentError> {
338 if self.state.is_running() {
339 return Err(AgentError::AlreadyRunning);
340 }
341
342 if self.state.is_pending() {
343 if let Some(system_prompt) = &self.system_prompt {
344 self.context
345 .add_messages(vec![ChatMessage::System(
346 system_prompt
347 .render()
348 .map_err(AgentError::FailedToRenderSystemPrompt)?,
349 )])
350 .await;
351 }
352
353 invoke_hooks!(BeforeAll, self);
354
355 self.load_toolboxes().await?;
356 }
357
358 invoke_hooks!(OnStart, self);
359
360 self.state = state::State::Running;
361
362 if let Some(query) = maybe_query {
363 self.context.add_message(ChatMessage::User(query)).await;
364 }
365
366 let mut loop_counter = 0;
367
368 while let Some(messages) = self.context.next_completion().await {
369 if let Some(limit) = self.limit {
370 if loop_counter >= limit {
371 tracing::warn!("Agent loop limit reached");
372 break;
373 }
374 }
375 let result = self.run_completions(&messages).await;
376
377 if let Err(err) = result {
378 self.stop_with_error(&err).await;
379 tracing::error!(error = ?err, "Agent stopped with error {err}");
380 return Err(err);
381 }
382
383 if just_once || self.state.is_stopped() {
384 break;
385 }
386 loop_counter += 1;
387 }
388
389 self.stop(StopReason::NoNewMessages).await;
391
392 Ok(())
393 }
394
395 #[tracing::instrument(skip_all, err)]
396 async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<(), AgentError> {
397 debug!(
398 "Running completion for agent with {} messages",
399 messages.len()
400 );
401
402 let mut chat_completion_request = ChatCompletionRequest::builder()
403 .messages(messages)
404 .tools_spec(
405 self.tools
406 .iter()
407 .map(swiftide_core::Tool::tool_spec)
408 .collect::<HashSet<_>>(),
409 )
410 .build()
411 .map_err(AgentError::FailedToBuildRequest)?;
412
413 invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
414
415 debug!(
416 "Calling LLM with the following new messages:\n {}",
417 self.context
418 .current_new_messages()
419 .await
420 .iter()
421 .map(ToString::to_string)
422 .collect::<Vec<_>>()
423 .join(",\n")
424 );
425
426 let mut response = if self.streaming {
427 let mut last_response = None;
428 let mut stream = self.llm.complete_stream(&chat_completion_request).await;
429
430 while let Some(response) = stream.next().await {
431 let response = response.map_err(AgentError::CompletionsFailed)?;
432 invoke_hooks!(OnStream, self, &response);
433 last_response = Some(response);
434 }
435 tracing::trace!(?last_response, "Streaming completed");
436 last_response.ok_or(AgentError::EmptyStream)
437 } else {
438 self.llm
439 .complete(&chat_completion_request)
440 .await
441 .map_err(AgentError::CompletionsFailed)
442 }?;
443
444 invoke_hooks!(AfterCompletion, self, &mut response);
445
446 self.add_message(ChatMessage::Assistant(
447 response.message,
448 response.tool_calls.clone(),
449 ))
450 .await?;
451
452 if let Some(tool_calls) = response.tool_calls {
453 self.invoke_tools(tool_calls).await?;
454 }
455
456 invoke_hooks!(AfterEach, self);
457
458 Ok(())
459 }
460
461 async fn invoke_tools(&mut self, tool_calls: Vec<ToolCall>) -> Result<(), AgentError> {
462 debug!("LLM returned tool calls: {:?}", tool_calls);
463
464 let mut handles = vec![];
465 for tool_call in tool_calls {
466 let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
467 tracing::warn!("Tool {} not found", tool_call.name());
468 continue;
469 };
470 tracing::info!("Calling tool `{}`", tool_call.name());
471
472 let tool_args = tool_call.args().map(String::from);
473 let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
474
475 invoke_hooks!(BeforeTool, self, &tool_call);
476
477 let tool_span = tracing::info_span!(
478 "tool",
479 "otel.name" = format!("tool.{}", tool.name().as_ref())
480 );
481
482 let handle = tokio::spawn(async move {
483 let tool_args = ArgPreprocessor::preprocess(tool_args.as_deref());
484 let output = tool.invoke(&*context, tool_args.as_deref()).await.map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
485
486 tracing::debug!(output = output.to_string(), args = ?tool_args, tool_name = tool.name().as_ref(), "Completed tool call");
487
488 Ok(output)
489 }.instrument(tool_span.or_current()));
490
491 handles.push((handle, tool_call));
492 }
493
494 for (handle, tool_call) in handles {
495 let mut output = handle.await.map_err(AgentError::ToolFailedToJoin)?;
496
497 invoke_hooks!(AfterTool, self, &tool_call, &mut output);
498
499 if let Err(error) = output {
500 let stop = self.tool_calls_over_limit(&tool_call);
501 if stop {
502 tracing::error!(
503 ?error,
504 "Tool call failed, retry limit reached, stopping agent: {error}",
505 );
506 } else {
507 tracing::warn!(
508 ?error,
509 tool_call = ?tool_call,
510 "Tool call failed, retrying",
511 );
512 }
513 self.add_message(ChatMessage::ToolOutput(
514 tool_call.clone(),
515 ToolOutput::Fail(error.to_string()),
516 ))
517 .await?;
518 if stop {
519 self.stop(StopReason::ToolCallsOverLimit(tool_call)).await;
520 return Err(error.into());
521 }
522 continue;
523 }
524
525 let output = output?;
526 self.handle_control_tools(&tool_call, &output).await;
527 self.add_message(ChatMessage::ToolOutput(tool_call, output))
528 .await?;
529 }
530
531 Ok(())
532 }
533
534 fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
535 self.hooks
536 .iter()
537 .filter(|h| hook_type == (*h).into())
538 .collect()
539 }
540
541 fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
542 self.tools
543 .iter()
544 .find(|tool| tool.name() == tool_name)
545 .cloned()
546 }
547
548 async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
550 if let ToolOutput::Stop = output {
551 tracing::warn!("Stop tool called, stopping agent");
552 self.stop(StopReason::RequestedByTool(tool_call.clone()))
553 .await;
554 }
555 }
556
557 fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
558 let mut s = DefaultHasher::new();
559 tool_call.hash(&mut s);
560 let hash = s.finish();
561
562 if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
563 let val = *retries >= self.tool_retry_limit;
564 *retries += 1;
565 val
566 } else {
567 self.tool_retries_counter.insert(hash, 1);
568 false
569 }
570 }
571
572 #[tracing::instrument(skip_all, fields(message = message.to_string()))]
578 pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
579 invoke_hooks!(OnNewMessage, self, &mut message);
580
581 self.context.add_message(message).await;
582 Ok(())
583 }
584
585 pub async fn stop(&mut self, reason: impl Into<StopReason>) {
587 if self.state.is_stopped() {
588 return;
589 }
590 let reason = reason.into();
591 invoke_hooks!(OnStop, self, reason.clone(), None);
592
593 self.state = state::State::Stopped(reason);
594 }
595
596 pub async fn stop_with_error(&mut self, error: &AgentError) {
597 if self.state.is_stopped() {
598 return;
599 }
600 invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
601
602 self.state = state::State::Stopped(StopReason::Error);
603 }
604
605 pub fn context(&self) -> &dyn AgentContext {
607 &self.context
608 }
609
610 pub fn is_running(&self) -> bool {
612 self.state.is_running()
613 }
614
615 pub fn is_stopped(&self) -> bool {
617 self.state.is_stopped()
618 }
619
620 pub fn is_pending(&self) -> bool {
622 self.state.is_pending()
623 }
624
625 pub fn tools(&self) -> &HashSet<Box<dyn Tool>> {
627 &self.tools
628 }
629
630 async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
631 for toolbox in &self.toolboxes {
632 let tools = toolbox
633 .available_tools()
634 .await
635 .map_err(AgentError::ToolBoxFailedToLoad)?;
636 self.toolbox_tools.extend(tools);
637 }
638
639 self.tools.extend(self.toolbox_tools.clone());
640
641 Ok(())
642 }
643}
644
645#[cfg(test)]
646mod tests {
647
648 use serde::ser::Error;
649 use swiftide_core::chat_completion::errors::ToolError;
650 use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
651 use swiftide_core::test_utils::MockChatCompletion;
652
653 use super::*;
654 use crate::{
655 assistant, chat_request, chat_response, summary, system, tool_failed, tool_output, user,
656 };
657
658 use crate::test_utils::{MockHook, MockTool};
659
660 #[test_log::test(tokio::test)]
661 async fn test_agent_builder_defaults() {
662 let mock_llm = MockChatCompletion::new();
664
665 let agent = Agent::builder().llm(&mock_llm).build().unwrap();
667
668 assert!(agent.find_tool_by_name("stop").is_some());
672
673 let agent = Agent::builder()
675 .tools([Stop::default(), Stop::default()])
676 .llm(&mock_llm)
677 .build()
678 .unwrap();
679
680 assert_eq!(agent.tools.len(), 1);
681
682 let agent = Agent::builder()
684 .tools([MockTool::new("mock_tool")])
685 .llm(&mock_llm)
686 .build()
687 .unwrap();
688
689 assert_eq!(agent.tools.len(), 2);
690 assert!(agent.find_tool_by_name("mock_tool").is_some());
691 assert!(agent.find_tool_by_name("stop").is_some());
692
693 assert!(agent.context().history().await.is_empty());
694 }
695
696 #[test_log::test(tokio::test)]
697 async fn test_agent_tool_calling_loop() {
698 let prompt = "Write a poem";
699 let mock_llm = MockChatCompletion::new();
700 let mock_tool = MockTool::new("mock_tool");
701
702 let chat_request = chat_request! {
703 user!("Write a poem");
704
705 tools = [mock_tool.clone()]
706 };
707
708 let mock_tool_response = chat_response! {
709 "Roses are red";
710 tool_calls = ["mock_tool"]
711
712 };
713
714 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
715
716 let chat_request = chat_request! {
717 user!("Write a poem"),
718 assistant!("Roses are red", ["mock_tool"]),
719 tool_output!("mock_tool", "Great!");
720
721 tools = [mock_tool.clone()]
722 };
723
724 let stop_response = chat_response! {
725 "Roses are red";
726 tool_calls = ["stop"]
727 };
728
729 mock_llm.expect_complete(chat_request, Ok(stop_response));
730 mock_tool.expect_invoke_ok("Great!".into(), None);
731
732 let mut agent = Agent::builder()
733 .tools([mock_tool])
734 .llm(&mock_llm)
735 .no_system_prompt()
736 .build()
737 .unwrap();
738
739 agent.query(prompt).await.unwrap();
740 }
741
742 #[test_log::test(tokio::test)]
743 async fn test_agent_tool_run_once() {
744 let prompt = "Write a poem";
745 let mock_llm = MockChatCompletion::new();
746 let mock_tool = MockTool::default();
747
748 let chat_request = chat_request! {
749 system!("My system prompt"),
750 user!("Write a poem");
751
752 tools = [mock_tool.clone()]
753 };
754
755 let mock_tool_response = chat_response! {
756 "Roses are red";
757 tool_calls = ["mock_tool"]
758
759 };
760
761 mock_tool.expect_invoke_ok("Great!".into(), None);
762 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
763
764 let mut agent = Agent::builder()
765 .tools([mock_tool])
766 .system_prompt("My system prompt")
767 .llm(&mock_llm)
768 .build()
769 .unwrap();
770
771 agent.query_once(prompt).await.unwrap();
772 }
773
774 #[test_log::test(tokio::test)]
775 async fn test_agent_tool_via_toolbox_run_once() {
776 let prompt = "Write a poem";
777 let mock_llm = MockChatCompletion::new();
778 let mock_tool = MockTool::default();
779
780 let chat_request = chat_request! {
781 system!("My system prompt"),
782 user!("Write a poem");
783
784 tools = [mock_tool.clone()]
785 };
786
787 let mock_tool_response = chat_response! {
788 "Roses are red";
789 tool_calls = ["mock_tool"]
790
791 };
792
793 mock_tool.expect_invoke_ok("Great!".into(), None);
794 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
795
796 let mut agent = Agent::builder()
797 .add_toolbox(vec![mock_tool.boxed()])
798 .system_prompt("My system prompt")
799 .llm(&mock_llm)
800 .build()
801 .unwrap();
802
803 agent.query_once(prompt).await.unwrap();
804 }
805
806 #[test_log::test(tokio::test(flavor = "multi_thread"))]
807 async fn test_multiple_tool_calls() {
808 let prompt = "Write a poem";
809 let mock_llm = MockChatCompletion::new();
810 let mock_tool = MockTool::new("mock_tool1");
811 let mock_tool2 = MockTool::new("mock_tool2");
812
813 let chat_request = chat_request! {
814 system!("My system prompt"),
815 user!("Write a poem");
816
817
818
819 tools = [mock_tool.clone(), mock_tool2.clone()]
820 };
821
822 let mock_tool_response = chat_response! {
823 "Roses are red";
824
825 tool_calls = ["mock_tool1", "mock_tool2"]
826
827 };
828
829 dbg!(&chat_request);
830 mock_tool.expect_invoke_ok("Great!".into(), None);
831 mock_tool2.expect_invoke_ok("Great!".into(), None);
832 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
833
834 let chat_request = chat_request! {
835 system!("My system prompt"),
836 user!("Write a poem"),
837 assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
838 tool_output!("mock_tool1", "Great!"),
839 tool_output!("mock_tool2", "Great!");
840
841 tools = [mock_tool.clone(), mock_tool2.clone()]
842 };
843
844 let mock_tool_response = chat_response! {
845 "Ok!";
846
847 tool_calls = ["stop"]
848
849 };
850
851 mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
852
853 let mut agent = Agent::builder()
854 .tools([mock_tool, mock_tool2])
855 .system_prompt("My system prompt")
856 .llm(&mock_llm)
857 .build()
858 .unwrap();
859
860 agent.query(prompt).await.unwrap();
861 }
862
863 #[test_log::test(tokio::test)]
864 async fn test_agent_state_machine() {
865 let prompt = "Write a poem";
866 let mock_llm = MockChatCompletion::new();
867
868 let chat_request = chat_request! {
869 user!("Write a poem");
870 tools = []
871 };
872 let mock_tool_response = chat_response! {
873 "Roses are red";
874 tool_calls = []
875 };
876
877 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
878 let mut agent = Agent::builder()
879 .llm(&mock_llm)
880 .no_system_prompt()
881 .build()
882 .unwrap();
883
884 assert!(agent.state.is_pending());
886 agent.query_once(prompt).await.unwrap();
887
888 assert!(agent.state.is_stopped());
890 }
891
892 #[test_log::test(tokio::test)]
893 async fn test_summary() {
894 let prompt = "Write a poem";
895 let mock_llm = MockChatCompletion::new();
896
897 let mock_tool_response = chat_response! {
898 "Roses are red";
899 tool_calls = []
900
901 };
902
903 let expected_chat_request = chat_request! {
904 system!("My system prompt"),
905 user!("Write a poem");
906
907 tools = []
908 };
909
910 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
911
912 let mut agent = Agent::builder()
913 .system_prompt("My system prompt")
914 .llm(&mock_llm)
915 .build()
916 .unwrap();
917
918 agent.query_once(prompt).await.unwrap();
919
920 agent
921 .context
922 .add_message(ChatMessage::new_summary("Summary"))
923 .await;
924
925 let expected_chat_request = chat_request! {
926 system!("My system prompt"),
927 summary!("Summary"),
928 user!("Write another poem");
929 tools = []
930 };
931 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
932
933 agent.query_once("Write another poem").await.unwrap();
934
935 agent
936 .context
937 .add_message(ChatMessage::new_summary("Summary 2"))
938 .await;
939
940 let expected_chat_request = chat_request! {
941 system!("My system prompt"),
942 summary!("Summary 2"),
943 user!("Write a third poem");
944 tools = []
945 };
946 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
947
948 agent.query_once("Write a third poem").await.unwrap();
949 }
950
951 #[test_log::test(tokio::test)]
952 async fn test_agent_hooks() {
953 let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
954 let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
955 let mock_before_completion = MockHook::new("before_completion")
956 .expect_calls(2)
957 .to_owned();
958 let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
959 let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
960 let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
961 let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
962
963 let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
965 let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
966
967 let prompt = "Write a poem";
968 let mock_llm = MockChatCompletion::new();
969 let mock_tool = MockTool::default();
970
971 let chat_request = chat_request! {
972 user!("Write a poem");
973
974 tools = [mock_tool.clone()]
975 };
976
977 let mock_tool_response = chat_response! {
978 "Roses are red";
979 tool_calls = ["mock_tool"]
980
981 };
982
983 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
984
985 let chat_request = chat_request! {
986 user!("Write a poem"),
987 assistant!("Roses are red", ["mock_tool"]),
988 tool_output!("mock_tool", "Great!");
989
990 tools = [mock_tool.clone()]
991 };
992
993 let stop_response = chat_response! {
994 "Roses are red";
995 tool_calls = ["stop"]
996 };
997
998 mock_llm.expect_complete(chat_request, Ok(stop_response));
999 mock_tool.expect_invoke_ok("Great!".into(), None);
1000
1001 let mut agent = Agent::builder()
1002 .tools([mock_tool])
1003 .llm(&mock_llm)
1004 .no_system_prompt()
1005 .before_all(mock_before_all.hook_fn())
1006 .on_start(mock_on_start_fn.on_start_fn())
1007 .before_completion(mock_before_completion.before_completion_fn())
1008 .before_tool(mock_before_tool.before_tool_fn())
1009 .after_completion(mock_after_completion.after_completion_fn())
1010 .after_tool(mock_after_tool.after_tool_fn())
1011 .after_each(mock_after_each.hook_fn())
1012 .on_new_message(mock_on_message.message_hook_fn())
1013 .on_stop(mock_on_stop.stop_hook_fn())
1014 .build()
1015 .unwrap();
1016
1017 agent.query(prompt).await.unwrap();
1018 }
1019
1020 #[test_log::test(tokio::test)]
1021 async fn test_agent_loop_limit() {
1022 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
1024 let mock_tool = MockTool::new("mock_tool");
1025
1026 let chat_request = chat_request! {
1027 user!(prompt);
1028 tools = [mock_tool.clone()]
1029 };
1030 mock_tool.expect_invoke_ok("Great!".into(), None);
1031
1032 let mock_tool_response = chat_response! {
1033 "Some response";
1034 tool_calls = ["mock_tool"]
1035 };
1036
1037 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1039
1040 let stop_response = chat_response! {
1042 "Final response";
1043 tool_calls = ["stop"]
1044 };
1045
1046 mock_llm.expect_complete(chat_request, Ok(stop_response));
1047
1048 let mut agent = Agent::builder()
1049 .tools([mock_tool])
1050 .llm(&mock_llm)
1051 .no_system_prompt()
1052 .limit(1) .build()
1054 .unwrap();
1055
1056 agent.query(prompt).await.unwrap();
1058
1059 let remaining = mock_llm.expectations.lock().unwrap().pop();
1061 assert!(remaining.is_some());
1062
1063 assert!(agent.is_stopped());
1065 }
1066
1067 #[test_log::test(tokio::test)]
1068 async fn test_tool_retry_mechanism() {
1069 let prompt = "Execute my tool";
1070 let mock_llm = MockChatCompletion::new();
1071 let mock_tool = MockTool::new("retry_tool");
1072
1073 mock_tool.expect_invoke(
1076 Err(ToolError::WrongArguments(serde_json::Error::custom(
1077 "missing `query`",
1078 ))),
1079 None,
1080 );
1081 mock_tool.expect_invoke(
1082 Err(ToolError::WrongArguments(serde_json::Error::custom(
1083 "missing `query`",
1084 ))),
1085 None,
1086 );
1087
1088 let chat_request = chat_request! {
1089 user!(prompt);
1090 tools = [mock_tool.clone()]
1091 };
1092 let retry_response = chat_response! {
1093 "First failing attempt";
1094 tool_calls = ["retry_tool"]
1095 };
1096 mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1097
1098 let chat_request = chat_request! {
1099 user!(prompt),
1100 assistant!("First failing attempt", ["retry_tool"]),
1101 tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1102
1103 tools = [mock_tool.clone()]
1104 };
1105 let will_fail_response = chat_response! {
1106 "Finished execution";
1107 tool_calls = ["retry_tool"]
1108 };
1109 mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1110
1111 let mut agent = Agent::builder()
1112 .tools([mock_tool])
1113 .llm(&mock_llm)
1114 .no_system_prompt()
1115 .tool_retry_limit(1) .build()
1117 .unwrap();
1118
1119 let result = agent.query(prompt).await;
1121
1122 assert!(result.is_err());
1123 assert!(result.unwrap_err().to_string().contains("missing `query`"));
1124 assert!(agent.is_stopped());
1125 }
1126
1127 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1128 async fn test_streaming() {
1129 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
1131 let on_stream_fn = MockHook::new("on_stream").expect_calls(3).to_owned();
1132
1133 let chat_request = chat_request! {
1134 user!(prompt);
1135
1136 tools = []
1137 };
1138
1139 let response = chat_response! {
1140 "one two three";
1141 tool_calls = ["stop"]
1142 };
1143
1144 mock_llm.expect_complete(chat_request, Ok(response));
1146
1147 let mut agent = Agent::builder()
1148 .llm(&mock_llm)
1149 .on_stream(on_stream_fn.on_stream_fn())
1150 .no_system_prompt()
1151 .build()
1152 .unwrap();
1153
1154 agent.query(prompt).await.unwrap();
1156
1157 tracing::debug!("Agent finished running");
1158
1159 assert!(agent.is_stopped());
1161 }
1162
1163 #[test_log::test(tokio::test)]
1164 async fn test_recovering_agent_existing_history() {
1165 let prompt = "Write a poem";
1167 let mock_llm = MockChatCompletion::new();
1168 let mock_tool = MockTool::new("mock_tool");
1169
1170 let chat_request = chat_request! {
1171 user!("Write a poem");
1172
1173 tools = [mock_tool.clone()]
1174 };
1175
1176 let mock_tool_response = chat_response! {
1177 "Roses are red";
1178 tool_calls = ["mock_tool"]
1179
1180 };
1181
1182 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1183
1184 let chat_request = chat_request! {
1185 user!("Write a poem"),
1186 assistant!("Roses are red", ["mock_tool"]),
1187 tool_output!("mock_tool", "Great!");
1188
1189 tools = [mock_tool.clone()]
1190 };
1191
1192 let stop_response = chat_response! {
1193 "Roses are red";
1194 tool_calls = ["stop"]
1195 };
1196
1197 mock_llm.expect_complete(chat_request, Ok(stop_response));
1198 mock_tool.expect_invoke_ok("Great!".into(), None);
1199
1200 let mut agent = Agent::builder()
1201 .tools([mock_tool.clone()])
1202 .llm(&mock_llm)
1203 .no_system_prompt()
1204 .build()
1205 .unwrap();
1206
1207 agent.query(prompt).await.unwrap();
1208
1209 let history = agent.history().await;
1211
1212 let serialized = serde_json::to_string(&history).unwrap();
1214
1215 let history: Vec<ChatMessage> = serde_json::from_str(&serialized).unwrap();
1217
1218 let context = DefaultContext::default()
1220 .with_message_history(history)
1221 .to_owned();
1222
1223 let expected_chat_request = chat_request! {
1224 user!("Write a poem"),
1225 assistant!("Roses are red", ["mock_tool"]),
1226 tool_output!("mock_tool", "Great!"),
1227 assistant!("Roses are red", ["stop"]),
1228 tool_output!("stop", ToolOutput::Stop),
1229 user!("Try again!");
1230
1231 tools = [mock_tool.clone()]
1232 };
1233
1234 let stop_response = chat_response! {
1235 "Really stopping now";
1236 tool_calls = ["stop"]
1237 };
1238
1239 mock_llm.expect_complete(expected_chat_request, Ok(stop_response));
1240
1241 let mut agent = Agent::builder()
1242 .context(context)
1243 .tools([mock_tool])
1244 .llm(&mock_llm)
1245 .no_system_prompt()
1246 .build()
1247 .unwrap();
1248
1249 agent.query_once("Try again!").await.unwrap();
1250 }
1251}