Skip to main content

spec_ai/spec_ai_core/agent/
model.rs

1//! Model Provider Abstraction Layer
2//!
3//! This module defines the core traits and types for integrating with various LLM providers.
4//! It provides a unified interface that abstracts away provider-specific details.
5
6use anyhow::Result;
7use async_trait::async_trait;
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use std::pin::Pin;
11
12/// Configuration for model generation requests
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct GenerationConfig {
15    /// Sampling temperature (0.0 - 2.0)
16    pub temperature: Option<f32>,
17    /// Maximum tokens to generate
18    pub max_tokens: Option<u32>,
19    /// Stop sequences
20    pub stop_sequences: Option<Vec<String>>,
21    /// Top-p sampling
22    pub top_p: Option<f32>,
23    /// Frequency penalty
24    pub frequency_penalty: Option<f32>,
25    /// Presence penalty
26    pub presence_penalty: Option<f32>,
27}
28
29impl Default for GenerationConfig {
30    fn default() -> Self {
31        Self {
32            temperature: Some(0.7),
33            max_tokens: Some(2048),
34            stop_sequences: None,
35            top_p: Some(1.0),
36            frequency_penalty: None,
37            presence_penalty: None,
38        }
39    }
40}
41
42/// Tool call from a model response
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ToolCall {
45    /// Unique identifier for this tool call
46    pub id: String,
47    /// Name of the function/tool to call
48    pub function_name: String,
49    /// Arguments as JSON
50    pub arguments: serde_json::Value,
51}
52
53/// Parse thinking/reasoning tokens from model response
54///
55/// Extracts content between `<think>` and `</think>` tags as reasoning,
56/// and returns the content after `</think>` as the main response.
57///
58/// # Arguments
59/// * `response` - Raw model response that may contain thinking tokens
60///
61/// # Returns
62/// A tuple of (reasoning, content) where:
63/// - `reasoning` is Some(String) if thinking tags were found, None otherwise
64/// - `content` is the text after `</think>`, or the full response if no tags present
65///
66/// # Example
67/// ```
68/// use spec_ai_core::agent::model::parse_thinking_tokens;
69///
70/// let response = "<think>Let me consider this...</think>Here's my answer.";
71/// let (reasoning, content) = parse_thinking_tokens(response);
72/// assert_eq!(reasoning, Some("Let me consider this...".to_string()));
73/// assert_eq!(content, "Here's my answer.");
74/// ```
75pub fn parse_thinking_tokens(response: &str) -> (Option<String>, String) {
76    // Pattern to match content between <think> and </think>
77    let think_pattern = regex::Regex::new(r"<think>([\s\S]*?)</think>").unwrap();
78
79    // Try to find thinking content
80    let reasoning = if let Some(captures) = think_pattern.captures(response) {
81        captures.get(1).map(|m| m.as_str().trim().to_string())
82    } else {
83        None
84    };
85
86    // Extract content after </think> tag, or return full response if no tags
87    let content = if let Some(end_idx) = response.find("</think>") {
88        // Get everything after </think>
89        let after_think = &response[end_idx + "</think>".len()..];
90        after_think.trim().to_string()
91    } else {
92        // No thinking tags found, return original response
93        response.to_string()
94    };
95
96    (reasoning, content)
97}
98
99/// Response from a model generation request
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ModelResponse {
102    /// Generated content (with thinking tokens removed if present)
103    pub content: String,
104    /// Model used for generation
105    pub model: String,
106    /// Token usage statistics
107    pub usage: Option<TokenUsage>,
108    /// Finish reason
109    pub finish_reason: Option<String>,
110    /// Tool calls from the model (if any)
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub tool_calls: Option<Vec<ToolCall>>,
113    /// Reasoning/thinking content extracted from <think> tags (if present)
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub reasoning: Option<String>,
116}
117
118/// Token usage statistics
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct TokenUsage {
121    pub prompt_tokens: u32,
122    pub completion_tokens: u32,
123    pub total_tokens: u32,
124}
125
126/// Binary image attachment for multimodal prompts.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ImageAttachment {
129    pub mime: String,
130    pub data: Vec<u8>,
131}
132
133/// Provider metadata
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct ProviderMetadata {
136    /// Provider name
137    pub name: String,
138    /// Supported models
139    pub supported_models: Vec<String>,
140    /// Supports streaming
141    pub supports_streaming: bool,
142}
143
144/// Types of model providers
145#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
146#[serde(rename_all = "lowercase")]
147pub enum ProviderKind {
148    Mock,
149    #[cfg(feature = "openai")]
150    OpenAI,
151    #[cfg(feature = "anthropic")]
152    Anthropic,
153    #[cfg(feature = "ollama")]
154    Ollama,
155    #[cfg(feature = "mlx")]
156    MLX,
157    #[cfg(feature = "lmstudio")]
158    LMStudio,
159}
160
161impl ProviderKind {
162    #[allow(clippy::should_implement_trait)]
163    pub fn from_str(s: &str) -> Option<Self> {
164        match s.to_lowercase().as_str() {
165            "mock" => Some(ProviderKind::Mock),
166            #[cfg(feature = "openai")]
167            "openai" => Some(ProviderKind::OpenAI),
168            #[cfg(feature = "anthropic")]
169            "anthropic" => Some(ProviderKind::Anthropic),
170            #[cfg(feature = "ollama")]
171            "ollama" => Some(ProviderKind::Ollama),
172            #[cfg(feature = "mlx")]
173            "mlx" => Some(ProviderKind::MLX),
174            #[cfg(feature = "lmstudio")]
175            "lmstudio" => Some(ProviderKind::LMStudio),
176            _ => None,
177        }
178    }
179
180    pub fn as_str(&self) -> &'static str {
181        match self {
182            ProviderKind::Mock => "mock",
183            #[cfg(feature = "openai")]
184            ProviderKind::OpenAI => "openai",
185            #[cfg(feature = "anthropic")]
186            ProviderKind::Anthropic => "anthropic",
187            #[cfg(feature = "ollama")]
188            ProviderKind::Ollama => "ollama",
189            #[cfg(feature = "mlx")]
190            ProviderKind::MLX => "mlx",
191            #[cfg(feature = "lmstudio")]
192            ProviderKind::LMStudio => "lmstudio",
193        }
194    }
195}
196
197/// Core trait that all model providers must implement
198#[async_trait]
199pub trait ModelProvider: Send + Sync {
200    /// Generate a response to the given prompt
201    async fn generate(&self, prompt: &str, config: &GenerationConfig) -> Result<ModelResponse>;
202
203    /// Generate a response with optional image attachments (default: ignore attachments).
204    async fn generate_with_attachments(
205        &self,
206        prompt: &str,
207        attachments: &[ImageAttachment],
208        config: &GenerationConfig,
209    ) -> Result<ModelResponse> {
210        let _ = attachments;
211        self.generate(prompt, config).await
212    }
213
214    /// Stream a response to the given prompt
215    async fn stream(
216        &self,
217        prompt: &str,
218        config: &GenerationConfig,
219    ) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>>;
220
221    /// Stream a response with optional image attachments (default: ignore attachments).
222    async fn stream_with_attachments(
223        &self,
224        prompt: &str,
225        attachments: &[ImageAttachment],
226        config: &GenerationConfig,
227    ) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>> {
228        let _ = attachments;
229        self.stream(prompt, config).await
230    }
231
232    /// Get provider metadata
233    fn metadata(&self) -> ProviderMetadata;
234
235    /// Get the provider kind
236    fn kind(&self) -> ProviderKind;
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn test_provider_kind_from_str() {
245        assert_eq!(ProviderKind::from_str("mock"), Some(ProviderKind::Mock));
246        assert_eq!(ProviderKind::from_str("Mock"), Some(ProviderKind::Mock));
247        assert_eq!(ProviderKind::from_str("MOCK"), Some(ProviderKind::Mock));
248        assert_eq!(ProviderKind::from_str("invalid"), None);
249    }
250
251    #[test]
252    fn test_provider_kind_as_str() {
253        assert_eq!(ProviderKind::Mock.as_str(), "mock");
254    }
255
256    #[test]
257    fn test_generation_config_default() {
258        let config = GenerationConfig::default();
259        assert_eq!(config.temperature, Some(0.7));
260        assert_eq!(config.max_tokens, Some(2048));
261        assert_eq!(config.top_p, Some(1.0));
262    }
263
264    #[test]
265    fn test_generation_config_serialization() {
266        let config = GenerationConfig {
267            temperature: Some(0.9),
268            max_tokens: Some(1024),
269            stop_sequences: Some(vec!["STOP".to_string()]),
270            top_p: Some(0.95),
271            frequency_penalty: None,
272            presence_penalty: None,
273        };
274
275        let json = serde_json::to_string(&config).unwrap();
276        let deserialized: GenerationConfig = serde_json::from_str(&json).unwrap();
277
278        assert_eq!(config.temperature, deserialized.temperature);
279        assert_eq!(config.max_tokens, deserialized.max_tokens);
280    }
281
282    #[test]
283    fn test_parse_thinking_tokens_with_tags() {
284        let response = "<think>Let me consider this carefully...</think>Here's my final answer.";
285        let (reasoning, content) = parse_thinking_tokens(response);
286
287        assert_eq!(
288            reasoning,
289            Some("Let me consider this carefully...".to_string())
290        );
291        assert_eq!(content, "Here's my final answer.");
292    }
293
294    #[test]
295    fn test_parse_thinking_tokens_without_tags() {
296        let response = "This is a normal response without thinking tags.";
297        let (reasoning, content) = parse_thinking_tokens(response);
298
299        assert_eq!(reasoning, None);
300        assert_eq!(content, "This is a normal response without thinking tags.");
301    }
302
303    #[test]
304    fn test_parse_thinking_tokens_multiline() {
305        let response = "<think>\nFirst, I need to analyze the problem.\nThen I'll formulate a solution.\n</think>\n\nHere's the answer: 42";
306        let (reasoning, content) = parse_thinking_tokens(response);
307
308        assert!(reasoning.is_some());
309        let reasoning_text = reasoning.unwrap();
310        assert!(reasoning_text.contains("analyze the problem"));
311        assert!(reasoning_text.contains("formulate a solution"));
312        assert_eq!(content, "Here's the answer: 42");
313    }
314
315    #[test]
316    fn test_parse_thinking_tokens_empty_think() {
317        let response = "<think></think>Content after empty think.";
318        let (reasoning, content) = parse_thinking_tokens(response);
319
320        assert_eq!(reasoning, Some("".to_string()));
321        assert_eq!(content, "Content after empty think.");
322    }
323
324    #[test]
325    fn test_parse_thinking_tokens_whitespace_handling() {
326        let response = "<think>  \n  Some reasoning  \n  </think>  \n  Final answer";
327        let (reasoning, content) = parse_thinking_tokens(response);
328
329        assert_eq!(reasoning, Some("Some reasoning".to_string()));
330        assert_eq!(content, "Final answer");
331    }
332
333    #[test]
334    fn test_parse_thinking_tokens_incomplete_tag() {
335        let response = "<think>Incomplete thinking...";
336        let (reasoning, content) = parse_thinking_tokens(response);
337
338        // No closing tag means no reasoning extracted
339        assert_eq!(reasoning, None);
340        assert_eq!(content, "<think>Incomplete thinking...");
341    }
342}