1use std::sync::atomic::{AtomicUsize, Ordering};
30
31use async_trait::async_trait;
32
33use traitclaw_core::traits::provider::Provider;
34use traitclaw_core::types::completion::{
35 CompletionRequest, CompletionResponse, ResponseContent, Usage,
36};
37use traitclaw_core::types::model_info::{ModelInfo, ModelTier};
38use traitclaw_core::types::stream::CompletionStream;
39use traitclaw_core::types::tool_call::ToolCall;
40use traitclaw_core::{Error, Result};
41
42fn default_usage() -> Usage {
44 Usage {
45 prompt_tokens: 10,
46 completion_tokens: 5,
47 total_tokens: 15,
48 }
49}
50
51pub struct MockProvider {
72 pub info: ModelInfo,
74 pub responses: Vec<CompletionResponse>,
76 call_idx: AtomicUsize,
78 error_message: Option<String>,
80}
81
82impl MockProvider {
83 pub fn text(text: &str) -> Self {
93 Self {
94 info: ModelInfo::new("mock-model", ModelTier::Small, 4096, false, false, false),
95 responses: vec![CompletionResponse {
96 content: ResponseContent::Text(text.into()),
97 usage: default_usage(),
98 }],
99 call_idx: AtomicUsize::new(0),
100 error_message: None,
101 }
102 }
103
104 pub fn sequence(responses: Vec<CompletionResponse>) -> Self {
131 assert!(
132 !responses.is_empty(),
133 "MockProvider::sequence requires at least one response"
134 );
135 Self {
136 info: ModelInfo::new("mock-model", ModelTier::Small, 4096, true, false, false),
137 responses,
138 call_idx: AtomicUsize::new(0),
139 error_message: None,
140 }
141 }
142
143 pub fn tool_then_text(tool_calls: Vec<ToolCall>, final_text: &str) -> Self {
163 Self::sequence(vec![
164 CompletionResponse {
165 content: ResponseContent::ToolCalls(tool_calls),
166 usage: default_usage(),
167 },
168 CompletionResponse {
169 content: ResponseContent::Text(final_text.into()),
170 usage: default_usage(),
171 },
172 ])
173 }
174
175 pub fn always_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
179 Self {
180 info: ModelInfo::new("mock-model", ModelTier::Small, 4096, true, false, false),
181 responses: vec![CompletionResponse {
182 content: ResponseContent::ToolCalls(tool_calls),
183 usage: default_usage(),
184 }],
185 call_idx: AtomicUsize::new(0),
186 error_message: None,
187 }
188 }
189
190 pub fn error(msg: &str) -> Self {
203 Self {
204 info: ModelInfo::new("mock-model", ModelTier::Small, 4096, false, false, false),
205 responses: vec![],
206 call_idx: AtomicUsize::new(0),
207 error_message: Some(msg.to_string()),
208 }
209 }
210
211 pub fn call_count(&self) -> usize {
213 self.call_idx.load(Ordering::SeqCst)
214 }
215}
216
217#[async_trait]
218impl Provider for MockProvider {
219 async fn complete(&self, _req: CompletionRequest) -> Result<CompletionResponse> {
220 if let Some(msg) = &self.error_message {
222 return Err(Error::Runtime(msg.clone()));
223 }
224
225 let idx = self.call_idx.fetch_add(1, Ordering::SeqCst);
226 Ok(self.responses[idx.min(self.responses.len() - 1)].clone())
227 }
228
229 async fn stream(&self, _req: CompletionRequest) -> Result<CompletionStream> {
230 unimplemented!("MockProvider does not support streaming")
231 }
232
233 fn model_info(&self) -> &ModelInfo {
234 &self.info
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use traitclaw_core::types::message::Message;
242
243 fn dummy_request() -> CompletionRequest {
244 CompletionRequest {
245 model: "mock-model".to_string(),
246 messages: vec![Message::user("test")],
247 tools: vec![],
248 max_tokens: None,
249 temperature: None,
250 response_format: None,
251 stream: false,
252 }
253 }
254
255 #[tokio::test]
256 async fn test_text_returns_correct_response() {
257 let p = MockProvider::text("hello");
258 let resp = p.complete(dummy_request()).await.unwrap();
259 match resp.content {
260 ResponseContent::Text(t) => assert_eq!(t, "hello"),
261 ResponseContent::ToolCalls(_) => panic!("expected Text"),
262 }
263 }
264
265 #[tokio::test]
266 async fn test_text_returns_same_response_on_multiple_calls() {
267 let p = MockProvider::text("constant");
268 for _ in 0..5 {
269 let resp = p.complete(dummy_request()).await.unwrap();
270 match &resp.content {
271 ResponseContent::Text(t) => assert_eq!(t, "constant"),
272 ResponseContent::ToolCalls(_) => panic!("expected Text"),
273 }
274 }
275 assert_eq!(p.call_count(), 5);
276 }
277
278 #[tokio::test]
279 async fn test_sequence_returns_in_order() {
280 let p = MockProvider::sequence(vec![
281 CompletionResponse {
282 content: ResponseContent::Text("first".into()),
283 usage: default_usage(),
284 },
285 CompletionResponse {
286 content: ResponseContent::Text("second".into()),
287 usage: default_usage(),
288 },
289 ]);
290
291 let r1 = p.complete(dummy_request()).await.unwrap();
292 let r2 = p.complete(dummy_request()).await.unwrap();
293
294 match r1.content {
295 ResponseContent::Text(t) => assert_eq!(t, "first"),
296 _ => panic!("expected first"),
297 }
298 match r2.content {
299 ResponseContent::Text(t) => assert_eq!(t, "second"),
300 _ => panic!("expected second"),
301 }
302 }
303
304 #[tokio::test]
305 async fn test_sequence_clamps_to_last_response() {
306 let p = MockProvider::sequence(vec![
307 CompletionResponse {
308 content: ResponseContent::Text("only".into()),
309 usage: default_usage(),
310 },
311 CompletionResponse {
312 content: ResponseContent::Text("last".into()),
313 usage: default_usage(),
314 },
315 ]);
316
317 let _ = p.complete(dummy_request()).await.unwrap(); let _ = p.complete(dummy_request()).await.unwrap(); let r3 = p.complete(dummy_request()).await.unwrap();
323 let r4 = p.complete(dummy_request()).await.unwrap();
324
325 match r3.content {
326 ResponseContent::Text(t) => assert_eq!(t, "last"),
327 _ => panic!("expected last"),
328 }
329 match r4.content {
330 ResponseContent::Text(t) => assert_eq!(t, "last"),
331 _ => panic!("expected last"),
332 }
333 }
334
335 #[tokio::test]
336 async fn test_tool_then_text_returns_tool_calls_then_text() {
337 let tool_call = ToolCall {
338 id: "call_1".into(),
339 name: "echo".into(),
340 arguments: r#"{"text":"hi"}"#.into(),
341 };
342 let p = MockProvider::tool_then_text(vec![tool_call.clone()], "done");
343
344 let r1 = p.complete(dummy_request()).await.unwrap();
345 match &r1.content {
346 ResponseContent::ToolCalls(calls) => {
347 assert_eq!(calls.len(), 1);
348 assert_eq!(calls[0].name, "echo");
349 }
350 ResponseContent::Text(_) => panic!("expected ToolCalls on first call"),
351 }
352
353 let r2 = p.complete(dummy_request()).await.unwrap();
354 match r2.content {
355 ResponseContent::Text(t) => assert_eq!(t, "done"),
356 ResponseContent::ToolCalls(_) => panic!("expected Text on second call"),
357 }
358 }
359
360 #[tokio::test]
361 async fn test_error_returns_error() {
362 let p = MockProvider::error("rate limited");
363 let result = p.complete(dummy_request()).await;
364 assert!(result.is_err());
365 let err_str = result.unwrap_err().to_string();
366 assert!(err_str.contains("rate limited"), "got: {err_str}");
367 }
368
369 #[tokio::test]
370 async fn test_always_tool_calls_never_returns_text() {
371 let tool_call = ToolCall {
372 id: "1".into(),
373 name: "search".into(),
374 arguments: "{}".into(),
375 };
376 let p = MockProvider::always_tool_calls(vec![tool_call]);
377
378 for _ in 0..3 {
379 let resp = p.complete(dummy_request()).await.unwrap();
380 assert!(
381 matches!(resp.content, ResponseContent::ToolCalls(_)),
382 "expected ToolCalls"
383 );
384 }
385 }
386
387 #[test]
388 fn test_mock_provider_is_send_sync() {
389 fn assert_send_sync<T: Send + Sync>() {}
390 assert_send_sync::<MockProvider>();
391 }
392
393 #[test]
394 fn test_call_count_tracks_invocations() {
395 let p = MockProvider::text("x");
396 assert_eq!(p.call_count(), 0);
397 }
398
399 #[test]
400 fn test_model_info_returns_expected_defaults() {
401 let p = MockProvider::text("x");
402 let info = p.model_info();
403 assert_eq!(info.name, "mock-model");
404 assert_eq!(info.context_window, 4096);
405 }
406}