spec_ai_core/agent/
model.rs

1//! Model Provider Abstraction Layer
2//!
3//! This module defines the core traits and types for integrating with various LLM providers.
4//! It provides a unified interface that abstracts away provider-specific details.
5
6use anyhow::Result;
7use async_trait::async_trait;
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use std::pin::Pin;
11
12/// Configuration for model generation requests
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct GenerationConfig {
15    /// Sampling temperature (0.0 - 2.0)
16    pub temperature: Option<f32>,
17    /// Maximum tokens to generate
18    pub max_tokens: Option<u32>,
19    /// Stop sequences
20    pub stop_sequences: Option<Vec<String>>,
21    /// Top-p sampling
22    pub top_p: Option<f32>,
23    /// Frequency penalty
24    pub frequency_penalty: Option<f32>,
25    /// Presence penalty
26    pub presence_penalty: Option<f32>,
27}
28
29impl Default for GenerationConfig {
30    fn default() -> Self {
31        Self {
32            temperature: Some(0.7),
33            max_tokens: Some(2048),
34            stop_sequences: None,
35            top_p: Some(1.0),
36            frequency_penalty: None,
37            presence_penalty: None,
38        }
39    }
40}
41
42/// Tool call from a model response
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ToolCall {
45    /// Unique identifier for this tool call
46    pub id: String,
47    /// Name of the function/tool to call
48    pub function_name: String,
49    /// Arguments as JSON
50    pub arguments: serde_json::Value,
51}
52
53/// Parse thinking/reasoning tokens from model response
54///
55/// Extracts content between `<think>` and `</think>` tags as reasoning,
56/// and returns the content after `</think>` as the main response.
57///
58/// # Arguments
59/// * `response` - Raw model response that may contain thinking tokens
60///
61/// # Returns
62/// A tuple of (reasoning, content) where:
63/// - `reasoning` is Some(String) if thinking tags were found, None otherwise
64/// - `content` is the text after `</think>`, or the full response if no tags present
65///
66/// # Example
67/// ```
68/// use spec_ai_core::agent::model::parse_thinking_tokens;
69///
70/// let response = "<think>Let me consider this...</think>Here's my answer.";
71/// let (reasoning, content) = parse_thinking_tokens(response);
72/// assert_eq!(reasoning, Some("Let me consider this...".to_string()));
73/// assert_eq!(content, "Here's my answer.");
74/// ```
75pub fn parse_thinking_tokens(response: &str) -> (Option<String>, String) {
76    // Pattern to match content between <think> and </think>
77    let think_pattern = regex::Regex::new(r"<think>([\s\S]*?)</think>").unwrap();
78
79    // Try to find thinking content
80    let reasoning = if let Some(captures) = think_pattern.captures(response) {
81        captures.get(1).map(|m| m.as_str().trim().to_string())
82    } else {
83        None
84    };
85
86    // Extract content after </think> tag, or return full response if no tags
87    let content = if let Some(end_idx) = response.find("</think>") {
88        // Get everything after </think>
89        let after_think = &response[end_idx + "</think>".len()..];
90        after_think.trim().to_string()
91    } else {
92        // No thinking tags found, return original response
93        response.to_string()
94    };
95
96    (reasoning, content)
97}
98
99/// Response from a model generation request
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ModelResponse {
102    /// Generated content (with thinking tokens removed if present)
103    pub content: String,
104    /// Model used for generation
105    pub model: String,
106    /// Token usage statistics
107    pub usage: Option<TokenUsage>,
108    /// Finish reason
109    pub finish_reason: Option<String>,
110    /// Tool calls from the model (if any)
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub tool_calls: Option<Vec<ToolCall>>,
113    /// Reasoning/thinking content extracted from <think> tags (if present)
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub reasoning: Option<String>,
116}
117
118/// Token usage statistics
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct TokenUsage {
121    pub prompt_tokens: u32,
122    pub completion_tokens: u32,
123    pub total_tokens: u32,
124}
125
126/// Provider metadata
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ProviderMetadata {
129    /// Provider name
130    pub name: String,
131    /// Supported models
132    pub supported_models: Vec<String>,
133    /// Supports streaming
134    pub supports_streaming: bool,
135}
136
137/// Types of model providers
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
139#[serde(rename_all = "lowercase")]
140pub enum ProviderKind {
141    Mock,
142    #[cfg(feature = "openai")]
143    OpenAI,
144    #[cfg(feature = "anthropic")]
145    Anthropic,
146    #[cfg(feature = "ollama")]
147    Ollama,
148    #[cfg(feature = "mlx")]
149    MLX,
150    #[cfg(feature = "lmstudio")]
151    LMStudio,
152}
153
154impl ProviderKind {
155    pub fn from_str(s: &str) -> Option<Self> {
156        match s.to_lowercase().as_str() {
157            "mock" => Some(ProviderKind::Mock),
158            #[cfg(feature = "openai")]
159            "openai" => Some(ProviderKind::OpenAI),
160            #[cfg(feature = "anthropic")]
161            "anthropic" => Some(ProviderKind::Anthropic),
162            #[cfg(feature = "ollama")]
163            "ollama" => Some(ProviderKind::Ollama),
164            #[cfg(feature = "mlx")]
165            "mlx" => Some(ProviderKind::MLX),
166            #[cfg(feature = "lmstudio")]
167            "lmstudio" => Some(ProviderKind::LMStudio),
168            _ => None,
169        }
170    }
171
172    pub fn as_str(&self) -> &'static str {
173        match self {
174            ProviderKind::Mock => "mock",
175            #[cfg(feature = "openai")]
176            ProviderKind::OpenAI => "openai",
177            #[cfg(feature = "anthropic")]
178            ProviderKind::Anthropic => "anthropic",
179            #[cfg(feature = "ollama")]
180            ProviderKind::Ollama => "ollama",
181            #[cfg(feature = "mlx")]
182            ProviderKind::MLX => "mlx",
183            #[cfg(feature = "lmstudio")]
184            ProviderKind::LMStudio => "lmstudio",
185        }
186    }
187}
188
189/// Core trait that all model providers must implement
190#[async_trait]
191pub trait ModelProvider: Send + Sync {
192    /// Generate a response to the given prompt
193    async fn generate(&self, prompt: &str, config: &GenerationConfig) -> Result<ModelResponse>;
194
195    /// Stream a response to the given prompt
196    async fn stream(
197        &self,
198        prompt: &str,
199        config: &GenerationConfig,
200    ) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>>;
201
202    /// Get provider metadata
203    fn metadata(&self) -> ProviderMetadata;
204
205    /// Get the provider kind
206    fn kind(&self) -> ProviderKind;
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_provider_kind_from_str() {
215        assert_eq!(ProviderKind::from_str("mock"), Some(ProviderKind::Mock));
216        assert_eq!(ProviderKind::from_str("Mock"), Some(ProviderKind::Mock));
217        assert_eq!(ProviderKind::from_str("MOCK"), Some(ProviderKind::Mock));
218        assert_eq!(ProviderKind::from_str("invalid"), None);
219    }
220
221    #[test]
222    fn test_provider_kind_as_str() {
223        assert_eq!(ProviderKind::Mock.as_str(), "mock");
224    }
225
226    #[test]
227    fn test_generation_config_default() {
228        let config = GenerationConfig::default();
229        assert_eq!(config.temperature, Some(0.7));
230        assert_eq!(config.max_tokens, Some(2048));
231        assert_eq!(config.top_p, Some(1.0));
232    }
233
234    #[test]
235    fn test_generation_config_serialization() {
236        let config = GenerationConfig {
237            temperature: Some(0.9),
238            max_tokens: Some(1024),
239            stop_sequences: Some(vec!["STOP".to_string()]),
240            top_p: Some(0.95),
241            frequency_penalty: None,
242            presence_penalty: None,
243        };
244
245        let json = serde_json::to_string(&config).unwrap();
246        let deserialized: GenerationConfig = serde_json::from_str(&json).unwrap();
247
248        assert_eq!(config.temperature, deserialized.temperature);
249        assert_eq!(config.max_tokens, deserialized.max_tokens);
250    }
251
252    #[test]
253    fn test_parse_thinking_tokens_with_tags() {
254        let response = "<think>Let me consider this carefully...</think>Here's my final answer.";
255        let (reasoning, content) = parse_thinking_tokens(response);
256
257        assert_eq!(
258            reasoning,
259            Some("Let me consider this carefully...".to_string())
260        );
261        assert_eq!(content, "Here's my final answer.");
262    }
263
264    #[test]
265    fn test_parse_thinking_tokens_without_tags() {
266        let response = "This is a normal response without thinking tags.";
267        let (reasoning, content) = parse_thinking_tokens(response);
268
269        assert_eq!(reasoning, None);
270        assert_eq!(content, "This is a normal response without thinking tags.");
271    }
272
273    #[test]
274    fn test_parse_thinking_tokens_multiline() {
275        let response = "<think>\nFirst, I need to analyze the problem.\nThen I'll formulate a solution.\n</think>\n\nHere's the answer: 42";
276        let (reasoning, content) = parse_thinking_tokens(response);
277
278        assert!(reasoning.is_some());
279        let reasoning_text = reasoning.unwrap();
280        assert!(reasoning_text.contains("analyze the problem"));
281        assert!(reasoning_text.contains("formulate a solution"));
282        assert_eq!(content, "Here's the answer: 42");
283    }
284
285    #[test]
286    fn test_parse_thinking_tokens_empty_think() {
287        let response = "<think></think>Content after empty think.";
288        let (reasoning, content) = parse_thinking_tokens(response);
289
290        assert_eq!(reasoning, Some("".to_string()));
291        assert_eq!(content, "Content after empty think.");
292    }
293
294    #[test]
295    fn test_parse_thinking_tokens_whitespace_handling() {
296        let response = "<think>  \n  Some reasoning  \n  </think>  \n  Final answer";
297        let (reasoning, content) = parse_thinking_tokens(response);
298
299        assert_eq!(reasoning, Some("Some reasoning".to_string()));
300        assert_eq!(content, "Final answer");
301    }
302
303    #[test]
304    fn test_parse_thinking_tokens_incomplete_tag() {
305        let response = "<think>Incomplete thinking...";
306        let (reasoning, content) = parse_thinking_tokens(response);
307
308        // No closing tag means no reasoning extracted
309        assert_eq!(reasoning, None);
310        assert_eq!(content, "<think>Incomplete thinking...");
311    }
312}