swiftide_agents/
test_utils.rs

1use std::borrow::Cow;
2use std::sync::{Arc, Mutex};
3
4use async_trait::async_trait;
5use swiftide_core::chat_completion::ToolCall;
6use swiftide_core::chat_completion::{Tool, ToolOutput, ToolSpec, errors::ToolError};
7
8use swiftide_core::AgentContext;
9
10use crate::Agent;
11use crate::hooks::{
12    AfterCompletionFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn, MessageHookFn,
13    OnStartFn, OnStopFn, OnStreamFn,
14};
15
16#[macro_export]
17macro_rules! chat_request {
18    ($($message:expr),+; tools = [$($tool:expr),*]) => {
19        swiftide_core::chat_completion::ChatCompletionRequest::builder()
20            .messages(vec![$($message),*])
21            .tools_spec(
22                vec![$(Box::new($tool) as Box<dyn Tool>),*]
23                    .into_iter()
24                    .chain(Agent::default_tools())
25                    .map(|tool| tool.tool_spec())
26                    .collect::<std::collections::HashSet<_>>(),
27            )
28            .build()
29            .unwrap()
30    };
31    ($($message:expr),+; tool_specs = [$($tool:expr),*]) => {
32        swiftide_core::chat_completion::ChatCompletionRequest::builder()
33            .messages(vec![$($message),*])
34            .tools_spec(
35                vec![$(($tool)),*]
36                    .into_iter()
37                    .chain(Agent::default_tools().into_iter().map(|tool| tool.tool_spec()))
38                    .collect::<std::collections::HashSet<_>>(),
39            )
40            .build()
41            .unwrap()
42    }
43}
44
45#[macro_export]
46macro_rules! user {
47    ($message:expr) => {
48        swiftide_core::chat_completion::ChatMessage::User($message.to_string())
49    };
50}
51
52#[macro_export]
53macro_rules! system {
54    ($message:expr) => {
55        swiftide_core::chat_completion::ChatMessage::System($message.to_string())
56    };
57}
58
59#[macro_export]
60macro_rules! summary {
61    ($message:expr) => {
62        swiftide_core::chat_completion::ChatMessage::Summary($message.to_string())
63    };
64}
65
66#[macro_export]
67macro_rules! assistant {
68    ($message:expr) => {
69        swiftide_core::chat_completion::ChatMessage::Assistant(Some($message.to_string()), None)
70    };
71    ($message:expr, [$($tool_call_name:expr),*]) => {{
72        let tool_calls = vec![
73            $(
74            ToolCall::builder()
75                .name($tool_call_name)
76                .id("1")
77                .build()
78                .unwrap()
79            ),*
80        ];
81
82        ChatMessage::Assistant(Some($message.to_string()), Some(tool_calls))
83    }};
84}
85
86#[macro_export]
87macro_rules! tool_output {
88    ($tool_name:expr, $message:expr) => {{
89        ChatMessage::ToolOutput(
90            ToolCall::builder()
91                .name($tool_name)
92                .id("1")
93                .build()
94                .unwrap(),
95            $message.into(),
96        )
97    }};
98}
99
100#[macro_export]
101macro_rules! tool_failed {
102    ($tool_name:expr, $message:expr) => {{
103        ChatMessage::ToolOutput(
104            ToolCall::builder()
105                .name($tool_name)
106                .id("1")
107                .build()
108                .unwrap(),
109            ToolOutput::fail($message),
110        )
111    }};
112}
113
114#[macro_export]
115macro_rules! chat_response {
116    ($message:expr; tool_calls = [$($tool_name:expr),*]) => {{
117
118        let tool_calls = vec![
119            $(ToolCall::builder().name($tool_name).id("1").build().unwrap()),*
120        ];
121
122        ChatCompletionResponse::builder()
123            .message($message)
124            .tool_calls(tool_calls)
125            .build()
126            .unwrap()
127    }};
128    (tool_calls = [$($tool_name:expr),*]) => {{
129
130        let tool_calls = vec![
131            $(ToolCall::builder().name($tool_name).id("1").build().unwrap()),*
132        ];
133
134        ChatCompletionResponse::builder()
135            .tool_calls(tool_calls)
136            .build()
137            .unwrap()
138    }};
139}
140
141type Expectations = Arc<Mutex<Vec<(Result<ToolOutput, ToolError>, Option<&'static str>)>>>;
142
143#[derive(Debug, Clone)]
144pub struct MockTool {
145    expectations: Expectations,
146    name: &'static str,
147}
148
149impl MockTool {
150    #[allow(clippy::should_implement_trait)]
151    pub fn default() -> Self {
152        Self::new("mock_tool")
153    }
154    pub fn new(name: &'static str) -> Self {
155        Self {
156            expectations: Arc::new(Mutex::new(Vec::new())),
157            name,
158        }
159    }
160    pub fn expect_invoke_ok(
161        &self,
162        expected_result: ToolOutput,
163        expected_args: Option<&'static str>,
164    ) {
165        self.expect_invoke(Ok(expected_result), expected_args);
166    }
167
168    #[allow(clippy::missing_panics_doc)]
169    pub fn expect_invoke(
170        &self,
171        expected_result: Result<ToolOutput, ToolError>,
172        expected_args: Option<&'static str>,
173    ) {
174        self.expectations
175            .lock()
176            .unwrap()
177            .push((expected_result, expected_args));
178    }
179}
180
181#[async_trait]
182impl Tool for MockTool {
183    async fn invoke(
184        &self,
185        _agent_context: &dyn AgentContext,
186        tool_call: &ToolCall,
187    ) -> std::result::Result<ToolOutput, ToolError> {
188        tracing::debug!(
189            "[MockTool] Invoked `{}` with args: {:?}",
190            self.name,
191            tool_call
192        );
193        let expectation = self
194            .expectations
195            .lock()
196            .unwrap()
197            .pop()
198            .unwrap_or_else(|| panic!("[MockTool] No expectations left for `{}`", self.name));
199
200        assert_eq!(expectation.1, tool_call.args());
201
202        expectation.0
203    }
204
205    fn name(&self) -> Cow<'_, str> {
206        self.name.into()
207    }
208
209    fn tool_spec(&self) -> ToolSpec {
210        ToolSpec::builder()
211            .name(self.name().as_ref())
212            .description("A fake tool for testing purposes")
213            .build()
214            .unwrap()
215    }
216}
217
218impl From<MockTool> for Box<dyn Tool> {
219    fn from(val: MockTool) -> Self {
220        Box::new(val) as Box<dyn Tool>
221    }
222}
223
224impl Drop for MockTool {
225    fn drop(&mut self) {
226        // Mock still borrowed elsewhere and expectations still be invoked
227        if Arc::strong_count(&self.expectations) > 1 {
228            return;
229        }
230        if self.expectations.lock().is_err() {
231            return;
232        }
233
234        let name = self.name;
235        if self.expectations.lock().unwrap().is_empty() {
236            tracing::debug!("[MockTool] All expectations were met for `{name}`");
237        } else {
238            panic!(
239                "[MockTool] Not all expectations were met for `{name}: {:?}",
240                *self.expectations.lock().unwrap()
241            );
242        }
243    }
244}
245
246#[derive(Debug, Clone)]
247pub struct MockHook {
248    name: &'static str,
249    called: Arc<Mutex<usize>>,
250    expected_calls: usize,
251}
252
253impl MockHook {
254    pub fn new(name: &'static str) -> Self {
255        Self {
256            name,
257            called: Arc::new(Mutex::new(0)),
258            expected_calls: 0,
259        }
260    }
261
262    pub fn expect_calls(&mut self, expected_calls: usize) -> &mut Self {
263        self.expected_calls = expected_calls;
264        self
265    }
266
267    #[allow(clippy::missing_panics_doc)]
268    pub fn hook_fn(&self) -> impl BeforeAllFn + use<> {
269        let called = Arc::clone(&self.called);
270        move |_: &Agent| {
271            let called = Arc::clone(&called);
272            Box::pin(async move {
273                let mut called = called.lock().unwrap();
274                *called += 1;
275                Ok(())
276            })
277        }
278    }
279
280    #[allow(clippy::missing_panics_doc)]
281    pub fn on_start_fn(&self) -> impl OnStartFn + use<> {
282        let called = Arc::clone(&self.called);
283        move |_: &Agent| {
284            let called = Arc::clone(&called);
285            Box::pin(async move {
286                let mut called = called.lock().unwrap();
287                *called += 1;
288                Ok(())
289            })
290        }
291    }
292    #[allow(clippy::missing_panics_doc)]
293    pub fn before_completion_fn(&self) -> impl BeforeCompletionFn + use<> {
294        let called = Arc::clone(&self.called);
295        move |_: &Agent, _| {
296            let called = Arc::clone(&called);
297            Box::pin(async move {
298                let mut called = called.lock().unwrap();
299                *called += 1;
300                Ok(())
301            })
302        }
303    }
304
305    #[allow(clippy::missing_panics_doc)]
306    pub fn after_completion_fn(&self) -> impl AfterCompletionFn + use<> {
307        let called = Arc::clone(&self.called);
308        move |_: &Agent, _| {
309            let called = Arc::clone(&called);
310            Box::pin(async move {
311                let mut called = called.lock().unwrap();
312                *called += 1;
313                Ok(())
314            })
315        }
316    }
317
318    #[allow(clippy::missing_panics_doc)]
319    pub fn after_tool_fn(&self) -> impl AfterToolFn + use<> {
320        let called = Arc::clone(&self.called);
321        move |_: &Agent, _, _| {
322            let called = Arc::clone(&called);
323            Box::pin(async move {
324                let mut called = called.lock().unwrap();
325                *called += 1;
326                Ok(())
327            })
328        }
329    }
330
331    #[allow(clippy::missing_panics_doc)]
332    pub fn before_tool_fn(&self) -> impl BeforeToolFn + use<> {
333        let called = Arc::clone(&self.called);
334        move |_: &Agent, _| {
335            let called = Arc::clone(&called);
336            Box::pin(async move {
337                let mut called = called.lock().unwrap();
338                *called += 1;
339                Ok(())
340            })
341        }
342    }
343
344    #[allow(clippy::missing_panics_doc)]
345    pub fn message_hook_fn(&self) -> impl MessageHookFn + use<> {
346        let called = Arc::clone(&self.called);
347        move |_: &Agent, _| {
348            let called = Arc::clone(&called);
349            Box::pin(async move {
350                let mut called = called.lock().unwrap();
351                *called += 1;
352                Ok(())
353            })
354        }
355    }
356
357    #[allow(clippy::missing_panics_doc)]
358    pub fn stop_hook_fn(&self) -> impl OnStopFn + use<> {
359        let called = Arc::clone(&self.called);
360        move |_: &Agent, _, _| {
361            let called = Arc::clone(&called);
362            Box::pin(async move {
363                let mut called = called.lock().unwrap();
364                *called += 1;
365                Ok(())
366            })
367        }
368    }
369
370    #[allow(clippy::missing_panics_doc)]
371    pub fn on_stream_fn(&self) -> impl OnStreamFn + use<> {
372        let called = Arc::clone(&self.called);
373        move |_: &Agent, _| {
374            let called = Arc::clone(&called);
375            Box::pin(async move {
376                let mut called = called.lock().unwrap();
377                *called += 1;
378                Ok(())
379            })
380        }
381    }
382}
383
384impl Drop for MockHook {
385    fn drop(&mut self) {
386        if Arc::strong_count(&self.called) > 1 {
387            return;
388        }
389        let Ok(called) = self.called.lock() else {
390            return;
391        };
392
393        if *called == self.expected_calls {
394            tracing::debug!(
395                "[MockHook] `{}` all expectations met; called {} times",
396                self.name,
397                *called
398            );
399        } else {
400            panic!(
401                "[MockHook] `{}` was called {} times but expected {}",
402                self.name, *called, self.expected_calls
403            )
404        }
405    }
406}