1use anyhow::Result;
7use async_trait::async_trait;
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use std::pin::Pin;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct GenerationConfig {
15 pub temperature: Option<f32>,
17 pub max_tokens: Option<u32>,
19 pub stop_sequences: Option<Vec<String>>,
21 pub top_p: Option<f32>,
23 pub frequency_penalty: Option<f32>,
25 pub presence_penalty: Option<f32>,
27}
28
29impl Default for GenerationConfig {
30 fn default() -> Self {
31 Self {
32 temperature: Some(0.7),
33 max_tokens: Some(2048),
34 stop_sequences: None,
35 top_p: Some(1.0),
36 frequency_penalty: None,
37 presence_penalty: None,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ToolCall {
45 pub id: String,
47 pub function_name: String,
49 pub arguments: serde_json::Value,
51}
52
53pub fn parse_thinking_tokens(response: &str) -> (Option<String>, String) {
76 let think_pattern = regex::Regex::new(r"<think>([\s\S]*?)</think>").unwrap();
78
79 let reasoning = if let Some(captures) = think_pattern.captures(response) {
81 captures.get(1).map(|m| m.as_str().trim().to_string())
82 } else {
83 None
84 };
85
86 let content = if let Some(end_idx) = response.find("</think>") {
88 let after_think = &response[end_idx + "</think>".len()..];
90 after_think.trim().to_string()
91 } else {
92 response.to_string()
94 };
95
96 (reasoning, content)
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ModelResponse {
102 pub content: String,
104 pub model: String,
106 pub usage: Option<TokenUsage>,
108 pub finish_reason: Option<String>,
110 #[serde(skip_serializing_if = "Option::is_none")]
112 pub tool_calls: Option<Vec<ToolCall>>,
113 #[serde(skip_serializing_if = "Option::is_none")]
115 pub reasoning: Option<String>,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct TokenUsage {
121 pub prompt_tokens: u32,
122 pub completion_tokens: u32,
123 pub total_tokens: u32,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ProviderMetadata {
129 pub name: String,
131 pub supported_models: Vec<String>,
133 pub supports_streaming: bool,
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
139#[serde(rename_all = "lowercase")]
140pub enum ProviderKind {
141 Mock,
142 #[cfg(feature = "openai")]
143 OpenAI,
144 #[cfg(feature = "anthropic")]
145 Anthropic,
146 #[cfg(feature = "ollama")]
147 Ollama,
148 #[cfg(feature = "mlx")]
149 MLX,
150 #[cfg(feature = "lmstudio")]
151 LMStudio,
152}
153
154impl ProviderKind {
155 pub fn from_str(s: &str) -> Option<Self> {
156 match s.to_lowercase().as_str() {
157 "mock" => Some(ProviderKind::Mock),
158 #[cfg(feature = "openai")]
159 "openai" => Some(ProviderKind::OpenAI),
160 #[cfg(feature = "anthropic")]
161 "anthropic" => Some(ProviderKind::Anthropic),
162 #[cfg(feature = "ollama")]
163 "ollama" => Some(ProviderKind::Ollama),
164 #[cfg(feature = "mlx")]
165 "mlx" => Some(ProviderKind::MLX),
166 #[cfg(feature = "lmstudio")]
167 "lmstudio" => Some(ProviderKind::LMStudio),
168 _ => None,
169 }
170 }
171
172 pub fn as_str(&self) -> &'static str {
173 match self {
174 ProviderKind::Mock => "mock",
175 #[cfg(feature = "openai")]
176 ProviderKind::OpenAI => "openai",
177 #[cfg(feature = "anthropic")]
178 ProviderKind::Anthropic => "anthropic",
179 #[cfg(feature = "ollama")]
180 ProviderKind::Ollama => "ollama",
181 #[cfg(feature = "mlx")]
182 ProviderKind::MLX => "mlx",
183 #[cfg(feature = "lmstudio")]
184 ProviderKind::LMStudio => "lmstudio",
185 }
186 }
187}
188
189#[async_trait]
191pub trait ModelProvider: Send + Sync {
192 async fn generate(&self, prompt: &str, config: &GenerationConfig) -> Result<ModelResponse>;
194
195 async fn stream(
197 &self,
198 prompt: &str,
199 config: &GenerationConfig,
200 ) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>>;
201
202 fn metadata(&self) -> ProviderMetadata;
204
205 fn kind(&self) -> ProviderKind;
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_provider_kind_from_str() {
215 assert_eq!(ProviderKind::from_str("mock"), Some(ProviderKind::Mock));
216 assert_eq!(ProviderKind::from_str("Mock"), Some(ProviderKind::Mock));
217 assert_eq!(ProviderKind::from_str("MOCK"), Some(ProviderKind::Mock));
218 assert_eq!(ProviderKind::from_str("invalid"), None);
219 }
220
221 #[test]
222 fn test_provider_kind_as_str() {
223 assert_eq!(ProviderKind::Mock.as_str(), "mock");
224 }
225
226 #[test]
227 fn test_generation_config_default() {
228 let config = GenerationConfig::default();
229 assert_eq!(config.temperature, Some(0.7));
230 assert_eq!(config.max_tokens, Some(2048));
231 assert_eq!(config.top_p, Some(1.0));
232 }
233
234 #[test]
235 fn test_generation_config_serialization() {
236 let config = GenerationConfig {
237 temperature: Some(0.9),
238 max_tokens: Some(1024),
239 stop_sequences: Some(vec!["STOP".to_string()]),
240 top_p: Some(0.95),
241 frequency_penalty: None,
242 presence_penalty: None,
243 };
244
245 let json = serde_json::to_string(&config).unwrap();
246 let deserialized: GenerationConfig = serde_json::from_str(&json).unwrap();
247
248 assert_eq!(config.temperature, deserialized.temperature);
249 assert_eq!(config.max_tokens, deserialized.max_tokens);
250 }
251
252 #[test]
253 fn test_parse_thinking_tokens_with_tags() {
254 let response = "<think>Let me consider this carefully...</think>Here's my final answer.";
255 let (reasoning, content) = parse_thinking_tokens(response);
256
257 assert_eq!(
258 reasoning,
259 Some("Let me consider this carefully...".to_string())
260 );
261 assert_eq!(content, "Here's my final answer.");
262 }
263
264 #[test]
265 fn test_parse_thinking_tokens_without_tags() {
266 let response = "This is a normal response without thinking tags.";
267 let (reasoning, content) = parse_thinking_tokens(response);
268
269 assert_eq!(reasoning, None);
270 assert_eq!(content, "This is a normal response without thinking tags.");
271 }
272
273 #[test]
274 fn test_parse_thinking_tokens_multiline() {
275 let response = "<think>\nFirst, I need to analyze the problem.\nThen I'll formulate a solution.\n</think>\n\nHere's the answer: 42";
276 let (reasoning, content) = parse_thinking_tokens(response);
277
278 assert!(reasoning.is_some());
279 let reasoning_text = reasoning.unwrap();
280 assert!(reasoning_text.contains("analyze the problem"));
281 assert!(reasoning_text.contains("formulate a solution"));
282 assert_eq!(content, "Here's the answer: 42");
283 }
284
285 #[test]
286 fn test_parse_thinking_tokens_empty_think() {
287 let response = "<think></think>Content after empty think.";
288 let (reasoning, content) = parse_thinking_tokens(response);
289
290 assert_eq!(reasoning, Some("".to_string()));
291 assert_eq!(content, "Content after empty think.");
292 }
293
294 #[test]
295 fn test_parse_thinking_tokens_whitespace_handling() {
296 let response = "<think> \n Some reasoning \n </think> \n Final answer";
297 let (reasoning, content) = parse_thinking_tokens(response);
298
299 assert_eq!(reasoning, Some("Some reasoning".to_string()));
300 assert_eq!(content, "Final answer");
301 }
302
303 #[test]
304 fn test_parse_thinking_tokens_incomplete_tag() {
305 let response = "<think>Incomplete thinking...";
306 let (reasoning, content) = parse_thinking_tokens(response);
307
308 assert_eq!(reasoning, None);
310 assert_eq!(content, "<think>Incomplete thinking...");
311 }
312}