spec_ai_core/agent/providers/
mock.rs1use crate::agent::model::{
6 GenerationConfig, ModelProvider, ModelResponse, ProviderKind, ProviderMetadata, TokenUsage,
7};
8use anyhow::Result;
9use async_stream::stream;
10use async_trait::async_trait;
11use futures::Stream;
12use std::pin::Pin;
13
14#[derive(Debug, Clone)]
16pub struct MockProvider {
17 responses: Vec<String>,
19 current_index: std::sync::Arc<std::sync::Mutex<usize>>,
21 model_name: String,
23}
24
25impl MockProvider {
26 pub fn new(response: impl Into<String>) -> Self {
28 Self {
29 responses: vec![response.into()],
30 current_index: std::sync::Arc::new(std::sync::Mutex::new(0)),
31 model_name: "mock-model".to_string(),
32 }
33 }
34
35 pub fn with_responses(responses: Vec<String>) -> Self {
37 Self {
38 responses,
39 current_index: std::sync::Arc::new(std::sync::Mutex::new(0)),
40 model_name: "mock-model".to_string(),
41 }
42 }
43
44 pub fn with_model_name(mut self, model_name: impl Into<String>) -> Self {
46 self.model_name = model_name.into();
47 self
48 }
49
50 fn next_response(&self) -> String {
52 let mut index = self.current_index.lock().unwrap();
53 let response = self.responses[*index % self.responses.len()].clone();
54 *index += 1;
55 response
56 }
57}
58
59impl Default for MockProvider {
60 fn default() -> Self {
61 Self::new("This is a mock response from the test provider.")
62 }
63}
64
65#[async_trait]
66impl ModelProvider for MockProvider {
67 async fn generate(&self, _prompt: &str, _config: &GenerationConfig) -> Result<ModelResponse> {
68 let content = self.next_response();
69 let prompt_tokens = 10; let completion_tokens = content.split_whitespace().count() as u32;
71
72 Ok(ModelResponse {
73 content,
74 model: self.model_name.clone(),
75 usage: Some(TokenUsage {
76 prompt_tokens,
77 completion_tokens,
78 total_tokens: prompt_tokens + completion_tokens,
79 }),
80 finish_reason: Some("stop".to_string()),
81 tool_calls: None,
82 reasoning: None,
83 })
84 }
85
86 async fn stream(
87 &self,
88 _prompt: &str,
89 _config: &GenerationConfig,
90 ) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>> {
91 let content = self.next_response();
92 let words: Vec<String> = content.split_whitespace().map(|s| s.to_string()).collect();
93
94 let stream = stream! {
95 for word in words {
96 yield Ok(format!("{} ", word));
97 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
99 }
100 };
101
102 Ok(Box::pin(stream))
103 }
104
105 fn metadata(&self) -> ProviderMetadata {
106 ProviderMetadata {
107 name: "Mock Provider".to_string(),
108 supported_models: vec![
109 "mock-model".to_string(),
110 "mock-gpt-4".to_string(),
111 "mock-claude-3".to_string(),
112 ],
113 supports_streaming: true,
114 }
115 }
116
117 fn kind(&self) -> ProviderKind {
118 ProviderKind::Mock
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use futures::StreamExt;
126
127 #[tokio::test]
128 async fn test_mock_provider_generate() {
129 let provider = MockProvider::new("Hello, world!");
130 let config = GenerationConfig::default();
131
132 let response = provider.generate("test prompt", &config).await.unwrap();
133
134 assert_eq!(response.content, "Hello, world!");
135 assert_eq!(response.model, "mock-model");
136 assert!(response.usage.is_some());
137 assert_eq!(response.finish_reason, Some("stop".to_string()));
138 }
139
140 #[tokio::test]
141 async fn test_mock_provider_multiple_responses() {
142 let provider = MockProvider::with_responses(vec![
143 "First response".to_string(),
144 "Second response".to_string(),
145 "Third response".to_string(),
146 ]);
147 let config = GenerationConfig::default();
148
149 let resp1 = provider.generate("prompt", &config).await.unwrap();
150 assert_eq!(resp1.content, "First response");
151
152 let resp2 = provider.generate("prompt", &config).await.unwrap();
153 assert_eq!(resp2.content, "Second response");
154
155 let resp3 = provider.generate("prompt", &config).await.unwrap();
156 assert_eq!(resp3.content, "Third response");
157
158 let resp4 = provider.generate("prompt", &config).await.unwrap();
160 assert_eq!(resp4.content, "First response");
161 }
162
163 #[tokio::test]
164 async fn test_mock_provider_stream() {
165 let provider = MockProvider::new("Hello world test");
166 let config = GenerationConfig::default();
167
168 let mut stream = provider.stream("test prompt", &config).await.unwrap();
169 let mut chunks = Vec::new();
170
171 while let Some(chunk) = stream.next().await {
172 chunks.push(chunk.unwrap());
173 }
174
175 assert_eq!(chunks.len(), 3); assert!(chunks[0].contains("Hello"));
177 assert!(chunks[1].contains("world"));
178 assert!(chunks[2].contains("test"));
179 }
180
181 #[tokio::test]
182 async fn test_mock_provider_metadata() {
183 let provider = MockProvider::default();
184 let metadata = provider.metadata();
185
186 assert_eq!(metadata.name, "Mock Provider");
187 assert!(metadata.supports_streaming);
188 assert!(!metadata.supported_models.is_empty());
189 }
190
191 #[tokio::test]
192 async fn test_mock_provider_custom_model_name() {
193 let provider = MockProvider::new("test").with_model_name("custom-model");
194 let config = GenerationConfig::default();
195
196 let response = provider.generate("prompt", &config).await.unwrap();
197 assert_eq!(response.model, "custom-model");
198 }
199}