Skip to main content

traitclaw_test_utils/
provider.rs

1//! Mock LLM provider for deterministic testing.
2//!
3//! [`MockProvider`] returns pre-defined [`CompletionResponse`] values in
4//! sequence, enabling fully deterministic tests without real API calls.
5//!
6//! When the sequence is exhausted, subsequent calls return the **last**
7//! response (clamp-to-end behavior).
8//!
9//! # Quick Start
10//!
11//! ```rust
12//! use traitclaw_test_utils::provider::MockProvider;
13//!
14//! // Single text response — used for simple agent tests
15//! let p = MockProvider::text("Hello!");
16//!
17//! // Tool call → text — used for ReAct loop tests
18//! # use traitclaw_core::types::tool_call::ToolCall;
19//! let p = MockProvider::tool_then_text(
20//!     vec![ToolCall {
21//!         id: "1".into(),
22//!         name: "echo".into(),
23//!         arguments: r#"{"text":"hi"}"#.into(),
24//!     }],
25//!     "Done.",
26//! );
27//! ```
28
29use std::sync::atomic::{AtomicUsize, Ordering};
30
31use async_trait::async_trait;
32
33use traitclaw_core::traits::provider::Provider;
34use traitclaw_core::types::completion::{
35    CompletionRequest, CompletionResponse, ResponseContent, Usage,
36};
37use traitclaw_core::types::model_info::{ModelInfo, ModelTier};
38use traitclaw_core::types::stream::CompletionStream;
39use traitclaw_core::types::tool_call::ToolCall;
40use traitclaw_core::{Error, Result};
41
42/// Default usage stats for mock responses.
43fn default_usage() -> Usage {
44    Usage {
45        prompt_tokens: 10,
46        completion_tokens: 5,
47        total_tokens: 15,
48    }
49}
50
51/// Deterministic mock provider that returns responses in sequence.
52///
53/// Each call to [`complete()`](Provider::complete) returns the next
54/// response from the internal sequence. When all responses have been
55/// returned, subsequent calls return the **last** response (clamp
56/// behavior, not wrap-around).
57///
58/// # Thread Safety
59///
60/// `MockProvider` is [`Send`] + [`Sync`] by design — it uses
61/// [`AtomicUsize`] for lock-free call indexing.
62///
63/// # Example
64///
65/// ```rust
66/// use traitclaw_test_utils::provider::MockProvider;
67///
68/// let p = MockProvider::text("hello");
69/// // p.complete(req).await returns "hello" every time
70/// ```
71pub struct MockProvider {
72    /// Model information returned by [`Provider::model_info`].
73    pub info: ModelInfo,
74    /// Ordered list of responses to return.
75    pub responses: Vec<CompletionResponse>,
76    /// Tracks the current position in the response sequence.
77    call_idx: AtomicUsize,
78    /// Optional: return an error instead of a response.
79    error_message: Option<String>,
80}
81
82impl MockProvider {
83    /// Create a provider that always returns a single text response.
84    ///
85    /// # Example
86    ///
87    /// ```rust
88    /// use traitclaw_test_utils::provider::MockProvider;
89    ///
90    /// let p = MockProvider::text("I am a mock LLM");
91    /// ```
92    pub fn text(text: &str) -> Self {
93        Self {
94            info: ModelInfo::new("mock-model", ModelTier::Small, 4096, false, false, false),
95            responses: vec![CompletionResponse {
96                content: ResponseContent::Text(text.into()),
97                usage: default_usage(),
98            }],
99            call_idx: AtomicUsize::new(0),
100            error_message: None,
101        }
102    }
103
104    /// Create a provider with an explicit sequence of responses.
105    ///
106    /// Responses are returned in order. Once exhausted, the last
107    /// response is repeated.
108    ///
109    /// # Panics
110    ///
111    /// Panics if `responses` is empty.
112    ///
113    /// # Example
114    ///
115    /// ```rust
116    /// use traitclaw_test_utils::provider::MockProvider;
117    /// use traitclaw_core::types::completion::{CompletionResponse, ResponseContent, Usage};
118    ///
119    /// let p = MockProvider::sequence(vec![
120    ///     CompletionResponse {
121    ///         content: ResponseContent::Text("first".into()),
122    ///         usage: Usage { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
123    ///     },
124    ///     CompletionResponse {
125    ///         content: ResponseContent::Text("second".into()),
126    ///         usage: Usage { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
127    ///     },
128    /// ]);
129    /// ```
130    pub fn sequence(responses: Vec<CompletionResponse>) -> Self {
131        assert!(
132            !responses.is_empty(),
133            "MockProvider::sequence requires at least one response"
134        );
135        Self {
136            info: ModelInfo::new("mock-model", ModelTier::Small, 4096, true, false, false),
137            responses,
138            call_idx: AtomicUsize::new(0),
139            error_message: None,
140        }
141    }
142
143    /// Create a provider that returns tool calls first, then a final text response.
144    ///
145    /// This is the standard pattern for testing ReAct-style agent loops.
146    ///
147    /// # Example
148    ///
149    /// ```rust
150    /// use traitclaw_test_utils::provider::MockProvider;
151    /// use traitclaw_core::types::tool_call::ToolCall;
152    ///
153    /// let p = MockProvider::tool_then_text(
154    ///     vec![ToolCall {
155    ///         id: "call_1".into(),
156    ///         name: "search".into(),
157    ///         arguments: r#"{"query":"rust"}"#.into(),
158    ///     }],
159    ///     "Here are the results.",
160    /// );
161    /// ```
162    pub fn tool_then_text(tool_calls: Vec<ToolCall>, final_text: &str) -> Self {
163        Self::sequence(vec![
164            CompletionResponse {
165                content: ResponseContent::ToolCalls(tool_calls),
166                usage: default_usage(),
167            },
168            CompletionResponse {
169                content: ResponseContent::Text(final_text.into()),
170                usage: default_usage(),
171            },
172        ])
173    }
174
175    /// Create a provider that always returns tool calls (never text).
176    ///
177    /// Useful for testing tool-budget guards and loop detection.
178    pub fn always_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
179        Self {
180            info: ModelInfo::new("mock-model", ModelTier::Small, 4096, true, false, false),
181            responses: vec![CompletionResponse {
182                content: ResponseContent::ToolCalls(tool_calls),
183                usage: default_usage(),
184            }],
185            call_idx: AtomicUsize::new(0),
186            error_message: None,
187        }
188    }
189
190    /// Create a provider that always returns an error.
191    ///
192    /// Useful for testing error handling paths in strategies and agents.
193    ///
194    /// # Example
195    ///
196    /// ```rust
197    /// use traitclaw_test_utils::provider::MockProvider;
198    ///
199    /// let p = MockProvider::error("API rate limit exceeded");
200    /// // p.complete(req).await will return Err(Error::Runtime(...))
201    /// ```
202    pub fn error(msg: &str) -> Self {
203        Self {
204            info: ModelInfo::new("mock-model", ModelTier::Small, 4096, false, false, false),
205            responses: vec![],
206            call_idx: AtomicUsize::new(0),
207            error_message: Some(msg.to_string()),
208        }
209    }
210
211    /// Returns how many times `complete()` has been called.
212    pub fn call_count(&self) -> usize {
213        self.call_idx.load(Ordering::SeqCst)
214    }
215}
216
217#[async_trait]
218impl Provider for MockProvider {
219    async fn complete(&self, _req: CompletionRequest) -> Result<CompletionResponse> {
220        // Error path — always returns error if configured
221        if let Some(msg) = &self.error_message {
222            return Err(Error::Runtime(msg.clone()));
223        }
224
225        let idx = self.call_idx.fetch_add(1, Ordering::SeqCst);
226        Ok(self.responses[idx.min(self.responses.len() - 1)].clone())
227    }
228
229    async fn stream(&self, _req: CompletionRequest) -> Result<CompletionStream> {
230        unimplemented!("MockProvider does not support streaming")
231    }
232
233    fn model_info(&self) -> &ModelInfo {
234        &self.info
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use traitclaw_core::types::message::Message;
242
243    fn dummy_request() -> CompletionRequest {
244        CompletionRequest {
245            model: "mock-model".to_string(),
246            messages: vec![Message::user("test")],
247            tools: vec![],
248            max_tokens: None,
249            temperature: None,
250            response_format: None,
251            stream: false,
252        }
253    }
254
255    #[tokio::test]
256    async fn test_text_returns_correct_response() {
257        let p = MockProvider::text("hello");
258        let resp = p.complete(dummy_request()).await.unwrap();
259        match resp.content {
260            ResponseContent::Text(t) => assert_eq!(t, "hello"),
261            ResponseContent::ToolCalls(_) => panic!("expected Text"),
262        }
263    }
264
265    #[tokio::test]
266    async fn test_text_returns_same_response_on_multiple_calls() {
267        let p = MockProvider::text("constant");
268        for _ in 0..5 {
269            let resp = p.complete(dummy_request()).await.unwrap();
270            match &resp.content {
271                ResponseContent::Text(t) => assert_eq!(t, "constant"),
272                ResponseContent::ToolCalls(_) => panic!("expected Text"),
273            }
274        }
275        assert_eq!(p.call_count(), 5);
276    }
277
278    #[tokio::test]
279    async fn test_sequence_returns_in_order() {
280        let p = MockProvider::sequence(vec![
281            CompletionResponse {
282                content: ResponseContent::Text("first".into()),
283                usage: default_usage(),
284            },
285            CompletionResponse {
286                content: ResponseContent::Text("second".into()),
287                usage: default_usage(),
288            },
289        ]);
290
291        let r1 = p.complete(dummy_request()).await.unwrap();
292        let r2 = p.complete(dummy_request()).await.unwrap();
293
294        match r1.content {
295            ResponseContent::Text(t) => assert_eq!(t, "first"),
296            _ => panic!("expected first"),
297        }
298        match r2.content {
299            ResponseContent::Text(t) => assert_eq!(t, "second"),
300            _ => panic!("expected second"),
301        }
302    }
303
304    #[tokio::test]
305    async fn test_sequence_clamps_to_last_response() {
306        let p = MockProvider::sequence(vec![
307            CompletionResponse {
308                content: ResponseContent::Text("only".into()),
309                usage: default_usage(),
310            },
311            CompletionResponse {
312                content: ResponseContent::Text("last".into()),
313                usage: default_usage(),
314            },
315        ]);
316
317        // Exhaust sequence
318        let _ = p.complete(dummy_request()).await.unwrap(); // "only"
319        let _ = p.complete(dummy_request()).await.unwrap(); // "last"
320
321        // Beyond sequence — should clamp to "last"
322        let r3 = p.complete(dummy_request()).await.unwrap();
323        let r4 = p.complete(dummy_request()).await.unwrap();
324
325        match r3.content {
326            ResponseContent::Text(t) => assert_eq!(t, "last"),
327            _ => panic!("expected last"),
328        }
329        match r4.content {
330            ResponseContent::Text(t) => assert_eq!(t, "last"),
331            _ => panic!("expected last"),
332        }
333    }
334
335    #[tokio::test]
336    async fn test_tool_then_text_returns_tool_calls_then_text() {
337        let tool_call = ToolCall {
338            id: "call_1".into(),
339            name: "echo".into(),
340            arguments: r#"{"text":"hi"}"#.into(),
341        };
342        let p = MockProvider::tool_then_text(vec![tool_call.clone()], "done");
343
344        let r1 = p.complete(dummy_request()).await.unwrap();
345        match &r1.content {
346            ResponseContent::ToolCalls(calls) => {
347                assert_eq!(calls.len(), 1);
348                assert_eq!(calls[0].name, "echo");
349            }
350            ResponseContent::Text(_) => panic!("expected ToolCalls on first call"),
351        }
352
353        let r2 = p.complete(dummy_request()).await.unwrap();
354        match r2.content {
355            ResponseContent::Text(t) => assert_eq!(t, "done"),
356            ResponseContent::ToolCalls(_) => panic!("expected Text on second call"),
357        }
358    }
359
360    #[tokio::test]
361    async fn test_error_returns_error() {
362        let p = MockProvider::error("rate limited");
363        let result = p.complete(dummy_request()).await;
364        assert!(result.is_err());
365        let err_str = result.unwrap_err().to_string();
366        assert!(err_str.contains("rate limited"), "got: {err_str}");
367    }
368
369    #[tokio::test]
370    async fn test_always_tool_calls_never_returns_text() {
371        let tool_call = ToolCall {
372            id: "1".into(),
373            name: "search".into(),
374            arguments: "{}".into(),
375        };
376        let p = MockProvider::always_tool_calls(vec![tool_call]);
377
378        for _ in 0..3 {
379            let resp = p.complete(dummy_request()).await.unwrap();
380            assert!(
381                matches!(resp.content, ResponseContent::ToolCalls(_)),
382                "expected ToolCalls"
383            );
384        }
385    }
386
387    #[test]
388    fn test_mock_provider_is_send_sync() {
389        fn assert_send_sync<T: Send + Sync>() {}
390        assert_send_sync::<MockProvider>();
391    }
392
393    #[test]
394    fn test_call_count_tracks_invocations() {
395        let p = MockProvider::text("x");
396        assert_eq!(p.call_count(), 0);
397    }
398
399    #[test]
400    fn test_model_info_returns_expected_defaults() {
401        let p = MockProvider::text("x");
402        let info = p.model_info();
403        assert_eq!(info.name, "mock-model");
404        assert_eq!(info.context_window, 4096);
405    }
406}