1#![allow(dead_code)]
2use crate::{
3 default_context::DefaultContext,
4 errors::AgentError,
5 hooks::{
6 AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7 Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn, OnStreamFn,
8 },
9 invoke_hooks,
10 state::{self, StopReason},
11 system_prompt::SystemPrompt,
12 tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15 collections::{HashMap, HashSet},
16 hash::{DefaultHasher, Hash as _, Hasher as _},
17 sync::Arc,
18};
19
20use derive_builder::Builder;
21use futures_util::stream::StreamExt;
22use swiftide_core::{
23 AgentContext, ToolBox,
24 chat_completion::{
25 ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
26 },
27 prompt::Prompt,
28};
29use tracing::{Instrument, debug};
30
31#[derive(Builder)]
46pub struct Agent {
47 #[builder(default, setter(into))]
49 pub(crate) hooks: Vec<Hook>,
50 #[builder(
52 setter(custom),
53 default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
54 )]
55 pub(crate) context: Arc<dyn AgentContext>,
56 #[builder(default = Agent::default_tools(), setter(custom))]
58 pub(crate) tools: HashSet<Box<dyn Tool>>,
59
60 #[builder(default)]
64 pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
65
66 #[builder(setter(custom))]
68 pub(crate) llm: Box<dyn ChatCompletion>,
69
70 #[builder(setter(into, strip_option), default = Some(SystemPrompt::default()))]
90 pub(crate) system_prompt: Option<SystemPrompt>,
91
92 #[builder(private, default = state::State::default())]
94 pub(crate) state: state::State,
95
96 #[builder(default, setter(strip_option))]
99 pub(crate) limit: Option<usize>,
100
101 #[builder(default = 3)]
113 pub(crate) tool_retry_limit: usize,
114
115 #[builder(default)]
117 pub(crate) streaming: bool,
118
119 #[builder(private, default)]
122 pub(crate) clear_default_tools: bool,
123
124 #[builder(private, default)]
127 pub(crate) tool_retries_counter: HashMap<u64, usize>,
128
129 #[builder(default = "unnamed_agent".into(), setter(into))]
131 pub(crate) name: String,
132}
133
134impl Clone for Agent {
135 fn clone(&self) -> Self {
136 Agent {
137 hooks: self.hooks.clone(),
138 context: Arc::new(self.context.clone()),
139 tools: self.tools.clone(),
140 toolboxes: self.toolboxes.clone(),
141 llm: self.llm.clone(),
142 system_prompt: self.system_prompt.clone(),
143 state: self.state.clone(),
144 limit: self.limit,
145 tool_retry_limit: self.tool_retry_limit,
146 tool_retries_counter: HashMap::new(),
147 streaming: self.streaming,
148 name: self.name.clone(),
149 clear_default_tools: self.clear_default_tools,
150 }
151 }
152}
153
154impl std::fmt::Debug for Agent {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 f.debug_struct("Agent")
157 .field("name", &self.name)
158 .field(
160 "hooks",
161 &self
162 .hooks
163 .iter()
164 .map(std::string::ToString::to_string)
165 .collect::<Vec<_>>(),
166 )
167 .field(
168 "tools",
169 &self
170 .tools
171 .iter()
172 .map(swiftide_core::Tool::name)
173 .collect::<Vec<_>>(),
174 )
175 .field("llm", &"Box<dyn ChatCompletion>")
176 .field("state", &self.state)
177 .finish()
178 }
179}
180
181impl AgentBuilder {
182 pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
184 where
185 Self: Clone,
186 {
187 self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
188 self
189 }
190
191 pub fn system_prompt_mut(&mut self) -> Option<&mut SystemPrompt> {
193 self.system_prompt.as_mut().and_then(Option::as_mut)
194 }
195
196 pub fn no_system_prompt(&mut self) -> &mut Self {
198 self.system_prompt = Some(None);
199
200 self
201 }
202
203 pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
205 let hooks = self.hooks.get_or_insert_with(Vec::new);
206 hooks.push(hook);
207
208 self
209 }
210
211 pub fn add_tool(&mut self, tool: impl Tool + 'static) -> &mut Self {
213 self.tools = Some(
214 self.tools
215 .take()
216 .unwrap_or_default()
217 .into_iter()
218 .chain([Box::new(tool) as Box<dyn Tool>])
219 .collect(),
220 );
221 self
222 }
223
224 pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
227 self.add_hook(Hook::BeforeAll(Box::new(hook)))
228 }
229
230 pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
234 self.add_hook(Hook::OnStart(Box::new(hook)))
235 }
236
237 pub fn on_stream(&mut self, hook: impl OnStreamFn + 'static) -> &mut Self {
244 self.streaming = Some(true);
245 self.add_hook(Hook::OnStream(Box::new(hook)))
246 }
247
248 pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
250 self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
251 }
252
253 pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
259 self.add_hook(Hook::AfterTool(Box::new(hook)))
260 }
261
262 pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
264 self.add_hook(Hook::BeforeTool(Box::new(hook)))
265 }
266
267 pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
269 self.add_hook(Hook::AfterCompletion(Box::new(hook)))
270 }
271
272 pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
275 self.add_hook(Hook::AfterEach(Box::new(hook)))
276 }
277
278 pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
281 self.add_hook(Hook::OnNewMessage(Box::new(hook)))
282 }
283
284 pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
285 self.add_hook(Hook::OnStop(Box::new(hook)))
286 }
287
288 pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
290 let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
291
292 self.llm = Some(boxed);
293 self
294 }
295
296 fn builder_default_tools(&self) -> HashSet<Box<dyn Tool>> {
297 if self.clear_default_tools.is_some_and(|b| b) {
298 HashSet::new()
299 } else {
300 Agent::default_tools()
301 }
302 }
303
304 pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
309 where
310 TOOL: Into<Box<dyn Tool>>,
311 {
312 self.tools = Some(
313 tools
314 .into_iter()
315 .map(Into::into)
316 .chain(self.builder_default_tools())
317 .collect(),
318 );
319 self
320 }
321
322 pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
328 let toolboxes = self.toolboxes.get_or_insert_with(Vec::new);
329 toolboxes.push(Box::new(toolbox));
330
331 self
332 }
333}
334
335impl Agent {
336 pub fn builder() -> AgentBuilder {
338 AgentBuilder::default()
339 .tools(Agent::default_tools())
340 .to_owned()
341 }
342
343 pub fn name(&self) -> &str {
345 &self.name
346 }
347
348 pub fn default_tools() -> HashSet<Box<dyn Tool>> {
351 HashSet::from([Stop::default().boxed()])
352 }
353
354 #[tracing::instrument(skip_all, name = "agent.query")]
361 pub async fn query(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
362 let query = query
363 .into()
364 .render()
365 .map_err(AgentError::FailedToRenderPrompt)?;
366 self.run_agent(Some(query), false).await
367 }
368
369 pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
371 self.tools.insert(tool);
372 }
373
374 pub fn tools_mut(&mut self) -> &mut HashSet<Box<dyn Tool>> {
379 &mut self.tools
380 }
381
382 #[tracing::instrument(skip_all, name = "agent.query_once")]
388 pub async fn query_once(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
389 self.run_agent(Some(query), true).await
390 }
391
392 #[tracing::instrument(skip_all, name = "agent.run")]
399 pub async fn run(&mut self) -> Result<(), AgentError> {
400 self.run_agent(None::<Prompt>, false).await
401 }
402
403 #[tracing::instrument(skip_all, name = "agent.run_once")]
410 pub async fn run_once(&mut self) -> Result<(), AgentError> {
411 self.run_agent(None::<Prompt>, true).await
412 }
413
414 pub async fn history(&self) -> Result<Vec<ChatMessage>, AgentError> {
421 self.context
422 .history()
423 .await
424 .map_err(AgentError::MessageHistoryError)
425 }
426
427 pub(crate) async fn run_agent(
428 &mut self,
429 maybe_query: Option<impl Into<Prompt>>,
430 just_once: bool,
431 ) -> Result<(), AgentError> {
432 let maybe_query = maybe_query
433 .map(|q| q.into().render())
434 .transpose()
435 .map_err(AgentError::FailedToRenderPrompt)?;
436 if self.state.is_running() {
437 return Err(AgentError::AlreadyRunning);
438 }
439
440 if self.state.is_pending() {
441 if let Some(system_prompt) = &self.system_prompt {
442 self.context
443 .add_messages(vec![ChatMessage::System(
444 system_prompt
445 .to_prompt()
446 .render()
447 .map_err(AgentError::FailedToRenderSystemPrompt)?,
448 )])
449 .await
450 .map_err(AgentError::MessageHistoryError)?;
451 }
452
453 invoke_hooks!(BeforeAll, self);
454
455 self.load_toolboxes().await?;
456 }
457
458 if let Some(query) = maybe_query {
459 self.context
460 .add_message(ChatMessage::User(query))
461 .await
462 .map_err(AgentError::MessageHistoryError)?;
463 }
464
465 invoke_hooks!(OnStart, self);
466
467 self.state = state::State::Running;
468
469 let mut loop_counter = 0;
470
471 while let Some(messages) = self
472 .context
473 .next_completion()
474 .await
475 .map_err(AgentError::MessageHistoryError)?
476 {
477 if let Some(limit) = self.limit
478 && loop_counter >= limit
479 {
480 tracing::warn!("Agent loop limit reached");
481 break;
482 }
483
484 if let Some(&ChatMessage::Assistant(.., Some(ref tool_calls))) =
487 maybe_tool_call_without_output(&messages)
488 {
489 tracing::debug!("Uncompleted tool calls found; invoking tools");
490 self.invoke_tools(tool_calls).await?;
491 continue;
493 }
494
495 let result = self.run_completions(&messages).await;
496
497 if let Err(err) = result {
498 self.stop_with_error(&err).await;
499 tracing::error!(error = ?err, "Agent stopped with error {err}");
500 return Err(err);
501 }
502
503 if just_once || self.state.is_stopped() {
504 break;
505 }
506 loop_counter += 1;
507 }
508
509 self.stop(StopReason::NoNewMessages).await;
511
512 Ok(())
513 }
514
515 #[tracing::instrument(skip_all, err)]
516 async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<(), AgentError> {
517 debug!(
518 tools = ?self
519 .tools
520 .iter()
521 .map(|t| t.name())
522 .collect::<Vec<_>>()
523 ,
524 "Running completion for agent with {} new messages",
525 messages.len()
526 );
527
528 let mut chat_completion_request = ChatCompletionRequest::builder()
529 .messages(messages)
530 .tools_spec(
531 self.tools
532 .iter()
533 .map(swiftide_core::Tool::tool_spec)
534 .collect::<HashSet<_>>(),
535 )
536 .build()
537 .map_err(AgentError::FailedToBuildRequest)?;
538
539 invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
540
541 debug!(
542 "Calling LLM with the following new messages:\n {}",
543 self.context
544 .current_new_messages()
545 .await
546 .map_err(AgentError::MessageHistoryError)?
547 .iter()
548 .map(ToString::to_string)
549 .collect::<Vec<_>>()
550 .join(",\n")
551 );
552
553 let mut response = if self.streaming {
554 let mut last_response = None;
555 let mut stream = self.llm.complete_stream(&chat_completion_request).await;
556
557 while let Some(response) = stream.next().await {
558 let response = response.map_err(AgentError::CompletionsFailed)?;
559 invoke_hooks!(OnStream, self, &response);
560 last_response = Some(response);
561 }
562 tracing::trace!(?last_response, "Streaming completed");
563 last_response.ok_or(AgentError::EmptyStream)
564 } else {
565 self.llm
566 .complete(&chat_completion_request)
567 .await
568 .map_err(AgentError::CompletionsFailed)
569 }?;
570
571 response
574 .tool_calls
575 .as_deref_mut()
576 .map(ArgPreprocessor::preprocess_tool_calls);
577
578 invoke_hooks!(AfterCompletion, self, &mut response);
579
580 self.add_message(ChatMessage::Assistant(
581 response.message,
582 response.tool_calls.clone(),
583 ))
584 .await?;
585
586 if let Some(tool_calls) = response.tool_calls {
587 self.invoke_tools(&tool_calls).await?;
588 }
589
590 invoke_hooks!(AfterEach, self);
591
592 Ok(())
593 }
594
595 async fn invoke_tools(&mut self, tool_calls: &[ToolCall]) -> Result<(), AgentError> {
596 tracing::debug!("LLM returned tool calls: {:?}", tool_calls);
597
598 let mut handles = vec![];
599 for tool_call in tool_calls {
600 let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
601 tracing::warn!("Tool {} not found", tool_call.name());
602 continue;
603 };
604 tracing::info!("Calling tool `{}`", tool_call.name());
605
606 let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
608
609 invoke_hooks!(BeforeTool, self, &tool_call);
610
611 let tool_span = tracing::info_span!(
612 "tool",
613 "otel.name" = format!("tool.{}", tool.name().as_ref())
614 );
615
616 let handle_tool_call = tool_call.clone();
617 let handle = tokio::spawn(async move {
618 let handle_tool_call = handle_tool_call;
619 let output = tool.invoke(&*context, &handle_tool_call)
620 .await
621 .map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
622
623 tracing::debug!(output = output.to_string(), args = ?handle_tool_call.args(), tool_name = tool.name().as_ref(), "Completed tool call");
624
625 Ok(output)
626 }.instrument(tool_span.or_current()));
627
628 handles.push((handle, tool_call));
629 }
630
631 for (handle, tool_call) in handles {
632 let mut output = handle
633 .await
634 .map_err(|err| AgentError::ToolFailedToJoin(tool_call.name().to_string(), err))?;
635
636 invoke_hooks!(AfterTool, self, &tool_call, &mut output);
637
638 if let Err(error) = output {
639 let stop = self.tool_calls_over_limit(tool_call);
640 if stop {
641 tracing::error!(
642 ?error,
643 "Tool call failed, retry limit reached, stopping agent: {error}",
644 );
645 } else {
646 tracing::warn!(
647 ?error,
648 tool_call = ?tool_call,
649 "Tool call failed, retrying",
650 );
651 }
652 self.add_message(ChatMessage::ToolOutput(
653 tool_call.clone(),
654 ToolOutput::fail(error.to_string()),
655 ))
656 .await?;
657 if stop {
658 self.stop(StopReason::ToolCallsOverLimit(tool_call.to_owned()))
659 .await;
660 return Err(error.into());
661 }
662 continue;
663 }
664
665 let output = output?;
666 self.handle_control_tools(tool_call, &output).await;
667
668 if !output.is_feedback_required() {
672 self.add_message(ChatMessage::ToolOutput(tool_call.to_owned(), output))
673 .await?;
674 }
675 }
676
677 Ok(())
678 }
679
680 fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
681 self.hooks
682 .iter()
683 .filter(|h| hook_type == (*h).into())
684 .collect()
685 }
686
687 fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
688 self.tools
689 .iter()
690 .find(|tool| tool.name() == tool_name)
691 .cloned()
692 }
693
694 async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
696 match output {
697 ToolOutput::Stop(maybe_message) => {
698 tracing::warn!("Stop tool called, stopping agent");
699 self.stop(StopReason::RequestedByTool(
700 tool_call.clone(),
701 maybe_message.clone(),
702 ))
703 .await;
704 }
705 ToolOutput::FeedbackRequired(maybe_payload) => {
706 tracing::warn!("Feedback required, stopping agent");
707 self.stop(StopReason::FeedbackRequired {
708 tool_call: tool_call.clone(),
709 payload: maybe_payload.clone(),
710 })
711 .await;
712 }
713 ToolOutput::AgentFailed(output) => {
714 tracing::warn!("Agent failed, stopping agent");
715 self.stop(StopReason::AgentFailed(output.clone())).await;
716 }
717 _ => (),
718 }
719 }
720
721 fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
722 let mut s = DefaultHasher::new();
723 tool_call.hash(&mut s);
724 let hash = s.finish();
725
726 if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
727 let val = *retries >= self.tool_retry_limit;
728 *retries += 1;
729 val
730 } else {
731 self.tool_retries_counter.insert(hash, 1);
732 false
733 }
734 }
735
736 #[tracing::instrument(skip_all, fields(message = message.to_string()))]
747 pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
748 invoke_hooks!(OnNewMessage, self, &mut message);
749
750 self.context
751 .add_message(message)
752 .await
753 .map_err(AgentError::MessageHistoryError)?;
754 Ok(())
755 }
756
757 pub async fn stop(&mut self, reason: impl Into<StopReason>) {
759 if self.state.is_stopped() {
760 return;
761 }
762 let reason = reason.into();
763 invoke_hooks!(OnStop, self, reason.clone(), None);
764
765 self.state = state::State::Stopped(reason);
766 }
767
768 pub async fn stop_with_error(&mut self, error: &AgentError) {
769 if self.state.is_stopped() {
770 return;
771 }
772 invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
773
774 self.state = state::State::Stopped(StopReason::Error);
775 }
776
777 pub fn context(&self) -> &dyn AgentContext {
779 &self.context
780 }
781
782 pub fn is_running(&self) -> bool {
784 self.state.is_running()
785 }
786
787 pub fn is_stopped(&self) -> bool {
789 self.state.is_stopped()
790 }
791
792 pub fn is_pending(&self) -> bool {
794 self.state.is_pending()
795 }
796
797 pub fn tools(&self) -> &HashSet<Box<dyn Tool>> {
799 &self.tools
800 }
801
802 pub fn state(&self) -> &state::State {
803 &self.state
804 }
805
806 pub fn stop_reason(&self) -> Option<&StopReason> {
807 self.state.stop_reason()
808 }
809
810 async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
811 for toolbox in &self.toolboxes {
812 let tools = toolbox
813 .available_tools()
814 .await
815 .map_err(AgentError::ToolBoxFailedToLoad)?;
816 self.tools.extend(tools);
817 }
818
819 Ok(())
820 }
821}
822
823fn maybe_tool_call_without_output(messages: &[ChatMessage]) -> Option<&ChatMessage> {
826 for message in messages.iter().rev() {
827 if let ChatMessage::ToolOutput(..) = message {
828 return None;
829 }
830
831 if let ChatMessage::Assistant(.., Some(tool_calls)) = message
832 && !tool_calls.is_empty()
833 {
834 return Some(message);
835 }
836 }
837
838 None
839}
840
841#[cfg(test)]
842mod tests {
843
844 use serde::ser::Error;
845 use swiftide_core::ToolFeedback;
846 use swiftide_core::chat_completion::errors::ToolError;
847 use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
848 use swiftide_core::test_utils::MockChatCompletion;
849
850 use super::*;
851 use crate::{
852 State, assistant, chat_request, chat_response, summary, system, tool_failed, tool_output,
853 user,
854 };
855
856 use crate::test_utils::{MockHook, MockTool};
857
858 #[test_log::test(tokio::test)]
859 async fn test_agent_builder_defaults() {
860 let mock_llm = MockChatCompletion::new();
862
863 let agent = Agent::builder().llm(&mock_llm).build().unwrap();
865
866 assert!(agent.find_tool_by_name("stop").is_some());
870
871 let agent = Agent::builder()
873 .tools([Stop::default(), Stop::default()])
874 .llm(&mock_llm)
875 .build()
876 .unwrap();
877
878 assert_eq!(agent.tools.len(), 1);
879
880 let agent = Agent::builder()
882 .tools([MockTool::new("mock_tool")])
883 .llm(&mock_llm)
884 .build()
885 .unwrap();
886
887 assert_eq!(agent.tools.len(), 2);
888 assert!(agent.find_tool_by_name("mock_tool").is_some());
889 assert!(agent.find_tool_by_name("stop").is_some());
890
891 assert!(agent.context().history().await.unwrap().is_empty());
892 }
893
894 #[test_log::test(tokio::test)]
895 async fn test_agent_tool_calling_loop() {
896 let prompt = "Write a poem";
897 let mock_llm = MockChatCompletion::new();
898 let mock_tool = MockTool::new("mock_tool");
899
900 let chat_request = chat_request! {
901 user!("Write a poem");
902
903 tools = [mock_tool.clone()]
904 };
905
906 let mock_tool_response = chat_response! {
907 "Roses are red";
908 tool_calls = ["mock_tool"]
909
910 };
911
912 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
913
914 let chat_request = chat_request! {
915 user!("Write a poem"),
916 assistant!("Roses are red", ["mock_tool"]),
917 tool_output!("mock_tool", "Great!");
918
919 tools = [mock_tool.clone()]
920 };
921
922 let stop_response = chat_response! {
923 "Roses are red";
924 tool_calls = ["stop"]
925 };
926
927 mock_llm.expect_complete(chat_request, Ok(stop_response));
928 mock_tool.expect_invoke_ok("Great!".into(), None);
929
930 let mut agent = Agent::builder()
931 .tools([mock_tool])
932 .llm(&mock_llm)
933 .no_system_prompt()
934 .build()
935 .unwrap();
936
937 agent.query(prompt).await.unwrap();
938 }
939
940 #[test_log::test(tokio::test)]
941 async fn test_agent_tool_run_once() {
942 let prompt = "Write a poem";
943 let mock_llm = MockChatCompletion::new();
944 let mock_tool = MockTool::default();
945
946 let chat_request = chat_request! {
947 system!("My system prompt"),
948 user!("Write a poem");
949
950 tools = [mock_tool.clone()]
951 };
952
953 let mock_tool_response = chat_response! {
954 "Roses are red";
955 tool_calls = ["mock_tool"]
956
957 };
958
959 mock_tool.expect_invoke_ok("Great!".into(), None);
960 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
961
962 let mut agent = Agent::builder()
963 .tools([mock_tool])
964 .system_prompt("My system prompt")
965 .llm(&mock_llm)
966 .build()
967 .unwrap();
968
969 agent.query_once(prompt).await.unwrap();
970 }
971
972 #[test_log::test(tokio::test)]
973 async fn test_agent_tool_via_toolbox_run_once() {
974 let prompt = "Write a poem";
975 let mock_llm = MockChatCompletion::new();
976 let mock_tool = MockTool::default();
977
978 let chat_request = chat_request! {
979 system!("My system prompt"),
980 user!("Write a poem");
981
982 tools = [mock_tool.clone()]
983 };
984
985 let mock_tool_response = chat_response! {
986 "Roses are red";
987 tool_calls = ["mock_tool"]
988
989 };
990
991 mock_tool.expect_invoke_ok("Great!".into(), None);
992 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
993
994 let mut agent = Agent::builder()
995 .add_toolbox(vec![mock_tool.boxed()])
996 .system_prompt("My system prompt")
997 .llm(&mock_llm)
998 .build()
999 .unwrap();
1000
1001 agent.query_once(prompt).await.unwrap();
1002 }
1003
1004 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1005 async fn test_multiple_tool_calls() {
1006 let prompt = "Write a poem";
1007 let mock_llm = MockChatCompletion::new();
1008 let mock_tool = MockTool::new("mock_tool1");
1009 let mock_tool2 = MockTool::new("mock_tool2");
1010
1011 let chat_request = chat_request! {
1012 system!("My system prompt"),
1013 user!("Write a poem");
1014
1015
1016
1017 tools = [mock_tool.clone(), mock_tool2.clone()]
1018 };
1019
1020 let mock_tool_response = chat_response! {
1021 "Roses are red";
1022
1023 tool_calls = ["mock_tool1", "mock_tool2"]
1024
1025 };
1026
1027 dbg!(&chat_request);
1028 mock_tool.expect_invoke_ok("Great!".into(), None);
1029 mock_tool2.expect_invoke_ok("Great!".into(), None);
1030 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1031
1032 let chat_request = chat_request! {
1033 system!("My system prompt"),
1034 user!("Write a poem"),
1035 assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
1036 tool_output!("mock_tool1", "Great!"),
1037 tool_output!("mock_tool2", "Great!");
1038
1039 tools = [mock_tool.clone(), mock_tool2.clone()]
1040 };
1041
1042 let mock_tool_response = chat_response! {
1043 "Ok!";
1044
1045 tool_calls = ["stop"]
1046
1047 };
1048
1049 mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
1050
1051 let mut agent = Agent::builder()
1052 .tools([mock_tool, mock_tool2])
1053 .system_prompt("My system prompt")
1054 .llm(&mock_llm)
1055 .build()
1056 .unwrap();
1057
1058 agent.query(prompt).await.unwrap();
1059 }
1060
1061 #[test_log::test(tokio::test)]
1062 async fn test_agent_state_machine() {
1063 let prompt = "Write a poem";
1064 let mock_llm = MockChatCompletion::new();
1065
1066 let chat_request = chat_request! {
1067 user!("Write a poem");
1068 tools = []
1069 };
1070 let mock_tool_response = chat_response! {
1071 "Roses are red";
1072 tool_calls = []
1073 };
1074
1075 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1076 let mut agent = Agent::builder()
1077 .llm(&mock_llm)
1078 .no_system_prompt()
1079 .build()
1080 .unwrap();
1081
1082 assert!(agent.state.is_pending());
1084 agent.query_once(prompt).await.unwrap();
1085
1086 assert!(agent.state.is_stopped());
1088 }
1089
1090 #[test_log::test(tokio::test)]
1091 async fn test_summary() {
1092 let prompt = "Write a poem";
1093 let mock_llm = MockChatCompletion::new();
1094
1095 let mock_tool_response = chat_response! {
1096 "Roses are red";
1097 tool_calls = []
1098
1099 };
1100
1101 let expected_chat_request = chat_request! {
1102 system!("My system prompt"),
1103 user!("Write a poem");
1104
1105 tools = []
1106 };
1107
1108 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1109
1110 let mut agent = Agent::builder()
1111 .system_prompt("My system prompt")
1112 .llm(&mock_llm)
1113 .build()
1114 .unwrap();
1115
1116 agent.query_once(prompt).await.unwrap();
1117
1118 agent
1119 .context
1120 .add_message(ChatMessage::new_summary("Summary"))
1121 .await
1122 .unwrap();
1123
1124 let expected_chat_request = chat_request! {
1125 system!("My system prompt"),
1126 summary!("Summary"),
1127 user!("Write another poem");
1128 tools = []
1129 };
1130 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1131
1132 agent.query_once("Write another poem").await.unwrap();
1133
1134 agent
1135 .context
1136 .add_message(ChatMessage::new_summary("Summary 2"))
1137 .await
1138 .unwrap();
1139
1140 let expected_chat_request = chat_request! {
1141 system!("My system prompt"),
1142 summary!("Summary 2"),
1143 user!("Write a third poem");
1144 tools = []
1145 };
1146 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
1147
1148 agent.query_once("Write a third poem").await.unwrap();
1149 }
1150
1151 #[test_log::test(tokio::test)]
1152 async fn test_agent_hooks() {
1153 let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
1154 let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
1155 let mock_before_completion = MockHook::new("before_completion")
1156 .expect_calls(2)
1157 .to_owned();
1158 let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
1159 let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
1160 let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
1161 let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
1162
1163 let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
1165 let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
1166
1167 let prompt = "Write a poem";
1168 let mock_llm = MockChatCompletion::new();
1169 let mock_tool = MockTool::default();
1170
1171 let chat_request = chat_request! {
1172 user!("Write a poem");
1173
1174 tools = [mock_tool.clone()]
1175 };
1176
1177 let mock_tool_response = chat_response! {
1178 "Roses are red";
1179 tool_calls = ["mock_tool"]
1180
1181 };
1182
1183 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1184
1185 let chat_request = chat_request! {
1186 user!("Write a poem"),
1187 assistant!("Roses are red", ["mock_tool"]),
1188 tool_output!("mock_tool", "Great!");
1189
1190 tools = [mock_tool.clone()]
1191 };
1192
1193 let stop_response = chat_response! {
1194 "Roses are red";
1195 tool_calls = ["stop"]
1196 };
1197
1198 mock_llm.expect_complete(chat_request, Ok(stop_response));
1199 mock_tool.expect_invoke_ok("Great!".into(), None);
1200
1201 let mut agent = Agent::builder()
1202 .tools([mock_tool])
1203 .llm(&mock_llm)
1204 .no_system_prompt()
1205 .before_all(mock_before_all.hook_fn())
1206 .on_start(mock_on_start_fn.on_start_fn())
1207 .before_completion(mock_before_completion.before_completion_fn())
1208 .before_tool(mock_before_tool.before_tool_fn())
1209 .after_completion(mock_after_completion.after_completion_fn())
1210 .after_tool(mock_after_tool.after_tool_fn())
1211 .after_each(mock_after_each.hook_fn())
1212 .on_new_message(mock_on_message.message_hook_fn())
1213 .on_stop(mock_on_stop.stop_hook_fn())
1214 .build()
1215 .unwrap();
1216
1217 agent.query(prompt).await.unwrap();
1218 }
1219
1220 #[test_log::test(tokio::test)]
1221 async fn test_agent_loop_limit() {
1222 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
1224 let mock_tool = MockTool::new("mock_tool");
1225
1226 let chat_request = chat_request! {
1227 user!(prompt);
1228 tools = [mock_tool.clone()]
1229 };
1230 mock_tool.expect_invoke_ok("Great!".into(), None);
1231
1232 let mock_tool_response = chat_response! {
1233 "Some response";
1234 tool_calls = ["mock_tool"]
1235 };
1236
1237 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1239
1240 let stop_response = chat_response! {
1242 "Final response";
1243 tool_calls = ["stop"]
1244 };
1245
1246 mock_llm.expect_complete(chat_request, Ok(stop_response));
1247
1248 let mut agent = Agent::builder()
1249 .tools([mock_tool])
1250 .llm(&mock_llm)
1251 .no_system_prompt()
1252 .limit(1) .build()
1254 .unwrap();
1255
1256 agent.query(prompt).await.unwrap();
1258
1259 let remaining = mock_llm.expectations.lock().unwrap().pop();
1261 assert!(remaining.is_some());
1262
1263 assert!(agent.is_stopped());
1265 }
1266
1267 #[test_log::test(tokio::test)]
1268 async fn test_tool_retry_mechanism() {
1269 let prompt = "Execute my tool";
1270 let mock_llm = MockChatCompletion::new();
1271 let mock_tool = MockTool::new("retry_tool");
1272
1273 mock_tool.expect_invoke(
1276 Err(ToolError::WrongArguments(serde_json::Error::custom(
1277 "missing `query`",
1278 ))),
1279 None,
1280 );
1281 mock_tool.expect_invoke(
1282 Err(ToolError::WrongArguments(serde_json::Error::custom(
1283 "missing `query`",
1284 ))),
1285 None,
1286 );
1287
1288 let chat_request = chat_request! {
1289 user!(prompt);
1290 tools = [mock_tool.clone()]
1291 };
1292 let retry_response = chat_response! {
1293 "First failing attempt";
1294 tool_calls = ["retry_tool"]
1295 };
1296 mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1297
1298 let chat_request = chat_request! {
1299 user!(prompt),
1300 assistant!("First failing attempt", ["retry_tool"]),
1301 tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1302
1303 tools = [mock_tool.clone()]
1304 };
1305 let will_fail_response = chat_response! {
1306 "Finished execution";
1307 tool_calls = ["retry_tool"]
1308 };
1309 mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1310
1311 let mut agent = Agent::builder()
1312 .tools([mock_tool])
1313 .llm(&mock_llm)
1314 .no_system_prompt()
1315 .tool_retry_limit(1) .build()
1317 .unwrap();
1318
1319 let result = agent.query(prompt).await;
1321
1322 assert!(result.is_err());
1323 assert!(result.unwrap_err().to_string().contains("missing `query`"));
1324 assert!(agent.is_stopped());
1325 }
1326
1327 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1328 async fn test_streaming() {
1329 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
1331 let on_stream_fn = MockHook::new("on_stream").expect_calls(3).to_owned();
1332
1333 let chat_request = chat_request! {
1334 user!(prompt);
1335
1336 tools = []
1337 };
1338
1339 let response = chat_response! {
1340 "one two three";
1341 tool_calls = ["stop"]
1342 };
1343
1344 mock_llm.expect_complete(chat_request, Ok(response));
1346
1347 let mut agent = Agent::builder()
1348 .llm(&mock_llm)
1349 .on_stream(on_stream_fn.on_stream_fn())
1350 .no_system_prompt()
1351 .build()
1352 .unwrap();
1353
1354 agent.query(prompt).await.unwrap();
1356
1357 tracing::debug!("Agent finished running");
1358
1359 assert!(agent.is_stopped());
1361 }
1362
1363 #[test_log::test(tokio::test)]
1364 async fn test_recovering_agent_existing_history() {
1365 let prompt = "Write a poem";
1367 let mock_llm = MockChatCompletion::new();
1368 let mock_tool = MockTool::new("mock_tool");
1369
1370 let chat_request = chat_request! {
1371 user!("Write a poem");
1372
1373 tools = [mock_tool.clone()]
1374 };
1375
1376 let mock_tool_response = chat_response! {
1377 "Roses are red";
1378 tool_calls = ["mock_tool"]
1379
1380 };
1381
1382 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1383
1384 let chat_request = chat_request! {
1385 user!("Write a poem"),
1386 assistant!("Roses are red", ["mock_tool"]),
1387 tool_output!("mock_tool", "Great!");
1388
1389 tools = [mock_tool.clone()]
1390 };
1391
1392 let stop_response = chat_response! {
1393 "Roses are red";
1394 tool_calls = ["stop"]
1395 };
1396
1397 mock_llm.expect_complete(chat_request, Ok(stop_response));
1398 mock_tool.expect_invoke_ok("Great!".into(), None);
1399
1400 let mut agent = Agent::builder()
1401 .tools([mock_tool.clone()])
1402 .llm(&mock_llm)
1403 .no_system_prompt()
1404 .build()
1405 .unwrap();
1406
1407 agent.query(prompt).await.unwrap();
1408
1409 let history = agent.history().await.unwrap();
1411
1412 let serialized = serde_json::to_string(&history).unwrap();
1414
1415 let history: Vec<ChatMessage> = serde_json::from_str(&serialized).unwrap();
1417
1418 let context = DefaultContext::default()
1420 .with_existing_messages(history)
1421 .await
1422 .unwrap()
1423 .to_owned();
1424
1425 let stop_output = ToolOutput::stop();
1426 let expected_chat_request = chat_request! {
1427 user!("Write a poem"),
1428 assistant!("Roses are red", ["mock_tool"]),
1429 tool_output!("mock_tool", "Great!"),
1430 assistant!("Roses are red", ["stop"]),
1431 tool_output!("stop", stop_output),
1432 user!("Try again!");
1433
1434 tools = [mock_tool.clone()]
1435 };
1436
1437 let stop_response = chat_response! {
1438 "Really stopping now";
1439 tool_calls = ["stop"]
1440 };
1441
1442 mock_llm.expect_complete(expected_chat_request, Ok(stop_response));
1443
1444 let mut agent = Agent::builder()
1445 .context(context)
1446 .tools([mock_tool])
1447 .llm(&mock_llm)
1448 .no_system_prompt()
1449 .build()
1450 .unwrap();
1451
1452 agent.query_once("Try again!").await.unwrap();
1453 }
1454
1455 #[test_log::test(tokio::test)]
1456 async fn test_agent_with_approval_required_tool() {
1457 use super::*;
1458 use crate::tools::control::ApprovalRequired;
1459 use crate::{assistant, chat_request, chat_response, user};
1460 use swiftide_core::chat_completion::ToolCall;
1461
1462 let mock_tool = MockTool::default();
1464 mock_tool.expect_invoke_ok("Great!".into(), None);
1465
1466 let approval_tool = ApprovalRequired(mock_tool.boxed());
1467
1468 let mock_llm = MockChatCompletion::new();
1470
1471 let chat_req1 = chat_request! {
1472 user!("Request with approval");
1473 tools = [approval_tool.clone()]
1474 };
1475 let chat_resp1 = chat_response! {
1476 "Completion message";
1477 tool_calls = ["mock_tool"]
1478 };
1479 mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1480
1481 let chat_req2 = chat_request! {
1484 user!("Request with approval"),
1485 assistant!("Completion message", ["mock_tool"]),
1486 tool_output!("mock_tool", "Great!");
1487 tools = [approval_tool.clone()]
1489 };
1490 let chat_resp2 = chat_response! {
1491 "Post-feedback message";
1492 tool_calls = ["stop"]
1493 };
1494 mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1495
1496 let mut agent = Agent::builder()
1498 .tools([approval_tool])
1499 .llm(&mock_llm)
1500 .no_system_prompt()
1501 .build()
1502 .unwrap();
1503
1504 agent.query_once("Request with approval").await.unwrap();
1506
1507 assert!(matches!(
1508 agent.state,
1509 crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1510 ));
1511
1512 let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1513 else {
1514 panic!("Expected feedback required");
1515 };
1516
1517 agent
1519 .context
1520 .feedback_received(&tool_call, &ToolFeedback::approved())
1521 .await
1522 .unwrap();
1523
1524 tracing::debug!("running after approval");
1525 agent.run_once().await.unwrap();
1526 assert!(agent.is_stopped());
1527 }
1528
1529 #[test_log::test(tokio::test)]
1530 async fn test_agent_with_approval_required_tool_denied() {
1531 use super::*;
1532 use crate::tools::control::ApprovalRequired;
1533 use crate::{assistant, chat_request, chat_response, user};
1534 use swiftide_core::chat_completion::ToolCall;
1535
1536 let mock_tool = MockTool::default();
1538
1539 let approval_tool = ApprovalRequired(mock_tool.boxed());
1540
1541 let mock_llm = MockChatCompletion::new();
1543
1544 let chat_req1 = chat_request! {
1545 user!("Request with approval");
1546 tools = [approval_tool.clone()]
1547 };
1548 let chat_resp1 = chat_response! {
1549 "Completion message";
1550 tool_calls = ["mock_tool"]
1551 };
1552 mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1553
1554 let chat_req2 = chat_request! {
1557 user!("Request with approval"),
1558 assistant!("Completion message", ["mock_tool"]),
1559 tool_output!("mock_tool", "This tool call was refused");
1560 tools = [approval_tool.clone()]
1562 };
1563 let chat_resp2 = chat_response! {
1564 "Post-feedback message";
1565 tool_calls = ["stop"]
1566 };
1567 mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1568
1569 let mut agent = Agent::builder()
1571 .tools([approval_tool])
1572 .llm(&mock_llm)
1573 .no_system_prompt()
1574 .build()
1575 .unwrap();
1576
1577 agent.query_once("Request with approval").await.unwrap();
1579
1580 assert!(matches!(
1581 agent.state,
1582 crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1583 ));
1584
1585 let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1586 else {
1587 panic!("Expected feedback required");
1588 };
1589
1590 agent
1592 .context
1593 .feedback_received(&tool_call, &ToolFeedback::refused())
1594 .await
1595 .unwrap();
1596
1597 tracing::debug!("running after approval");
1598 agent.run_once().await.unwrap();
1599
1600 let history = agent.context().history().await.unwrap();
1601 history
1602 .iter()
1603 .rfind(|m| {
1604 let ChatMessage::ToolOutput(.., ToolOutput::Text(msg)) = m else {
1605 return false;
1606 };
1607 msg.contains("refused")
1608 })
1609 .expect("Could not find refusal message");
1610
1611 assert!(agent.is_stopped());
1612 }
1613}