Skip to main content

spec_ai/spec_ai_core/agent/
model.rs

1//! Model Provider Abstraction Layer
2//!
3//! This module defines the core traits and types for integrating with various LLM providers.
4//! It provides a unified interface that abstracts away provider-specific details.
5
6use anyhow::Result;
7use async_trait::async_trait;
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use std::pin::Pin;
11
12/// Configuration for model generation requests
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct GenerationConfig {
15    /// Sampling temperature (0.0 - 2.0)
16    pub temperature: Option<f32>,
17    /// Maximum tokens to generate
18    pub max_tokens: Option<u32>,
19    /// Stop sequences
20    pub stop_sequences: Option<Vec<String>>,
21    /// Top-p sampling
22    pub top_p: Option<f32>,
23    /// Frequency penalty
24    pub frequency_penalty: Option<f32>,
25    /// Presence penalty
26    pub presence_penalty: Option<f32>,
27}
28
29impl Default for GenerationConfig {
30    fn default() -> Self {
31        Self {
32            temperature: Some(0.7),
33            max_tokens: Some(2048),
34            stop_sequences: None,
35            top_p: Some(1.0),
36            frequency_penalty: None,
37            presence_penalty: None,
38        }
39    }
40}
41
42/// Tool call from a model response
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ToolCall {
45    /// Unique identifier for this tool call
46    pub id: String,
47    /// Name of the function/tool to call
48    pub function_name: String,
49    /// Arguments as JSON
50    pub arguments: serde_json::Value,
51}
52
53/// Parse thinking/reasoning tokens from model response
54///
55/// Extracts content between `<think>` and `</think>` tags as reasoning,
56/// and returns the content after `</think>` as the main response.
57///
58/// # Arguments
59/// * `response` - Raw model response that may contain thinking tokens
60///
61/// # Returns
62/// A tuple of (reasoning, content) where:
63/// - `reasoning` is Some(String) if thinking tags were found, None otherwise
64/// - `content` is the text after `</think>`, or the full response if no tags present
65///
66/// # Example
67/// ```
68/// use spec_ai_core::agent::model::parse_thinking_tokens;
69///
70/// let response = "<think>Let me consider this...</think>Here's my answer.";
71/// let (reasoning, content) = parse_thinking_tokens(response);
72/// assert_eq!(reasoning, Some("Let me consider this...".to_string()));
73/// assert_eq!(content, "Here's my answer.");
74/// ```
75pub fn parse_thinking_tokens(response: &str) -> (Option<String>, String) {
76    // Pattern to match content between <think> and </think>
77    let think_pattern = regex::Regex::new(r"<think>([\s\S]*?)</think>").unwrap();
78
79    // Try to find thinking content
80    let reasoning = if let Some(captures) = think_pattern.captures(response) {
81        captures.get(1).map(|m| m.as_str().trim().to_string())
82    } else {
83        None
84    };
85
86    // Extract content after </think> tag, or return full response if no tags
87    let content = if let Some(end_idx) = response.find("</think>") {
88        // Get everything after </think>
89        let after_think = &response[end_idx + "</think>".len()..];
90        after_think.trim().to_string()
91    } else {
92        // No thinking tags found, return original response
93        response.to_string()
94    };
95
96    (reasoning, content)
97}
98
99/// Response from a model generation request
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ModelResponse {
102    /// Generated content (with thinking tokens removed if present)
103    pub content: String,
104    /// Model used for generation
105    pub model: String,
106    /// Token usage statistics
107    pub usage: Option<TokenUsage>,
108    /// Finish reason
109    pub finish_reason: Option<String>,
110    /// Tool calls from the model (if any)
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub tool_calls: Option<Vec<ToolCall>>,
113    /// Reasoning/thinking content extracted from <think> tags (if present)
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub reasoning: Option<String>,
116}
117
118/// Token usage statistics
119#[derive(Debug, Clone, Default, Serialize, Deserialize)]
120pub struct TokenUsage {
121    pub prompt_tokens: u32,
122    pub completion_tokens: u32,
123    pub total_tokens: u32,
124}
125
126impl TokenUsage {
127    pub fn add(&mut self, other: &TokenUsage) {
128        self.prompt_tokens += other.prompt_tokens;
129        self.completion_tokens += other.completion_tokens;
130        self.total_tokens += other.total_tokens;
131    }
132}
133
134/// Binary image attachment for multimodal prompts.
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct ImageAttachment {
137    pub mime: String,
138    pub data: Vec<u8>,
139}
140
141/// Provider metadata
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct ProviderMetadata {
144    /// Provider name
145    pub name: String,
146    /// Supported models
147    pub supported_models: Vec<String>,
148    /// Supports streaming
149    pub supports_streaming: bool,
150}
151
152/// Types of model providers
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
154#[serde(rename_all = "lowercase")]
155pub enum ProviderKind {
156    Mock,
157    #[cfg(feature = "openai")]
158    OpenAI,
159    #[cfg(feature = "anthropic")]
160    Anthropic,
161    #[cfg(feature = "ollama")]
162    Ollama,
163    #[cfg(feature = "mlx")]
164    MLX,
165    #[cfg(feature = "lmstudio")]
166    LMStudio,
167}
168
169impl ProviderKind {
170    #[allow(clippy::should_implement_trait)]
171    pub fn from_str(s: &str) -> Option<Self> {
172        match s.to_lowercase().as_str() {
173            "mock" => Some(ProviderKind::Mock),
174            #[cfg(feature = "openai")]
175            "openai" => Some(ProviderKind::OpenAI),
176            #[cfg(feature = "anthropic")]
177            "anthropic" => Some(ProviderKind::Anthropic),
178            #[cfg(feature = "ollama")]
179            "ollama" => Some(ProviderKind::Ollama),
180            #[cfg(feature = "mlx")]
181            "mlx" => Some(ProviderKind::MLX),
182            #[cfg(feature = "lmstudio")]
183            "lmstudio" => Some(ProviderKind::LMStudio),
184            _ => None,
185        }
186    }
187
188    pub fn as_str(&self) -> &'static str {
189        match self {
190            ProviderKind::Mock => "mock",
191            #[cfg(feature = "openai")]
192            ProviderKind::OpenAI => "openai",
193            #[cfg(feature = "anthropic")]
194            ProviderKind::Anthropic => "anthropic",
195            #[cfg(feature = "ollama")]
196            ProviderKind::Ollama => "ollama",
197            #[cfg(feature = "mlx")]
198            ProviderKind::MLX => "mlx",
199            #[cfg(feature = "lmstudio")]
200            ProviderKind::LMStudio => "lmstudio",
201        }
202    }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206#[serde(tag = "type", content = "data")]
207pub enum ModelStreamItem {
208    #[serde(rename = "content")]
209    Content(String),
210    #[serde(rename = "usage")]
211    Usage(TokenUsage),
212    #[serde(rename = "finish_reason")]
213    FinishReason(String),
214}
215
216/// Core trait that all model providers must implement
217#[async_trait]
218pub trait ModelProvider: Send + Sync {
219    /// Generate a response to the given prompt
220    async fn generate(&self, prompt: &str, config: &GenerationConfig) -> Result<ModelResponse>;
221
222    /// Generate a response with optional image attachments (default: ignore attachments).
223    async fn generate_with_attachments(
224        &self,
225        prompt: &str,
226        attachments: &[ImageAttachment],
227        config: &GenerationConfig,
228    ) -> Result<ModelResponse> {
229        let _ = attachments;
230        self.generate(prompt, config).await
231    }
232
233    /// Stream a response to the given prompt
234    async fn stream(
235        &self,
236        prompt: &str,
237        config: &GenerationConfig,
238    ) -> Result<Pin<Box<dyn Stream<Item = Result<ModelStreamItem>> + Send>>>;
239
240    /// Stream a response with optional image attachments (default: ignore attachments).
241    async fn stream_with_attachments(
242        &self,
243        prompt: &str,
244        attachments: &[ImageAttachment],
245        config: &GenerationConfig,
246    ) -> Result<Pin<Box<dyn Stream<Item = Result<ModelStreamItem>> + Send>>> {
247        let _ = attachments;
248        self.stream(prompt, config).await
249    }
250
251    /// Get provider metadata
252    fn metadata(&self) -> ProviderMetadata;
253
254    /// Get the provider kind
255    fn kind(&self) -> ProviderKind;
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_provider_kind_from_str() {
264        assert_eq!(ProviderKind::from_str("mock"), Some(ProviderKind::Mock));
265        assert_eq!(ProviderKind::from_str("Mock"), Some(ProviderKind::Mock));
266        assert_eq!(ProviderKind::from_str("MOCK"), Some(ProviderKind::Mock));
267        assert_eq!(ProviderKind::from_str("invalid"), None);
268    }
269
270    #[test]
271    fn test_provider_kind_as_str() {
272        assert_eq!(ProviderKind::Mock.as_str(), "mock");
273    }
274
275    #[test]
276    fn test_generation_config_default() {
277        let config = GenerationConfig::default();
278        assert_eq!(config.temperature, Some(0.7));
279        assert_eq!(config.max_tokens, Some(2048));
280        assert_eq!(config.top_p, Some(1.0));
281    }
282
283    #[test]
284    fn test_generation_config_serialization() {
285        let config = GenerationConfig {
286            temperature: Some(0.9),
287            max_tokens: Some(1024),
288            stop_sequences: Some(vec!["STOP".to_string()]),
289            top_p: Some(0.95),
290            frequency_penalty: None,
291            presence_penalty: None,
292        };
293
294        let json = serde_json::to_string(&config).unwrap();
295        let deserialized: GenerationConfig = serde_json::from_str(&json).unwrap();
296
297        assert_eq!(config.temperature, deserialized.temperature);
298        assert_eq!(config.max_tokens, deserialized.max_tokens);
299    }
300
301    #[test]
302    fn test_parse_thinking_tokens_with_tags() {
303        let response = "<think>Let me consider this carefully...</think>Here's my final answer.";
304        let (reasoning, content) = parse_thinking_tokens(response);
305
306        assert_eq!(
307            reasoning,
308            Some("Let me consider this carefully...".to_string())
309        );
310        assert_eq!(content, "Here's my final answer.");
311    }
312
313    #[test]
314    fn test_parse_thinking_tokens_without_tags() {
315        let response = "This is a normal response without thinking tags.";
316        let (reasoning, content) = parse_thinking_tokens(response);
317
318        assert_eq!(reasoning, None);
319        assert_eq!(content, "This is a normal response without thinking tags.");
320    }
321
322    #[test]
323    fn test_parse_thinking_tokens_multiline() {
324        let response = "<think>\nFirst, I need to analyze the problem.\nThen I'll formulate a solution.\n</think>\n\nHere's the answer: 42";
325        let (reasoning, content) = parse_thinking_tokens(response);
326
327        assert!(reasoning.is_some());
328        let reasoning_text = reasoning.unwrap();
329        assert!(reasoning_text.contains("analyze the problem"));
330        assert!(reasoning_text.contains("formulate a solution"));
331        assert_eq!(content, "Here's the answer: 42");
332    }
333
334    #[test]
335    fn test_parse_thinking_tokens_empty_think() {
336        let response = "<think></think>Content after empty think.";
337        let (reasoning, content) = parse_thinking_tokens(response);
338
339        assert_eq!(reasoning, Some("".to_string()));
340        assert_eq!(content, "Content after empty think.");
341    }
342
343    #[test]
344    fn test_parse_thinking_tokens_whitespace_handling() {
345        let response = "<think>  \n  Some reasoning  \n  </think>  \n  Final answer";
346        let (reasoning, content) = parse_thinking_tokens(response);
347
348        assert_eq!(reasoning, Some("Some reasoning".to_string()));
349        assert_eq!(content, "Final answer");
350    }
351
352    #[test]
353    fn test_parse_thinking_tokens_incomplete_tag() {
354        let response = "<think>Incomplete thinking...";
355        let (reasoning, content) = parse_thinking_tokens(response);
356
357        // No closing tag means no reasoning extracted
358        assert_eq!(reasoning, None);
359        assert_eq!(content, "<think>Incomplete thinking...");
360    }
361}