swiftide_core/
test_utils.rs

1#![allow(clippy::missing_panics_doc)]
2use std::fmt::Write as _;
3use std::sync::{Arc, Mutex};
4
5use async_trait::async_trait;
6
7use crate::chat_completion::{
8    errors::ChatCompletionError, ChatCompletion, ChatCompletionRequest, ChatCompletionResponse,
9};
10use anyhow::Result;
11use pretty_assertions::assert_eq;
12
13#[macro_export]
14macro_rules! assert_default_prompt_snapshot {
15    ($node:expr, $($key:expr => $value:expr),*) => {
16        #[tokio::test]
17        async fn test_default_prompt() {
18        let template = default_prompt();
19        let mut prompt = template.clone().with_node(&Node::new($node));
20        $(
21            prompt = prompt.with_context_value($key, $value);
22        )*
23        insta::assert_snapshot!(prompt.render().unwrap());
24        }
25    };
26
27    ($($key:expr => $value:expr),*) => {
28        #[tokio::test]
29        async fn test_default_prompt() {
30            let template = default_prompt();
31            let mut prompt = template;
32            $(
33                prompt = prompt.with_context_value($key, $value);
34            )*
35            insta::assert_snapshot!(prompt.render().unwrap());
36        }
37    };
38}
39
40type Expectations = Arc<Mutex<Vec<(ChatCompletionRequest, Result<ChatCompletionResponse>)>>>;
41
42#[derive(Clone)]
43pub struct MockChatCompletion {
44    pub expectations: Expectations,
45    pub received_expectations: Expectations,
46}
47
48impl Default for MockChatCompletion {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl MockChatCompletion {
55    pub fn new() -> Self {
56        Self {
57            expectations: Arc::new(Mutex::new(Vec::new())),
58            received_expectations: Arc::new(Mutex::new(Vec::new())),
59        }
60    }
61
62    pub fn expect_complete(
63        &self,
64        request: ChatCompletionRequest,
65        response: Result<ChatCompletionResponse>,
66    ) {
67        let mut mutex = self.expectations.lock().unwrap();
68
69        mutex.insert(0, (request, response));
70    }
71}
72
73#[async_trait]
74impl ChatCompletion for MockChatCompletion {
75    async fn complete(
76        &self,
77        request: &ChatCompletionRequest,
78    ) -> Result<ChatCompletionResponse, ChatCompletionError> {
79        let (expected_request, response) =
80            self.expectations.lock().unwrap().pop().unwrap_or_else(|| {
81                panic!(
82                    "Received completion request, but no expectations are set\n {}",
83                    pretty_request(request)
84                )
85            });
86
87        assert_eq!(
88            &expected_request,
89            request,
90            "Unexpected request\n: {}\nRemaining expectations:\n{}",
91            pretty_request(request),
92            pretty_expectation(&(expected_request.clone(), response))
93                + "---\n"
94                + &self
95                    .expectations
96                    .lock()
97                    .unwrap()
98                    .iter()
99                    .map(pretty_expectation)
100                    .collect::<Vec<_>>()
101                    .join("---\n")
102        );
103
104        if let Ok(response) = response {
105            self.received_expectations
106                .lock()
107                .unwrap()
108                .push((expected_request, Ok(response.clone())));
109
110            Ok(response)
111        } else {
112            let err = response.unwrap_err();
113            self.received_expectations
114                .lock()
115                .unwrap()
116                .push((expected_request, Err(anyhow::anyhow!(err.to_string()))));
117
118            Err(err.into())
119        }
120    }
121}
122
123impl Drop for MockChatCompletion {
124    fn drop(&mut self) {
125        // We are still cloned, so do not check assertions yet
126        if Arc::strong_count(&self.received_expectations) > 1 {
127            return;
128        }
129        let Ok(expectations) = self.expectations.lock() else {
130            return;
131        };
132        let Ok(received) = self.received_expectations.lock() else {
133            return;
134        };
135
136        if expectations.is_empty() {
137            let num_received = received.len();
138            tracing::debug!("[MockChatCompletion] All {num_received} expectations were met");
139        } else {
140            let received = received
141                .iter()
142                .map(pretty_expectation)
143                .collect::<Vec<_>>()
144                .join("---\n");
145
146            let pending = expectations
147                .iter()
148                .map(pretty_expectation)
149                .collect::<Vec<_>>()
150                .join("---\n");
151
152            panic!("[MockChatCompletion] Not all expectations were met\n received:\n{received}\n\npending:\n{pending}");
153        }
154    }
155}
156
157fn pretty_expectation(
158    expectation: &(ChatCompletionRequest, Result<ChatCompletionResponse>),
159) -> String {
160    let mut output = String::new();
161
162    let request = &expectation.0;
163    output.push_str("Request:\n");
164    output.push_str(&pretty_request(request));
165
166    output.push_str(" =>\n");
167
168    let response_result = &expectation.1;
169
170    if let Ok(response) = response_result {
171        output += &pretty_response(response);
172    }
173
174    output
175}
176
177fn pretty_request(request: &ChatCompletionRequest) -> String {
178    let mut output = String::new();
179    for message in request.messages() {
180        writeln!(output, " {message}").unwrap();
181    }
182    output
183}
184
185fn pretty_response(response: &ChatCompletionResponse) -> String {
186    let mut output = String::new();
187    if let Some(message) = response.message() {
188        writeln!(output, " {message}").unwrap();
189    }
190    if let Some(tool_calls) = response.tool_calls() {
191        for tool_call in tool_calls {
192            writeln!(output, " {tool_call}").unwrap();
193        }
194    }
195    output
196}