1pub mod claude;
2pub mod error;
3pub mod gemini;
4pub mod openai;
5pub mod provider;
6pub mod xai;
7
8use crate::config::{ApiAuth, LlmConfigProvider};
9use crate::error::Result;
10pub use claude::AnthropicClient;
11pub use error::ApiError;
12pub use gemini::GeminiClient;
13pub use openai::OpenAIClient;
14pub use provider::{CompletionResponse, Provider};
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::RwLock;
18pub use steer_tools::{InputSchema, ToolCall, ToolSchema};
19use strum::Display;
20use strum::EnumIter;
21use strum::IntoStaticStr;
22use strum_macros::{AsRefStr, EnumString};
23use tokio_util::sync::CancellationToken;
24use tracing::debug;
25use tracing::warn;
26pub use xai::XAIClient;
27
28use crate::app::conversation::Message;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Display, IntoStaticStr)]
31#[strum(serialize_all = "lowercase")]
32pub enum ProviderKind {
33 Anthropic,
34 OpenAI,
35 Google,
36 #[strum(serialize = "xai")]
37 XAI,
38}
39
40impl ProviderKind {
41 pub fn display_name(&self) -> String {
42 match self {
43 ProviderKind::Anthropic => "Anthropic".to_string(),
44 ProviderKind::OpenAI => "OpenAI".to_string(),
45 ProviderKind::Google => "Google".to_string(),
46 ProviderKind::XAI => "xAI".to_string(),
47 }
48 }
49}
50
51#[derive(
52 Debug,
53 Clone,
54 Copy,
55 PartialEq,
56 Eq,
57 Hash,
58 EnumIter,
59 EnumString,
60 AsRefStr,
61 Display,
62 IntoStaticStr,
63 serde::Serialize,
64 serde::Deserialize,
65 Default,
66)]
67pub enum Model {
68 #[strum(serialize = "claude-3-5-sonnet-20240620")]
69 Claude3_5Sonnet20240620,
70 #[strum(serialize = "claude-3-5-sonnet-20241022")]
71 Claude3_5Sonnet20241022,
72 #[strum(serialize = "claude-3-7-sonnet-20250219")]
73 Claude3_7Sonnet20250219,
74 #[strum(serialize = "claude-3-5-haiku-20241022")]
75 Claude3_5Haiku20241022,
76 #[strum(serialize = "claude-sonnet-4-20250514", serialize = "sonnet")]
77 ClaudeSonnet4_20250514,
78 #[strum(serialize = "claude-opus-4-20250514", serialize = "opus-4")]
79 ClaudeOpus4_20250514,
80 #[strum(
81 serialize = "claude-opus-4-1-20250805",
82 serialize = "opus",
83 serialize = "opus-4-1"
84 )]
85 #[default]
86 ClaudeOpus4_1_20250805,
87 #[strum(serialize = "gpt-4.1-2025-04-14")]
88 Gpt4_1_20250414,
89 #[strum(serialize = "gpt-4.1-mini-2025-04-14")]
90 Gpt4_1Mini20250414,
91 #[strum(serialize = "gpt-4.1-nano-2025-04-14")]
92 Gpt4_1Nano20250414,
93 #[strum(serialize = "o3-2025-04-16", serialize = "o3")]
94 O3_20250416,
95 #[strum(serialize = "o3-pro-2025-06-10", serialize = "o3-pro")]
96 O3Pro20250610,
97 #[strum(serialize = "o4-mini-2025-04-16", serialize = "o4-mini")]
98 O4Mini20250416,
99 #[strum(serialize = "gemini-2.5-flash-preview-04-17")]
100 Gemini2_5FlashPreview0417,
101 #[strum(serialize = "gemini-2.5-pro-preview-05-06")]
102 Gemini2_5ProPreview0506,
103 #[strum(serialize = "gemini-2.5-pro-preview-06-05", serialize = "gemini")]
104 Gemini2_5ProPreview0605,
105 #[strum(serialize = "grok-3")]
106 Grok3,
107 #[strum(serialize = "grok-3-mini", serialize = "grok-mini")]
108 Grok3Mini,
109 #[strum(serialize = "grok-4-0709", serialize = "grok")]
110 Grok4_0709,
111}
112
113impl Model {
114 pub fn should_show(&self) -> bool {
116 matches!(
117 self,
118 Model::ClaudeOpus4_20250514
119 | Model::ClaudeOpus4_1_20250805
120 | Model::ClaudeSonnet4_20250514
121 | Model::O3_20250416
122 | Model::O3Pro20250610
123 | Model::Gemini2_5ProPreview0605
124 | Model::Grok4_0709
125 | Model::Grok3
126 | Model::Gpt4_1_20250414
127 | Model::O4Mini20250416
128 )
129 }
130
131 pub fn iter_recommended() -> impl Iterator<Item = Model> {
132 use strum::IntoEnumIterator;
133 Model::iter().filter(|m| m.should_show())
134 }
135
136 pub fn provider(&self) -> ProviderKind {
137 match self {
138 Model::Claude3_7Sonnet20250219
139 | Model::Claude3_5Sonnet20240620
140 | Model::Claude3_5Sonnet20241022
141 | Model::Claude3_5Haiku20241022
142 | Model::ClaudeSonnet4_20250514
143 | Model::ClaudeOpus4_20250514
144 | Model::ClaudeOpus4_1_20250805 => ProviderKind::Anthropic,
145
146 Model::Gpt4_1_20250414
147 | Model::Gpt4_1Mini20250414
148 | Model::Gpt4_1Nano20250414
149 | Model::O3_20250416
150 | Model::O3Pro20250610
151 | Model::O4Mini20250416 => ProviderKind::OpenAI,
152
153 Model::Gemini2_5FlashPreview0417
154 | Model::Gemini2_5ProPreview0506
155 | Model::Gemini2_5ProPreview0605 => ProviderKind::Google,
156
157 Model::Grok3 | Model::Grok3Mini | Model::Grok4_0709 => ProviderKind::XAI,
158 }
159 }
160
161 pub fn aliases(&self) -> Vec<&'static str> {
162 match self {
163 Model::ClaudeSonnet4_20250514 => vec!["sonnet"],
164 Model::ClaudeOpus4_20250514 => vec!["opus-4-0"],
165 Model::ClaudeOpus4_1_20250805 => vec!["opus-4-1", "opus"],
166 Model::O3_20250416 => vec!["o3"],
167 Model::O3Pro20250610 => vec!["o3-pro"],
168 Model::O4Mini20250416 => vec!["o4-mini"],
169 Model::Gemini2_5ProPreview0605 => vec!["gemini"],
170 Model::Grok3 => vec![],
171 Model::Grok3Mini => vec!["grok-mini"],
172 Model::Grok4_0709 => vec!["grok"],
173 _ => vec![],
174 }
175 }
176
177 pub fn supports_thinking(&self) -> bool {
178 matches!(
179 self,
180 Model::Claude3_7Sonnet20250219
181 | Model::ClaudeSonnet4_20250514
182 | Model::ClaudeOpus4_20250514
183 | Model::ClaudeOpus4_1_20250805
184 | Model::O3_20250416
185 | Model::O3Pro20250610
186 | Model::O4Mini20250416
187 | Model::Gemini2_5FlashPreview0417
188 | Model::Gemini2_5ProPreview0506
189 | Model::Gemini2_5ProPreview0605
190 | Model::Grok3Mini
191 | Model::Grok4_0709
192 )
193 }
194
195 pub fn default_system_prompt_file(&self) -> Option<&'static str> {
196 match self {
197 Model::O3_20250416 => Some("models/o3.md"),
198 Model::O3Pro20250610 => Some("models/o3.md"),
199 Model::O4Mini20250416 => Some("models/o3.md"),
200 _ => None,
201 }
202 }
203
204 pub fn all() -> Vec<Model> {
206 use strum::IntoEnumIterator;
207 Model::iter().collect()
208 }
209}
210
211#[derive(Clone)]
212pub struct Client {
213 provider_map: Arc<RwLock<HashMap<Model, Arc<dyn Provider>>>>,
214 config_provider: LlmConfigProvider,
215}
216
217impl Client {
218 pub fn new_with_provider(provider: LlmConfigProvider) -> Self {
219 Self {
220 provider_map: Arc::new(RwLock::new(HashMap::new())),
221 config_provider: provider,
222 }
223 }
224
225 async fn get_or_create_provider(&self, model: Model) -> Result<Arc<dyn Provider>> {
226 {
228 let map = self.provider_map.read().unwrap();
229 if let Some(provider) = map.get(&model) {
230 return Ok(provider.clone());
231 }
232 }
233
234 let provider_kind = model.provider();
236 let auth = self
237 .config_provider
238 .get_auth_for_provider(provider_kind)
239 .await?;
240
241 let mut map = self.provider_map.write().unwrap();
243
244 if let Some(provider) = map.get(&model) {
246 return Ok(provider.clone());
247 }
248
249 let provider_instance: Arc<dyn Provider> = match auth {
251 Some(ApiAuth::OAuth) => {
252 if provider_kind == ProviderKind::Anthropic {
253 let storage = self.config_provider.auth_storage();
254 Arc::new(AnthropicClient::with_oauth(storage.clone()))
255 } else {
256 return Err(crate::error::Error::Api(ApiError::Configuration(format!(
257 "OAuth is not supported for {provider_kind:?} provider"
258 ))));
259 }
260 }
261 Some(ApiAuth::Key(key)) => match provider_kind {
262 ProviderKind::Anthropic => Arc::new(AnthropicClient::with_api_key(&key)),
263 ProviderKind::OpenAI => Arc::new(OpenAIClient::new(key)),
264 ProviderKind::Google => Arc::new(GeminiClient::new(&key)),
265 ProviderKind::XAI => Arc::new(XAIClient::new(key)),
266 },
267
268 None => {
269 return Err(crate::error::Error::Api(ApiError::Configuration(format!(
270 "No authentication configured for {provider_kind:?} needed by model {model:?}"
271 ))));
272 }
273 };
274 map.insert(model, provider_instance.clone());
275 Ok(provider_instance)
276 }
277
278 pub async fn complete(
279 &self,
280 model: Model,
281 messages: Vec<Message>,
282 system: Option<String>,
283 tools: Option<Vec<ToolSchema>>,
284 token: CancellationToken,
285 ) -> std::result::Result<CompletionResponse, ApiError> {
286 let provider = self
287 .get_or_create_provider(model)
288 .await
289 .map_err(ApiError::from)?;
290
291 if token.is_cancelled() {
292 return Err(ApiError::Cancelled {
293 provider: provider.name().to_string(),
294 });
295 }
296
297 provider
298 .complete(model, messages, system, tools, token)
299 .await
300 }
301
302 pub async fn complete_with_retry(
303 &self,
304 model: Model,
305 messages: &[Message],
306 system_prompt: &Option<String>,
307 tools: &Option<Vec<ToolSchema>>,
308 token: CancellationToken,
309 max_attempts: usize,
310 ) -> std::result::Result<CompletionResponse, ApiError> {
311 let mut attempts = 0;
312 debug!(
313 target: "api::complete",
314 model =% model,
315 "system: {:?}",
316 system_prompt
317 );
318 debug!(
319 target: "api::complete",
320 model =% model,
321 "messages: {:?}",
322 messages
323 );
324 loop {
325 match self
326 .complete(
327 model,
328 messages.to_vec(),
329 system_prompt.clone(),
330 tools.clone(),
331 token.clone(),
332 )
333 .await
334 {
335 Ok(response) => {
336 return Ok(response);
337 }
338 Err(error) => {
339 attempts += 1;
340 warn!(
341 "API completion attempt {}/{} failed for model {}: {:?}",
342 attempts,
343 max_attempts,
344 model.as_ref(),
345 error
346 );
347
348 if attempts >= max_attempts {
349 return Err(error);
350 }
351
352 match error {
353 ApiError::RateLimited { provider, details } => {
354 let sleep_duration =
355 std::time::Duration::from_secs(1 << (attempts - 1));
356 warn!(
357 "Rate limited by API: {} {} (retrying in {} seconds)",
358 provider,
359 details,
360 sleep_duration.as_secs()
361 );
362 tokio::time::sleep(sleep_duration).await;
363 }
364 ApiError::NoChoices { provider } => {
365 warn!("No choices returned from API: {}", provider);
366 }
367 ApiError::ServerError {
368 provider,
369 status_code,
370 details,
371 } => {
372 warn!(
373 "Server error for API: {} {} {}",
374 provider, status_code, details
375 );
376 }
377 _ => {
378 return Err(error);
380 }
381 }
382 }
383 }
384 }
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use std::str::FromStr;
392
393 #[test]
394 fn test_model_from_str() {
395 let model = Model::from_str("claude-3-7-sonnet-20250219").unwrap();
396 assert_eq!(model, Model::Claude3_7Sonnet20250219);
397 }
398
399 #[test]
400 fn test_model_aliases() {
401 assert_eq!(
403 Model::from_str("sonnet").unwrap(),
404 Model::ClaudeSonnet4_20250514
405 );
406 assert_eq!(
407 Model::from_str("opus").unwrap(),
408 Model::ClaudeOpus4_1_20250805
409 );
410 assert_eq!(Model::from_str("o3").unwrap(), Model::O3_20250416);
411 assert_eq!(Model::from_str("o3-pro").unwrap(), Model::O3Pro20250610);
412 assert_eq!(
413 Model::from_str("gemini").unwrap(),
414 Model::Gemini2_5ProPreview0605
415 );
416 assert_eq!(Model::from_str("grok").unwrap(), Model::Grok4_0709);
417 assert_eq!(Model::from_str("grok-mini").unwrap(), Model::Grok3Mini);
418
419 assert_eq!(
421 Model::from_str("claude-sonnet-4-20250514").unwrap(),
422 Model::ClaudeSonnet4_20250514
423 );
424 assert_eq!(
425 Model::from_str("o3-2025-04-16").unwrap(),
426 Model::O3_20250416
427 );
428
429 assert_eq!(
430 Model::from_str("o4-mini-2025-04-16").unwrap(),
431 Model::O4Mini20250416
432 );
433 assert_eq!(Model::from_str("grok-3").unwrap(), Model::Grok3);
434 assert_eq!(Model::from_str("grok").unwrap(), Model::Grok4_0709);
435 assert_eq!(Model::from_str("grok-4-0709").unwrap(), Model::Grok4_0709);
436 assert_eq!(Model::from_str("grok-3-mini").unwrap(), Model::Grok3Mini);
437 assert_eq!(Model::from_str("grok-mini").unwrap(), Model::Grok3Mini);
438 }
439}