swiftide_core/
test_utils.rs1#![allow(clippy::missing_panics_doc)]
2use std::fmt::Write as _;
3use std::sync::{Arc, Mutex};
4
5use async_trait::async_trait;
6
7use crate::ChatCompletionStream;
8use crate::chat_completion::{
9 ChatCompletion, ChatCompletionRequest, ChatCompletionResponse, errors::LanguageModelError,
10};
11use anyhow::Result;
12use pretty_assertions::assert_eq;
13
14#[macro_export]
15macro_rules! assert_default_prompt_snapshot {
16 ($node:expr, $($key:expr => $value:expr),*) => {
17 #[tokio::test]
18 async fn test_default_prompt() {
19 let template = default_prompt();
20 let mut prompt = template.clone().with_node(&Node::new($node));
21 $(
22 prompt = prompt.with_context_value($key, $value);
23 )*
24 insta::assert_snapshot!(prompt.render().unwrap());
25 }
26 };
27
28 ($($key:expr => $value:expr),*) => {
29 #[tokio::test]
30 async fn test_default_prompt() {
31 let template = default_prompt();
32 let mut prompt = template;
33 $(
34 prompt = prompt.with_context_value($key, $value);
35 )*
36 insta::assert_snapshot!(prompt.render().unwrap());
37 }
38 };
39}
40
41type Expectations = Arc<Mutex<Vec<(ChatCompletionRequest, Result<ChatCompletionResponse>)>>>;
42
43#[derive(Clone)]
44pub struct MockChatCompletion {
45 pub expectations: Expectations,
46 pub received_expectations: Expectations,
47}
48
49impl Default for MockChatCompletion {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl MockChatCompletion {
56 pub fn new() -> Self {
57 Self {
58 expectations: Arc::new(Mutex::new(Vec::new())),
59 received_expectations: Arc::new(Mutex::new(Vec::new())),
60 }
61 }
62
63 pub fn expect_complete(
64 &self,
65 request: ChatCompletionRequest,
66 response: Result<ChatCompletionResponse>,
67 ) {
68 let mut mutex = self.expectations.lock().unwrap();
69
70 mutex.insert(0, (request, response));
71 }
72}
73
74#[async_trait]
75impl ChatCompletion for MockChatCompletion {
76 async fn complete(
77 &self,
78 request: &ChatCompletionRequest,
79 ) -> Result<ChatCompletionResponse, LanguageModelError> {
80 let (expected_request, response) =
81 self.expectations.lock().unwrap().pop().unwrap_or_else(|| {
82 panic!(
83 "Received completion request, but no expectations are set\n {}",
84 pretty_request(request)
85 )
86 });
87
88 assert_eq!(
89 &expected_request,
90 request,
91 "Unexpected request\n: {}\nRemaining expectations:\n{}",
92 pretty_request(request),
93 pretty_expectation(&(expected_request.clone(), response))
94 + "---\n"
95 + &self
96 .expectations
97 .lock()
98 .unwrap()
99 .iter()
100 .map(pretty_expectation)
101 .collect::<Vec<_>>()
102 .join("---\n")
103 );
104
105 if let Ok(response) = response {
106 self.received_expectations
107 .lock()
108 .unwrap()
109 .push((expected_request, Ok(response.clone())));
110
111 tracing::debug!(
112 "[MockChatCompletion] Received request:\n{}\nResponse:\n{}",
113 pretty_request(request),
114 pretty_response(&response)
115 );
116 Ok(response)
117 } else {
118 let err = response.unwrap_err();
119 self.received_expectations
120 .lock()
121 .unwrap()
122 .push((expected_request, Err(anyhow::anyhow!(err.to_string()))));
123
124 Err(LanguageModelError::PermanentError(err.into()))
125 }
126 }
127
128 async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
131 let response = match self.complete(request).await {
132 Ok(response) => response,
133 Err(err) => return err.into(),
134 };
135
136 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<
137 Result<ChatCompletionResponse, LanguageModelError>,
138 >();
139
140 tokio::spawn(async move {
141 let mut chunk_response = ChatCompletionResponse::builder()
142 .maybe_tool_calls(response.tool_calls.clone())
143 .build()
144 .unwrap();
145
146 for chunk in response.message().unwrap().split_whitespace() {
147 tracing::debug!("[MockChatCompletion] Sending chunk: {chunk}");
148
149 let chunk_response = chunk_response.append_message_delta(Some(chunk)).clone();
150 let _ = tx.send(Ok(chunk_response));
151 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
152 }
153 });
154
155 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
156 }
157}
158
159impl Drop for MockChatCompletion {
160 fn drop(&mut self) {
161 if Arc::strong_count(&self.received_expectations) > 1 {
163 return;
164 }
165 let Ok(expectations) = self.expectations.lock() else {
166 return;
167 };
168 let Ok(received) = self.received_expectations.lock() else {
169 return;
170 };
171
172 if expectations.is_empty() {
173 let num_received = received.len();
174 tracing::debug!("[MockChatCompletion] All {num_received} expectations were met");
175 } else {
176 let received = received
177 .iter()
178 .map(pretty_expectation)
179 .collect::<Vec<_>>()
180 .join("---\n");
181
182 let pending = expectations
183 .iter()
184 .map(pretty_expectation)
185 .collect::<Vec<_>>()
186 .join("---\n");
187
188 panic!(
189 "[MockChatCompletion] Not all expectations were met\n received:\n{received}\n\npending:\n{pending}"
190 );
191 }
192 }
193}
194
195fn pretty_expectation(
196 expectation: &(ChatCompletionRequest, Result<ChatCompletionResponse>),
197) -> String {
198 let mut output = String::new();
199
200 let request = &expectation.0;
201 output.push_str("Request:\n");
202 output.push_str(&pretty_request(request));
203
204 output.push_str(" =>\n");
205
206 let response_result = &expectation.1;
207
208 if let Ok(response) = response_result {
209 output += &pretty_response(response);
210 }
211
212 output
213}
214
215fn pretty_request(request: &ChatCompletionRequest) -> String {
216 let mut output = String::new();
217 for message in request.messages() {
218 writeln!(output, " {message}").unwrap();
219 }
220 output
221}
222
223fn pretty_response(response: &ChatCompletionResponse) -> String {
224 let mut output = String::new();
225 if let Some(message) = response.message() {
226 writeln!(output, " {message}").unwrap();
227 }
228 if let Some(tool_calls) = response.tool_calls() {
229 for tool_call in tool_calls {
230 writeln!(output, " {tool_call}").unwrap();
231 }
232 }
233 output
234}