spec_ai/spec_ai_core/agent/
model.rs1use anyhow::Result;
7use async_trait::async_trait;
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use std::pin::Pin;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct GenerationConfig {
15 pub temperature: Option<f32>,
17 pub max_tokens: Option<u32>,
19 pub stop_sequences: Option<Vec<String>>,
21 pub top_p: Option<f32>,
23 pub frequency_penalty: Option<f32>,
25 pub presence_penalty: Option<f32>,
27}
28
29impl Default for GenerationConfig {
30 fn default() -> Self {
31 Self {
32 temperature: Some(0.7),
33 max_tokens: Some(2048),
34 stop_sequences: None,
35 top_p: Some(1.0),
36 frequency_penalty: None,
37 presence_penalty: None,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ToolCall {
45 pub id: String,
47 pub function_name: String,
49 pub arguments: serde_json::Value,
51}
52
53pub fn parse_thinking_tokens(response: &str) -> (Option<String>, String) {
76 let think_pattern = regex::Regex::new(r"<think>([\s\S]*?)</think>").unwrap();
78
79 let reasoning = if let Some(captures) = think_pattern.captures(response) {
81 captures.get(1).map(|m| m.as_str().trim().to_string())
82 } else {
83 None
84 };
85
86 let content = if let Some(end_idx) = response.find("</think>") {
88 let after_think = &response[end_idx + "</think>".len()..];
90 after_think.trim().to_string()
91 } else {
92 response.to_string()
94 };
95
96 (reasoning, content)
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ModelResponse {
102 pub content: String,
104 pub model: String,
106 pub usage: Option<TokenUsage>,
108 pub finish_reason: Option<String>,
110 #[serde(skip_serializing_if = "Option::is_none")]
112 pub tool_calls: Option<Vec<ToolCall>>,
113 #[serde(skip_serializing_if = "Option::is_none")]
115 pub reasoning: Option<String>,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct TokenUsage {
121 pub prompt_tokens: u32,
122 pub completion_tokens: u32,
123 pub total_tokens: u32,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ImageAttachment {
129 pub mime: String,
130 pub data: Vec<u8>,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct ProviderMetadata {
136 pub name: String,
138 pub supported_models: Vec<String>,
140 pub supports_streaming: bool,
142}
143
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
146#[serde(rename_all = "lowercase")]
147pub enum ProviderKind {
148 Mock,
149 #[cfg(feature = "openai")]
150 OpenAI,
151 #[cfg(feature = "anthropic")]
152 Anthropic,
153 #[cfg(feature = "ollama")]
154 Ollama,
155 #[cfg(feature = "mlx")]
156 MLX,
157 #[cfg(feature = "lmstudio")]
158 LMStudio,
159}
160
161impl ProviderKind {
162 #[allow(clippy::should_implement_trait)]
163 pub fn from_str(s: &str) -> Option<Self> {
164 match s.to_lowercase().as_str() {
165 "mock" => Some(ProviderKind::Mock),
166 #[cfg(feature = "openai")]
167 "openai" => Some(ProviderKind::OpenAI),
168 #[cfg(feature = "anthropic")]
169 "anthropic" => Some(ProviderKind::Anthropic),
170 #[cfg(feature = "ollama")]
171 "ollama" => Some(ProviderKind::Ollama),
172 #[cfg(feature = "mlx")]
173 "mlx" => Some(ProviderKind::MLX),
174 #[cfg(feature = "lmstudio")]
175 "lmstudio" => Some(ProviderKind::LMStudio),
176 _ => None,
177 }
178 }
179
180 pub fn as_str(&self) -> &'static str {
181 match self {
182 ProviderKind::Mock => "mock",
183 #[cfg(feature = "openai")]
184 ProviderKind::OpenAI => "openai",
185 #[cfg(feature = "anthropic")]
186 ProviderKind::Anthropic => "anthropic",
187 #[cfg(feature = "ollama")]
188 ProviderKind::Ollama => "ollama",
189 #[cfg(feature = "mlx")]
190 ProviderKind::MLX => "mlx",
191 #[cfg(feature = "lmstudio")]
192 ProviderKind::LMStudio => "lmstudio",
193 }
194 }
195}
196
197#[async_trait]
199pub trait ModelProvider: Send + Sync {
200 async fn generate(&self, prompt: &str, config: &GenerationConfig) -> Result<ModelResponse>;
202
203 async fn generate_with_attachments(
205 &self,
206 prompt: &str,
207 attachments: &[ImageAttachment],
208 config: &GenerationConfig,
209 ) -> Result<ModelResponse> {
210 let _ = attachments;
211 self.generate(prompt, config).await
212 }
213
214 async fn stream(
216 &self,
217 prompt: &str,
218 config: &GenerationConfig,
219 ) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>>;
220
221 async fn stream_with_attachments(
223 &self,
224 prompt: &str,
225 attachments: &[ImageAttachment],
226 config: &GenerationConfig,
227 ) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>> {
228 let _ = attachments;
229 self.stream(prompt, config).await
230 }
231
232 fn metadata(&self) -> ProviderMetadata;
234
235 fn kind(&self) -> ProviderKind;
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn test_provider_kind_from_str() {
245 assert_eq!(ProviderKind::from_str("mock"), Some(ProviderKind::Mock));
246 assert_eq!(ProviderKind::from_str("Mock"), Some(ProviderKind::Mock));
247 assert_eq!(ProviderKind::from_str("MOCK"), Some(ProviderKind::Mock));
248 assert_eq!(ProviderKind::from_str("invalid"), None);
249 }
250
251 #[test]
252 fn test_provider_kind_as_str() {
253 assert_eq!(ProviderKind::Mock.as_str(), "mock");
254 }
255
256 #[test]
257 fn test_generation_config_default() {
258 let config = GenerationConfig::default();
259 assert_eq!(config.temperature, Some(0.7));
260 assert_eq!(config.max_tokens, Some(2048));
261 assert_eq!(config.top_p, Some(1.0));
262 }
263
264 #[test]
265 fn test_generation_config_serialization() {
266 let config = GenerationConfig {
267 temperature: Some(0.9),
268 max_tokens: Some(1024),
269 stop_sequences: Some(vec!["STOP".to_string()]),
270 top_p: Some(0.95),
271 frequency_penalty: None,
272 presence_penalty: None,
273 };
274
275 let json = serde_json::to_string(&config).unwrap();
276 let deserialized: GenerationConfig = serde_json::from_str(&json).unwrap();
277
278 assert_eq!(config.temperature, deserialized.temperature);
279 assert_eq!(config.max_tokens, deserialized.max_tokens);
280 }
281
282 #[test]
283 fn test_parse_thinking_tokens_with_tags() {
284 let response = "<think>Let me consider this carefully...</think>Here's my final answer.";
285 let (reasoning, content) = parse_thinking_tokens(response);
286
287 assert_eq!(
288 reasoning,
289 Some("Let me consider this carefully...".to_string())
290 );
291 assert_eq!(content, "Here's my final answer.");
292 }
293
294 #[test]
295 fn test_parse_thinking_tokens_without_tags() {
296 let response = "This is a normal response without thinking tags.";
297 let (reasoning, content) = parse_thinking_tokens(response);
298
299 assert_eq!(reasoning, None);
300 assert_eq!(content, "This is a normal response without thinking tags.");
301 }
302
303 #[test]
304 fn test_parse_thinking_tokens_multiline() {
305 let response = "<think>\nFirst, I need to analyze the problem.\nThen I'll formulate a solution.\n</think>\n\nHere's the answer: 42";
306 let (reasoning, content) = parse_thinking_tokens(response);
307
308 assert!(reasoning.is_some());
309 let reasoning_text = reasoning.unwrap();
310 assert!(reasoning_text.contains("analyze the problem"));
311 assert!(reasoning_text.contains("formulate a solution"));
312 assert_eq!(content, "Here's the answer: 42");
313 }
314
315 #[test]
316 fn test_parse_thinking_tokens_empty_think() {
317 let response = "<think></think>Content after empty think.";
318 let (reasoning, content) = parse_thinking_tokens(response);
319
320 assert_eq!(reasoning, Some("".to_string()));
321 assert_eq!(content, "Content after empty think.");
322 }
323
324 #[test]
325 fn test_parse_thinking_tokens_whitespace_handling() {
326 let response = "<think> \n Some reasoning \n </think> \n Final answer";
327 let (reasoning, content) = parse_thinking_tokens(response);
328
329 assert_eq!(reasoning, Some("Some reasoning".to_string()));
330 assert_eq!(content, "Final answer");
331 }
332
333 #[test]
334 fn test_parse_thinking_tokens_incomplete_tag() {
335 let response = "<think>Incomplete thinking...";
336 let (reasoning, content) = parse_thinking_tokens(response);
337
338 assert_eq!(reasoning, None);
340 assert_eq!(content, "<think>Incomplete thinking...");
341 }
342}