1pub mod claude;
2pub mod error;
3pub mod gemini;
4pub mod openai;
5pub mod provider;
6pub mod xai;
7
8use crate::config::{ApiAuth, LlmConfigProvider};
9use crate::error::Result;
10pub use claude::AnthropicClient;
11pub use error::ApiError;
12pub use gemini::GeminiClient;
13pub use openai::OpenAIClient;
14pub use provider::{CompletionResponse, Provider};
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::RwLock;
18pub use steer_tools::{InputSchema, ToolCall, ToolSchema};
19use strum::Display;
20use strum::EnumIter;
21use strum::IntoStaticStr;
22use strum_macros::{AsRefStr, EnumString};
23use tokio_util::sync::CancellationToken;
24use tracing::debug;
25use tracing::warn;
26pub use xai::XAIClient;
27
28use crate::app::conversation::Message;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Display, IntoStaticStr)]
31#[strum(serialize_all = "lowercase")]
32pub enum ProviderKind {
33 Anthropic,
34 OpenAI,
35 Google,
36 #[strum(serialize = "xai")]
37 XAI,
38}
39
40impl ProviderKind {
41 pub fn display_name(&self) -> String {
42 match self {
43 ProviderKind::Anthropic => "Anthropic".to_string(),
44 ProviderKind::OpenAI => "OpenAI".to_string(),
45 ProviderKind::Google => "Google".to_string(),
46 ProviderKind::XAI => "xAI".to_string(),
47 }
48 }
49}
50
51#[derive(
52 Debug,
53 Clone,
54 Copy,
55 PartialEq,
56 Eq,
57 Hash,
58 EnumIter,
59 EnumString,
60 AsRefStr,
61 Display,
62 IntoStaticStr,
63 serde::Serialize,
64 serde::Deserialize,
65 Default,
66)]
67pub enum Model {
68 #[strum(serialize = "claude-3-5-sonnet-20240620")]
69 Claude3_5Sonnet20240620,
70 #[strum(serialize = "claude-3-5-sonnet-20241022")]
71 Claude3_5Sonnet20241022,
72 #[strum(serialize = "claude-3-7-sonnet-20250219")]
73 Claude3_7Sonnet20250219,
74 #[strum(serialize = "claude-3-5-haiku-20241022")]
75 Claude3_5Haiku20241022,
76 #[strum(serialize = "claude-sonnet-4-20250514", serialize = "sonnet")]
77 ClaudeSonnet4_20250514,
78 #[strum(serialize = "claude-opus-4-20250514", serialize = "opus-4")]
79 ClaudeOpus4_20250514,
80 #[strum(
81 serialize = "claude-opus-4-1-20250805",
82 serialize = "opus",
83 serialize = "opus-4-1"
84 )]
85 #[default]
86 ClaudeOpus4_1_20250805,
87 #[strum(serialize = "gpt-4.1-2025-04-14")]
88 Gpt4_1_20250414,
89 #[strum(serialize = "gpt-4.1-mini-2025-04-14")]
90 Gpt4_1Mini20250414,
91 #[strum(serialize = "gpt-4.1-nano-2025-04-14")]
92 Gpt4_1Nano20250414,
93 #[strum(serialize = "gpt-5-2025-08-07", serialize = "gpt-5")]
94 Gpt5_20250807,
95 #[strum(serialize = "o3-2025-04-16", serialize = "o3")]
96 O3_20250416,
97 #[strum(serialize = "o3-pro-2025-06-10", serialize = "o3-pro")]
98 O3Pro20250610,
99 #[strum(serialize = "o4-mini-2025-04-16", serialize = "o4-mini")]
100 O4Mini20250416,
101 #[strum(serialize = "gemini-2.5-flash-preview-04-17")]
102 Gemini2_5FlashPreview0417,
103 #[strum(serialize = "gemini-2.5-pro-preview-05-06")]
104 Gemini2_5ProPreview0506,
105 #[strum(serialize = "gemini-2.5-pro-preview-06-05", serialize = "gemini")]
106 Gemini2_5ProPreview0605,
107 #[strum(serialize = "grok-3")]
108 Grok3,
109 #[strum(serialize = "grok-3-mini", serialize = "grok-mini")]
110 Grok3Mini,
111 #[strum(serialize = "grok-4-0709", serialize = "grok")]
112 Grok4_0709,
113}
114
115impl Model {
116 pub fn should_show(&self) -> bool {
118 matches!(
119 self,
120 Model::ClaudeOpus4_20250514
121 | Model::ClaudeOpus4_1_20250805
122 | Model::ClaudeSonnet4_20250514
123 | Model::O3_20250416
124 | Model::O3Pro20250610
125 | Model::Gemini2_5ProPreview0605
126 | Model::Grok4_0709
127 | Model::Grok3
128 | Model::Gpt4_1_20250414
129 | Model::Gpt5_20250807
130 | Model::O4Mini20250416
131 )
132 }
133
134 pub fn iter_recommended() -> impl Iterator<Item = Model> {
135 use strum::IntoEnumIterator;
136 Model::iter().filter(|m| m.should_show())
137 }
138
139 pub fn provider(&self) -> ProviderKind {
140 match self {
141 Model::Claude3_7Sonnet20250219
142 | Model::Claude3_5Sonnet20240620
143 | Model::Claude3_5Sonnet20241022
144 | Model::Claude3_5Haiku20241022
145 | Model::ClaudeSonnet4_20250514
146 | Model::ClaudeOpus4_20250514
147 | Model::ClaudeOpus4_1_20250805 => ProviderKind::Anthropic,
148
149 Model::Gpt4_1_20250414
150 | Model::Gpt4_1Mini20250414
151 | Model::Gpt4_1Nano20250414
152 | Model::Gpt5_20250807
153 | Model::O3_20250416
154 | Model::O3Pro20250610
155 | Model::O4Mini20250416 => ProviderKind::OpenAI,
156
157 Model::Gemini2_5FlashPreview0417
158 | Model::Gemini2_5ProPreview0506
159 | Model::Gemini2_5ProPreview0605 => ProviderKind::Google,
160
161 Model::Grok3 | Model::Grok3Mini | Model::Grok4_0709 => ProviderKind::XAI,
162 }
163 }
164
165 pub fn aliases(&self) -> Vec<&'static str> {
166 match self {
167 Model::ClaudeSonnet4_20250514 => vec!["sonnet"],
168 Model::ClaudeOpus4_20250514 => vec!["opus-4-0"],
169 Model::ClaudeOpus4_1_20250805 => vec!["opus-4-1", "opus"],
170 Model::O3_20250416 => vec!["o3"],
171 Model::O3Pro20250610 => vec!["o3-pro"],
172 Model::O4Mini20250416 => vec!["o4-mini"],
173 Model::Gemini2_5ProPreview0605 => vec!["gemini"],
174 Model::Grok3 => vec![],
175 Model::Grok3Mini => vec!["grok-mini"],
176 Model::Grok4_0709 => vec!["grok"],
177 Model::Gpt5_20250807 => vec!["gpt-5"],
178 _ => vec![],
179 }
180 }
181
182 pub fn supports_thinking(&self) -> bool {
183 matches!(
184 self,
185 Model::Claude3_7Sonnet20250219
186 | Model::ClaudeSonnet4_20250514
187 | Model::ClaudeOpus4_20250514
188 | Model::ClaudeOpus4_1_20250805
189 | Model::Gpt5_20250807
190 | Model::O3_20250416
191 | Model::O3Pro20250610
192 | Model::O4Mini20250416
193 | Model::Gemini2_5FlashPreview0417
194 | Model::Gemini2_5ProPreview0506
195 | Model::Gemini2_5ProPreview0605
196 | Model::Grok3Mini
197 | Model::Grok4_0709
198 )
199 }
200
201 pub fn default_system_prompt_file(&self) -> Option<&'static str> {
202 match self {
203 Model::O3_20250416 => Some("models/o3.md"),
204 Model::O3Pro20250610 => Some("models/o3.md"),
205 Model::O4Mini20250416 => Some("models/o3.md"),
206 _ => None,
207 }
208 }
209
210 pub fn all() -> Vec<Model> {
212 use strum::IntoEnumIterator;
213 Model::iter().collect()
214 }
215}
216
217#[derive(Clone)]
218pub struct Client {
219 provider_map: Arc<RwLock<HashMap<Model, Arc<dyn Provider>>>>,
220 config_provider: LlmConfigProvider,
221}
222
223impl Client {
224 pub fn new_with_provider(provider: LlmConfigProvider) -> Self {
225 Self {
226 provider_map: Arc::new(RwLock::new(HashMap::new())),
227 config_provider: provider,
228 }
229 }
230
231 async fn get_or_create_provider(&self, model: Model) -> Result<Arc<dyn Provider>> {
232 {
234 let map = self.provider_map.read().unwrap();
235 if let Some(provider) = map.get(&model) {
236 return Ok(provider.clone());
237 }
238 }
239
240 let provider_kind = model.provider();
242 let auth = self
243 .config_provider
244 .get_auth_for_provider(provider_kind)
245 .await?;
246
247 let mut map = self.provider_map.write().unwrap();
249
250 if let Some(provider) = map.get(&model) {
252 return Ok(provider.clone());
253 }
254
255 let provider_instance: Arc<dyn Provider> = match auth {
257 Some(ApiAuth::OAuth) => {
258 if provider_kind == ProviderKind::Anthropic {
259 let storage = self.config_provider.auth_storage();
260 Arc::new(AnthropicClient::with_oauth(storage.clone()))
261 } else {
262 return Err(crate::error::Error::Api(ApiError::Configuration(format!(
263 "OAuth is not supported for {provider_kind:?} provider"
264 ))));
265 }
266 }
267 Some(ApiAuth::Key(key)) => match provider_kind {
268 ProviderKind::Anthropic => Arc::new(AnthropicClient::with_api_key(&key)),
269 ProviderKind::OpenAI => Arc::new(OpenAIClient::new(key)),
270 ProviderKind::Google => Arc::new(GeminiClient::new(&key)),
271 ProviderKind::XAI => Arc::new(XAIClient::new(key)),
272 },
273
274 None => {
275 return Err(crate::error::Error::Api(ApiError::Configuration(format!(
276 "No authentication configured for {provider_kind:?} needed by model {model:?}"
277 ))));
278 }
279 };
280 map.insert(model, provider_instance.clone());
281 Ok(provider_instance)
282 }
283
284 pub async fn complete(
285 &self,
286 model: Model,
287 messages: Vec<Message>,
288 system: Option<String>,
289 tools: Option<Vec<ToolSchema>>,
290 token: CancellationToken,
291 ) -> std::result::Result<CompletionResponse, ApiError> {
292 let provider = self
293 .get_or_create_provider(model)
294 .await
295 .map_err(ApiError::from)?;
296
297 if token.is_cancelled() {
298 return Err(ApiError::Cancelled {
299 provider: provider.name().to_string(),
300 });
301 }
302
303 provider
304 .complete(model, messages, system, tools, token)
305 .await
306 }
307
308 pub async fn complete_with_retry(
309 &self,
310 model: Model,
311 messages: &[Message],
312 system_prompt: &Option<String>,
313 tools: &Option<Vec<ToolSchema>>,
314 token: CancellationToken,
315 max_attempts: usize,
316 ) -> std::result::Result<CompletionResponse, ApiError> {
317 let mut attempts = 0;
318 debug!(
319 target: "api::complete",
320 model =% model,
321 "system: {:?}",
322 system_prompt
323 );
324 debug!(
325 target: "api::complete",
326 model =% model,
327 "messages: {:?}",
328 messages
329 );
330 loop {
331 match self
332 .complete(
333 model,
334 messages.to_vec(),
335 system_prompt.clone(),
336 tools.clone(),
337 token.clone(),
338 )
339 .await
340 {
341 Ok(response) => {
342 return Ok(response);
343 }
344 Err(error) => {
345 attempts += 1;
346 warn!(
347 "API completion attempt {}/{} failed for model {}: {:?}",
348 attempts,
349 max_attempts,
350 model.as_ref(),
351 error
352 );
353
354 if attempts >= max_attempts {
355 return Err(error);
356 }
357
358 match error {
359 ApiError::RateLimited { provider, details } => {
360 let sleep_duration =
361 std::time::Duration::from_secs(1 << (attempts - 1));
362 warn!(
363 "Rate limited by API: {} {} (retrying in {} seconds)",
364 provider,
365 details,
366 sleep_duration.as_secs()
367 );
368 tokio::time::sleep(sleep_duration).await;
369 }
370 ApiError::NoChoices { provider } => {
371 warn!("No choices returned from API: {}", provider);
372 }
373 ApiError::ServerError {
374 provider,
375 status_code,
376 details,
377 } => {
378 warn!(
379 "Server error for API: {} {} {}",
380 provider, status_code, details
381 );
382 }
383 _ => {
384 return Err(error);
386 }
387 }
388 }
389 }
390 }
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use std::str::FromStr;
398
399 #[test]
400 fn test_model_from_str() {
401 let model = Model::from_str("claude-3-7-sonnet-20250219").unwrap();
402 assert_eq!(model, Model::Claude3_7Sonnet20250219);
403 }
404
405 #[test]
406 fn test_model_aliases() {
407 assert_eq!(
409 Model::from_str("sonnet").unwrap(),
410 Model::ClaudeSonnet4_20250514
411 );
412 assert_eq!(
413 Model::from_str("opus").unwrap(),
414 Model::ClaudeOpus4_1_20250805
415 );
416 assert_eq!(Model::from_str("o3").unwrap(), Model::O3_20250416);
417 assert_eq!(Model::from_str("o3-pro").unwrap(), Model::O3Pro20250610);
418 assert_eq!(
419 Model::from_str("gemini").unwrap(),
420 Model::Gemini2_5ProPreview0605
421 );
422 assert_eq!(Model::from_str("grok").unwrap(), Model::Grok4_0709);
423 assert_eq!(Model::from_str("grok-mini").unwrap(), Model::Grok3Mini);
424
425 assert_eq!(
427 Model::from_str("claude-sonnet-4-20250514").unwrap(),
428 Model::ClaudeSonnet4_20250514
429 );
430 assert_eq!(
431 Model::from_str("o3-2025-04-16").unwrap(),
432 Model::O3_20250416
433 );
434
435 assert_eq!(
436 Model::from_str("o4-mini-2025-04-16").unwrap(),
437 Model::O4Mini20250416
438 );
439 assert_eq!(Model::from_str("grok-3").unwrap(), Model::Grok3);
440 assert_eq!(Model::from_str("grok").unwrap(), Model::Grok4_0709);
441 assert_eq!(Model::from_str("grok-4-0709").unwrap(), Model::Grok4_0709);
442 assert_eq!(Model::from_str("grok-3-mini").unwrap(), Model::Grok3Mini);
443 assert_eq!(Model::from_str("grok-mini").unwrap(), Model::Grok3Mini);
444 assert_eq!(
445 Model::from_str("gpt-5-2025-08-07").unwrap(),
446 Model::Gpt5_20250807
447 );
448 assert_eq!(Model::from_str("gpt-5").unwrap(), Model::Gpt5_20250807);
449 }
450}