swiftide_core/
test_utils.rs

1#![allow(clippy::missing_panics_doc)]
2use std::fmt::Write as _;
3use std::sync::{Arc, Mutex};
4
5use async_trait::async_trait;
6
7use crate::ChatCompletionStream;
8use crate::chat_completion::{
9    ChatCompletion, ChatCompletionRequest, ChatCompletionResponse, errors::LanguageModelError,
10};
11use anyhow::Result;
12use pretty_assertions::assert_eq;
13
14#[macro_export]
15macro_rules! assert_default_prompt_snapshot {
16    ($node:expr, $($key:expr => $value:expr),*) => {
17        #[tokio::test]
18        async fn test_default_prompt() {
19        let template = default_prompt();
20        let mut prompt = template.clone().with_node(&Node::new($node));
21        $(
22            prompt = prompt.with_context_value($key, $value);
23        )*
24        insta::assert_snapshot!(prompt.render().unwrap());
25        }
26    };
27
28    ($($key:expr => $value:expr),*) => {
29        #[tokio::test]
30        async fn test_default_prompt() {
31            let template = default_prompt();
32            let mut prompt = template;
33            $(
34                prompt = prompt.with_context_value($key, $value);
35            )*
36            insta::assert_snapshot!(prompt.render().unwrap());
37        }
38    };
39}
40
41type Expectations = Arc<Mutex<Vec<(ChatCompletionRequest, Result<ChatCompletionResponse>)>>>;
42
43#[derive(Clone)]
44pub struct MockChatCompletion {
45    pub expectations: Expectations,
46    pub received_expectations: Expectations,
47}
48
49impl Default for MockChatCompletion {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl MockChatCompletion {
56    pub fn new() -> Self {
57        Self {
58            expectations: Arc::new(Mutex::new(Vec::new())),
59            received_expectations: Arc::new(Mutex::new(Vec::new())),
60        }
61    }
62
63    pub fn expect_complete(
64        &self,
65        request: ChatCompletionRequest,
66        response: Result<ChatCompletionResponse>,
67    ) {
68        let mut mutex = self.expectations.lock().unwrap();
69
70        mutex.insert(0, (request, response));
71    }
72}
73
74#[async_trait]
75impl ChatCompletion for MockChatCompletion {
76    async fn complete(
77        &self,
78        request: &ChatCompletionRequest,
79    ) -> Result<ChatCompletionResponse, LanguageModelError> {
80        let (expected_request, response) =
81            self.expectations.lock().unwrap().pop().unwrap_or_else(|| {
82                panic!(
83                    "Received completion request, but no expectations are set\n {}",
84                    pretty_request(request)
85                )
86            });
87
88        assert_eq!(
89            &expected_request,
90            request,
91            "Unexpected request\n: {}\nRemaining expectations:\n{}",
92            pretty_request(request),
93            pretty_expectation(&(expected_request.clone(), response))
94                + "---\n"
95                + &self
96                    .expectations
97                    .lock()
98                    .unwrap()
99                    .iter()
100                    .map(pretty_expectation)
101                    .collect::<Vec<_>>()
102                    .join("---\n")
103        );
104
105        if let Ok(response) = response {
106            self.received_expectations
107                .lock()
108                .unwrap()
109                .push((expected_request, Ok(response.clone())));
110
111            tracing::debug!(
112                "[MockChatCompletion] Received request:\n{}\nResponse:\n{}",
113                pretty_request(request),
114                pretty_response(&response)
115            );
116            Ok(response)
117        } else {
118            let err = response.unwrap_err();
119            self.received_expectations
120                .lock()
121                .unwrap()
122                .push((expected_request, Err(anyhow::anyhow!(err.to_string()))));
123
124            Err(LanguageModelError::PermanentError(err.into()))
125        }
126    }
127
128    /// Fakes a stream, first it checks the expectations, then it streams the response
129    /// instantly in small chunks
130    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
131        let response = match self.complete(request).await {
132            Ok(response) => response,
133            Err(err) => return err.into(),
134        };
135
136        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<
137            Result<ChatCompletionResponse, LanguageModelError>,
138        >();
139
140        tokio::spawn(async move {
141            let mut chunk_response = ChatCompletionResponse::builder()
142                .maybe_tool_calls(response.tool_calls.clone())
143                .build()
144                .unwrap();
145
146            for chunk in response.message().unwrap().split_whitespace() {
147                tracing::debug!("[MockChatCompletion] Sending chunk: {chunk}");
148
149                let chunk_response = chunk_response.append_message_delta(Some(chunk)).clone();
150                let _ = tx.send(Ok(chunk_response));
151                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
152            }
153        });
154
155        Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
156    }
157}
158
159impl Drop for MockChatCompletion {
160    fn drop(&mut self) {
161        // We are still cloned, so do not check assertions yet
162        if Arc::strong_count(&self.received_expectations) > 1 {
163            return;
164        }
165        let Ok(expectations) = self.expectations.lock() else {
166            return;
167        };
168        let Ok(received) = self.received_expectations.lock() else {
169            return;
170        };
171
172        if expectations.is_empty() {
173            let num_received = received.len();
174            tracing::debug!("[MockChatCompletion] All {num_received} expectations were met");
175        } else {
176            let received = received
177                .iter()
178                .map(pretty_expectation)
179                .collect::<Vec<_>>()
180                .join("---\n");
181
182            let pending = expectations
183                .iter()
184                .map(pretty_expectation)
185                .collect::<Vec<_>>()
186                .join("---\n");
187
188            panic!(
189                "[MockChatCompletion] Not all expectations were met\n received:\n{received}\n\npending:\n{pending}"
190            );
191        }
192    }
193}
194
195fn pretty_expectation(
196    expectation: &(ChatCompletionRequest, Result<ChatCompletionResponse>),
197) -> String {
198    let mut output = String::new();
199
200    let request = &expectation.0;
201    output.push_str("Request:\n");
202    output.push_str(&pretty_request(request));
203
204    output.push_str(" =>\n");
205
206    let response_result = &expectation.1;
207
208    if let Ok(response) = response_result {
209        output += &pretty_response(response);
210    }
211
212    output
213}
214
215fn pretty_request(request: &ChatCompletionRequest) -> String {
216    let mut output = String::new();
217    for message in request.messages() {
218        writeln!(output, " {message}").unwrap();
219    }
220    output
221}
222
223fn pretty_response(response: &ChatCompletionResponse) -> String {
224    let mut output = String::new();
225    if let Some(message) = response.message() {
226        writeln!(output, " {message}").unwrap();
227    }
228    if let Some(tool_calls) = response.tool_calls() {
229        for tool_call in tool_calls {
230            writeln!(output, " {tool_call}").unwrap();
231        }
232    }
233    output
234}