Skip to main content

sage_runtime/
mock.rs

1//! Mock infrastructure for the Sage testing framework (RFC-0012).
2//!
3//! This module provides:
4//! - `MockResponse` - represents either a value or error response
5//! - `MockQueue` - thread-safe queue of mock responses
6//! - `MockLlmClient` - mock implementation of LLM inference
7//! - `MockToolRegistry` - mock implementations for tool calls
8//! - Task-local mock context for tool mocking in tests
9
10use crate::error::{SageError, SageResult};
11use serde::de::DeserializeOwned;
12use std::cell::RefCell;
13use std::future::Future;
14use std::sync::{Arc, Mutex};
15
16// Task-local storage for the mock tool registry.
17// This allows tests to intercept tool calls without threading the registry through all code.
18tokio::task_local! {
19    static MOCK_TOOL_REGISTRY: RefCell<Option<MockToolRegistry>>;
20}
21
22/// Run a future with a mock tool registry in scope.
23///
24/// All tool calls made during the execution of the future will check
25/// the registry for mocks before making real calls.
26///
27/// # Example
28/// ```ignore
29/// let registry = MockToolRegistry::new();
30/// registry.register("Http", "get", MockResponse::string("{\"status\": 200}"));
31///
32/// with_mock_tools(registry, async {
33///     // Http.get() calls here will return the mock response
34/// }).await;
35/// ```
36pub async fn with_mock_tools<F, R>(registry: MockToolRegistry, f: F) -> R
37where
38    F: Future<Output = R>,
39{
40    MOCK_TOOL_REGISTRY
41        .scope(RefCell::new(Some(registry)), f)
42        .await
43}
44
45/// Try to get a mock response for a tool function call.
46///
47/// Returns `Some(response)` if a mock is registered and available,
48/// `None` if no mock is registered or if called outside a mock context.
49///
50/// This is called by tool clients to intercept calls during tests.
51pub fn try_get_mock(tool: &str, function: &str) -> Option<MockResponse> {
52    MOCK_TOOL_REGISTRY
53        .try_with(|cell| {
54            cell.borrow_mut()
55                .as_ref()
56                .and_then(|reg| reg.get(tool, function))
57        })
58        .ok()
59        .flatten()
60}
61
62/// A mock response for an `infer` call.
63#[derive(Debug, Clone)]
64pub enum MockResponse {
65    /// A successful response with the given value.
66    Value(serde_json::Value),
67    /// A failure response with the given error message.
68    Fail(String),
69}
70
71impl MockResponse {
72    /// Create a successful mock response from a JSON-serializable value.
73    pub fn value<T: serde::Serialize>(value: T) -> Self {
74        Self::Value(serde_json::to_value(value).expect("failed to serialize mock value"))
75    }
76
77    /// Create a successful mock response from a string.
78    pub fn string(s: impl Into<String>) -> Self {
79        Self::Value(serde_json::Value::String(s.into()))
80    }
81
82    /// Create a failure mock response.
83    pub fn fail(message: impl Into<String>) -> Self {
84        Self::Fail(message.into())
85    }
86}
87
88/// A thread-safe queue of mock responses.
89///
90/// Mock responses are consumed in order - the first `infer` call gets
91/// the first mock, the second gets the second, etc.
92#[derive(Debug, Clone, Default)]
93pub struct MockQueue {
94    responses: Arc<Mutex<Vec<MockResponse>>>,
95}
96
97impl MockQueue {
98    /// Create a new empty mock queue.
99    pub fn new() -> Self {
100        Self::default()
101    }
102
103    /// Create a mock queue with the given responses.
104    pub fn with_responses(responses: Vec<MockResponse>) -> Self {
105        Self {
106            responses: Arc::new(Mutex::new(responses)),
107        }
108    }
109
110    /// Add a mock response to the queue.
111    pub fn push(&self, response: MockResponse) {
112        self.responses.lock().unwrap().push(response);
113    }
114
115    /// Pop the next mock response from the queue.
116    ///
117    /// Returns `None` if the queue is empty.
118    pub fn pop(&self) -> Option<MockResponse> {
119        let mut queue = self.responses.lock().unwrap();
120        if queue.is_empty() {
121            None
122        } else {
123            Some(queue.remove(0))
124        }
125    }
126
127    /// Check if the queue is empty.
128    pub fn is_empty(&self) -> bool {
129        self.responses.lock().unwrap().is_empty()
130    }
131
132    /// Get the number of remaining mock responses.
133    pub fn len(&self) -> usize {
134        self.responses.lock().unwrap().len()
135    }
136}
137
138/// Mock LLM client for testing.
139///
140/// This client uses a `MockQueue` to return pre-configured responses
141/// instead of making real API calls.
142#[derive(Debug, Clone)]
143pub struct MockLlmClient {
144    queue: MockQueue,
145}
146
147impl MockLlmClient {
148    /// Create a new mock client with an empty queue.
149    pub fn new() -> Self {
150        Self {
151            queue: MockQueue::new(),
152        }
153    }
154
155    /// Create a mock client with the given responses.
156    pub fn with_responses(responses: Vec<MockResponse>) -> Self {
157        Self {
158            queue: MockQueue::with_responses(responses),
159        }
160    }
161
162    /// Get a reference to the mock queue for adding responses.
163    pub fn queue(&self) -> &MockQueue {
164        &self.queue
165    }
166
167    /// Call the mock LLM with a prompt and return the raw string response.
168    ///
169    /// Returns an error if no mock responses are queued.
170    pub async fn infer_string(&self, _prompt: &str) -> SageResult<String> {
171        match self.queue.pop() {
172            Some(MockResponse::Value(value)) => {
173                // Convert JSON value to string
174                match value {
175                    serde_json::Value::String(s) => Ok(s),
176                    other => Ok(other.to_string()),
177                }
178            }
179            Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
180            None => Err(SageError::Llm(
181                "infer called with no mock available (E054)".to_string(),
182            )),
183        }
184    }
185
186    /// Call the mock LLM with a prompt and parse the response as the given type.
187    ///
188    /// Returns an error if no mock responses are queued.
189    pub async fn infer<T>(&self, _prompt: &str) -> SageResult<T>
190    where
191        T: DeserializeOwned,
192    {
193        match self.queue.pop() {
194            Some(MockResponse::Value(value)) => serde_json::from_value(value)
195                .map_err(|e| SageError::Llm(format!("failed to deserialize mock value: {e}"))),
196            Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
197            None => Err(SageError::Llm(
198                "infer called with no mock available (E054)".to_string(),
199            )),
200        }
201    }
202
203    /// Call the mock LLM with schema-injected prompt for structured output.
204    ///
205    /// Returns an error if no mock responses are queued.
206    pub async fn infer_structured<T>(&self, _prompt: &str, _schema: &str) -> SageResult<T>
207    where
208        T: DeserializeOwned,
209    {
210        // Same as infer - the schema is ignored for mocks
211        match self.queue.pop() {
212            Some(MockResponse::Value(value)) => serde_json::from_value(value)
213                .map_err(|e| SageError::Llm(format!("failed to deserialize mock value: {e}"))),
214            Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
215            None => Err(SageError::Llm(
216                "infer called with no mock available (E054)".to_string(),
217            )),
218        }
219    }
220}
221
222impl Default for MockLlmClient {
223    fn default() -> Self {
224        Self::new()
225    }
226}
227
228/// Mock registry for tool calls.
229///
230/// Stores mock responses for specific tool.function combinations.
231#[derive(Debug, Clone, Default)]
232pub struct MockToolRegistry {
233    mocks: Arc<Mutex<std::collections::HashMap<String, MockQueue>>>,
234}
235
236impl MockToolRegistry {
237    /// Create a new empty mock registry.
238    pub fn new() -> Self {
239        Self::default()
240    }
241
242    /// Register a mock response for a tool function.
243    ///
244    /// The key is in the format "ToolName.function_name".
245    pub fn register(&self, tool: &str, function: &str, response: MockResponse) {
246        let key = format!("{}.{}", tool, function);
247        let mut mocks = self.mocks.lock().unwrap();
248        mocks
249            .entry(key)
250            .or_insert_with(MockQueue::new)
251            .push(response);
252    }
253
254    /// Get the next mock response for a tool function.
255    ///
256    /// Returns `None` if no mock is registered for this function.
257    pub fn get(&self, tool: &str, function: &str) -> Option<MockResponse> {
258        let key = format!("{}.{}", tool, function);
259        let mocks = self.mocks.lock().unwrap();
260        mocks.get(&key).and_then(|q| q.pop())
261    }
262
263    /// Check if a mock is registered for a tool function.
264    pub fn has_mock(&self, tool: &str, function: &str) -> bool {
265        let key = format!("{}.{}", tool, function);
266        let mocks = self.mocks.lock().unwrap();
267        mocks.get(&key).is_some_and(|q| !q.is_empty())
268    }
269
270    /// Call a mocked tool function and return the result.
271    ///
272    /// Returns an error if no mock is registered.
273    pub async fn call<T>(&self, tool: &str, function: &str) -> SageResult<T>
274    where
275        T: DeserializeOwned,
276    {
277        match self.get(tool, function) {
278            Some(MockResponse::Value(value)) => serde_json::from_value(value).map_err(|e| {
279                SageError::Tool(format!("failed to deserialize mock tool response: {e}"))
280            }),
281            Some(MockResponse::Fail(msg)) => Err(SageError::Tool(msg)),
282            None => Err(SageError::Tool(format!(
283                "no mock registered for {}.{}",
284                tool, function
285            ))),
286        }
287    }
288
289    /// Call a mocked tool function and return the raw string.
290    pub async fn call_string(&self, tool: &str, function: &str) -> SageResult<String> {
291        match self.get(tool, function) {
292            Some(MockResponse::Value(value)) => match value {
293                serde_json::Value::String(s) => Ok(s),
294                other => Ok(other.to_string()),
295            },
296            Some(MockResponse::Fail(msg)) => Err(SageError::Tool(msg)),
297            None => Err(SageError::Tool(format!(
298                "no mock registered for {}.{}",
299                tool, function
300            ))),
301        }
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[tokio::test]
310    async fn mock_infer_string_returns_value() {
311        let client = MockLlmClient::with_responses(vec![MockResponse::string("hello world")]);
312        let result = client.infer_string("test").await.unwrap();
313        assert_eq!(result, "hello world");
314    }
315
316    #[tokio::test]
317    async fn mock_infer_string_returns_fail() {
318        let client = MockLlmClient::with_responses(vec![MockResponse::fail("test error")]);
319        let result = client.infer_string("test").await;
320        assert!(result.is_err());
321        assert!(result.unwrap_err().to_string().contains("test error"));
322    }
323
324    #[tokio::test]
325    async fn mock_infer_empty_queue_returns_error() {
326        let client = MockLlmClient::new();
327        let result = client.infer_string("test").await;
328        assert!(result.is_err());
329        assert!(result.unwrap_err().to_string().contains("E054"));
330    }
331
332    #[tokio::test]
333    async fn mock_queue_fifo_order() {
334        let client = MockLlmClient::with_responses(vec![
335            MockResponse::string("first"),
336            MockResponse::string("second"),
337            MockResponse::string("third"),
338        ]);
339
340        assert_eq!(client.infer_string("a").await.unwrap(), "first");
341        assert_eq!(client.infer_string("b").await.unwrap(), "second");
342        assert_eq!(client.infer_string("c").await.unwrap(), "third");
343        assert!(client.infer_string("d").await.is_err());
344    }
345
346    #[tokio::test]
347    async fn mock_infer_typed_value() {
348        #[derive(Debug, serde::Deserialize, PartialEq)]
349        struct Person {
350            name: String,
351            age: i32,
352        }
353
354        let client = MockLlmClient::with_responses(vec![MockResponse::value(
355            serde_json::json!({ "name": "Ward", "age": 42 }),
356        )]);
357
358        let person: Person = client.infer("test").await.unwrap();
359        assert_eq!(person.name, "Ward");
360        assert_eq!(person.age, 42);
361    }
362
363    #[test]
364    fn mock_queue_thread_safe() {
365        use std::thread;
366
367        let queue = MockQueue::with_responses(vec![
368            MockResponse::string("1"),
369            MockResponse::string("2"),
370            MockResponse::string("3"),
371        ]);
372
373        let queue_clone = queue.clone();
374        let handle = thread::spawn(move || {
375            queue_clone.pop();
376            queue_clone.pop();
377        });
378
379        handle.join().unwrap();
380        assert_eq!(queue.len(), 1);
381    }
382
383    #[tokio::test]
384    async fn mock_infer_structured() {
385        #[derive(Debug, serde::Deserialize, PartialEq)]
386        struct Summary {
387            text: String,
388            confidence: f64,
389        }
390
391        let client = MockLlmClient::with_responses(vec![MockResponse::value(serde_json::json!({
392            "text": "A summary",
393            "confidence": 0.95
394        }))]);
395
396        let summary: Summary = client
397            .infer_structured("summarize", "schema")
398            .await
399            .unwrap();
400        assert_eq!(summary.text, "A summary");
401        assert!((summary.confidence - 0.95).abs() < 0.001);
402    }
403
404    #[tokio::test]
405    async fn mock_tool_registry_basic() {
406        let registry = MockToolRegistry::new();
407
408        // Register a mock
409        registry.register("Http", "get", MockResponse::string("mocked response"));
410
411        // Should have mock
412        assert!(registry.has_mock("Http", "get"));
413
414        // Call and get result
415        let result: String = registry.call("Http", "get").await.unwrap();
416        assert_eq!(result, "mocked response");
417
418        // Queue should be empty now
419        assert!(!registry.has_mock("Http", "get"));
420    }
421
422    #[tokio::test]
423    async fn mock_tool_registry_multiple() {
424        let registry = MockToolRegistry::new();
425
426        // Register multiple mocks for same function
427        registry.register("Http", "get", MockResponse::string("first"));
428        registry.register("Http", "get", MockResponse::string("second"));
429
430        // Should get them in order
431        let r1: String = registry.call("Http", "get").await.unwrap();
432        let r2: String = registry.call("Http", "get").await.unwrap();
433
434        assert_eq!(r1, "first");
435        assert_eq!(r2, "second");
436    }
437
438    #[tokio::test]
439    async fn mock_tool_registry_fail() {
440        let registry = MockToolRegistry::new();
441        registry.register("Http", "get", MockResponse::fail("network error"));
442
443        let result: Result<String, _> = registry.call("Http", "get").await;
444        assert!(result.is_err());
445        assert!(result.unwrap_err().to_string().contains("network error"));
446    }
447
448    #[tokio::test]
449    async fn mock_tool_registry_no_mock() {
450        let registry = MockToolRegistry::new();
451
452        let result: Result<String, _> = registry.call("Http", "get").await;
453        assert!(result.is_err());
454        assert!(result
455            .unwrap_err()
456            .to_string()
457            .contains("no mock registered"));
458    }
459}