swiftide_core/
test_utils.rs1#![allow(clippy::missing_panics_doc)]
2use std::fmt::Write as _;
3use std::sync::{Arc, Mutex};
4
5use async_trait::async_trait;
6
7use crate::chat_completion::{
8 errors::ChatCompletionError, ChatCompletion, ChatCompletionRequest, ChatCompletionResponse,
9};
10use anyhow::Result;
11use pretty_assertions::assert_eq;
12
13#[macro_export]
14macro_rules! assert_default_prompt_snapshot {
15 ($node:expr, $($key:expr => $value:expr),*) => {
16 #[tokio::test]
17 async fn test_default_prompt() {
18 let template = default_prompt();
19 let mut prompt = template.clone().with_node(&Node::new($node));
20 $(
21 prompt = prompt.with_context_value($key, $value);
22 )*
23 insta::assert_snapshot!(prompt.render().unwrap());
24 }
25 };
26
27 ($($key:expr => $value:expr),*) => {
28 #[tokio::test]
29 async fn test_default_prompt() {
30 let template = default_prompt();
31 let mut prompt = template;
32 $(
33 prompt = prompt.with_context_value($key, $value);
34 )*
35 insta::assert_snapshot!(prompt.render().unwrap());
36 }
37 };
38}
39
40type Expectations = Arc<Mutex<Vec<(ChatCompletionRequest, Result<ChatCompletionResponse>)>>>;
41
42#[derive(Clone)]
43pub struct MockChatCompletion {
44 pub expectations: Expectations,
45 pub received_expectations: Expectations,
46}
47
48impl Default for MockChatCompletion {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl MockChatCompletion {
55 pub fn new() -> Self {
56 Self {
57 expectations: Arc::new(Mutex::new(Vec::new())),
58 received_expectations: Arc::new(Mutex::new(Vec::new())),
59 }
60 }
61
62 pub fn expect_complete(
63 &self,
64 request: ChatCompletionRequest,
65 response: Result<ChatCompletionResponse>,
66 ) {
67 let mut mutex = self.expectations.lock().unwrap();
68
69 mutex.insert(0, (request, response));
70 }
71}
72
73#[async_trait]
74impl ChatCompletion for MockChatCompletion {
75 async fn complete(
76 &self,
77 request: &ChatCompletionRequest,
78 ) -> Result<ChatCompletionResponse, ChatCompletionError> {
79 let (expected_request, response) =
80 self.expectations.lock().unwrap().pop().unwrap_or_else(|| {
81 panic!(
82 "Received completion request, but no expectations are set\n {}",
83 pretty_request(request)
84 )
85 });
86
87 assert_eq!(
88 &expected_request,
89 request,
90 "Unexpected request\n: {}\nRemaining expectations:\n{}",
91 pretty_request(request),
92 pretty_expectation(&(expected_request.clone(), response))
93 + "---\n"
94 + &self
95 .expectations
96 .lock()
97 .unwrap()
98 .iter()
99 .map(pretty_expectation)
100 .collect::<Vec<_>>()
101 .join("---\n")
102 );
103
104 if let Ok(response) = response {
105 self.received_expectations
106 .lock()
107 .unwrap()
108 .push((expected_request, Ok(response.clone())));
109
110 Ok(response)
111 } else {
112 let err = response.unwrap_err();
113 self.received_expectations
114 .lock()
115 .unwrap()
116 .push((expected_request, Err(anyhow::anyhow!(err.to_string()))));
117
118 Err(err.into())
119 }
120 }
121}
122
123impl Drop for MockChatCompletion {
124 fn drop(&mut self) {
125 if Arc::strong_count(&self.received_expectations) > 1 {
127 return;
128 }
129 let Ok(expectations) = self.expectations.lock() else {
130 return;
131 };
132 let Ok(received) = self.received_expectations.lock() else {
133 return;
134 };
135
136 if expectations.is_empty() {
137 let num_received = received.len();
138 tracing::debug!("[MockChatCompletion] All {num_received} expectations were met");
139 } else {
140 let received = received
141 .iter()
142 .map(pretty_expectation)
143 .collect::<Vec<_>>()
144 .join("---\n");
145
146 let pending = expectations
147 .iter()
148 .map(pretty_expectation)
149 .collect::<Vec<_>>()
150 .join("---\n");
151
152 panic!("[MockChatCompletion] Not all expectations were met\n received:\n{received}\n\npending:\n{pending}");
153 }
154 }
155}
156
157fn pretty_expectation(
158 expectation: &(ChatCompletionRequest, Result<ChatCompletionResponse>),
159) -> String {
160 let mut output = String::new();
161
162 let request = &expectation.0;
163 output.push_str("Request:\n");
164 output.push_str(&pretty_request(request));
165
166 output.push_str(" =>\n");
167
168 let response_result = &expectation.1;
169
170 if let Ok(response) = response_result {
171 output += &pretty_response(response);
172 }
173
174 output
175}
176
177fn pretty_request(request: &ChatCompletionRequest) -> String {
178 let mut output = String::new();
179 for message in request.messages() {
180 writeln!(output, " {message}").unwrap();
181 }
182 output
183}
184
185fn pretty_response(response: &ChatCompletionResponse) -> String {
186 let mut output = String::new();
187 if let Some(message) = response.message() {
188 writeln!(output, " {message}").unwrap();
189 }
190 if let Some(tool_calls) = response.tool_calls() {
191 for tool_call in tool_calls {
192 writeln!(output, " {tool_call}").unwrap();
193 }
194 }
195 output
196}