1pub mod claude;
2pub mod error;
3pub mod gemini;
4pub mod openai;
5pub mod provider;
6pub mod xai;
7
8use crate::config::{ApiAuth, LlmConfigProvider};
9use crate::error::Result;
10pub use claude::AnthropicClient;
11pub use error::ApiError;
12pub use gemini::GeminiClient;
13pub use openai::OpenAIClient;
14pub use provider::{CompletionResponse, Provider};
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::RwLock;
18pub use steer_tools::{InputSchema, ToolCall, ToolSchema};
19use strum::Display;
20use strum::EnumIter;
21use strum::IntoStaticStr;
22use strum_macros::{AsRefStr, EnumString};
23use tokio_util::sync::CancellationToken;
24use tracing::debug;
25use tracing::warn;
26pub use xai::XAIClient;
27
28use crate::app::conversation::Message;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Display, IntoStaticStr)]
31#[strum(serialize_all = "lowercase")]
32pub enum ProviderKind {
33 Anthropic,
34 OpenAI,
35 Google,
36 #[strum(serialize = "xai")]
37 XAI,
38}
39
40impl ProviderKind {
41 pub fn display_name(&self) -> String {
42 match self {
43 ProviderKind::Anthropic => "Anthropic".to_string(),
44 ProviderKind::OpenAI => "OpenAI".to_string(),
45 ProviderKind::Google => "Google".to_string(),
46 ProviderKind::XAI => "xAI".to_string(),
47 }
48 }
49}
50
51#[derive(
52 Debug,
53 Clone,
54 Copy,
55 PartialEq,
56 Eq,
57 Hash,
58 EnumIter,
59 EnumString,
60 AsRefStr,
61 Display,
62 IntoStaticStr,
63 serde::Serialize,
64 serde::Deserialize,
65 Default,
66)]
67pub enum Model {
68 #[strum(serialize = "claude-3-5-sonnet-20240620")]
69 Claude3_5Sonnet20240620,
70 #[strum(serialize = "claude-3-5-sonnet-20241022")]
71 Claude3_5Sonnet20241022,
72 #[strum(serialize = "claude-3-7-sonnet-20250219")]
73 Claude3_7Sonnet20250219,
74 #[strum(serialize = "claude-3-5-haiku-20241022")]
75 Claude3_5Haiku20241022,
76 #[strum(serialize = "claude-sonnet-4-20250514", serialize = "sonnet")]
77 ClaudeSonnet4_20250514,
78 #[strum(serialize = "claude-opus-4-20250514", serialize = "opus")]
79 #[default]
80 ClaudeOpus4_20250514,
81 #[strum(serialize = "gpt-4.1-2025-04-14")]
82 Gpt4_1_20250414,
83 #[strum(serialize = "gpt-4.1-mini-2025-04-14")]
84 Gpt4_1Mini20250414,
85 #[strum(serialize = "gpt-4.1-nano-2025-04-14")]
86 Gpt4_1Nano20250414,
87 #[strum(serialize = "o3-2025-04-16", serialize = "o3")]
88 O3_20250416,
89 #[strum(serialize = "o3-pro-2025-06-10", serialize = "o3-pro")]
90 O3Pro20250610,
91 #[strum(serialize = "o4-mini-2025-04-16", serialize = "o4-mini")]
92 O4Mini20250416,
93 #[strum(serialize = "gemini-2.5-flash-preview-04-17")]
94 Gemini2_5FlashPreview0417,
95 #[strum(serialize = "gemini-2.5-pro-preview-05-06")]
96 Gemini2_5ProPreview0506,
97 #[strum(serialize = "gemini-2.5-pro-preview-06-05", serialize = "gemini")]
98 Gemini2_5ProPreview0605,
99 #[strum(serialize = "grok-3")]
100 Grok3,
101 #[strum(serialize = "grok-3-mini", serialize = "grok-mini")]
102 Grok3Mini,
103 #[strum(serialize = "grok-4-0709", serialize = "grok")]
104 Grok4_0709,
105}
106
107impl Model {
108 pub fn should_show(&self) -> bool {
110 matches!(
111 self,
112 Model::ClaudeOpus4_20250514
113 | Model::ClaudeSonnet4_20250514
114 | Model::O3_20250416
115 | Model::O3Pro20250610
116 | Model::Gemini2_5ProPreview0605
117 | Model::Grok4_0709
118 | Model::Grok3
119 | Model::Gpt4_1_20250414
120 | Model::O4Mini20250416
121 )
122 }
123
124 pub fn iter_recommended() -> impl Iterator<Item = Model> {
125 use strum::IntoEnumIterator;
126 Model::iter().filter(|m| m.should_show())
127 }
128
129 pub fn provider(&self) -> ProviderKind {
130 match self {
131 Model::Claude3_7Sonnet20250219
132 | Model::Claude3_5Sonnet20240620
133 | Model::Claude3_5Sonnet20241022
134 | Model::Claude3_5Haiku20241022
135 | Model::ClaudeSonnet4_20250514
136 | Model::ClaudeOpus4_20250514 => ProviderKind::Anthropic,
137
138 Model::Gpt4_1_20250414
139 | Model::Gpt4_1Mini20250414
140 | Model::Gpt4_1Nano20250414
141 | Model::O3_20250416
142 | Model::O3Pro20250610
143 | Model::O4Mini20250416 => ProviderKind::OpenAI,
144
145 Model::Gemini2_5FlashPreview0417
146 | Model::Gemini2_5ProPreview0506
147 | Model::Gemini2_5ProPreview0605 => ProviderKind::Google,
148
149 Model::Grok3 | Model::Grok3Mini | Model::Grok4_0709 => ProviderKind::XAI,
150 }
151 }
152
153 pub fn aliases(&self) -> Vec<&'static str> {
154 match self {
155 Model::ClaudeSonnet4_20250514 => vec!["sonnet"],
156 Model::ClaudeOpus4_20250514 => vec!["opus"],
157 Model::O3_20250416 => vec!["o3"],
158 Model::O3Pro20250610 => vec!["o3-pro"],
159 Model::O4Mini20250416 => vec!["o4-mini"],
160 Model::Gemini2_5ProPreview0605 => vec!["gemini"],
161 Model::Grok3 => vec![],
162 Model::Grok3Mini => vec!["grok-mini"],
163 Model::Grok4_0709 => vec!["grok"],
164 _ => vec![],
165 }
166 }
167
168 pub fn supports_thinking(&self) -> bool {
169 matches!(
170 self,
171 Model::Claude3_7Sonnet20250219
172 | Model::ClaudeSonnet4_20250514
173 | Model::ClaudeOpus4_20250514
174 | Model::O3_20250416
175 | Model::O3Pro20250610
176 | Model::O4Mini20250416
177 | Model::Gemini2_5FlashPreview0417
178 | Model::Gemini2_5ProPreview0506
179 | Model::Gemini2_5ProPreview0605
180 | Model::Grok3Mini
181 | Model::Grok4_0709
182 )
183 }
184
185 pub fn default_system_prompt_file(&self) -> Option<&'static str> {
186 match self {
187 Model::O3_20250416 => Some("models/o3.md"),
188 Model::O3Pro20250610 => Some("models/o3.md"),
189 Model::O4Mini20250416 => Some("models/o3.md"),
190 _ => None,
191 }
192 }
193
194 pub fn all() -> Vec<Model> {
196 use strum::IntoEnumIterator;
197 Model::iter().collect()
198 }
199}
200
201#[derive(Clone)]
202pub struct Client {
203 provider_map: Arc<RwLock<HashMap<Model, Arc<dyn Provider>>>>,
204 config_provider: LlmConfigProvider,
205}
206
207impl Client {
208 pub fn new_with_provider(provider: LlmConfigProvider) -> Self {
209 Self {
210 provider_map: Arc::new(RwLock::new(HashMap::new())),
211 config_provider: provider,
212 }
213 }
214
215 async fn get_or_create_provider(&self, model: Model) -> Result<Arc<dyn Provider>> {
216 {
218 let map = self.provider_map.read().unwrap();
219 if let Some(provider) = map.get(&model) {
220 return Ok(provider.clone());
221 }
222 }
223
224 let provider_kind = model.provider();
226 let auth = self
227 .config_provider
228 .get_auth_for_provider(provider_kind)
229 .await?;
230
231 let mut map = self.provider_map.write().unwrap();
233
234 if let Some(provider) = map.get(&model) {
236 return Ok(provider.clone());
237 }
238
239 let provider_instance: Arc<dyn Provider> = match auth {
241 Some(ApiAuth::OAuth) => {
242 if provider_kind == ProviderKind::Anthropic {
243 let storage = self.config_provider.auth_storage();
244 Arc::new(AnthropicClient::with_oauth(storage.clone()))
245 } else {
246 return Err(crate::error::Error::Api(ApiError::Configuration(format!(
247 "OAuth is not supported for {provider_kind:?} provider"
248 ))));
249 }
250 }
251 Some(ApiAuth::Key(key)) => match provider_kind {
252 ProviderKind::Anthropic => Arc::new(AnthropicClient::with_api_key(&key)),
253 ProviderKind::OpenAI => Arc::new(OpenAIClient::new(key)),
254 ProviderKind::Google => Arc::new(GeminiClient::new(&key)),
255 ProviderKind::XAI => Arc::new(XAIClient::new(key)),
256 },
257
258 None => {
259 return Err(crate::error::Error::Api(ApiError::Configuration(format!(
260 "No authentication configured for {provider_kind:?} needed by model {model:?}"
261 ))));
262 }
263 };
264 map.insert(model, provider_instance.clone());
265 Ok(provider_instance)
266 }
267
268 pub async fn complete(
269 &self,
270 model: Model,
271 messages: Vec<Message>,
272 system: Option<String>,
273 tools: Option<Vec<ToolSchema>>,
274 token: CancellationToken,
275 ) -> std::result::Result<CompletionResponse, ApiError> {
276 let provider = self
277 .get_or_create_provider(model)
278 .await
279 .map_err(ApiError::from)?;
280
281 if token.is_cancelled() {
282 return Err(ApiError::Cancelled {
283 provider: provider.name().to_string(),
284 });
285 }
286
287 provider
288 .complete(model, messages, system, tools, token)
289 .await
290 }
291
292 pub async fn complete_with_retry(
293 &self,
294 model: Model,
295 messages: &[Message],
296 system_prompt: &Option<String>,
297 tools: &Option<Vec<ToolSchema>>,
298 token: CancellationToken,
299 max_attempts: usize,
300 ) -> std::result::Result<CompletionResponse, ApiError> {
301 let mut attempts = 0;
302 debug!(
303 target: "api::complete",
304 model =% model,
305 "system: {:?}",
306 system_prompt
307 );
308 debug!(
309 target: "api::complete",
310 model =% model,
311 "messages: {:?}",
312 messages
313 );
314 loop {
315 match self
316 .complete(
317 model,
318 messages.to_vec(),
319 system_prompt.clone(),
320 tools.clone(),
321 token.clone(),
322 )
323 .await
324 {
325 Ok(response) => {
326 return Ok(response);
327 }
328 Err(error) => {
329 attempts += 1;
330 warn!(
331 "API completion attempt {}/{} failed for model {}: {:?}",
332 attempts,
333 max_attempts,
334 model.as_ref(),
335 error
336 );
337
338 if attempts >= max_attempts {
339 return Err(error);
340 }
341
342 match error {
343 ApiError::RateLimited { provider, details } => {
344 let sleep_duration =
345 std::time::Duration::from_secs(1 << (attempts - 1));
346 warn!(
347 "Rate limited by API: {} {} (retrying in {} seconds)",
348 provider,
349 details,
350 sleep_duration.as_secs()
351 );
352 tokio::time::sleep(sleep_duration).await;
353 }
354 ApiError::NoChoices { provider } => {
355 warn!("No choices returned from API: {}", provider);
356 }
357 ApiError::ServerError {
358 provider,
359 status_code,
360 details,
361 } => {
362 warn!(
363 "Server error for API: {} {} {}",
364 provider, status_code, details
365 );
366 }
367 _ => {
368 return Err(error);
370 }
371 }
372 }
373 }
374 }
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use std::str::FromStr;
382
383 #[test]
384 fn test_model_from_str() {
385 let model = Model::from_str("claude-3-7-sonnet-20250219").unwrap();
386 assert_eq!(model, Model::Claude3_7Sonnet20250219);
387 }
388
389 #[test]
390 fn test_model_aliases() {
391 assert_eq!(
393 Model::from_str("sonnet").unwrap(),
394 Model::ClaudeSonnet4_20250514
395 );
396 assert_eq!(
397 Model::from_str("opus").unwrap(),
398 Model::ClaudeOpus4_20250514
399 );
400 assert_eq!(Model::from_str("o3").unwrap(), Model::O3_20250416);
401 assert_eq!(Model::from_str("o3-pro").unwrap(), Model::O3Pro20250610);
402 assert_eq!(
403 Model::from_str("gemini").unwrap(),
404 Model::Gemini2_5ProPreview0605
405 );
406 assert_eq!(Model::from_str("grok").unwrap(), Model::Grok4_0709);
407 assert_eq!(Model::from_str("grok-mini").unwrap(), Model::Grok3Mini);
408
409 assert_eq!(
411 Model::from_str("claude-sonnet-4-20250514").unwrap(),
412 Model::ClaudeSonnet4_20250514
413 );
414 assert_eq!(
415 Model::from_str("o3-2025-04-16").unwrap(),
416 Model::O3_20250416
417 );
418
419 assert_eq!(
420 Model::from_str("o4-mini-2025-04-16").unwrap(),
421 Model::O4Mini20250416
422 );
423 assert_eq!(Model::from_str("grok-3").unwrap(), Model::Grok3);
424 assert_eq!(Model::from_str("grok").unwrap(), Model::Grok4_0709);
425 assert_eq!(Model::from_str("grok-4-0709").unwrap(), Model::Grok4_0709);
426 assert_eq!(Model::from_str("grok-3-mini").unwrap(), Model::Grok3Mini);
427 assert_eq!(Model::from_str("grok-mini").unwrap(), Model::Grok3Mini);
428 }
429}