semantic_query/clients/
mock.rs1use async_trait::async_trait;
2use std::sync::{Arc, Mutex, Weak};
3use std::collections::VecDeque;
4use crate::{core::LowLevelClient, error::AIError};
5
6#[derive(Debug, Clone)]
8pub enum MockResponse {
9 Success(String),
10 Error(AIError),
11}
12
13#[derive(Debug, Default)]
15pub struct MockState {
16 responses: VecDeque<MockResponse>,
17 fail_on_empty: bool,
18}
19
20impl MockState {
21 pub fn new(fail_on_empty: bool) -> Self {
22 Self {
23 responses: VecDeque::new(),
24 fail_on_empty,
25 }
26 }
27
28 pub fn push_response(&mut self, response: MockResponse) {
29 self.responses.push_back(response);
30 }
31
32 pub fn push_responses(&mut self, responses: Vec<MockResponse>) {
33 for response in responses {
34 self.responses.push_back(response);
35 }
36 }
37
38 pub fn next_response(&mut self) -> Result<MockResponse, AIError> {
39 self.responses.pop_front().ok_or_else(|| {
40 if self.fail_on_empty {
41 AIError::Mock("No mock responses available - did you forget to configure the mock?".to_string())
42 } else {
43 AIError::Mock("Mock queue exhausted".to_string())
45 }
46 })
47 }
48
49 pub fn clear(&mut self) {
50 self.responses.clear();
51 }
52
53 pub fn remaining_count(&self) -> usize {
54 self.responses.len()
55 }
56
57 pub fn is_empty(&self) -> bool {
58 self.responses.is_empty()
59 }
60}
61
62#[derive(Debug)]
64pub struct MockHandle {
65 state: Arc<Mutex<MockState>>,
66}
67
68impl MockHandle {
69 fn new(fail_on_empty: bool) -> Self {
70 Self {
71 state: Arc::new(Mutex::new(MockState::new(fail_on_empty))),
72 }
73 }
74
75 pub fn add_response(&self, response: MockResponse) {
77 let mut state = self.state.lock().unwrap();
78 state.push_response(response);
79 }
80
81 pub fn add_responses(&self, responses: Vec<MockResponse>) {
83 let mut state = self.state.lock().unwrap();
84 state.push_responses(responses);
85 }
86
87 pub fn add_json_response(&self, json: &str) {
89 self.add_response(MockResponse::Success(json.to_string()));
90 }
91
92 pub fn add_json_responses(&self, jsons: Vec<&str>) {
94 let responses: Vec<MockResponse> = jsons
95 .into_iter()
96 .map(|json| MockResponse::Success(json.to_string()))
97 .collect();
98 self.add_responses(responses);
99 }
100
101 pub fn add_error(&self, error: AIError) {
103 self.add_response(MockResponse::Error(error));
104 }
105
106 pub fn clear(&self) {
108 let mut state = self.state.lock().unwrap();
109 state.clear();
110 }
111
112 pub fn remaining_count(&self) -> usize {
114 let state = self.state.lock().unwrap();
115 state.remaining_count()
116 }
117
118 pub fn is_empty(&self) -> bool {
120 let state = self.state.lock().unwrap();
121 state.is_empty()
122 }
123
124 fn next_response(&self) -> Result<MockResponse, AIError> {
126 let mut state = self.state.lock().unwrap();
127 state.next_response()
128 }
129}
130
131#[derive(Debug)]
133pub struct MockClient {
134 handle: Weak<MockHandle>,
135}
136
137impl MockClient {
138 pub fn new() -> (Self, Arc<MockHandle>) {
141 let handle = Arc::new(MockHandle::new(true)); let weak_handle = Arc::downgrade(&handle);
143
144 let client = Self {
145 handle: weak_handle,
146 };
147
148 (client, handle)
149 }
150
151 pub fn with_responses(responses: Vec<MockResponse>) -> (Self, Arc<MockHandle>) {
153 let (client, handle) = Self::new();
154 handle.add_responses(responses);
155 (client, handle)
156 }
157
158 fn try_next_response(&self) -> Result<MockResponse, AIError> {
160 match self.handle.upgrade() {
161 Some(handle) => handle.next_response(),
162 None => Err(AIError::Mock(
163 "MockHandle has been dropped - mock is no longer controllable".to_string()
164 )),
165 }
166 }
167}
168
169impl Clone for MockClient {
170 fn clone(&self) -> Self {
171 Self {
172 handle: self.handle.clone(),
173 }
174 }
175}
176
177#[async_trait]
178impl LowLevelClient for MockClient {
179 async fn ask_raw(&self, _prompt: String) -> Result<String, AIError> {
180 match self.try_next_response()? {
181 MockResponse::Success(response) => Ok(response),
182 MockResponse::Error(error) => Err(error),
183 }
184 }
185
186 fn clone_box(&self) -> Box<dyn LowLevelClient> {
187 Box::new(self.clone())
188 }
189}
190
191#[derive(Debug, Clone, Default)]
193pub struct MockVoid;
194
195#[async_trait]
196impl LowLevelClient for MockVoid {
197 async fn ask_raw(&self, _prompt: String) -> Result<String, AIError> {
198 Ok("{}".to_string())
199 }
200
201 fn clone_box(&self) -> Box<dyn LowLevelClient> {
202 Box::new(self.clone())
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[tokio::test]
211 async fn test_mock_with_responses() {
212 let (client, mock_handle) = MockClient::new();
213
214 mock_handle.add_json_responses(vec![
216 r#"{"result": "first"}"#,
217 r#"{"result": "second"}"#,
218 ]);
219
220 let response1 = client.ask_raw("test1".to_string()).await.unwrap();
222 let response2 = client.ask_raw("test2".to_string()).await.unwrap();
223
224 assert_eq!(response1, r#"{"result": "first"}"#);
225 assert_eq!(response2, r#"{"result": "second"}"#);
226
227 let response3 = client.ask_raw("test3".to_string()).await;
229 assert!(response3.is_err());
230 assert!(response3.unwrap_err().to_string().contains("No mock responses available"));
231 }
232
233 #[tokio::test]
234 async fn test_mock_handle_dropped() {
235 let client = {
236 let (client, _mock_handle) = MockClient::new();
237 client
239 };
240
241 let result = client.ask_raw("test".to_string()).await;
243 assert!(result.is_err());
244 assert!(result.unwrap_err().to_string().contains("MockHandle has been dropped"));
245 }
246
247 #[tokio::test]
248 async fn test_mock_with_errors() {
249 let (client, mock_handle) = MockClient::new();
250
251 mock_handle.add_response(MockResponse::Success(r#"{"ok": true}"#.to_string()));
252 mock_handle.add_error(AIError::Mock("Simulated error".to_string()));
253
254 let response1 = client.ask_raw("test1".to_string()).await.unwrap();
256 assert_eq!(response1, r#"{"ok": true}"#);
257
258 let response2 = client.ask_raw("test2".to_string()).await;
260 assert!(response2.is_err());
261 assert!(response2.unwrap_err().to_string().contains("Simulated error"));
262 }
263
264 #[tokio::test]
265 async fn test_runtime_mock_configuration() {
266 let (client, mock_handle) = MockClient::new();
267
268 let result = client.ask_raw("test".to_string()).await;
270 assert!(result.is_err());
271
272 mock_handle.add_json_response(r#"{"added": "later"}"#);
274
275 let response = client.ask_raw("test".to_string()).await.unwrap();
277 assert_eq!(response, r#"{"added": "later"}"#);
278
279 assert!(mock_handle.is_empty());
281 assert_eq!(mock_handle.remaining_count(), 0);
282 }
283}