swiftide_core/
test_utils.rs

1#![allow(clippy::missing_panics_doc)]
2use std::sync::{Arc, Mutex};
3
4use async_trait::async_trait;
5
6use crate::chat_completion::{
7    errors::ChatCompletionError, ChatCompletion, ChatCompletionRequest, ChatCompletionResponse,
8};
9use anyhow::Result;
10use pretty_assertions::assert_eq;
11
12#[macro_export]
13macro_rules! assert_default_prompt_snapshot {
14    ($node:expr, $($key:expr => $value:expr),*) => {
15        #[tokio::test]
16        async fn test_default_prompt() {
17        let template = default_prompt();
18        let mut prompt = template.to_prompt().with_node(&Node::new($node));
19        $(
20            prompt = prompt.with_context_value($key, $value);
21        )*
22        insta::assert_snapshot!(prompt.render().await.unwrap());
23        }
24    };
25
26    ($($key:expr => $value:expr),*) => {
27        #[tokio::test]
28        async fn test_default_prompt() {
29            let template = default_prompt();
30            let mut prompt = template.to_prompt();
31            $(
32                prompt = prompt.with_context_value($key, $value);
33            )*
34            insta::assert_snapshot!(prompt.render().await.unwrap());
35        }
36    };
37}
38
39type Expectations = Arc<Mutex<Vec<(ChatCompletionRequest, Result<ChatCompletionResponse>)>>>;
40
41#[derive(Clone)]
42pub struct MockChatCompletion {
43    pub expectations: Expectations,
44    pub received_expectations: Expectations,
45}
46
47impl Default for MockChatCompletion {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl MockChatCompletion {
54    pub fn new() -> Self {
55        Self {
56            expectations: Arc::new(Mutex::new(Vec::new())),
57            received_expectations: Arc::new(Mutex::new(Vec::new())),
58        }
59    }
60
61    pub fn expect_complete(
62        &self,
63        request: ChatCompletionRequest,
64        response: Result<ChatCompletionResponse>,
65    ) {
66        let mut mutex = self.expectations.lock().unwrap();
67
68        mutex.insert(0, (request, response));
69    }
70}
71
72#[async_trait]
73impl ChatCompletion for MockChatCompletion {
74    async fn complete(
75        &self,
76        request: &ChatCompletionRequest,
77    ) -> Result<ChatCompletionResponse, ChatCompletionError> {
78        let (expected_request, response) =
79            self.expectations.lock().unwrap().pop().unwrap_or_else(|| {
80                panic!(
81                    "Received completion request, but no expectations are set\n {}",
82                    pretty_request(request)
83                )
84            });
85
86        assert_eq!(
87            &expected_request,
88            request,
89            "Unexpected request\n: {}\nRemaining expectations:\n{}",
90            pretty_request(request),
91            pretty_expectation(&(expected_request.clone(), response))
92                + "---\n"
93                + &self
94                    .expectations
95                    .lock()
96                    .unwrap()
97                    .iter()
98                    .map(pretty_expectation)
99                    .collect::<Vec<_>>()
100                    .join("---\n")
101        );
102
103        if let Ok(response) = response {
104            self.received_expectations
105                .lock()
106                .unwrap()
107                .push((expected_request, Ok(response.clone())));
108
109            Ok(response)
110        } else {
111            let err = response.unwrap_err();
112            self.received_expectations
113                .lock()
114                .unwrap()
115                .push((expected_request, Err(anyhow::anyhow!(err.to_string()))));
116
117            Err(err.into())
118        }
119    }
120}
121
122impl Drop for MockChatCompletion {
123    fn drop(&mut self) {
124        // We are still cloned, so do not check assertions yet
125        if Arc::strong_count(&self.received_expectations) > 1 {
126            return;
127        }
128        let Ok(expectations) = self.expectations.lock() else {
129            return;
130        };
131        let Ok(received) = self.received_expectations.lock() else {
132            return;
133        };
134
135        if expectations.is_empty() {
136            let num_received = received.len();
137            tracing::debug!("[MockChatCompletion] All {num_received} expectations were met");
138        } else {
139            let received = received
140                .iter()
141                .map(pretty_expectation)
142                .collect::<Vec<_>>()
143                .join("---\n");
144
145            let pending = expectations
146                .iter()
147                .map(pretty_expectation)
148                .collect::<Vec<_>>()
149                .join("---\n");
150
151            panic!("[MockChatCompletion] Not all expectations were met\n received:\n{received}\n\npending:\n{pending}");
152        }
153    }
154}
155
156fn pretty_expectation(
157    expectation: &(ChatCompletionRequest, Result<ChatCompletionResponse>),
158) -> String {
159    let mut output = String::new();
160
161    let request = &expectation.0;
162    output.push_str("Request:\n");
163    output.push_str(&pretty_request(request));
164
165    output.push_str(" =>\n");
166
167    let response_result = &expectation.1;
168
169    if let Ok(response) = response_result {
170        output += &pretty_response(response);
171    }
172
173    output
174}
175
176fn pretty_request(request: &ChatCompletionRequest) -> String {
177    let mut output = String::new();
178    for message in request.messages() {
179        output.push_str(&format!(" {message}\n"));
180    }
181    output
182}
183
184fn pretty_response(response: &ChatCompletionResponse) -> String {
185    let mut output = String::new();
186    if let Some(message) = response.message() {
187        output.push_str(&format!(" {message}\n"));
188    }
189    if let Some(tool_calls) = response.tool_calls() {
190        for tool_call in tool_calls {
191            output.push_str(&format!(" {tool_call}\n"));
192        }
193    }
194    output
195}