1#![allow(dead_code)]
2use crate::{
3 default_context::DefaultContext,
4 errors::AgentError,
5 hooks::{
6 AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7 Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn, OnStreamFn,
8 },
9 invoke_hooks,
10 state::{self, StopReason},
11 system_prompt::SystemPrompt,
12 tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15 collections::{HashMap, HashSet},
16 hash::{DefaultHasher, Hash as _, Hasher as _},
17 sync::Arc,
18};
19
20use derive_builder::Builder;
21use futures_util::stream::StreamExt;
22use swiftide_core::{
23 AgentContext, ToolBox,
24 chat_completion::{
25 ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
26 },
27 prompt::Prompt,
28};
29use tracing::{Instrument, debug};
30
31#[derive(Clone, Builder)]
43pub struct Agent {
44 #[builder(default, setter(into))]
46 pub(crate) hooks: Vec<Hook>,
47 #[builder(
49 setter(custom),
50 default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
51 )]
52 pub(crate) context: Arc<dyn AgentContext>,
53 #[builder(default = Agent::default_tools(), setter(custom))]
55 pub(crate) tools: HashSet<Box<dyn Tool>>,
56
57 #[builder(default)]
61 pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
62
63 #[builder(setter(custom))]
65 pub(crate) llm: Box<dyn ChatCompletion>,
66
67 #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
87 pub(crate) system_prompt: Option<Prompt>,
88
89 #[builder(private, default = state::State::default())]
91 pub(crate) state: state::State,
92
93 #[builder(default, setter(strip_option))]
96 pub(crate) limit: Option<usize>,
97
98 #[builder(default = 3)]
110 pub(crate) tool_retry_limit: usize,
111
112 #[builder(default)]
114 pub(crate) streaming: bool,
115
116 #[builder(private, default)]
119 pub(crate) tool_retries_counter: HashMap<u64, usize>,
120
121 #[builder(private, default)]
123 pub(crate) toolbox_tools: HashSet<Box<dyn Tool>>,
124}
125
126impl std::fmt::Debug for Agent {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 f.debug_struct("Agent")
129 .field(
131 "hooks",
132 &self
133 .hooks
134 .iter()
135 .map(std::string::ToString::to_string)
136 .collect::<Vec<_>>(),
137 )
138 .field(
139 "tools",
140 &self
141 .tools
142 .iter()
143 .map(swiftide_core::Tool::name)
144 .collect::<Vec<_>>(),
145 )
146 .field("llm", &"Box<dyn ChatCompletion>")
147 .field("state", &self.state)
148 .finish()
149 }
150}
151
152impl AgentBuilder {
153 pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
155 where
156 Self: Clone,
157 {
158 self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
159 self
160 }
161
162 pub fn no_system_prompt(&mut self) -> &mut Self {
164 self.system_prompt = Some(None);
165
166 self
167 }
168
169 pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
171 let hooks = self.hooks.get_or_insert_with(Vec::new);
172 hooks.push(hook);
173
174 self
175 }
176
177 pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
180 self.add_hook(Hook::BeforeAll(Box::new(hook)))
181 }
182
183 pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
187 self.add_hook(Hook::OnStart(Box::new(hook)))
188 }
189
190 pub fn on_stream(&mut self, hook: impl OnStreamFn + 'static) -> &mut Self {
197 self.streaming = Some(true);
198 self.add_hook(Hook::OnStream(Box::new(hook)))
199 }
200
201 pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
203 self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
204 }
205
206 pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
212 self.add_hook(Hook::AfterTool(Box::new(hook)))
213 }
214
215 pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
217 self.add_hook(Hook::BeforeTool(Box::new(hook)))
218 }
219
220 pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
222 self.add_hook(Hook::AfterCompletion(Box::new(hook)))
223 }
224
225 pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
228 self.add_hook(Hook::AfterEach(Box::new(hook)))
229 }
230
231 pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
234 self.add_hook(Hook::OnNewMessage(Box::new(hook)))
235 }
236
237 pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
238 self.add_hook(Hook::OnStop(Box::new(hook)))
239 }
240
241 pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
243 let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
244
245 self.llm = Some(boxed);
246 self
247 }
248
249 pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
254 where
255 TOOL: Into<Box<dyn Tool>>,
256 {
257 self.tools = Some(
258 tools
259 .into_iter()
260 .map(Into::into)
261 .chain(Agent::default_tools())
262 .collect(),
263 );
264 self
265 }
266
267 pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
273 let toolboxes = self.toolboxes.get_or_insert_with(Vec::new);
274 toolboxes.push(Box::new(toolbox));
275
276 self
277 }
278}
279
280impl Agent {
281 pub fn builder() -> AgentBuilder {
283 AgentBuilder::default()
284 .tools(Agent::default_tools())
285 .to_owned()
286 }
287
288 pub fn default_tools() -> HashSet<Box<dyn Tool>> {
291 HashSet::from([Stop::default().boxed()])
292 }
293
294 #[tracing::instrument(skip_all, name = "agent.query")]
301 pub async fn query(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
302 let query = query
303 .into()
304 .render()
305 .map_err(AgentError::FailedToRenderPrompt)?;
306 self.run_agent(Some(query), false).await
307 }
308
309 #[tracing::instrument(skip_all, name = "agent.query_once")]
315 pub async fn query_once(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
316 let query = query
317 .into()
318 .render()
319 .map_err(AgentError::FailedToRenderPrompt)?;
320 self.run_agent(Some(query), true).await
321 }
322
323 #[tracing::instrument(skip_all, name = "agent.run")]
330 pub async fn run(&mut self) -> Result<(), AgentError> {
331 self.run_agent(None, false).await
332 }
333
334 #[tracing::instrument(skip_all, name = "agent.run_once")]
341 pub async fn run_once(&mut self) -> Result<(), AgentError> {
342 self.run_agent(None, true).await
343 }
344
345 pub async fn history(&self) -> Result<Vec<ChatMessage>, AgentError> {
352 self.context
353 .history()
354 .await
355 .map_err(AgentError::MessageHistoryError)
356 }
357
358 async fn run_agent(
359 &mut self,
360 maybe_query: Option<String>,
361 just_once: bool,
362 ) -> Result<(), AgentError> {
363 if self.state.is_running() {
364 return Err(AgentError::AlreadyRunning);
365 }
366
367 if self.state.is_pending() {
368 if let Some(system_prompt) = &self.system_prompt {
369 self.context
370 .add_messages(vec![ChatMessage::System(
371 system_prompt
372 .render()
373 .map_err(AgentError::FailedToRenderSystemPrompt)?,
374 )])
375 .await
376 .map_err(AgentError::MessageHistoryError)?;
377 }
378
379 invoke_hooks!(BeforeAll, self);
380
381 self.load_toolboxes().await?;
382 }
383
384 invoke_hooks!(OnStart, self);
385
386 self.state = state::State::Running;
387
388 if let Some(query) = maybe_query {
389 self.context
390 .add_message(ChatMessage::User(query))
391 .await
392 .map_err(AgentError::MessageHistoryError)?;
393 }
394
395 let mut loop_counter = 0;
396
397 while let Some(messages) = self
398 .context
399 .next_completion()
400 .await
401 .map_err(AgentError::MessageHistoryError)?
402 {
403 if let Some(limit) = self.limit {
404 if loop_counter >= limit {
405 tracing::warn!("Agent loop limit reached");
406 break;
407 }
408 }
409
410 if let Some(&ChatMessage::Assistant(.., Some(ref tool_calls))) =
413 maybe_tool_call_without_output(&messages)
414 {
415 tracing::debug!("Uncompleted tool calls found; invoking tools");
416 self.invoke_tools(tool_calls).await?;
417 continue;
419 }
420
421 let result = self.run_completions(&messages).await;
422
423 if let Err(err) = result {
424 self.stop_with_error(&err).await;
425 tracing::error!(error = ?err, "Agent stopped with error {err}");
426 return Err(err);
427 }
428
429 if just_once || self.state.is_stopped() {
430 break;
431 }
432 loop_counter += 1;
433 }
434
435 self.stop(StopReason::NoNewMessages).await;
437
438 Ok(())
439 }
440
441 #[tracing::instrument(skip_all, err)]
442 async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<(), AgentError> {
443 debug!(
444 tools = ?self
445 .tools
446 .iter()
447 .map(|t| t.name())
448 .collect::<Vec<_>>()
449 ,
450 "Running completion for agent with {} new messages",
451 messages.len()
452 );
453
454 let mut chat_completion_request = ChatCompletionRequest::builder()
455 .messages(messages)
456 .tools_spec(
457 self.tools
458 .iter()
459 .map(swiftide_core::Tool::tool_spec)
460 .collect::<HashSet<_>>(),
461 )
462 .build()
463 .map_err(AgentError::FailedToBuildRequest)?;
464
465 invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
466
467 debug!(
468 "Calling LLM with the following new messages:\n {}",
469 self.context
470 .current_new_messages()
471 .await
472 .map_err(AgentError::MessageHistoryError)?
473 .iter()
474 .map(ToString::to_string)
475 .collect::<Vec<_>>()
476 .join(",\n")
477 );
478
479 let mut response = if self.streaming {
480 let mut last_response = None;
481 let mut stream = self.llm.complete_stream(&chat_completion_request).await;
482
483 while let Some(response) = stream.next().await {
484 let response = response.map_err(AgentError::CompletionsFailed)?;
485 invoke_hooks!(OnStream, self, &response);
486 last_response = Some(response);
487 }
488 tracing::trace!(?last_response, "Streaming completed");
489 last_response.ok_or(AgentError::EmptyStream)
490 } else {
491 self.llm
492 .complete(&chat_completion_request)
493 .await
494 .map_err(AgentError::CompletionsFailed)
495 }?;
496
497 response
500 .tool_calls
501 .as_deref_mut()
502 .map(ArgPreprocessor::preprocess_tool_calls);
503
504 invoke_hooks!(AfterCompletion, self, &mut response);
505
506 self.add_message(ChatMessage::Assistant(
507 response.message,
508 response.tool_calls.clone(),
509 ))
510 .await?;
511
512 if let Some(tool_calls) = response.tool_calls {
513 self.invoke_tools(&tool_calls).await?;
514 }
515
516 invoke_hooks!(AfterEach, self);
517
518 Ok(())
519 }
520
521 async fn invoke_tools(&mut self, tool_calls: &[ToolCall]) -> Result<(), AgentError> {
522 tracing::debug!("LLM returned tool calls: {:?}", tool_calls);
523
524 let mut handles = vec![];
525 for tool_call in tool_calls {
526 let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
527 tracing::warn!("Tool {} not found", tool_call.name());
528 continue;
529 };
530 tracing::info!("Calling tool `{}`", tool_call.name());
531
532 let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
534
535 invoke_hooks!(BeforeTool, self, &tool_call);
536
537 let tool_span = tracing::info_span!(
538 "tool",
539 "otel.name" = format!("tool.{}", tool.name().as_ref())
540 );
541
542 let handle_tool_call = tool_call.clone();
543 let handle = tokio::spawn(async move {
544 let handle_tool_call = handle_tool_call;
545 let output = tool.invoke(&*context, &handle_tool_call)
546 .await
547 .map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
548
549 tracing::debug!(output = output.to_string(), args = ?handle_tool_call.args(), tool_name = tool.name().as_ref(), "Completed tool call");
550
551 Ok(output)
552 }.instrument(tool_span.or_current()));
553
554 handles.push((handle, tool_call));
555 }
556
557 for (handle, tool_call) in handles {
558 let mut output = handle.await.map_err(AgentError::ToolFailedToJoin)?;
559
560 invoke_hooks!(AfterTool, self, &tool_call, &mut output);
561
562 if let Err(error) = output {
563 let stop = self.tool_calls_over_limit(tool_call);
564 if stop {
565 tracing::error!(
566 ?error,
567 "Tool call failed, retry limit reached, stopping agent: {error}",
568 );
569 } else {
570 tracing::warn!(
571 ?error,
572 tool_call = ?tool_call,
573 "Tool call failed, retrying",
574 );
575 }
576 self.add_message(ChatMessage::ToolOutput(
577 tool_call.clone(),
578 ToolOutput::Fail(error.to_string()),
579 ))
580 .await?;
581 if stop {
582 self.stop(StopReason::ToolCallsOverLimit(tool_call.to_owned()))
583 .await;
584 return Err(error.into());
585 }
586 continue;
587 }
588
589 let output = output?;
590 self.handle_control_tools(tool_call, &output).await;
591
592 if !output.is_feedback_required() {
596 self.add_message(ChatMessage::ToolOutput(tool_call.to_owned(), output))
597 .await?;
598 }
599 }
600
601 Ok(())
602 }
603
604 fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
605 self.hooks
606 .iter()
607 .filter(|h| hook_type == (*h).into())
608 .collect()
609 }
610
611 fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
612 self.tools
613 .iter()
614 .find(|tool| tool.name() == tool_name)
615 .cloned()
616 }
617
618 async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
620 match output {
621 ToolOutput::Stop => {
622 tracing::warn!("Stop tool called, stopping agent");
623 self.stop(StopReason::RequestedByTool(tool_call.clone()))
624 .await;
625 }
626
627 ToolOutput::FeedbackRequired(maybe_payload) => {
628 tracing::warn!("Feedback required, stopping agent");
629 self.stop(StopReason::FeedbackRequired {
630 tool_call: tool_call.clone(),
631 payload: maybe_payload.clone(),
632 })
633 .await;
634 }
635 _ => (),
636 }
637 }
638
639 fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
640 let mut s = DefaultHasher::new();
641 tool_call.hash(&mut s);
642 let hash = s.finish();
643
644 if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
645 let val = *retries >= self.tool_retry_limit;
646 *retries += 1;
647 val
648 } else {
649 self.tool_retries_counter.insert(hash, 1);
650 false
651 }
652 }
653
654 #[tracing::instrument(skip_all, fields(message = message.to_string()))]
665 pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
666 invoke_hooks!(OnNewMessage, self, &mut message);
667
668 self.context
669 .add_message(message)
670 .await
671 .map_err(AgentError::MessageHistoryError)?;
672 Ok(())
673 }
674
675 pub async fn stop(&mut self, reason: impl Into<StopReason>) {
677 if self.state.is_stopped() {
678 return;
679 }
680 let reason = reason.into();
681 invoke_hooks!(OnStop, self, reason.clone(), None);
682
683 self.state = state::State::Stopped(reason);
684 }
685
686 pub async fn stop_with_error(&mut self, error: &AgentError) {
687 if self.state.is_stopped() {
688 return;
689 }
690 invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
691
692 self.state = state::State::Stopped(StopReason::Error);
693 }
694
695 pub fn context(&self) -> &dyn AgentContext {
697 &self.context
698 }
699
700 pub fn is_running(&self) -> bool {
702 self.state.is_running()
703 }
704
705 pub fn is_stopped(&self) -> bool {
707 self.state.is_stopped()
708 }
709
710 pub fn is_pending(&self) -> bool {
712 self.state.is_pending()
713 }
714
715 pub fn tools(&self) -> &HashSet<Box<dyn Tool>> {
717 &self.tools
718 }
719
720 pub fn state(&self) -> &state::State {
721 &self.state
722 }
723
724 pub fn stop_reason(&self) -> Option<&StopReason> {
725 self.state.stop_reason()
726 }
727
728 async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
729 for toolbox in &self.toolboxes {
730 let tools = toolbox
731 .available_tools()
732 .await
733 .map_err(AgentError::ToolBoxFailedToLoad)?;
734 self.toolbox_tools.extend(tools);
735 }
736
737 self.tools.extend(self.toolbox_tools.clone());
738
739 Ok(())
740 }
741}
742
743fn maybe_tool_call_without_output(messages: &[ChatMessage]) -> Option<&ChatMessage> {
746 for message in messages.iter().rev() {
747 if let ChatMessage::ToolOutput(..) = message {
748 return None;
749 }
750
751 if let ChatMessage::Assistant(.., Some(tool_calls)) = message {
752 if !tool_calls.is_empty() {
753 return Some(message);
754 }
755 }
756 }
757
758 None
759}
760
761#[cfg(test)]
762mod tests {
763
764 use serde::ser::Error;
765 use swiftide_core::ToolFeedback;
766 use swiftide_core::chat_completion::errors::ToolError;
767 use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
768 use swiftide_core::test_utils::MockChatCompletion;
769
770 use super::*;
771 use crate::{
772 State, assistant, chat_request, chat_response, summary, system, tool_failed, tool_output,
773 user,
774 };
775
776 use crate::test_utils::{MockHook, MockTool};
777
778 #[test_log::test(tokio::test)]
779 async fn test_agent_builder_defaults() {
780 let mock_llm = MockChatCompletion::new();
782
783 let agent = Agent::builder().llm(&mock_llm).build().unwrap();
785
786 assert!(agent.find_tool_by_name("stop").is_some());
790
791 let agent = Agent::builder()
793 .tools([Stop::default(), Stop::default()])
794 .llm(&mock_llm)
795 .build()
796 .unwrap();
797
798 assert_eq!(agent.tools.len(), 1);
799
800 let agent = Agent::builder()
802 .tools([MockTool::new("mock_tool")])
803 .llm(&mock_llm)
804 .build()
805 .unwrap();
806
807 assert_eq!(agent.tools.len(), 2);
808 assert!(agent.find_tool_by_name("mock_tool").is_some());
809 assert!(agent.find_tool_by_name("stop").is_some());
810
811 assert!(agent.context().history().await.unwrap().is_empty());
812 }
813
814 #[test_log::test(tokio::test)]
815 async fn test_agent_tool_calling_loop() {
816 let prompt = "Write a poem";
817 let mock_llm = MockChatCompletion::new();
818 let mock_tool = MockTool::new("mock_tool");
819
820 let chat_request = chat_request! {
821 user!("Write a poem");
822
823 tools = [mock_tool.clone()]
824 };
825
826 let mock_tool_response = chat_response! {
827 "Roses are red";
828 tool_calls = ["mock_tool"]
829
830 };
831
832 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
833
834 let chat_request = chat_request! {
835 user!("Write a poem"),
836 assistant!("Roses are red", ["mock_tool"]),
837 tool_output!("mock_tool", "Great!");
838
839 tools = [mock_tool.clone()]
840 };
841
842 let stop_response = chat_response! {
843 "Roses are red";
844 tool_calls = ["stop"]
845 };
846
847 mock_llm.expect_complete(chat_request, Ok(stop_response));
848 mock_tool.expect_invoke_ok("Great!".into(), None);
849
850 let mut agent = Agent::builder()
851 .tools([mock_tool])
852 .llm(&mock_llm)
853 .no_system_prompt()
854 .build()
855 .unwrap();
856
857 agent.query(prompt).await.unwrap();
858 }
859
860 #[test_log::test(tokio::test)]
861 async fn test_agent_tool_run_once() {
862 let prompt = "Write a poem";
863 let mock_llm = MockChatCompletion::new();
864 let mock_tool = MockTool::default();
865
866 let chat_request = chat_request! {
867 system!("My system prompt"),
868 user!("Write a poem");
869
870 tools = [mock_tool.clone()]
871 };
872
873 let mock_tool_response = chat_response! {
874 "Roses are red";
875 tool_calls = ["mock_tool"]
876
877 };
878
879 mock_tool.expect_invoke_ok("Great!".into(), None);
880 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
881
882 let mut agent = Agent::builder()
883 .tools([mock_tool])
884 .system_prompt("My system prompt")
885 .llm(&mock_llm)
886 .build()
887 .unwrap();
888
889 agent.query_once(prompt).await.unwrap();
890 }
891
892 #[test_log::test(tokio::test)]
893 async fn test_agent_tool_via_toolbox_run_once() {
894 let prompt = "Write a poem";
895 let mock_llm = MockChatCompletion::new();
896 let mock_tool = MockTool::default();
897
898 let chat_request = chat_request! {
899 system!("My system prompt"),
900 user!("Write a poem");
901
902 tools = [mock_tool.clone()]
903 };
904
905 let mock_tool_response = chat_response! {
906 "Roses are red";
907 tool_calls = ["mock_tool"]
908
909 };
910
911 mock_tool.expect_invoke_ok("Great!".into(), None);
912 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
913
914 let mut agent = Agent::builder()
915 .add_toolbox(vec![mock_tool.boxed()])
916 .system_prompt("My system prompt")
917 .llm(&mock_llm)
918 .build()
919 .unwrap();
920
921 agent.query_once(prompt).await.unwrap();
922 }
923
924 #[test_log::test(tokio::test(flavor = "multi_thread"))]
925 async fn test_multiple_tool_calls() {
926 let prompt = "Write a poem";
927 let mock_llm = MockChatCompletion::new();
928 let mock_tool = MockTool::new("mock_tool1");
929 let mock_tool2 = MockTool::new("mock_tool2");
930
931 let chat_request = chat_request! {
932 system!("My system prompt"),
933 user!("Write a poem");
934
935
936
937 tools = [mock_tool.clone(), mock_tool2.clone()]
938 };
939
940 let mock_tool_response = chat_response! {
941 "Roses are red";
942
943 tool_calls = ["mock_tool1", "mock_tool2"]
944
945 };
946
947 dbg!(&chat_request);
948 mock_tool.expect_invoke_ok("Great!".into(), None);
949 mock_tool2.expect_invoke_ok("Great!".into(), None);
950 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
951
952 let chat_request = chat_request! {
953 system!("My system prompt"),
954 user!("Write a poem"),
955 assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
956 tool_output!("mock_tool1", "Great!"),
957 tool_output!("mock_tool2", "Great!");
958
959 tools = [mock_tool.clone(), mock_tool2.clone()]
960 };
961
962 let mock_tool_response = chat_response! {
963 "Ok!";
964
965 tool_calls = ["stop"]
966
967 };
968
969 mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
970
971 let mut agent = Agent::builder()
972 .tools([mock_tool, mock_tool2])
973 .system_prompt("My system prompt")
974 .llm(&mock_llm)
975 .build()
976 .unwrap();
977
978 agent.query(prompt).await.unwrap();
979 }
980
981 #[test_log::test(tokio::test)]
982 async fn test_agent_state_machine() {
983 let prompt = "Write a poem";
984 let mock_llm = MockChatCompletion::new();
985
986 let chat_request = chat_request! {
987 user!("Write a poem");
988 tools = []
989 };
990 let mock_tool_response = chat_response! {
991 "Roses are red";
992 tool_calls = []
993 };
994
995 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
996 let mut agent = Agent::builder()
997 .llm(&mock_llm)
998 .no_system_prompt()
999 .build()
1000 .unwrap();
1001
1002 assert!(agent.state.is_pending());
1004 agent.query_once(prompt).await.unwrap();
1005
1006 assert!(agent.state.is_stopped());
1008 }
1009
1010 #[test_log::test(tokio::test)]
1011 async fn test_summary() {
1012 let prompt = "Write a poem";
1013 let mock_llm = MockChatCompletion::new();
1014
1015 let mock_tool_response = chat_response! {
1016 "Roses are red";
1017 tool_calls = []
1018
1019 };
1020
1021 let expected_chat_request = chat_request! {
1022 system!("My system prompt"),
1023 user!("Write a poem");
1024
1025 tools = []
1026 };
1027
1028 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1029
1030 let mut agent = Agent::builder()
1031 .system_prompt("My system prompt")
1032 .llm(&mock_llm)
1033 .build()
1034 .unwrap();
1035
1036 agent.query_once(prompt).await.unwrap();
1037
1038 agent
1039 .context
1040 .add_message(ChatMessage::new_summary("Summary"))
1041 .await
1042 .unwrap();
1043
1044 let expected_chat_request = chat_request! {
1045 system!("My system prompt"),
1046 summary!("Summary"),
1047 user!("Write another poem");
1048 tools = []
1049 };
1050 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1051
1052 agent.query_once("Write another poem").await.unwrap();
1053
1054 agent
1055 .context
1056 .add_message(ChatMessage::new_summary("Summary 2"))
1057 .await
1058 .unwrap();
1059
1060 let expected_chat_request = chat_request! {
1061 system!("My system prompt"),
1062 summary!("Summary 2"),
1063 user!("Write a third poem");
1064 tools = []
1065 };
1066 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
1067
1068 agent.query_once("Write a third poem").await.unwrap();
1069 }
1070
1071 #[test_log::test(tokio::test)]
1072 async fn test_agent_hooks() {
1073 let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
1074 let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
1075 let mock_before_completion = MockHook::new("before_completion")
1076 .expect_calls(2)
1077 .to_owned();
1078 let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
1079 let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
1080 let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
1081 let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
1082
1083 let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
1085 let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
1086
1087 let prompt = "Write a poem";
1088 let mock_llm = MockChatCompletion::new();
1089 let mock_tool = MockTool::default();
1090
1091 let chat_request = chat_request! {
1092 user!("Write a poem");
1093
1094 tools = [mock_tool.clone()]
1095 };
1096
1097 let mock_tool_response = chat_response! {
1098 "Roses are red";
1099 tool_calls = ["mock_tool"]
1100
1101 };
1102
1103 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1104
1105 let chat_request = chat_request! {
1106 user!("Write a poem"),
1107 assistant!("Roses are red", ["mock_tool"]),
1108 tool_output!("mock_tool", "Great!");
1109
1110 tools = [mock_tool.clone()]
1111 };
1112
1113 let stop_response = chat_response! {
1114 "Roses are red";
1115 tool_calls = ["stop"]
1116 };
1117
1118 mock_llm.expect_complete(chat_request, Ok(stop_response));
1119 mock_tool.expect_invoke_ok("Great!".into(), None);
1120
1121 let mut agent = Agent::builder()
1122 .tools([mock_tool])
1123 .llm(&mock_llm)
1124 .no_system_prompt()
1125 .before_all(mock_before_all.hook_fn())
1126 .on_start(mock_on_start_fn.on_start_fn())
1127 .before_completion(mock_before_completion.before_completion_fn())
1128 .before_tool(mock_before_tool.before_tool_fn())
1129 .after_completion(mock_after_completion.after_completion_fn())
1130 .after_tool(mock_after_tool.after_tool_fn())
1131 .after_each(mock_after_each.hook_fn())
1132 .on_new_message(mock_on_message.message_hook_fn())
1133 .on_stop(mock_on_stop.stop_hook_fn())
1134 .build()
1135 .unwrap();
1136
1137 agent.query(prompt).await.unwrap();
1138 }
1139
1140 #[test_log::test(tokio::test)]
1141 async fn test_agent_loop_limit() {
1142 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
1144 let mock_tool = MockTool::new("mock_tool");
1145
1146 let chat_request = chat_request! {
1147 user!(prompt);
1148 tools = [mock_tool.clone()]
1149 };
1150 mock_tool.expect_invoke_ok("Great!".into(), None);
1151
1152 let mock_tool_response = chat_response! {
1153 "Some response";
1154 tool_calls = ["mock_tool"]
1155 };
1156
1157 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1159
1160 let stop_response = chat_response! {
1162 "Final response";
1163 tool_calls = ["stop"]
1164 };
1165
1166 mock_llm.expect_complete(chat_request, Ok(stop_response));
1167
1168 let mut agent = Agent::builder()
1169 .tools([mock_tool])
1170 .llm(&mock_llm)
1171 .no_system_prompt()
1172 .limit(1) .build()
1174 .unwrap();
1175
1176 agent.query(prompt).await.unwrap();
1178
1179 let remaining = mock_llm.expectations.lock().unwrap().pop();
1181 assert!(remaining.is_some());
1182
1183 assert!(agent.is_stopped());
1185 }
1186
1187 #[test_log::test(tokio::test)]
1188 async fn test_tool_retry_mechanism() {
1189 let prompt = "Execute my tool";
1190 let mock_llm = MockChatCompletion::new();
1191 let mock_tool = MockTool::new("retry_tool");
1192
1193 mock_tool.expect_invoke(
1196 Err(ToolError::WrongArguments(serde_json::Error::custom(
1197 "missing `query`",
1198 ))),
1199 None,
1200 );
1201 mock_tool.expect_invoke(
1202 Err(ToolError::WrongArguments(serde_json::Error::custom(
1203 "missing `query`",
1204 ))),
1205 None,
1206 );
1207
1208 let chat_request = chat_request! {
1209 user!(prompt);
1210 tools = [mock_tool.clone()]
1211 };
1212 let retry_response = chat_response! {
1213 "First failing attempt";
1214 tool_calls = ["retry_tool"]
1215 };
1216 mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1217
1218 let chat_request = chat_request! {
1219 user!(prompt),
1220 assistant!("First failing attempt", ["retry_tool"]),
1221 tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1222
1223 tools = [mock_tool.clone()]
1224 };
1225 let will_fail_response = chat_response! {
1226 "Finished execution";
1227 tool_calls = ["retry_tool"]
1228 };
1229 mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1230
1231 let mut agent = Agent::builder()
1232 .tools([mock_tool])
1233 .llm(&mock_llm)
1234 .no_system_prompt()
1235 .tool_retry_limit(1) .build()
1237 .unwrap();
1238
1239 let result = agent.query(prompt).await;
1241
1242 assert!(result.is_err());
1243 assert!(result.unwrap_err().to_string().contains("missing `query`"));
1244 assert!(agent.is_stopped());
1245 }
1246
1247 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1248 async fn test_streaming() {
1249 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
1251 let on_stream_fn = MockHook::new("on_stream").expect_calls(3).to_owned();
1252
1253 let chat_request = chat_request! {
1254 user!(prompt);
1255
1256 tools = []
1257 };
1258
1259 let response = chat_response! {
1260 "one two three";
1261 tool_calls = ["stop"]
1262 };
1263
1264 mock_llm.expect_complete(chat_request, Ok(response));
1266
1267 let mut agent = Agent::builder()
1268 .llm(&mock_llm)
1269 .on_stream(on_stream_fn.on_stream_fn())
1270 .no_system_prompt()
1271 .build()
1272 .unwrap();
1273
1274 agent.query(prompt).await.unwrap();
1276
1277 tracing::debug!("Agent finished running");
1278
1279 assert!(agent.is_stopped());
1281 }
1282
1283 #[test_log::test(tokio::test)]
1284 async fn test_recovering_agent_existing_history() {
1285 let prompt = "Write a poem";
1287 let mock_llm = MockChatCompletion::new();
1288 let mock_tool = MockTool::new("mock_tool");
1289
1290 let chat_request = chat_request! {
1291 user!("Write a poem");
1292
1293 tools = [mock_tool.clone()]
1294 };
1295
1296 let mock_tool_response = chat_response! {
1297 "Roses are red";
1298 tool_calls = ["mock_tool"]
1299
1300 };
1301
1302 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1303
1304 let chat_request = chat_request! {
1305 user!("Write a poem"),
1306 assistant!("Roses are red", ["mock_tool"]),
1307 tool_output!("mock_tool", "Great!");
1308
1309 tools = [mock_tool.clone()]
1310 };
1311
1312 let stop_response = chat_response! {
1313 "Roses are red";
1314 tool_calls = ["stop"]
1315 };
1316
1317 mock_llm.expect_complete(chat_request, Ok(stop_response));
1318 mock_tool.expect_invoke_ok("Great!".into(), None);
1319
1320 let mut agent = Agent::builder()
1321 .tools([mock_tool.clone()])
1322 .llm(&mock_llm)
1323 .no_system_prompt()
1324 .build()
1325 .unwrap();
1326
1327 agent.query(prompt).await.unwrap();
1328
1329 let history = agent.history().await.unwrap();
1331
1332 let serialized = serde_json::to_string(&history).unwrap();
1334
1335 let history: Vec<ChatMessage> = serde_json::from_str(&serialized).unwrap();
1337
1338 let context = DefaultContext::default()
1340 .with_existing_messages(history)
1341 .await
1342 .unwrap()
1343 .to_owned();
1344
1345 let expected_chat_request = chat_request! {
1346 user!("Write a poem"),
1347 assistant!("Roses are red", ["mock_tool"]),
1348 tool_output!("mock_tool", "Great!"),
1349 assistant!("Roses are red", ["stop"]),
1350 tool_output!("stop", ToolOutput::Stop),
1351 user!("Try again!");
1352
1353 tools = [mock_tool.clone()]
1354 };
1355
1356 let stop_response = chat_response! {
1357 "Really stopping now";
1358 tool_calls = ["stop"]
1359 };
1360
1361 mock_llm.expect_complete(expected_chat_request, Ok(stop_response));
1362
1363 let mut agent = Agent::builder()
1364 .context(context)
1365 .tools([mock_tool])
1366 .llm(&mock_llm)
1367 .no_system_prompt()
1368 .build()
1369 .unwrap();
1370
1371 agent.query_once("Try again!").await.unwrap();
1372 }
1373
1374 #[test_log::test(tokio::test)]
1375 async fn test_agent_with_approval_required_tool() {
1376 use super::*;
1377 use crate::tools::control::ApprovalRequired;
1378 use crate::{assistant, chat_request, chat_response, user};
1379 use swiftide_core::chat_completion::ToolCall;
1380
1381 let mock_tool = MockTool::default();
1383 mock_tool.expect_invoke_ok("Great!".into(), None);
1384
1385 let approval_tool = ApprovalRequired(mock_tool.boxed());
1386
1387 let mock_llm = MockChatCompletion::new();
1389
1390 let chat_req1 = chat_request! {
1391 user!("Request with approval");
1392 tools = [approval_tool.clone()]
1393 };
1394 let chat_resp1 = chat_response! {
1395 "Completion message";
1396 tool_calls = ["mock_tool"]
1397 };
1398 mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1399
1400 let chat_req2 = chat_request! {
1403 user!("Request with approval"),
1404 assistant!("Completion message", ["mock_tool"]),
1405 tool_output!("mock_tool", "Great!");
1406 tools = [approval_tool.clone()]
1408 };
1409 let chat_resp2 = chat_response! {
1410 "Post-feedback message";
1411 tool_calls = ["stop"]
1412 };
1413 mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1414
1415 let mut agent = Agent::builder()
1417 .tools([approval_tool])
1418 .llm(&mock_llm)
1419 .no_system_prompt()
1420 .build()
1421 .unwrap();
1422
1423 agent.query_once("Request with approval").await.unwrap();
1425
1426 assert!(matches!(
1427 agent.state,
1428 crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1429 ));
1430
1431 let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1432 else {
1433 panic!("Expected feedback required");
1434 };
1435
1436 agent
1438 .context
1439 .feedback_received(&tool_call, &ToolFeedback::approved())
1440 .await
1441 .unwrap();
1442
1443 tracing::debug!("running after approval");
1444 agent.run_once().await.unwrap();
1445 assert!(agent.is_stopped());
1446 }
1447
1448 #[test_log::test(tokio::test)]
1449 async fn test_agent_with_approval_required_tool_denied() {
1450 use super::*;
1451 use crate::tools::control::ApprovalRequired;
1452 use crate::{assistant, chat_request, chat_response, user};
1453 use swiftide_core::chat_completion::ToolCall;
1454
1455 let mock_tool = MockTool::default();
1457
1458 let approval_tool = ApprovalRequired(mock_tool.boxed());
1459
1460 let mock_llm = MockChatCompletion::new();
1462
1463 let chat_req1 = chat_request! {
1464 user!("Request with approval");
1465 tools = [approval_tool.clone()]
1466 };
1467 let chat_resp1 = chat_response! {
1468 "Completion message";
1469 tool_calls = ["mock_tool"]
1470 };
1471 mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1472
1473 let chat_req2 = chat_request! {
1476 user!("Request with approval"),
1477 assistant!("Completion message", ["mock_tool"]),
1478 tool_output!("mock_tool", "This tool call was refused");
1479 tools = [approval_tool.clone()]
1481 };
1482 let chat_resp2 = chat_response! {
1483 "Post-feedback message";
1484 tool_calls = ["stop"]
1485 };
1486 mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1487
1488 let mut agent = Agent::builder()
1490 .tools([approval_tool])
1491 .llm(&mock_llm)
1492 .no_system_prompt()
1493 .build()
1494 .unwrap();
1495
1496 agent.query_once("Request with approval").await.unwrap();
1498
1499 assert!(matches!(
1500 agent.state,
1501 crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1502 ));
1503
1504 let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1505 else {
1506 panic!("Expected feedback required");
1507 };
1508
1509 agent
1511 .context
1512 .feedback_received(&tool_call, &ToolFeedback::refused())
1513 .await
1514 .unwrap();
1515
1516 tracing::debug!("running after approval");
1517 agent.run_once().await.unwrap();
1518
1519 let history = agent.context().history().await.unwrap();
1520 history
1521 .iter()
1522 .rfind(|m| {
1523 let ChatMessage::ToolOutput(.., ToolOutput::Text(msg)) = m else {
1524 return false;
1525 };
1526 msg.contains("refused")
1527 })
1528 .expect("Could not find refusal message");
1529
1530 assert!(agent.is_stopped());
1531 }
1532}