swiftide_agents/
test_utils.rs

1use std::borrow::Cow;
2use std::sync::{Arc, Mutex};
3
4use async_trait::async_trait;
5use swiftide_core::chat_completion::ToolCall;
6use swiftide_core::chat_completion::{Tool, ToolOutput, ToolSpec, errors::ToolError};
7
8use swiftide_core::AgentContext;
9
10use crate::Agent;
11use crate::hooks::{
12    AfterCompletionFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn, MessageHookFn,
13    OnStartFn, OnStopFn, OnStreamFn,
14};
15
16#[macro_export]
17macro_rules! chat_request {
18    ($($message:expr),+; tools = [$($tool:expr),*]) => {{
19        let mut builder = swiftide_core::chat_completion::ChatCompletionRequest::builder();
20        builder.messages(vec![$($message),*]);
21
22        let mut tool_specs = Vec::new();
23        $(tool_specs.push({
24            let tool = $tool;
25            tool.tool_spec()
26        });)*
27
28        tool_specs.extend(Agent::default_tools().into_iter().map(|tool| tool.tool_spec()));
29
30        builder.tool_specs(tool_specs);
31
32        builder.build().unwrap()
33    }};
34    ($($message:expr),+; tool_specs = [$($tool:expr),*]) => {{
35        let mut builder = swiftide_core::chat_completion::ChatCompletionRequest::builder();
36        builder.messages(vec![$($message),*]);
37
38        let mut tool_specs = Vec::new();
39        $(tool_specs.push($tool);)*
40        tool_specs.extend(Agent::default_tools().into_iter().map(|tool| tool.tool_spec()));
41
42        builder.tool_specs(tool_specs);
43
44        builder.build().unwrap()
45    }}
46}
47
48#[macro_export]
49macro_rules! user {
50    ($message:expr) => {
51        swiftide_core::chat_completion::ChatMessage::User($message.to_string())
52    };
53}
54
55#[macro_export]
56macro_rules! system {
57    ($message:expr) => {
58        swiftide_core::chat_completion::ChatMessage::System($message.to_string())
59    };
60}
61
62#[macro_export]
63macro_rules! summary {
64    ($message:expr) => {
65        swiftide_core::chat_completion::ChatMessage::Summary($message.to_string())
66    };
67}
68
69#[macro_export]
70macro_rules! assistant {
71    ($message:expr) => {
72        swiftide_core::chat_completion::ChatMessage::Assistant(Some($message.to_string()), None)
73    };
74    ($message:expr, [$($tool_call_name:expr),*]) => {{
75        let tool_calls = vec![
76            $(
77            ToolCall::builder()
78                .name($tool_call_name)
79                .id("1")
80                .build()
81                .unwrap()
82            ),*
83        ];
84
85        ChatMessage::Assistant(Some($message.to_string()), Some(tool_calls))
86    }};
87}
88
89#[macro_export]
90macro_rules! tool_output {
91    ($tool_name:expr, $message:expr) => {{
92        ChatMessage::ToolOutput(
93            ToolCall::builder()
94                .name($tool_name)
95                .id("1")
96                .build()
97                .unwrap(),
98            $message.into(),
99        )
100    }};
101}
102
103#[macro_export]
104macro_rules! tool_failed {
105    ($tool_name:expr, $message:expr) => {{
106        ChatMessage::ToolOutput(
107            ToolCall::builder()
108                .name($tool_name)
109                .id("1")
110                .build()
111                .unwrap(),
112            ToolOutput::fail($message),
113        )
114    }};
115}
116
117#[macro_export]
118macro_rules! chat_response {
119    ($message:expr; tool_calls = [$($tool_name:expr),*]) => {{
120
121        let tool_calls = vec![
122            $(ToolCall::builder().name($tool_name).id("1").build().unwrap()),*
123        ];
124
125        ChatCompletionResponse::builder()
126            .message($message)
127            .tool_calls(tool_calls)
128            .build()
129            .unwrap()
130    }};
131    (tool_calls = [$($tool_name:expr),*]) => {{
132
133        let tool_calls = vec![
134            $(ToolCall::builder().name($tool_name).id("1").build().unwrap()),*
135        ];
136
137        ChatCompletionResponse::builder()
138            .tool_calls(tool_calls)
139            .build()
140            .unwrap()
141    }};
142}
143
144type Expectations = Arc<Mutex<Vec<(Result<ToolOutput, ToolError>, Option<&'static str>)>>>;
145
146#[derive(Debug, Clone)]
147pub struct MockTool {
148    expectations: Expectations,
149    name: &'static str,
150}
151
152impl MockTool {
153    #[allow(clippy::should_implement_trait)]
154    pub fn default() -> Self {
155        Self::new("mock_tool")
156    }
157    pub fn new(name: &'static str) -> Self {
158        Self {
159            expectations: Arc::new(Mutex::new(Vec::new())),
160            name,
161        }
162    }
163    pub fn expect_invoke_ok(
164        &self,
165        expected_result: ToolOutput,
166        expected_args: Option<&'static str>,
167    ) {
168        self.expect_invoke(Ok(expected_result), expected_args);
169    }
170
171    #[allow(clippy::missing_panics_doc)]
172    pub fn expect_invoke(
173        &self,
174        expected_result: Result<ToolOutput, ToolError>,
175        expected_args: Option<&'static str>,
176    ) {
177        self.expectations
178            .lock()
179            .unwrap()
180            .push((expected_result, expected_args));
181    }
182}
183
184#[async_trait]
185impl Tool for MockTool {
186    async fn invoke(
187        &self,
188        _agent_context: &dyn AgentContext,
189        tool_call: &ToolCall,
190    ) -> std::result::Result<ToolOutput, ToolError> {
191        tracing::debug!(
192            "[MockTool] Invoked `{}` with args: {:?}",
193            self.name,
194            tool_call
195        );
196        let expectation = self
197            .expectations
198            .lock()
199            .unwrap()
200            .pop()
201            .unwrap_or_else(|| panic!("[MockTool] No expectations left for `{}`", self.name));
202
203        assert_eq!(expectation.1, tool_call.args());
204
205        expectation.0
206    }
207
208    fn name(&self) -> Cow<'_, str> {
209        self.name.into()
210    }
211
212    fn tool_spec(&self) -> ToolSpec {
213        ToolSpec::builder()
214            .name(self.name().as_ref())
215            .description("A fake tool for testing purposes")
216            .build()
217            .unwrap()
218    }
219}
220
221impl From<MockTool> for Box<dyn Tool> {
222    fn from(val: MockTool) -> Self {
223        Box::new(val) as Box<dyn Tool>
224    }
225}
226
227impl Drop for MockTool {
228    fn drop(&mut self) {
229        // Mock still borrowed elsewhere and expectations still be invoked
230        if Arc::strong_count(&self.expectations) > 1 {
231            return;
232        }
233        if self.expectations.lock().is_err() {
234            return;
235        }
236
237        let name = self.name;
238        if self.expectations.lock().unwrap().is_empty() {
239            tracing::debug!("[MockTool] All expectations were met for `{name}`");
240        } else {
241            panic!(
242                "[MockTool] Not all expectations were met for `{name}: {:?}",
243                *self.expectations.lock().unwrap()
244            );
245        }
246    }
247}
248
249#[derive(Debug, Clone)]
250pub struct MockHook {
251    name: &'static str,
252    called: Arc<Mutex<usize>>,
253    expected_calls: usize,
254}
255
256impl MockHook {
257    pub fn new(name: &'static str) -> Self {
258        Self {
259            name,
260            called: Arc::new(Mutex::new(0)),
261            expected_calls: 0,
262        }
263    }
264
265    pub fn expect_calls(&mut self, expected_calls: usize) -> &mut Self {
266        self.expected_calls = expected_calls;
267        self
268    }
269
270    #[allow(clippy::missing_panics_doc)]
271    pub fn hook_fn(&self) -> impl BeforeAllFn + use<> {
272        let called = Arc::clone(&self.called);
273        move |_: &Agent| {
274            let called = Arc::clone(&called);
275            Box::pin(async move {
276                let mut called = called.lock().unwrap();
277                *called += 1;
278                Ok(())
279            })
280        }
281    }
282
283    #[allow(clippy::missing_panics_doc)]
284    pub fn on_start_fn(&self) -> impl OnStartFn + use<> {
285        let called = Arc::clone(&self.called);
286        move |_: &Agent| {
287            let called = Arc::clone(&called);
288            Box::pin(async move {
289                let mut called = called.lock().unwrap();
290                *called += 1;
291                Ok(())
292            })
293        }
294    }
295    #[allow(clippy::missing_panics_doc)]
296    pub fn before_completion_fn(&self) -> impl BeforeCompletionFn + use<> {
297        let called = Arc::clone(&self.called);
298        move |_: &Agent, _| {
299            let called = Arc::clone(&called);
300            Box::pin(async move {
301                let mut called = called.lock().unwrap();
302                *called += 1;
303                Ok(())
304            })
305        }
306    }
307
308    #[allow(clippy::missing_panics_doc)]
309    pub fn after_completion_fn(&self) -> impl AfterCompletionFn + use<> {
310        let called = Arc::clone(&self.called);
311        move |_: &Agent, _| {
312            let called = Arc::clone(&called);
313            Box::pin(async move {
314                let mut called = called.lock().unwrap();
315                *called += 1;
316                Ok(())
317            })
318        }
319    }
320
321    #[allow(clippy::missing_panics_doc)]
322    pub fn after_tool_fn(&self) -> impl AfterToolFn + use<> {
323        let called = Arc::clone(&self.called);
324        move |_: &Agent, _, _| {
325            let called = Arc::clone(&called);
326            Box::pin(async move {
327                let mut called = called.lock().unwrap();
328                *called += 1;
329                Ok(())
330            })
331        }
332    }
333
334    #[allow(clippy::missing_panics_doc)]
335    pub fn before_tool_fn(&self) -> impl BeforeToolFn + use<> {
336        let called = Arc::clone(&self.called);
337        move |_: &Agent, _| {
338            let called = Arc::clone(&called);
339            Box::pin(async move {
340                let mut called = called.lock().unwrap();
341                *called += 1;
342                Ok(())
343            })
344        }
345    }
346
347    #[allow(clippy::missing_panics_doc)]
348    pub fn message_hook_fn(&self) -> impl MessageHookFn + use<> {
349        let called = Arc::clone(&self.called);
350        move |_: &Agent, _| {
351            let called = Arc::clone(&called);
352            Box::pin(async move {
353                let mut called = called.lock().unwrap();
354                *called += 1;
355                Ok(())
356            })
357        }
358    }
359
360    #[allow(clippy::missing_panics_doc)]
361    pub fn stop_hook_fn(&self) -> impl OnStopFn + use<> {
362        let called = Arc::clone(&self.called);
363        move |_: &Agent, _, _| {
364            let called = Arc::clone(&called);
365            Box::pin(async move {
366                let mut called = called.lock().unwrap();
367                *called += 1;
368                Ok(())
369            })
370        }
371    }
372
373    #[allow(clippy::missing_panics_doc)]
374    pub fn on_stream_fn(&self) -> impl OnStreamFn + use<> {
375        let called = Arc::clone(&self.called);
376        move |_: &Agent, _| {
377            let called = Arc::clone(&called);
378            Box::pin(async move {
379                let mut called = called.lock().unwrap();
380                *called += 1;
381                Ok(())
382            })
383        }
384    }
385}
386
387impl Drop for MockHook {
388    fn drop(&mut self) {
389        if Arc::strong_count(&self.called) > 1 {
390            return;
391        }
392        let Ok(called) = self.called.lock() else {
393            return;
394        };
395
396        if *called == self.expected_calls {
397            tracing::debug!(
398                "[MockHook] `{}` all expectations met; called {} times",
399                self.name,
400                *called
401            );
402        } else {
403            panic!(
404                "[MockHook] `{}` was called {} times but expected {}",
405                self.name, *called, self.expected_calls
406            )
407        }
408    }
409}