spec_ai/spec_ai_core/agent/
model.rs1use anyhow::Result;
7use async_trait::async_trait;
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use std::pin::Pin;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct GenerationConfig {
15 pub temperature: Option<f32>,
17 pub max_tokens: Option<u32>,
19 pub stop_sequences: Option<Vec<String>>,
21 pub top_p: Option<f32>,
23 pub frequency_penalty: Option<f32>,
25 pub presence_penalty: Option<f32>,
27}
28
29impl Default for GenerationConfig {
30 fn default() -> Self {
31 Self {
32 temperature: Some(0.7),
33 max_tokens: Some(2048),
34 stop_sequences: None,
35 top_p: Some(1.0),
36 frequency_penalty: None,
37 presence_penalty: None,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ToolCall {
45 pub id: String,
47 pub function_name: String,
49 pub arguments: serde_json::Value,
51}
52
53pub fn parse_thinking_tokens(response: &str) -> (Option<String>, String) {
76 let think_pattern = regex::Regex::new(r"<think>([\s\S]*?)</think>").unwrap();
78
79 let reasoning = if let Some(captures) = think_pattern.captures(response) {
81 captures.get(1).map(|m| m.as_str().trim().to_string())
82 } else {
83 None
84 };
85
86 let content = if let Some(end_idx) = response.find("</think>") {
88 let after_think = &response[end_idx + "</think>".len()..];
90 after_think.trim().to_string()
91 } else {
92 response.to_string()
94 };
95
96 (reasoning, content)
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ModelResponse {
102 pub content: String,
104 pub model: String,
106 pub usage: Option<TokenUsage>,
108 pub finish_reason: Option<String>,
110 #[serde(skip_serializing_if = "Option::is_none")]
112 pub tool_calls: Option<Vec<ToolCall>>,
113 #[serde(skip_serializing_if = "Option::is_none")]
115 pub reasoning: Option<String>,
116}
117
118#[derive(Debug, Clone, Default, Serialize, Deserialize)]
120pub struct TokenUsage {
121 pub prompt_tokens: u32,
122 pub completion_tokens: u32,
123 pub total_tokens: u32,
124}
125
126impl TokenUsage {
127 pub fn add(&mut self, other: &TokenUsage) {
128 self.prompt_tokens += other.prompt_tokens;
129 self.completion_tokens += other.completion_tokens;
130 self.total_tokens += other.total_tokens;
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct ImageAttachment {
137 pub mime: String,
138 pub data: Vec<u8>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct ProviderMetadata {
144 pub name: String,
146 pub supported_models: Vec<String>,
148 pub supports_streaming: bool,
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
154#[serde(rename_all = "lowercase")]
155pub enum ProviderKind {
156 Mock,
157 #[cfg(feature = "openai")]
158 OpenAI,
159 #[cfg(feature = "anthropic")]
160 Anthropic,
161 #[cfg(feature = "ollama")]
162 Ollama,
163 #[cfg(feature = "mlx")]
164 MLX,
165 #[cfg(feature = "lmstudio")]
166 LMStudio,
167}
168
169impl ProviderKind {
170 #[allow(clippy::should_implement_trait)]
171 pub fn from_str(s: &str) -> Option<Self> {
172 match s.to_lowercase().as_str() {
173 "mock" => Some(ProviderKind::Mock),
174 #[cfg(feature = "openai")]
175 "openai" => Some(ProviderKind::OpenAI),
176 #[cfg(feature = "anthropic")]
177 "anthropic" => Some(ProviderKind::Anthropic),
178 #[cfg(feature = "ollama")]
179 "ollama" => Some(ProviderKind::Ollama),
180 #[cfg(feature = "mlx")]
181 "mlx" => Some(ProviderKind::MLX),
182 #[cfg(feature = "lmstudio")]
183 "lmstudio" => Some(ProviderKind::LMStudio),
184 _ => None,
185 }
186 }
187
188 pub fn as_str(&self) -> &'static str {
189 match self {
190 ProviderKind::Mock => "mock",
191 #[cfg(feature = "openai")]
192 ProviderKind::OpenAI => "openai",
193 #[cfg(feature = "anthropic")]
194 ProviderKind::Anthropic => "anthropic",
195 #[cfg(feature = "ollama")]
196 ProviderKind::Ollama => "ollama",
197 #[cfg(feature = "mlx")]
198 ProviderKind::MLX => "mlx",
199 #[cfg(feature = "lmstudio")]
200 ProviderKind::LMStudio => "lmstudio",
201 }
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206#[serde(tag = "type", content = "data")]
207pub enum ModelStreamItem {
208 #[serde(rename = "content")]
209 Content(String),
210 #[serde(rename = "usage")]
211 Usage(TokenUsage),
212 #[serde(rename = "finish_reason")]
213 FinishReason(String),
214}
215
216#[async_trait]
218pub trait ModelProvider: Send + Sync {
219 async fn generate(&self, prompt: &str, config: &GenerationConfig) -> Result<ModelResponse>;
221
222 async fn generate_with_attachments(
224 &self,
225 prompt: &str,
226 attachments: &[ImageAttachment],
227 config: &GenerationConfig,
228 ) -> Result<ModelResponse> {
229 let _ = attachments;
230 self.generate(prompt, config).await
231 }
232
233 async fn stream(
235 &self,
236 prompt: &str,
237 config: &GenerationConfig,
238 ) -> Result<Pin<Box<dyn Stream<Item = Result<ModelStreamItem>> + Send>>>;
239
240 async fn stream_with_attachments(
242 &self,
243 prompt: &str,
244 attachments: &[ImageAttachment],
245 config: &GenerationConfig,
246 ) -> Result<Pin<Box<dyn Stream<Item = Result<ModelStreamItem>> + Send>>> {
247 let _ = attachments;
248 self.stream(prompt, config).await
249 }
250
251 fn metadata(&self) -> ProviderMetadata;
253
254 fn kind(&self) -> ProviderKind;
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_provider_kind_from_str() {
264 assert_eq!(ProviderKind::from_str("mock"), Some(ProviderKind::Mock));
265 assert_eq!(ProviderKind::from_str("Mock"), Some(ProviderKind::Mock));
266 assert_eq!(ProviderKind::from_str("MOCK"), Some(ProviderKind::Mock));
267 assert_eq!(ProviderKind::from_str("invalid"), None);
268 }
269
270 #[test]
271 fn test_provider_kind_as_str() {
272 assert_eq!(ProviderKind::Mock.as_str(), "mock");
273 }
274
275 #[test]
276 fn test_generation_config_default() {
277 let config = GenerationConfig::default();
278 assert_eq!(config.temperature, Some(0.7));
279 assert_eq!(config.max_tokens, Some(2048));
280 assert_eq!(config.top_p, Some(1.0));
281 }
282
283 #[test]
284 fn test_generation_config_serialization() {
285 let config = GenerationConfig {
286 temperature: Some(0.9),
287 max_tokens: Some(1024),
288 stop_sequences: Some(vec!["STOP".to_string()]),
289 top_p: Some(0.95),
290 frequency_penalty: None,
291 presence_penalty: None,
292 };
293
294 let json = serde_json::to_string(&config).unwrap();
295 let deserialized: GenerationConfig = serde_json::from_str(&json).unwrap();
296
297 assert_eq!(config.temperature, deserialized.temperature);
298 assert_eq!(config.max_tokens, deserialized.max_tokens);
299 }
300
301 #[test]
302 fn test_parse_thinking_tokens_with_tags() {
303 let response = "<think>Let me consider this carefully...</think>Here's my final answer.";
304 let (reasoning, content) = parse_thinking_tokens(response);
305
306 assert_eq!(
307 reasoning,
308 Some("Let me consider this carefully...".to_string())
309 );
310 assert_eq!(content, "Here's my final answer.");
311 }
312
313 #[test]
314 fn test_parse_thinking_tokens_without_tags() {
315 let response = "This is a normal response without thinking tags.";
316 let (reasoning, content) = parse_thinking_tokens(response);
317
318 assert_eq!(reasoning, None);
319 assert_eq!(content, "This is a normal response without thinking tags.");
320 }
321
322 #[test]
323 fn test_parse_thinking_tokens_multiline() {
324 let response = "<think>\nFirst, I need to analyze the problem.\nThen I'll formulate a solution.\n</think>\n\nHere's the answer: 42";
325 let (reasoning, content) = parse_thinking_tokens(response);
326
327 assert!(reasoning.is_some());
328 let reasoning_text = reasoning.unwrap();
329 assert!(reasoning_text.contains("analyze the problem"));
330 assert!(reasoning_text.contains("formulate a solution"));
331 assert_eq!(content, "Here's the answer: 42");
332 }
333
334 #[test]
335 fn test_parse_thinking_tokens_empty_think() {
336 let response = "<think></think>Content after empty think.";
337 let (reasoning, content) = parse_thinking_tokens(response);
338
339 assert_eq!(reasoning, Some("".to_string()));
340 assert_eq!(content, "Content after empty think.");
341 }
342
343 #[test]
344 fn test_parse_thinking_tokens_whitespace_handling() {
345 let response = "<think> \n Some reasoning \n </think> \n Final answer";
346 let (reasoning, content) = parse_thinking_tokens(response);
347
348 assert_eq!(reasoning, Some("Some reasoning".to_string()));
349 assert_eq!(content, "Final answer");
350 }
351
352 #[test]
353 fn test_parse_thinking_tokens_incomplete_tag() {
354 let response = "<think>Incomplete thinking...";
355 let (reasoning, content) = parse_thinking_tokens(response);
356
357 assert_eq!(reasoning, None);
359 assert_eq!(content, "<think>Incomplete thinking...");
360 }
361}