swiftide_core/
test_utils.rs1#![allow(clippy::missing_panics_doc)]
2use std::sync::{Arc, Mutex};
3
4use async_trait::async_trait;
5
6use crate::chat_completion::{
7 errors::ChatCompletionError, ChatCompletion, ChatCompletionRequest, ChatCompletionResponse,
8};
9use anyhow::Result;
10use pretty_assertions::assert_eq;
11
12#[macro_export]
13macro_rules! assert_default_prompt_snapshot {
14 ($node:expr, $($key:expr => $value:expr),*) => {
15 #[tokio::test]
16 async fn test_default_prompt() {
17 let template = default_prompt();
18 let mut prompt = template.to_prompt().with_node(&Node::new($node));
19 $(
20 prompt = prompt.with_context_value($key, $value);
21 )*
22 insta::assert_snapshot!(prompt.render().await.unwrap());
23 }
24 };
25
26 ($($key:expr => $value:expr),*) => {
27 #[tokio::test]
28 async fn test_default_prompt() {
29 let template = default_prompt();
30 let mut prompt = template.to_prompt();
31 $(
32 prompt = prompt.with_context_value($key, $value);
33 )*
34 insta::assert_snapshot!(prompt.render().await.unwrap());
35 }
36 };
37}
38
39type Expectations = Arc<Mutex<Vec<(ChatCompletionRequest, Result<ChatCompletionResponse>)>>>;
40
41#[derive(Clone)]
42pub struct MockChatCompletion {
43 pub expectations: Expectations,
44 pub received_expectations: Expectations,
45}
46
47impl Default for MockChatCompletion {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl MockChatCompletion {
54 pub fn new() -> Self {
55 Self {
56 expectations: Arc::new(Mutex::new(Vec::new())),
57 received_expectations: Arc::new(Mutex::new(Vec::new())),
58 }
59 }
60
61 pub fn expect_complete(
62 &self,
63 request: ChatCompletionRequest,
64 response: Result<ChatCompletionResponse>,
65 ) {
66 let mut mutex = self.expectations.lock().unwrap();
67
68 mutex.insert(0, (request, response));
69 }
70}
71
72#[async_trait]
73impl ChatCompletion for MockChatCompletion {
74 async fn complete(
75 &self,
76 request: &ChatCompletionRequest,
77 ) -> Result<ChatCompletionResponse, ChatCompletionError> {
78 let (expected_request, response) =
79 self.expectations.lock().unwrap().pop().unwrap_or_else(|| {
80 panic!(
81 "Received completion request, but no expectations are set\n {}",
82 pretty_request(request)
83 )
84 });
85
86 assert_eq!(
87 &expected_request,
88 request,
89 "Unexpected request\n: {}\nRemaining expectations:\n{}",
90 pretty_request(request),
91 pretty_expectation(&(expected_request.clone(), response))
92 + "---\n"
93 + &self
94 .expectations
95 .lock()
96 .unwrap()
97 .iter()
98 .map(pretty_expectation)
99 .collect::<Vec<_>>()
100 .join("---\n")
101 );
102
103 if let Ok(response) = response {
104 self.received_expectations
105 .lock()
106 .unwrap()
107 .push((expected_request, Ok(response.clone())));
108
109 Ok(response)
110 } else {
111 let err = response.unwrap_err();
112 self.received_expectations
113 .lock()
114 .unwrap()
115 .push((expected_request, Err(anyhow::anyhow!(err.to_string()))));
116
117 Err(err.into())
118 }
119 }
120}
121
122impl Drop for MockChatCompletion {
123 fn drop(&mut self) {
124 if Arc::strong_count(&self.received_expectations) > 1 {
126 return;
127 }
128 let Ok(expectations) = self.expectations.lock() else {
129 return;
130 };
131 let Ok(received) = self.received_expectations.lock() else {
132 return;
133 };
134
135 if expectations.is_empty() {
136 let num_received = received.len();
137 tracing::debug!("[MockChatCompletion] All {num_received} expectations were met");
138 } else {
139 let received = received
140 .iter()
141 .map(pretty_expectation)
142 .collect::<Vec<_>>()
143 .join("---\n");
144
145 let pending = expectations
146 .iter()
147 .map(pretty_expectation)
148 .collect::<Vec<_>>()
149 .join("---\n");
150
151 panic!("[MockChatCompletion] Not all expectations were met\n received:\n{received}\n\npending:\n{pending}");
152 }
153 }
154}
155
156fn pretty_expectation(
157 expectation: &(ChatCompletionRequest, Result<ChatCompletionResponse>),
158) -> String {
159 let mut output = String::new();
160
161 let request = &expectation.0;
162 output.push_str("Request:\n");
163 output.push_str(&pretty_request(request));
164
165 output.push_str(" =>\n");
166
167 let response_result = &expectation.1;
168
169 if let Ok(response) = response_result {
170 output += &pretty_response(response);
171 }
172
173 output
174}
175
176fn pretty_request(request: &ChatCompletionRequest) -> String {
177 let mut output = String::new();
178 for message in request.messages() {
179 output.push_str(&format!(" {message}\n"));
180 }
181 output
182}
183
184fn pretty_response(response: &ChatCompletionResponse) -> String {
185 let mut output = String::new();
186 if let Some(message) = response.message() {
187 output.push_str(&format!(" {message}\n"));
188 }
189 if let Some(tool_calls) = response.tool_calls() {
190 for tool_call in tool_calls {
191 output.push_str(&format!(" {tool_call}\n"));
192 }
193 }
194 output
195}