swiftide_agents/
test_utils.rs1use std::borrow::Cow;
2use std::sync::{Arc, Mutex};
3
4use async_trait::async_trait;
5use swiftide_core::chat_completion::ToolCall;
6use swiftide_core::chat_completion::{Tool, ToolOutput, ToolSpec, errors::ToolError};
7
8use swiftide_core::AgentContext;
9
10use crate::Agent;
11use crate::hooks::{
12 AfterCompletionFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn, MessageHookFn,
13 OnStartFn, OnStopFn, OnStreamFn,
14};
15
16#[macro_export]
17macro_rules! chat_request {
18 ($($message:expr),+; tools = [$($tool:expr),*]) => {{
19 let mut builder = swiftide_core::chat_completion::ChatCompletionRequest::builder();
20 builder.messages(vec![$($message),*]);
21
22 let mut tool_specs = Vec::new();
23 $(tool_specs.push({
24 let tool = $tool;
25 tool.tool_spec()
26 });)*
27
28 tool_specs.extend(Agent::default_tools().into_iter().map(|tool| tool.tool_spec()));
29
30 builder.tool_specs(tool_specs);
31
32 builder.build().unwrap()
33 }};
34 ($($message:expr),+; tool_specs = [$($tool:expr),*]) => {{
35 let mut builder = swiftide_core::chat_completion::ChatCompletionRequest::builder();
36 builder.messages(vec![$($message),*]);
37
38 let mut tool_specs = Vec::new();
39 $(tool_specs.push($tool);)*
40 tool_specs.extend(Agent::default_tools().into_iter().map(|tool| tool.tool_spec()));
41
42 builder.tool_specs(tool_specs);
43
44 builder.build().unwrap()
45 }}
46}
47
48#[macro_export]
49macro_rules! user {
50 ($message:expr) => {
51 swiftide_core::chat_completion::ChatMessage::User($message.to_string())
52 };
53}
54
55#[macro_export]
56macro_rules! system {
57 ($message:expr) => {
58 swiftide_core::chat_completion::ChatMessage::System($message.to_string())
59 };
60}
61
62#[macro_export]
63macro_rules! summary {
64 ($message:expr) => {
65 swiftide_core::chat_completion::ChatMessage::Summary($message.to_string())
66 };
67}
68
69#[macro_export]
70macro_rules! assistant {
71 ($message:expr) => {
72 swiftide_core::chat_completion::ChatMessage::Assistant(Some($message.to_string()), None)
73 };
74 ($message:expr, [$($tool_call_name:expr),*]) => {{
75 let tool_calls = vec![
76 $(
77 ToolCall::builder()
78 .name($tool_call_name)
79 .id("1")
80 .build()
81 .unwrap()
82 ),*
83 ];
84
85 ChatMessage::Assistant(Some($message.to_string()), Some(tool_calls))
86 }};
87}
88
89#[macro_export]
90macro_rules! tool_output {
91 ($tool_name:expr, $message:expr) => {{
92 ChatMessage::ToolOutput(
93 ToolCall::builder()
94 .name($tool_name)
95 .id("1")
96 .build()
97 .unwrap(),
98 $message.into(),
99 )
100 }};
101}
102
103#[macro_export]
104macro_rules! tool_failed {
105 ($tool_name:expr, $message:expr) => {{
106 ChatMessage::ToolOutput(
107 ToolCall::builder()
108 .name($tool_name)
109 .id("1")
110 .build()
111 .unwrap(),
112 ToolOutput::fail($message),
113 )
114 }};
115}
116
117#[macro_export]
118macro_rules! chat_response {
119 ($message:expr; tool_calls = [$($tool_name:expr),*]) => {{
120
121 let tool_calls = vec![
122 $(ToolCall::builder().name($tool_name).id("1").build().unwrap()),*
123 ];
124
125 ChatCompletionResponse::builder()
126 .message($message)
127 .tool_calls(tool_calls)
128 .build()
129 .unwrap()
130 }};
131 (tool_calls = [$($tool_name:expr),*]) => {{
132
133 let tool_calls = vec![
134 $(ToolCall::builder().name($tool_name).id("1").build().unwrap()),*
135 ];
136
137 ChatCompletionResponse::builder()
138 .tool_calls(tool_calls)
139 .build()
140 .unwrap()
141 }};
142}
143
144type Expectations = Arc<Mutex<Vec<(Result<ToolOutput, ToolError>, Option<&'static str>)>>>;
145
146#[derive(Debug, Clone)]
147pub struct MockTool {
148 expectations: Expectations,
149 name: &'static str,
150}
151
152impl MockTool {
153 #[allow(clippy::should_implement_trait)]
154 pub fn default() -> Self {
155 Self::new("mock_tool")
156 }
157 pub fn new(name: &'static str) -> Self {
158 Self {
159 expectations: Arc::new(Mutex::new(Vec::new())),
160 name,
161 }
162 }
163 pub fn expect_invoke_ok(
164 &self,
165 expected_result: ToolOutput,
166 expected_args: Option<&'static str>,
167 ) {
168 self.expect_invoke(Ok(expected_result), expected_args);
169 }
170
171 #[allow(clippy::missing_panics_doc)]
172 pub fn expect_invoke(
173 &self,
174 expected_result: Result<ToolOutput, ToolError>,
175 expected_args: Option<&'static str>,
176 ) {
177 self.expectations
178 .lock()
179 .unwrap()
180 .push((expected_result, expected_args));
181 }
182}
183
184#[async_trait]
185impl Tool for MockTool {
186 async fn invoke(
187 &self,
188 _agent_context: &dyn AgentContext,
189 tool_call: &ToolCall,
190 ) -> std::result::Result<ToolOutput, ToolError> {
191 tracing::debug!(
192 "[MockTool] Invoked `{}` with args: {:?}",
193 self.name,
194 tool_call
195 );
196 let expectation = self
197 .expectations
198 .lock()
199 .unwrap()
200 .pop()
201 .unwrap_or_else(|| panic!("[MockTool] No expectations left for `{}`", self.name));
202
203 assert_eq!(expectation.1, tool_call.args());
204
205 expectation.0
206 }
207
208 fn name(&self) -> Cow<'_, str> {
209 self.name.into()
210 }
211
212 fn tool_spec(&self) -> ToolSpec {
213 ToolSpec::builder()
214 .name(self.name().as_ref())
215 .description("A fake tool for testing purposes")
216 .build()
217 .unwrap()
218 }
219}
220
221impl From<MockTool> for Box<dyn Tool> {
222 fn from(val: MockTool) -> Self {
223 Box::new(val) as Box<dyn Tool>
224 }
225}
226
227impl Drop for MockTool {
228 fn drop(&mut self) {
229 if Arc::strong_count(&self.expectations) > 1 {
231 return;
232 }
233 if self.expectations.lock().is_err() {
234 return;
235 }
236
237 let name = self.name;
238 if self.expectations.lock().unwrap().is_empty() {
239 tracing::debug!("[MockTool] All expectations were met for `{name}`");
240 } else {
241 panic!(
242 "[MockTool] Not all expectations were met for `{name}: {:?}",
243 *self.expectations.lock().unwrap()
244 );
245 }
246 }
247}
248
249#[derive(Debug, Clone)]
250pub struct MockHook {
251 name: &'static str,
252 called: Arc<Mutex<usize>>,
253 expected_calls: usize,
254}
255
256impl MockHook {
257 pub fn new(name: &'static str) -> Self {
258 Self {
259 name,
260 called: Arc::new(Mutex::new(0)),
261 expected_calls: 0,
262 }
263 }
264
265 pub fn expect_calls(&mut self, expected_calls: usize) -> &mut Self {
266 self.expected_calls = expected_calls;
267 self
268 }
269
270 #[allow(clippy::missing_panics_doc)]
271 pub fn hook_fn(&self) -> impl BeforeAllFn + use<> {
272 let called = Arc::clone(&self.called);
273 move |_: &Agent| {
274 let called = Arc::clone(&called);
275 Box::pin(async move {
276 let mut called = called.lock().unwrap();
277 *called += 1;
278 Ok(())
279 })
280 }
281 }
282
283 #[allow(clippy::missing_panics_doc)]
284 pub fn on_start_fn(&self) -> impl OnStartFn + use<> {
285 let called = Arc::clone(&self.called);
286 move |_: &Agent| {
287 let called = Arc::clone(&called);
288 Box::pin(async move {
289 let mut called = called.lock().unwrap();
290 *called += 1;
291 Ok(())
292 })
293 }
294 }
295 #[allow(clippy::missing_panics_doc)]
296 pub fn before_completion_fn(&self) -> impl BeforeCompletionFn + use<> {
297 let called = Arc::clone(&self.called);
298 move |_: &Agent, _| {
299 let called = Arc::clone(&called);
300 Box::pin(async move {
301 let mut called = called.lock().unwrap();
302 *called += 1;
303 Ok(())
304 })
305 }
306 }
307
308 #[allow(clippy::missing_panics_doc)]
309 pub fn after_completion_fn(&self) -> impl AfterCompletionFn + use<> {
310 let called = Arc::clone(&self.called);
311 move |_: &Agent, _| {
312 let called = Arc::clone(&called);
313 Box::pin(async move {
314 let mut called = called.lock().unwrap();
315 *called += 1;
316 Ok(())
317 })
318 }
319 }
320
321 #[allow(clippy::missing_panics_doc)]
322 pub fn after_tool_fn(&self) -> impl AfterToolFn + use<> {
323 let called = Arc::clone(&self.called);
324 move |_: &Agent, _, _| {
325 let called = Arc::clone(&called);
326 Box::pin(async move {
327 let mut called = called.lock().unwrap();
328 *called += 1;
329 Ok(())
330 })
331 }
332 }
333
334 #[allow(clippy::missing_panics_doc)]
335 pub fn before_tool_fn(&self) -> impl BeforeToolFn + use<> {
336 let called = Arc::clone(&self.called);
337 move |_: &Agent, _| {
338 let called = Arc::clone(&called);
339 Box::pin(async move {
340 let mut called = called.lock().unwrap();
341 *called += 1;
342 Ok(())
343 })
344 }
345 }
346
347 #[allow(clippy::missing_panics_doc)]
348 pub fn message_hook_fn(&self) -> impl MessageHookFn + use<> {
349 let called = Arc::clone(&self.called);
350 move |_: &Agent, _| {
351 let called = Arc::clone(&called);
352 Box::pin(async move {
353 let mut called = called.lock().unwrap();
354 *called += 1;
355 Ok(())
356 })
357 }
358 }
359
360 #[allow(clippy::missing_panics_doc)]
361 pub fn stop_hook_fn(&self) -> impl OnStopFn + use<> {
362 let called = Arc::clone(&self.called);
363 move |_: &Agent, _, _| {
364 let called = Arc::clone(&called);
365 Box::pin(async move {
366 let mut called = called.lock().unwrap();
367 *called += 1;
368 Ok(())
369 })
370 }
371 }
372
373 #[allow(clippy::missing_panics_doc)]
374 pub fn on_stream_fn(&self) -> impl OnStreamFn + use<> {
375 let called = Arc::clone(&self.called);
376 move |_: &Agent, _| {
377 let called = Arc::clone(&called);
378 Box::pin(async move {
379 let mut called = called.lock().unwrap();
380 *called += 1;
381 Ok(())
382 })
383 }
384 }
385}
386
387impl Drop for MockHook {
388 fn drop(&mut self) {
389 if Arc::strong_count(&self.called) > 1 {
390 return;
391 }
392 let Ok(called) = self.called.lock() else {
393 return;
394 };
395
396 if *called == self.expected_calls {
397 tracing::debug!(
398 "[MockHook] `{}` all expectations met; called {} times",
399 self.name,
400 *called
401 );
402 } else {
403 panic!(
404 "[MockHook] `{}` was called {} times but expected {}",
405 self.name, *called, self.expected_calls
406 )
407 }
408 }
409}