1#![allow(dead_code)]
2use crate::{
3 default_context::DefaultContext,
4 errors::AgentError,
5 hooks::{
6 AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7 Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn, OnStreamFn,
8 },
9 invoke_hooks,
10 state::{self, StopReason},
11 system_prompt::SystemPrompt,
12 tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15 collections::{HashMap, HashSet},
16 hash::{DefaultHasher, Hash as _, Hasher as _},
17 sync::Arc,
18};
19
20use derive_builder::Builder;
21use futures_util::stream::StreamExt;
22use swiftide_core::{
23 AgentContext, ToolBox,
24 chat_completion::{
25 ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
26 },
27 prompt::Prompt,
28};
29use tracing::{Instrument, debug};
30
31#[derive(Clone, Builder)]
43pub struct Agent {
44 #[builder(default, setter(into))]
46 pub(crate) hooks: Vec<Hook>,
47 #[builder(
49 setter(custom),
50 default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
51 )]
52 pub(crate) context: Arc<dyn AgentContext>,
53 #[builder(default = Agent::default_tools(), setter(custom))]
55 pub(crate) tools: HashSet<Box<dyn Tool>>,
56
57 #[builder(default)]
61 pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
62
63 #[builder(setter(custom))]
65 pub(crate) llm: Box<dyn ChatCompletion>,
66
67 #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
87 pub(crate) system_prompt: Option<Prompt>,
88
89 #[builder(private, default = state::State::default())]
91 pub(crate) state: state::State,
92
93 #[builder(default, setter(strip_option))]
96 pub(crate) limit: Option<usize>,
97
98 #[builder(default = 3)]
110 pub(crate) tool_retry_limit: usize,
111
112 #[builder(default)]
114 pub(crate) streaming: bool,
115
116 #[builder(private, default)]
119 pub(crate) tool_retries_counter: HashMap<u64, usize>,
120
121 #[builder(private, default)]
123 pub(crate) toolbox_tools: HashSet<Box<dyn Tool>>,
124}
125
126impl std::fmt::Debug for Agent {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 f.debug_struct("Agent")
129 .field(
131 "hooks",
132 &self
133 .hooks
134 .iter()
135 .map(std::string::ToString::to_string)
136 .collect::<Vec<_>>(),
137 )
138 .field(
139 "tools",
140 &self
141 .tools
142 .iter()
143 .map(swiftide_core::Tool::name)
144 .collect::<Vec<_>>(),
145 )
146 .field("llm", &"Box<dyn ChatCompletion>")
147 .field("state", &self.state)
148 .finish()
149 }
150}
151
152impl AgentBuilder {
153 pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
155 where
156 Self: Clone,
157 {
158 self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
159 self
160 }
161
162 pub fn no_system_prompt(&mut self) -> &mut Self {
164 self.system_prompt = Some(None);
165
166 self
167 }
168
169 pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
171 let hooks = self.hooks.get_or_insert_with(Vec::new);
172 hooks.push(hook);
173
174 self
175 }
176
177 pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
180 self.add_hook(Hook::BeforeAll(Box::new(hook)))
181 }
182
183 pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
187 self.add_hook(Hook::OnStart(Box::new(hook)))
188 }
189
190 pub fn on_stream(&mut self, hook: impl OnStreamFn + 'static) -> &mut Self {
197 self.streaming = Some(true);
198 self.add_hook(Hook::OnStream(Box::new(hook)))
199 }
200
201 pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
203 self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
204 }
205
206 pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
212 self.add_hook(Hook::AfterTool(Box::new(hook)))
213 }
214
215 pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
217 self.add_hook(Hook::BeforeTool(Box::new(hook)))
218 }
219
220 pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
222 self.add_hook(Hook::AfterCompletion(Box::new(hook)))
223 }
224
225 pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
228 self.add_hook(Hook::AfterEach(Box::new(hook)))
229 }
230
231 pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
234 self.add_hook(Hook::OnNewMessage(Box::new(hook)))
235 }
236
237 pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
238 self.add_hook(Hook::OnStop(Box::new(hook)))
239 }
240
241 pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
243 let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
244
245 self.llm = Some(boxed);
246 self
247 }
248
249 pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
254 where
255 TOOL: Into<Box<dyn Tool>>,
256 {
257 self.tools = Some(
258 tools
259 .into_iter()
260 .map(Into::into)
261 .chain(Agent::default_tools())
262 .collect(),
263 );
264 self
265 }
266
267 pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
273 let toolboxes = self.toolboxes.get_or_insert_with(Vec::new);
274 toolboxes.push(Box::new(toolbox));
275
276 self
277 }
278}
279
280impl Agent {
281 pub fn builder() -> AgentBuilder {
283 AgentBuilder::default()
284 .tools(Agent::default_tools())
285 .to_owned()
286 }
287
288 pub fn default_tools() -> HashSet<Box<dyn Tool>> {
291 HashSet::from([Stop::default().boxed()])
292 }
293
294 #[tracing::instrument(skip_all, name = "agent.query")]
297 pub async fn query(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
298 let query = query
299 .into()
300 .render()
301 .map_err(AgentError::FailedToRenderPrompt)?;
302 self.run_agent(Some(query), false).await
303 }
304
305 #[tracing::instrument(skip_all, name = "agent.query_once")]
307 pub async fn query_once(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
308 let query = query
309 .into()
310 .render()
311 .map_err(AgentError::FailedToRenderPrompt)?;
312 self.run_agent(Some(query), true).await
313 }
314
315 #[tracing::instrument(skip_all, name = "agent.run")]
318 pub async fn run(&mut self) -> Result<(), AgentError> {
319 self.run_agent(None, false).await
320 }
321
322 #[tracing::instrument(skip_all, name = "agent.run_once")]
325 pub async fn run_once(&mut self) -> Result<(), AgentError> {
326 self.run_agent(None, true).await
327 }
328
329 pub async fn history(&self) -> Result<Vec<ChatMessage>, AgentError> {
336 self.context
337 .history()
338 .await
339 .map_err(AgentError::MessageHistoryError)
340 }
341
342 async fn run_agent(
343 &mut self,
344 maybe_query: Option<String>,
345 just_once: bool,
346 ) -> Result<(), AgentError> {
347 if self.state.is_running() {
348 return Err(AgentError::AlreadyRunning);
349 }
350
351 if self.state.is_pending() {
352 if let Some(system_prompt) = &self.system_prompt {
353 self.context
354 .add_messages(vec![ChatMessage::System(
355 system_prompt
356 .render()
357 .map_err(AgentError::FailedToRenderSystemPrompt)?,
358 )])
359 .await
360 .map_err(AgentError::MessageHistoryError)?;
361 }
362
363 invoke_hooks!(BeforeAll, self);
364
365 self.load_toolboxes().await?;
366 }
367
368 invoke_hooks!(OnStart, self);
369
370 self.state = state::State::Running;
371
372 if let Some(query) = maybe_query {
373 self.context
374 .add_message(ChatMessage::User(query))
375 .await
376 .map_err(AgentError::MessageHistoryError)?;
377 }
378
379 let mut loop_counter = 0;
380
381 while let Some(messages) = self
382 .context
383 .next_completion()
384 .await
385 .map_err(AgentError::MessageHistoryError)?
386 {
387 if let Some(limit) = self.limit {
388 if loop_counter >= limit {
389 tracing::warn!("Agent loop limit reached");
390 break;
391 }
392 }
393
394 if let Some(&ChatMessage::Assistant(.., Some(ref tool_calls))) =
397 maybe_tool_call_without_output(&messages)
398 {
399 tracing::debug!("Uncompleted tool calls found; invoking tools");
400 self.invoke_tools(tool_calls).await?;
401 continue;
403 }
404
405 let result = self.run_completions(&messages).await;
406
407 if let Err(err) = result {
408 self.stop_with_error(&err).await;
409 tracing::error!(error = ?err, "Agent stopped with error {err}");
410 return Err(err);
411 }
412
413 if just_once || self.state.is_stopped() {
414 break;
415 }
416 loop_counter += 1;
417 }
418
419 self.stop(StopReason::NoNewMessages).await;
421
422 Ok(())
423 }
424
425 #[tracing::instrument(skip_all, err)]
426 async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<(), AgentError> {
427 debug!(
428 tools = ?self
429 .tools
430 .iter()
431 .map(|t| t.name())
432 .collect::<Vec<_>>()
433 ,
434 "Running completion for agent with {} new messages",
435 messages.len()
436 );
437
438 let mut chat_completion_request = ChatCompletionRequest::builder()
439 .messages(messages)
440 .tools_spec(
441 self.tools
442 .iter()
443 .map(swiftide_core::Tool::tool_spec)
444 .collect::<HashSet<_>>(),
445 )
446 .build()
447 .map_err(AgentError::FailedToBuildRequest)?;
448
449 invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
450
451 debug!(
452 "Calling LLM with the following new messages:\n {}",
453 self.context
454 .current_new_messages()
455 .await
456 .map_err(AgentError::MessageHistoryError)?
457 .iter()
458 .map(ToString::to_string)
459 .collect::<Vec<_>>()
460 .join(",\n")
461 );
462
463 let mut response = if self.streaming {
464 let mut last_response = None;
465 let mut stream = self.llm.complete_stream(&chat_completion_request).await;
466
467 while let Some(response) = stream.next().await {
468 let response = response.map_err(AgentError::CompletionsFailed)?;
469 invoke_hooks!(OnStream, self, &response);
470 last_response = Some(response);
471 }
472 tracing::trace!(?last_response, "Streaming completed");
473 last_response.ok_or(AgentError::EmptyStream)
474 } else {
475 self.llm
476 .complete(&chat_completion_request)
477 .await
478 .map_err(AgentError::CompletionsFailed)
479 }?;
480
481 response
484 .tool_calls
485 .as_deref_mut()
486 .map(ArgPreprocessor::preprocess_tool_calls);
487
488 invoke_hooks!(AfterCompletion, self, &mut response);
489
490 self.add_message(ChatMessage::Assistant(
491 response.message,
492 response.tool_calls.clone(),
493 ))
494 .await?;
495
496 if let Some(tool_calls) = response.tool_calls {
497 self.invoke_tools(&tool_calls).await?;
498 }
499
500 invoke_hooks!(AfterEach, self);
501
502 Ok(())
503 }
504
505 async fn invoke_tools(&mut self, tool_calls: &[ToolCall]) -> Result<(), AgentError> {
506 tracing::debug!("LLM returned tool calls: {:?}", tool_calls);
507
508 let mut handles = vec![];
509 for tool_call in tool_calls {
510 let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
511 tracing::warn!("Tool {} not found", tool_call.name());
512 continue;
513 };
514 tracing::info!("Calling tool `{}`", tool_call.name());
515
516 let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
518
519 invoke_hooks!(BeforeTool, self, &tool_call);
520
521 let tool_span = tracing::info_span!(
522 "tool",
523 "otel.name" = format!("tool.{}", tool.name().as_ref())
524 );
525
526 let handle_tool_call = tool_call.clone();
527 let handle = tokio::spawn(async move {
528 let handle_tool_call = handle_tool_call;
529 let output = tool.invoke(&*context, &handle_tool_call)
530 .await
531 .map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
532
533 tracing::debug!(output = output.to_string(), args = ?handle_tool_call.args(), tool_name = tool.name().as_ref(), "Completed tool call");
534
535 Ok(output)
536 }.instrument(tool_span.or_current()));
537
538 handles.push((handle, tool_call));
539 }
540
541 for (handle, tool_call) in handles {
542 let mut output = handle.await.map_err(AgentError::ToolFailedToJoin)?;
543
544 invoke_hooks!(AfterTool, self, &tool_call, &mut output);
545
546 if let Err(error) = output {
547 let stop = self.tool_calls_over_limit(tool_call);
548 if stop {
549 tracing::error!(
550 ?error,
551 "Tool call failed, retry limit reached, stopping agent: {error}",
552 );
553 } else {
554 tracing::warn!(
555 ?error,
556 tool_call = ?tool_call,
557 "Tool call failed, retrying",
558 );
559 }
560 self.add_message(ChatMessage::ToolOutput(
561 tool_call.clone(),
562 ToolOutput::Fail(error.to_string()),
563 ))
564 .await?;
565 if stop {
566 self.stop(StopReason::ToolCallsOverLimit(tool_call.to_owned()))
567 .await;
568 return Err(error.into());
569 }
570 continue;
571 }
572
573 let output = output?;
574 self.handle_control_tools(tool_call, &output).await;
575
576 if !output.is_feedback_required() {
580 self.add_message(ChatMessage::ToolOutput(tool_call.to_owned(), output))
581 .await?;
582 }
583 }
584
585 Ok(())
586 }
587
588 fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
589 self.hooks
590 .iter()
591 .filter(|h| hook_type == (*h).into())
592 .collect()
593 }
594
595 fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
596 self.tools
597 .iter()
598 .find(|tool| tool.name() == tool_name)
599 .cloned()
600 }
601
602 async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
604 match output {
605 ToolOutput::Stop => {
606 tracing::warn!("Stop tool called, stopping agent");
607 self.stop(StopReason::RequestedByTool(tool_call.clone()))
608 .await;
609 }
610
611 ToolOutput::FeedbackRequired(maybe_payload) => {
612 tracing::warn!("Feedback required, stopping agent");
613 self.stop(StopReason::FeedbackRequired {
614 tool_call: tool_call.clone(),
615 payload: maybe_payload.clone(),
616 })
617 .await;
618 }
619 _ => (),
620 }
621 }
622
623 fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
624 let mut s = DefaultHasher::new();
625 tool_call.hash(&mut s);
626 let hash = s.finish();
627
628 if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
629 let val = *retries >= self.tool_retry_limit;
630 *retries += 1;
631 val
632 } else {
633 self.tool_retries_counter.insert(hash, 1);
634 false
635 }
636 }
637
638 #[tracing::instrument(skip_all, fields(message = message.to_string()))]
644 pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
645 invoke_hooks!(OnNewMessage, self, &mut message);
646
647 self.context
648 .add_message(message)
649 .await
650 .map_err(AgentError::MessageHistoryError)?;
651 Ok(())
652 }
653
654 pub async fn stop(&mut self, reason: impl Into<StopReason>) {
656 if self.state.is_stopped() {
657 return;
658 }
659 let reason = reason.into();
660 invoke_hooks!(OnStop, self, reason.clone(), None);
661
662 self.state = state::State::Stopped(reason);
663 }
664
665 pub async fn stop_with_error(&mut self, error: &AgentError) {
666 if self.state.is_stopped() {
667 return;
668 }
669 invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
670
671 self.state = state::State::Stopped(StopReason::Error);
672 }
673
674 pub fn context(&self) -> &dyn AgentContext {
676 &self.context
677 }
678
679 pub fn is_running(&self) -> bool {
681 self.state.is_running()
682 }
683
684 pub fn is_stopped(&self) -> bool {
686 self.state.is_stopped()
687 }
688
689 pub fn is_pending(&self) -> bool {
691 self.state.is_pending()
692 }
693
694 pub fn tools(&self) -> &HashSet<Box<dyn Tool>> {
696 &self.tools
697 }
698
699 pub fn state(&self) -> &state::State {
700 &self.state
701 }
702
703 pub fn stop_reason(&self) -> Option<&StopReason> {
704 self.state.stop_reason()
705 }
706
707 async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
708 for toolbox in &self.toolboxes {
709 let tools = toolbox
710 .available_tools()
711 .await
712 .map_err(AgentError::ToolBoxFailedToLoad)?;
713 self.toolbox_tools.extend(tools);
714 }
715
716 self.tools.extend(self.toolbox_tools.clone());
717
718 Ok(())
719 }
720}
721
722fn maybe_tool_call_without_output(messages: &[ChatMessage]) -> Option<&ChatMessage> {
725 for message in messages.iter().rev() {
726 if let ChatMessage::ToolOutput(..) = message {
727 return None;
728 }
729
730 if let ChatMessage::Assistant(.., Some(tool_calls)) = message {
731 if !tool_calls.is_empty() {
732 return Some(message);
733 }
734 }
735 }
736
737 None
738}
739
740#[cfg(test)]
741mod tests {
742
743 use serde::ser::Error;
744 use swiftide_core::ToolFeedback;
745 use swiftide_core::chat_completion::errors::ToolError;
746 use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
747 use swiftide_core::test_utils::MockChatCompletion;
748
749 use super::*;
750 use crate::{
751 State, assistant, chat_request, chat_response, summary, system, tool_failed, tool_output,
752 user,
753 };
754
755 use crate::test_utils::{MockHook, MockTool};
756
757 #[test_log::test(tokio::test)]
758 async fn test_agent_builder_defaults() {
759 let mock_llm = MockChatCompletion::new();
761
762 let agent = Agent::builder().llm(&mock_llm).build().unwrap();
764
765 assert!(agent.find_tool_by_name("stop").is_some());
769
770 let agent = Agent::builder()
772 .tools([Stop::default(), Stop::default()])
773 .llm(&mock_llm)
774 .build()
775 .unwrap();
776
777 assert_eq!(agent.tools.len(), 1);
778
779 let agent = Agent::builder()
781 .tools([MockTool::new("mock_tool")])
782 .llm(&mock_llm)
783 .build()
784 .unwrap();
785
786 assert_eq!(agent.tools.len(), 2);
787 assert!(agent.find_tool_by_name("mock_tool").is_some());
788 assert!(agent.find_tool_by_name("stop").is_some());
789
790 assert!(agent.context().history().await.unwrap().is_empty());
791 }
792
793 #[test_log::test(tokio::test)]
794 async fn test_agent_tool_calling_loop() {
795 let prompt = "Write a poem";
796 let mock_llm = MockChatCompletion::new();
797 let mock_tool = MockTool::new("mock_tool");
798
799 let chat_request = chat_request! {
800 user!("Write a poem");
801
802 tools = [mock_tool.clone()]
803 };
804
805 let mock_tool_response = chat_response! {
806 "Roses are red";
807 tool_calls = ["mock_tool"]
808
809 };
810
811 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
812
813 let chat_request = chat_request! {
814 user!("Write a poem"),
815 assistant!("Roses are red", ["mock_tool"]),
816 tool_output!("mock_tool", "Great!");
817
818 tools = [mock_tool.clone()]
819 };
820
821 let stop_response = chat_response! {
822 "Roses are red";
823 tool_calls = ["stop"]
824 };
825
826 mock_llm.expect_complete(chat_request, Ok(stop_response));
827 mock_tool.expect_invoke_ok("Great!".into(), None);
828
829 let mut agent = Agent::builder()
830 .tools([mock_tool])
831 .llm(&mock_llm)
832 .no_system_prompt()
833 .build()
834 .unwrap();
835
836 agent.query(prompt).await.unwrap();
837 }
838
839 #[test_log::test(tokio::test)]
840 async fn test_agent_tool_run_once() {
841 let prompt = "Write a poem";
842 let mock_llm = MockChatCompletion::new();
843 let mock_tool = MockTool::default();
844
845 let chat_request = chat_request! {
846 system!("My system prompt"),
847 user!("Write a poem");
848
849 tools = [mock_tool.clone()]
850 };
851
852 let mock_tool_response = chat_response! {
853 "Roses are red";
854 tool_calls = ["mock_tool"]
855
856 };
857
858 mock_tool.expect_invoke_ok("Great!".into(), None);
859 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
860
861 let mut agent = Agent::builder()
862 .tools([mock_tool])
863 .system_prompt("My system prompt")
864 .llm(&mock_llm)
865 .build()
866 .unwrap();
867
868 agent.query_once(prompt).await.unwrap();
869 }
870
871 #[test_log::test(tokio::test)]
872 async fn test_agent_tool_via_toolbox_run_once() {
873 let prompt = "Write a poem";
874 let mock_llm = MockChatCompletion::new();
875 let mock_tool = MockTool::default();
876
877 let chat_request = chat_request! {
878 system!("My system prompt"),
879 user!("Write a poem");
880
881 tools = [mock_tool.clone()]
882 };
883
884 let mock_tool_response = chat_response! {
885 "Roses are red";
886 tool_calls = ["mock_tool"]
887
888 };
889
890 mock_tool.expect_invoke_ok("Great!".into(), None);
891 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
892
893 let mut agent = Agent::builder()
894 .add_toolbox(vec![mock_tool.boxed()])
895 .system_prompt("My system prompt")
896 .llm(&mock_llm)
897 .build()
898 .unwrap();
899
900 agent.query_once(prompt).await.unwrap();
901 }
902
903 #[test_log::test(tokio::test(flavor = "multi_thread"))]
904 async fn test_multiple_tool_calls() {
905 let prompt = "Write a poem";
906 let mock_llm = MockChatCompletion::new();
907 let mock_tool = MockTool::new("mock_tool1");
908 let mock_tool2 = MockTool::new("mock_tool2");
909
910 let chat_request = chat_request! {
911 system!("My system prompt"),
912 user!("Write a poem");
913
914
915
916 tools = [mock_tool.clone(), mock_tool2.clone()]
917 };
918
919 let mock_tool_response = chat_response! {
920 "Roses are red";
921
922 tool_calls = ["mock_tool1", "mock_tool2"]
923
924 };
925
926 dbg!(&chat_request);
927 mock_tool.expect_invoke_ok("Great!".into(), None);
928 mock_tool2.expect_invoke_ok("Great!".into(), None);
929 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
930
931 let chat_request = chat_request! {
932 system!("My system prompt"),
933 user!("Write a poem"),
934 assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
935 tool_output!("mock_tool1", "Great!"),
936 tool_output!("mock_tool2", "Great!");
937
938 tools = [mock_tool.clone(), mock_tool2.clone()]
939 };
940
941 let mock_tool_response = chat_response! {
942 "Ok!";
943
944 tool_calls = ["stop"]
945
946 };
947
948 mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
949
950 let mut agent = Agent::builder()
951 .tools([mock_tool, mock_tool2])
952 .system_prompt("My system prompt")
953 .llm(&mock_llm)
954 .build()
955 .unwrap();
956
957 agent.query(prompt).await.unwrap();
958 }
959
960 #[test_log::test(tokio::test)]
961 async fn test_agent_state_machine() {
962 let prompt = "Write a poem";
963 let mock_llm = MockChatCompletion::new();
964
965 let chat_request = chat_request! {
966 user!("Write a poem");
967 tools = []
968 };
969 let mock_tool_response = chat_response! {
970 "Roses are red";
971 tool_calls = []
972 };
973
974 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
975 let mut agent = Agent::builder()
976 .llm(&mock_llm)
977 .no_system_prompt()
978 .build()
979 .unwrap();
980
981 assert!(agent.state.is_pending());
983 agent.query_once(prompt).await.unwrap();
984
985 assert!(agent.state.is_stopped());
987 }
988
989 #[test_log::test(tokio::test)]
990 async fn test_summary() {
991 let prompt = "Write a poem";
992 let mock_llm = MockChatCompletion::new();
993
994 let mock_tool_response = chat_response! {
995 "Roses are red";
996 tool_calls = []
997
998 };
999
1000 let expected_chat_request = chat_request! {
1001 system!("My system prompt"),
1002 user!("Write a poem");
1003
1004 tools = []
1005 };
1006
1007 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1008
1009 let mut agent = Agent::builder()
1010 .system_prompt("My system prompt")
1011 .llm(&mock_llm)
1012 .build()
1013 .unwrap();
1014
1015 agent.query_once(prompt).await.unwrap();
1016
1017 agent
1018 .context
1019 .add_message(ChatMessage::new_summary("Summary"))
1020 .await
1021 .unwrap();
1022
1023 let expected_chat_request = chat_request! {
1024 system!("My system prompt"),
1025 summary!("Summary"),
1026 user!("Write another poem");
1027 tools = []
1028 };
1029 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1030
1031 agent.query_once("Write another poem").await.unwrap();
1032
1033 agent
1034 .context
1035 .add_message(ChatMessage::new_summary("Summary 2"))
1036 .await
1037 .unwrap();
1038
1039 let expected_chat_request = chat_request! {
1040 system!("My system prompt"),
1041 summary!("Summary 2"),
1042 user!("Write a third poem");
1043 tools = []
1044 };
1045 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
1046
1047 agent.query_once("Write a third poem").await.unwrap();
1048 }
1049
1050 #[test_log::test(tokio::test)]
1051 async fn test_agent_hooks() {
1052 let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
1053 let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
1054 let mock_before_completion = MockHook::new("before_completion")
1055 .expect_calls(2)
1056 .to_owned();
1057 let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
1058 let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
1059 let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
1060 let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
1061
1062 let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
1064 let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
1065
1066 let prompt = "Write a poem";
1067 let mock_llm = MockChatCompletion::new();
1068 let mock_tool = MockTool::default();
1069
1070 let chat_request = chat_request! {
1071 user!("Write a poem");
1072
1073 tools = [mock_tool.clone()]
1074 };
1075
1076 let mock_tool_response = chat_response! {
1077 "Roses are red";
1078 tool_calls = ["mock_tool"]
1079
1080 };
1081
1082 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1083
1084 let chat_request = chat_request! {
1085 user!("Write a poem"),
1086 assistant!("Roses are red", ["mock_tool"]),
1087 tool_output!("mock_tool", "Great!");
1088
1089 tools = [mock_tool.clone()]
1090 };
1091
1092 let stop_response = chat_response! {
1093 "Roses are red";
1094 tool_calls = ["stop"]
1095 };
1096
1097 mock_llm.expect_complete(chat_request, Ok(stop_response));
1098 mock_tool.expect_invoke_ok("Great!".into(), None);
1099
1100 let mut agent = Agent::builder()
1101 .tools([mock_tool])
1102 .llm(&mock_llm)
1103 .no_system_prompt()
1104 .before_all(mock_before_all.hook_fn())
1105 .on_start(mock_on_start_fn.on_start_fn())
1106 .before_completion(mock_before_completion.before_completion_fn())
1107 .before_tool(mock_before_tool.before_tool_fn())
1108 .after_completion(mock_after_completion.after_completion_fn())
1109 .after_tool(mock_after_tool.after_tool_fn())
1110 .after_each(mock_after_each.hook_fn())
1111 .on_new_message(mock_on_message.message_hook_fn())
1112 .on_stop(mock_on_stop.stop_hook_fn())
1113 .build()
1114 .unwrap();
1115
1116 agent.query(prompt).await.unwrap();
1117 }
1118
1119 #[test_log::test(tokio::test)]
1120 async fn test_agent_loop_limit() {
1121 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
1123 let mock_tool = MockTool::new("mock_tool");
1124
1125 let chat_request = chat_request! {
1126 user!(prompt);
1127 tools = [mock_tool.clone()]
1128 };
1129 mock_tool.expect_invoke_ok("Great!".into(), None);
1130
1131 let mock_tool_response = chat_response! {
1132 "Some response";
1133 tool_calls = ["mock_tool"]
1134 };
1135
1136 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1138
1139 let stop_response = chat_response! {
1141 "Final response";
1142 tool_calls = ["stop"]
1143 };
1144
1145 mock_llm.expect_complete(chat_request, Ok(stop_response));
1146
1147 let mut agent = Agent::builder()
1148 .tools([mock_tool])
1149 .llm(&mock_llm)
1150 .no_system_prompt()
1151 .limit(1) .build()
1153 .unwrap();
1154
1155 agent.query(prompt).await.unwrap();
1157
1158 let remaining = mock_llm.expectations.lock().unwrap().pop();
1160 assert!(remaining.is_some());
1161
1162 assert!(agent.is_stopped());
1164 }
1165
1166 #[test_log::test(tokio::test)]
1167 async fn test_tool_retry_mechanism() {
1168 let prompt = "Execute my tool";
1169 let mock_llm = MockChatCompletion::new();
1170 let mock_tool = MockTool::new("retry_tool");
1171
1172 mock_tool.expect_invoke(
1175 Err(ToolError::WrongArguments(serde_json::Error::custom(
1176 "missing `query`",
1177 ))),
1178 None,
1179 );
1180 mock_tool.expect_invoke(
1181 Err(ToolError::WrongArguments(serde_json::Error::custom(
1182 "missing `query`",
1183 ))),
1184 None,
1185 );
1186
1187 let chat_request = chat_request! {
1188 user!(prompt);
1189 tools = [mock_tool.clone()]
1190 };
1191 let retry_response = chat_response! {
1192 "First failing attempt";
1193 tool_calls = ["retry_tool"]
1194 };
1195 mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1196
1197 let chat_request = chat_request! {
1198 user!(prompt),
1199 assistant!("First failing attempt", ["retry_tool"]),
1200 tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1201
1202 tools = [mock_tool.clone()]
1203 };
1204 let will_fail_response = chat_response! {
1205 "Finished execution";
1206 tool_calls = ["retry_tool"]
1207 };
1208 mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1209
1210 let mut agent = Agent::builder()
1211 .tools([mock_tool])
1212 .llm(&mock_llm)
1213 .no_system_prompt()
1214 .tool_retry_limit(1) .build()
1216 .unwrap();
1217
1218 let result = agent.query(prompt).await;
1220
1221 assert!(result.is_err());
1222 assert!(result.unwrap_err().to_string().contains("missing `query`"));
1223 assert!(agent.is_stopped());
1224 }
1225
1226 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1227 async fn test_streaming() {
1228 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
1230 let on_stream_fn = MockHook::new("on_stream").expect_calls(3).to_owned();
1231
1232 let chat_request = chat_request! {
1233 user!(prompt);
1234
1235 tools = []
1236 };
1237
1238 let response = chat_response! {
1239 "one two three";
1240 tool_calls = ["stop"]
1241 };
1242
1243 mock_llm.expect_complete(chat_request, Ok(response));
1245
1246 let mut agent = Agent::builder()
1247 .llm(&mock_llm)
1248 .on_stream(on_stream_fn.on_stream_fn())
1249 .no_system_prompt()
1250 .build()
1251 .unwrap();
1252
1253 agent.query(prompt).await.unwrap();
1255
1256 tracing::debug!("Agent finished running");
1257
1258 assert!(agent.is_stopped());
1260 }
1261
1262 #[test_log::test(tokio::test)]
1263 async fn test_recovering_agent_existing_history() {
1264 let prompt = "Write a poem";
1266 let mock_llm = MockChatCompletion::new();
1267 let mock_tool = MockTool::new("mock_tool");
1268
1269 let chat_request = chat_request! {
1270 user!("Write a poem");
1271
1272 tools = [mock_tool.clone()]
1273 };
1274
1275 let mock_tool_response = chat_response! {
1276 "Roses are red";
1277 tool_calls = ["mock_tool"]
1278
1279 };
1280
1281 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1282
1283 let chat_request = chat_request! {
1284 user!("Write a poem"),
1285 assistant!("Roses are red", ["mock_tool"]),
1286 tool_output!("mock_tool", "Great!");
1287
1288 tools = [mock_tool.clone()]
1289 };
1290
1291 let stop_response = chat_response! {
1292 "Roses are red";
1293 tool_calls = ["stop"]
1294 };
1295
1296 mock_llm.expect_complete(chat_request, Ok(stop_response));
1297 mock_tool.expect_invoke_ok("Great!".into(), None);
1298
1299 let mut agent = Agent::builder()
1300 .tools([mock_tool.clone()])
1301 .llm(&mock_llm)
1302 .no_system_prompt()
1303 .build()
1304 .unwrap();
1305
1306 agent.query(prompt).await.unwrap();
1307
1308 let history = agent.history().await.unwrap();
1310
1311 let serialized = serde_json::to_string(&history).unwrap();
1313
1314 let history: Vec<ChatMessage> = serde_json::from_str(&serialized).unwrap();
1316
1317 let context = DefaultContext::default()
1319 .with_existing_messages(history)
1320 .await
1321 .unwrap()
1322 .to_owned();
1323
1324 let expected_chat_request = chat_request! {
1325 user!("Write a poem"),
1326 assistant!("Roses are red", ["mock_tool"]),
1327 tool_output!("mock_tool", "Great!"),
1328 assistant!("Roses are red", ["stop"]),
1329 tool_output!("stop", ToolOutput::Stop),
1330 user!("Try again!");
1331
1332 tools = [mock_tool.clone()]
1333 };
1334
1335 let stop_response = chat_response! {
1336 "Really stopping now";
1337 tool_calls = ["stop"]
1338 };
1339
1340 mock_llm.expect_complete(expected_chat_request, Ok(stop_response));
1341
1342 let mut agent = Agent::builder()
1343 .context(context)
1344 .tools([mock_tool])
1345 .llm(&mock_llm)
1346 .no_system_prompt()
1347 .build()
1348 .unwrap();
1349
1350 agent.query_once("Try again!").await.unwrap();
1351 }
1352
1353 #[test_log::test(tokio::test)]
1354 async fn test_agent_with_approval_required_tool() {
1355 use super::*;
1356 use crate::tools::control::ApprovalRequired;
1357 use crate::{assistant, chat_request, chat_response, user};
1358 use swiftide_core::chat_completion::ToolCall;
1359
1360 let mock_tool = MockTool::default();
1362 mock_tool.expect_invoke_ok("Great!".into(), None);
1363
1364 let approval_tool = ApprovalRequired(mock_tool.boxed());
1365
1366 let mock_llm = MockChatCompletion::new();
1368
1369 let chat_req1 = chat_request! {
1370 user!("Request with approval");
1371 tools = [approval_tool.clone()]
1372 };
1373 let chat_resp1 = chat_response! {
1374 "Completion message";
1375 tool_calls = ["mock_tool"]
1376 };
1377 mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1378
1379 let chat_req2 = chat_request! {
1382 user!("Request with approval"),
1383 assistant!("Completion message", ["mock_tool"]),
1384 tool_output!("mock_tool", "Great!");
1385 tools = [approval_tool.clone()]
1387 };
1388 let chat_resp2 = chat_response! {
1389 "Post-feedback message";
1390 tool_calls = ["stop"]
1391 };
1392 mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1393
1394 let mut agent = Agent::builder()
1396 .tools([approval_tool])
1397 .llm(&mock_llm)
1398 .no_system_prompt()
1399 .build()
1400 .unwrap();
1401
1402 agent.query_once("Request with approval").await.unwrap();
1404
1405 assert!(matches!(
1406 agent.state,
1407 crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1408 ));
1409
1410 let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1411 else {
1412 panic!("Expected feedback required");
1413 };
1414
1415 agent
1417 .context
1418 .feedback_received(&tool_call, &ToolFeedback::approved())
1419 .await
1420 .unwrap();
1421
1422 tracing::debug!("running after approval");
1423 agent.run_once().await.unwrap();
1424 assert!(agent.is_stopped());
1425 }
1426
1427 #[test_log::test(tokio::test)]
1428 async fn test_agent_with_approval_required_tool_denied() {
1429 use super::*;
1430 use crate::tools::control::ApprovalRequired;
1431 use crate::{assistant, chat_request, chat_response, user};
1432 use swiftide_core::chat_completion::ToolCall;
1433
1434 let mock_tool = MockTool::default();
1436
1437 let approval_tool = ApprovalRequired(mock_tool.boxed());
1438
1439 let mock_llm = MockChatCompletion::new();
1441
1442 let chat_req1 = chat_request! {
1443 user!("Request with approval");
1444 tools = [approval_tool.clone()]
1445 };
1446 let chat_resp1 = chat_response! {
1447 "Completion message";
1448 tool_calls = ["mock_tool"]
1449 };
1450 mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1451
1452 let chat_req2 = chat_request! {
1455 user!("Request with approval"),
1456 assistant!("Completion message", ["mock_tool"]),
1457 tool_output!("mock_tool", "This tool call was refused");
1458 tools = [approval_tool.clone()]
1460 };
1461 let chat_resp2 = chat_response! {
1462 "Post-feedback message";
1463 tool_calls = ["stop"]
1464 };
1465 mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1466
1467 let mut agent = Agent::builder()
1469 .tools([approval_tool])
1470 .llm(&mock_llm)
1471 .no_system_prompt()
1472 .build()
1473 .unwrap();
1474
1475 agent.query_once("Request with approval").await.unwrap();
1477
1478 assert!(matches!(
1479 agent.state,
1480 crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1481 ));
1482
1483 let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1484 else {
1485 panic!("Expected feedback required");
1486 };
1487
1488 agent
1490 .context
1491 .feedback_received(&tool_call, &ToolFeedback::refused())
1492 .await
1493 .unwrap();
1494
1495 tracing::debug!("running after approval");
1496 agent.run_once().await.unwrap();
1497
1498 let history = agent.context().history().await.unwrap();
1499 history
1500 .iter()
1501 .rfind(|m| {
1502 let ChatMessage::ToolOutput(.., ToolOutput::Text(msg)) = m else {
1503 return false;
1504 };
1505 msg.contains("refused")
1506 })
1507 .expect("Could not find refusal message");
1508
1509 assert!(agent.is_stopped());
1510 }
1511}