spec_ai_core/agent/providers/
mock.rs

1//! Mock Model Provider
2//!
3//! A simple mock provider for testing that returns canned responses.
4
5use crate::agent::model::{
6    GenerationConfig, ModelProvider, ModelResponse, ProviderKind, ProviderMetadata, TokenUsage,
7};
8use anyhow::Result;
9use async_stream::stream;
10use async_trait::async_trait;
11use futures::Stream;
12use std::pin::Pin;
13
14/// Mock provider that returns predefined responses
15#[derive(Debug, Clone)]
16pub struct MockProvider {
17    /// Canned responses to cycle through
18    responses: Vec<String>,
19    /// Current response index
20    current_index: std::sync::Arc<std::sync::Mutex<usize>>,
21    /// Model name to report
22    model_name: String,
23}
24
25impl MockProvider {
26    /// Create a new mock provider with a single response
27    pub fn new(response: impl Into<String>) -> Self {
28        Self {
29            responses: vec![response.into()],
30            current_index: std::sync::Arc::new(std::sync::Mutex::new(0)),
31            model_name: "mock-model".to_string(),
32        }
33    }
34
35    /// Create a new mock provider with multiple responses
36    pub fn with_responses(responses: Vec<String>) -> Self {
37        Self {
38            responses,
39            current_index: std::sync::Arc::new(std::sync::Mutex::new(0)),
40            model_name: "mock-model".to_string(),
41        }
42    }
43
44    /// Set the model name
45    pub fn with_model_name(mut self, model_name: impl Into<String>) -> Self {
46        self.model_name = model_name.into();
47        self
48    }
49
50    /// Get the next response (cycles through available responses)
51    fn next_response(&self) -> String {
52        let mut index = self.current_index.lock().unwrap();
53        let response = self.responses[*index % self.responses.len()].clone();
54        *index += 1;
55        response
56    }
57}
58
59impl Default for MockProvider {
60    fn default() -> Self {
61        Self::new("This is a mock response from the test provider.")
62    }
63}
64
65#[async_trait]
66impl ModelProvider for MockProvider {
67    async fn generate(&self, _prompt: &str, _config: &GenerationConfig) -> Result<ModelResponse> {
68        let content = self.next_response();
69        let prompt_tokens = 10; // Mock values
70        let completion_tokens = content.split_whitespace().count() as u32;
71
72        Ok(ModelResponse {
73            content,
74            model: self.model_name.clone(),
75            usage: Some(TokenUsage {
76                prompt_tokens,
77                completion_tokens,
78                total_tokens: prompt_tokens + completion_tokens,
79            }),
80            finish_reason: Some("stop".to_string()),
81            tool_calls: None,
82            reasoning: None,
83        })
84    }
85
86    async fn stream(
87        &self,
88        _prompt: &str,
89        _config: &GenerationConfig,
90    ) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>> {
91        let content = self.next_response();
92        let words: Vec<String> = content.split_whitespace().map(|s| s.to_string()).collect();
93
94        let stream = stream! {
95            for word in words {
96                yield Ok(format!("{} ", word));
97                // Simulate network delay
98                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
99            }
100        };
101
102        Ok(Box::pin(stream))
103    }
104
105    fn metadata(&self) -> ProviderMetadata {
106        ProviderMetadata {
107            name: "Mock Provider".to_string(),
108            supported_models: vec![
109                "mock-model".to_string(),
110                "mock-gpt-4".to_string(),
111                "mock-claude-3".to_string(),
112            ],
113            supports_streaming: true,
114        }
115    }
116
117    fn kind(&self) -> ProviderKind {
118        ProviderKind::Mock
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use futures::StreamExt;
126
127    #[tokio::test]
128    async fn test_mock_provider_generate() {
129        let provider = MockProvider::new("Hello, world!");
130        let config = GenerationConfig::default();
131
132        let response = provider.generate("test prompt", &config).await.unwrap();
133
134        assert_eq!(response.content, "Hello, world!");
135        assert_eq!(response.model, "mock-model");
136        assert!(response.usage.is_some());
137        assert_eq!(response.finish_reason, Some("stop".to_string()));
138    }
139
140    #[tokio::test]
141    async fn test_mock_provider_multiple_responses() {
142        let provider = MockProvider::with_responses(vec![
143            "First response".to_string(),
144            "Second response".to_string(),
145            "Third response".to_string(),
146        ]);
147        let config = GenerationConfig::default();
148
149        let resp1 = provider.generate("prompt", &config).await.unwrap();
150        assert_eq!(resp1.content, "First response");
151
152        let resp2 = provider.generate("prompt", &config).await.unwrap();
153        assert_eq!(resp2.content, "Second response");
154
155        let resp3 = provider.generate("prompt", &config).await.unwrap();
156        assert_eq!(resp3.content, "Third response");
157
158        // Should cycle back to first
159        let resp4 = provider.generate("prompt", &config).await.unwrap();
160        assert_eq!(resp4.content, "First response");
161    }
162
163    #[tokio::test]
164    async fn test_mock_provider_stream() {
165        let provider = MockProvider::new("Hello world test");
166        let config = GenerationConfig::default();
167
168        let mut stream = provider.stream("test prompt", &config).await.unwrap();
169        let mut chunks = Vec::new();
170
171        while let Some(chunk) = stream.next().await {
172            chunks.push(chunk.unwrap());
173        }
174
175        assert_eq!(chunks.len(), 3); // "Hello ", "world ", "test "
176        assert!(chunks[0].contains("Hello"));
177        assert!(chunks[1].contains("world"));
178        assert!(chunks[2].contains("test"));
179    }
180
181    #[tokio::test]
182    async fn test_mock_provider_metadata() {
183        let provider = MockProvider::default();
184        let metadata = provider.metadata();
185
186        assert_eq!(metadata.name, "Mock Provider");
187        assert!(metadata.supports_streaming);
188        assert!(!metadata.supported_models.is_empty());
189    }
190
191    #[tokio::test]
192    async fn test_mock_provider_custom_model_name() {
193        let provider = MockProvider::new("test").with_model_name("custom-model");
194        let config = GenerationConfig::default();
195
196        let response = provider.generate("prompt", &config).await.unwrap();
197        assert_eq!(response.model, "custom-model");
198    }
199}