1use std::borrow::Cow;
2use std::sync::{Arc, Mutex};
3
4use async_trait::async_trait;
5use swiftide_core::chat_completion::ToolCall;
6use swiftide_core::chat_completion::{Tool, ToolOutput, ToolSpec, errors::ToolError};
7
8use swiftide_core::AgentContext;
9
10use crate::Agent;
11use crate::hooks::{
12 AfterCompletionFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn, MessageHookFn,
13 OnStartFn, OnStopFn, OnStreamFn,
14};
15
16#[macro_export]
17macro_rules! chat_request {
18 ($($message:expr),+; tools = [$($tool:expr),*]) => {
19 swiftide_core::chat_completion::ChatCompletionRequest::builder()
20 .messages(vec![$($message),*])
21 .tools_spec(
22 vec![$(Box::new($tool) as Box<dyn Tool>),*]
23 .into_iter()
24 .chain(Agent::default_tools())
25 .map(|tool| tool.tool_spec())
26 .collect::<std::collections::HashSet<_>>(),
27 )
28 .build()
29 .unwrap()
30 };
31 ($($message:expr),+; tool_specs = [$($tool:expr),*]) => {
32 swiftide_core::chat_completion::ChatCompletionRequest::builder()
33 .messages(vec![$($message),*])
34 .tools_spec(
35 vec![$(($tool)),*]
36 .into_iter()
37 .chain(Agent::default_tools().into_iter().map(|tool| tool.tool_spec()))
38 .collect::<std::collections::HashSet<_>>(),
39 )
40 .build()
41 .unwrap()
42 }
43}
44
45#[macro_export]
46macro_rules! user {
47 ($message:expr) => {
48 swiftide_core::chat_completion::ChatMessage::User($message.to_string())
49 };
50}
51
52#[macro_export]
53macro_rules! system {
54 ($message:expr) => {
55 swiftide_core::chat_completion::ChatMessage::System($message.to_string())
56 };
57}
58
59#[macro_export]
60macro_rules! summary {
61 ($message:expr) => {
62 swiftide_core::chat_completion::ChatMessage::Summary($message.to_string())
63 };
64}
65
66#[macro_export]
67macro_rules! assistant {
68 ($message:expr) => {
69 swiftide_core::chat_completion::ChatMessage::Assistant(Some($message.to_string()), None)
70 };
71 ($message:expr, [$($tool_call_name:expr),*]) => {{
72 let tool_calls = vec![
73 $(
74 ToolCall::builder()
75 .name($tool_call_name)
76 .id("1")
77 .build()
78 .unwrap()
79 ),*
80 ];
81
82 ChatMessage::Assistant(Some($message.to_string()), Some(tool_calls))
83 }};
84}
85
86#[macro_export]
87macro_rules! tool_output {
88 ($tool_name:expr, $message:expr) => {{
89 ChatMessage::ToolOutput(
90 ToolCall::builder()
91 .name($tool_name)
92 .id("1")
93 .build()
94 .unwrap(),
95 $message.into(),
96 )
97 }};
98}
99
100#[macro_export]
101macro_rules! tool_failed {
102 ($tool_name:expr, $message:expr) => {{
103 ChatMessage::ToolOutput(
104 ToolCall::builder()
105 .name($tool_name)
106 .id("1")
107 .build()
108 .unwrap(),
109 ToolOutput::fail($message),
110 )
111 }};
112}
113
114#[macro_export]
115macro_rules! chat_response {
116 ($message:expr; tool_calls = [$($tool_name:expr),*]) => {{
117
118 let tool_calls = vec![
119 $(ToolCall::builder().name($tool_name).id("1").build().unwrap()),*
120 ];
121
122 ChatCompletionResponse::builder()
123 .message($message)
124 .tool_calls(tool_calls)
125 .build()
126 .unwrap()
127 }};
128 (tool_calls = [$($tool_name:expr),*]) => {{
129
130 let tool_calls = vec![
131 $(ToolCall::builder().name($tool_name).id("1").build().unwrap()),*
132 ];
133
134 ChatCompletionResponse::builder()
135 .tool_calls(tool_calls)
136 .build()
137 .unwrap()
138 }};
139}
140
141type Expectations = Arc<Mutex<Vec<(Result<ToolOutput, ToolError>, Option<&'static str>)>>>;
142
143#[derive(Debug, Clone)]
144pub struct MockTool {
145 expectations: Expectations,
146 name: &'static str,
147}
148
149impl MockTool {
150 #[allow(clippy::should_implement_trait)]
151 pub fn default() -> Self {
152 Self::new("mock_tool")
153 }
154 pub fn new(name: &'static str) -> Self {
155 Self {
156 expectations: Arc::new(Mutex::new(Vec::new())),
157 name,
158 }
159 }
160 pub fn expect_invoke_ok(
161 &self,
162 expected_result: ToolOutput,
163 expected_args: Option<&'static str>,
164 ) {
165 self.expect_invoke(Ok(expected_result), expected_args);
166 }
167
168 #[allow(clippy::missing_panics_doc)]
169 pub fn expect_invoke(
170 &self,
171 expected_result: Result<ToolOutput, ToolError>,
172 expected_args: Option<&'static str>,
173 ) {
174 self.expectations
175 .lock()
176 .unwrap()
177 .push((expected_result, expected_args));
178 }
179}
180
181#[async_trait]
182impl Tool for MockTool {
183 async fn invoke(
184 &self,
185 _agent_context: &dyn AgentContext,
186 tool_call: &ToolCall,
187 ) -> std::result::Result<ToolOutput, ToolError> {
188 tracing::debug!(
189 "[MockTool] Invoked `{}` with args: {:?}",
190 self.name,
191 tool_call
192 );
193 let expectation = self
194 .expectations
195 .lock()
196 .unwrap()
197 .pop()
198 .unwrap_or_else(|| panic!("[MockTool] No expectations left for `{}`", self.name));
199
200 assert_eq!(expectation.1, tool_call.args());
201
202 expectation.0
203 }
204
205 fn name(&self) -> Cow<'_, str> {
206 self.name.into()
207 }
208
209 fn tool_spec(&self) -> ToolSpec {
210 ToolSpec::builder()
211 .name(self.name().as_ref())
212 .description("A fake tool for testing purposes")
213 .build()
214 .unwrap()
215 }
216}
217
218impl From<MockTool> for Box<dyn Tool> {
219 fn from(val: MockTool) -> Self {
220 Box::new(val) as Box<dyn Tool>
221 }
222}
223
224impl Drop for MockTool {
225 fn drop(&mut self) {
226 if Arc::strong_count(&self.expectations) > 1 {
228 return;
229 }
230 if self.expectations.lock().is_err() {
231 return;
232 }
233
234 let name = self.name;
235 if self.expectations.lock().unwrap().is_empty() {
236 tracing::debug!("[MockTool] All expectations were met for `{name}`");
237 } else {
238 panic!(
239 "[MockTool] Not all expectations were met for `{name}: {:?}",
240 *self.expectations.lock().unwrap()
241 );
242 }
243 }
244}
245
246#[derive(Debug, Clone)]
247pub struct MockHook {
248 name: &'static str,
249 called: Arc<Mutex<usize>>,
250 expected_calls: usize,
251}
252
253impl MockHook {
254 pub fn new(name: &'static str) -> Self {
255 Self {
256 name,
257 called: Arc::new(Mutex::new(0)),
258 expected_calls: 0,
259 }
260 }
261
262 pub fn expect_calls(&mut self, expected_calls: usize) -> &mut Self {
263 self.expected_calls = expected_calls;
264 self
265 }
266
267 #[allow(clippy::missing_panics_doc)]
268 pub fn hook_fn(&self) -> impl BeforeAllFn + use<> {
269 let called = Arc::clone(&self.called);
270 move |_: &Agent| {
271 let called = Arc::clone(&called);
272 Box::pin(async move {
273 let mut called = called.lock().unwrap();
274 *called += 1;
275 Ok(())
276 })
277 }
278 }
279
280 #[allow(clippy::missing_panics_doc)]
281 pub fn on_start_fn(&self) -> impl OnStartFn + use<> {
282 let called = Arc::clone(&self.called);
283 move |_: &Agent| {
284 let called = Arc::clone(&called);
285 Box::pin(async move {
286 let mut called = called.lock().unwrap();
287 *called += 1;
288 Ok(())
289 })
290 }
291 }
292 #[allow(clippy::missing_panics_doc)]
293 pub fn before_completion_fn(&self) -> impl BeforeCompletionFn + use<> {
294 let called = Arc::clone(&self.called);
295 move |_: &Agent, _| {
296 let called = Arc::clone(&called);
297 Box::pin(async move {
298 let mut called = called.lock().unwrap();
299 *called += 1;
300 Ok(())
301 })
302 }
303 }
304
305 #[allow(clippy::missing_panics_doc)]
306 pub fn after_completion_fn(&self) -> impl AfterCompletionFn + use<> {
307 let called = Arc::clone(&self.called);
308 move |_: &Agent, _| {
309 let called = Arc::clone(&called);
310 Box::pin(async move {
311 let mut called = called.lock().unwrap();
312 *called += 1;
313 Ok(())
314 })
315 }
316 }
317
318 #[allow(clippy::missing_panics_doc)]
319 pub fn after_tool_fn(&self) -> impl AfterToolFn + use<> {
320 let called = Arc::clone(&self.called);
321 move |_: &Agent, _, _| {
322 let called = Arc::clone(&called);
323 Box::pin(async move {
324 let mut called = called.lock().unwrap();
325 *called += 1;
326 Ok(())
327 })
328 }
329 }
330
331 #[allow(clippy::missing_panics_doc)]
332 pub fn before_tool_fn(&self) -> impl BeforeToolFn + use<> {
333 let called = Arc::clone(&self.called);
334 move |_: &Agent, _| {
335 let called = Arc::clone(&called);
336 Box::pin(async move {
337 let mut called = called.lock().unwrap();
338 *called += 1;
339 Ok(())
340 })
341 }
342 }
343
344 #[allow(clippy::missing_panics_doc)]
345 pub fn message_hook_fn(&self) -> impl MessageHookFn + use<> {
346 let called = Arc::clone(&self.called);
347 move |_: &Agent, _| {
348 let called = Arc::clone(&called);
349 Box::pin(async move {
350 let mut called = called.lock().unwrap();
351 *called += 1;
352 Ok(())
353 })
354 }
355 }
356
357 #[allow(clippy::missing_panics_doc)]
358 pub fn stop_hook_fn(&self) -> impl OnStopFn + use<> {
359 let called = Arc::clone(&self.called);
360 move |_: &Agent, _, _| {
361 let called = Arc::clone(&called);
362 Box::pin(async move {
363 let mut called = called.lock().unwrap();
364 *called += 1;
365 Ok(())
366 })
367 }
368 }
369
370 #[allow(clippy::missing_panics_doc)]
371 pub fn on_stream_fn(&self) -> impl OnStreamFn + use<> {
372 let called = Arc::clone(&self.called);
373 move |_: &Agent, _| {
374 let called = Arc::clone(&called);
375 Box::pin(async move {
376 let mut called = called.lock().unwrap();
377 *called += 1;
378 Ok(())
379 })
380 }
381 }
382}
383
384impl Drop for MockHook {
385 fn drop(&mut self) {
386 if Arc::strong_count(&self.called) > 1 {
387 return;
388 }
389 let Ok(called) = self.called.lock() else {
390 return;
391 };
392
393 if *called == self.expected_calls {
394 tracing::debug!(
395 "[MockHook] `{}` all expectations met; called {} times",
396 self.name,
397 *called
398 );
399 } else {
400 panic!(
401 "[MockHook] `{}` was called {} times but expected {}",
402 self.name, *called, self.expected_calls
403 )
404 }
405 }
406}