steer_core/api/
mod.rs

1pub mod claude;
2pub mod error;
3pub mod gemini;
4pub mod openai;
5pub mod provider;
6pub mod xai;
7
8use crate::config::{ApiAuth, LlmConfigProvider};
9use crate::error::Result;
10pub use claude::AnthropicClient;
11pub use error::ApiError;
12pub use gemini::GeminiClient;
13pub use openai::OpenAIClient;
14pub use provider::{CompletionResponse, Provider};
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::RwLock;
18pub use steer_tools::{InputSchema, ToolCall, ToolSchema};
19use strum::Display;
20use strum::EnumIter;
21use strum::IntoStaticStr;
22use strum_macros::{AsRefStr, EnumString};
23use tokio_util::sync::CancellationToken;
24use tracing::debug;
25use tracing::warn;
26pub use xai::XAIClient;
27
28use crate::app::conversation::Message;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Display, IntoStaticStr)]
31#[strum(serialize_all = "lowercase")]
32pub enum ProviderKind {
33    Anthropic,
34    OpenAI,
35    Google,
36    #[strum(serialize = "xai")]
37    XAI,
38}
39
40impl ProviderKind {
41    pub fn display_name(&self) -> String {
42        match self {
43            ProviderKind::Anthropic => "Anthropic".to_string(),
44            ProviderKind::OpenAI => "OpenAI".to_string(),
45            ProviderKind::Google => "Google".to_string(),
46            ProviderKind::XAI => "xAI".to_string(),
47        }
48    }
49}
50
51#[derive(
52    Debug,
53    Clone,
54    Copy,
55    PartialEq,
56    Eq,
57    Hash,
58    EnumIter,
59    EnumString,
60    AsRefStr,
61    Display,
62    IntoStaticStr,
63    serde::Serialize,
64    serde::Deserialize,
65    Default,
66)]
67pub enum Model {
68    #[strum(serialize = "claude-3-5-sonnet-20240620")]
69    Claude3_5Sonnet20240620,
70    #[strum(serialize = "claude-3-5-sonnet-20241022")]
71    Claude3_5Sonnet20241022,
72    #[strum(serialize = "claude-3-7-sonnet-20250219")]
73    Claude3_7Sonnet20250219,
74    #[strum(serialize = "claude-3-5-haiku-20241022")]
75    Claude3_5Haiku20241022,
76    #[strum(serialize = "claude-sonnet-4-20250514", serialize = "sonnet")]
77    ClaudeSonnet4_20250514,
78    #[strum(serialize = "claude-opus-4-20250514", serialize = "opus-4")]
79    ClaudeOpus4_20250514,
80    #[strum(
81        serialize = "claude-opus-4-1-20250805",
82        serialize = "opus",
83        serialize = "opus-4-1"
84    )]
85    #[default]
86    ClaudeOpus4_1_20250805,
87    #[strum(serialize = "gpt-4.1-2025-04-14")]
88    Gpt4_1_20250414,
89    #[strum(serialize = "gpt-4.1-mini-2025-04-14")]
90    Gpt4_1Mini20250414,
91    #[strum(serialize = "gpt-4.1-nano-2025-04-14")]
92    Gpt4_1Nano20250414,
93    #[strum(serialize = "o3-2025-04-16", serialize = "o3")]
94    O3_20250416,
95    #[strum(serialize = "o3-pro-2025-06-10", serialize = "o3-pro")]
96    O3Pro20250610,
97    #[strum(serialize = "o4-mini-2025-04-16", serialize = "o4-mini")]
98    O4Mini20250416,
99    #[strum(serialize = "gemini-2.5-flash-preview-04-17")]
100    Gemini2_5FlashPreview0417,
101    #[strum(serialize = "gemini-2.5-pro-preview-05-06")]
102    Gemini2_5ProPreview0506,
103    #[strum(serialize = "gemini-2.5-pro-preview-06-05", serialize = "gemini")]
104    Gemini2_5ProPreview0605,
105    #[strum(serialize = "grok-3")]
106    Grok3,
107    #[strum(serialize = "grok-3-mini", serialize = "grok-mini")]
108    Grok3Mini,
109    #[strum(serialize = "grok-4-0709", serialize = "grok")]
110    Grok4_0709,
111}
112
113impl Model {
114    /// Returns true if this model should be shown in the model picker UI
115    pub fn should_show(&self) -> bool {
116        matches!(
117            self,
118            Model::ClaudeOpus4_20250514
119                | Model::ClaudeOpus4_1_20250805
120                | Model::ClaudeSonnet4_20250514
121                | Model::O3_20250416
122                | Model::O3Pro20250610
123                | Model::Gemini2_5ProPreview0605
124                | Model::Grok4_0709
125                | Model::Grok3
126                | Model::Gpt4_1_20250414
127                | Model::O4Mini20250416
128        )
129    }
130
131    pub fn iter_recommended() -> impl Iterator<Item = Model> {
132        use strum::IntoEnumIterator;
133        Model::iter().filter(|m| m.should_show())
134    }
135
136    pub fn provider(&self) -> ProviderKind {
137        match self {
138            Model::Claude3_7Sonnet20250219
139            | Model::Claude3_5Sonnet20240620
140            | Model::Claude3_5Sonnet20241022
141            | Model::Claude3_5Haiku20241022
142            | Model::ClaudeSonnet4_20250514
143            | Model::ClaudeOpus4_20250514
144            | Model::ClaudeOpus4_1_20250805 => ProviderKind::Anthropic,
145
146            Model::Gpt4_1_20250414
147            | Model::Gpt4_1Mini20250414
148            | Model::Gpt4_1Nano20250414
149            | Model::O3_20250416
150            | Model::O3Pro20250610
151            | Model::O4Mini20250416 => ProviderKind::OpenAI,
152
153            Model::Gemini2_5FlashPreview0417
154            | Model::Gemini2_5ProPreview0506
155            | Model::Gemini2_5ProPreview0605 => ProviderKind::Google,
156
157            Model::Grok3 | Model::Grok3Mini | Model::Grok4_0709 => ProviderKind::XAI,
158        }
159    }
160
161    pub fn aliases(&self) -> Vec<&'static str> {
162        match self {
163            Model::ClaudeSonnet4_20250514 => vec!["sonnet"],
164            Model::ClaudeOpus4_20250514 => vec!["opus-4-0"],
165            Model::ClaudeOpus4_1_20250805 => vec!["opus-4-1", "opus"],
166            Model::O3_20250416 => vec!["o3"],
167            Model::O3Pro20250610 => vec!["o3-pro"],
168            Model::O4Mini20250416 => vec!["o4-mini"],
169            Model::Gemini2_5ProPreview0605 => vec!["gemini"],
170            Model::Grok3 => vec![],
171            Model::Grok3Mini => vec!["grok-mini"],
172            Model::Grok4_0709 => vec!["grok"],
173            _ => vec![],
174        }
175    }
176
177    pub fn supports_thinking(&self) -> bool {
178        matches!(
179            self,
180            Model::Claude3_7Sonnet20250219
181                | Model::ClaudeSonnet4_20250514
182                | Model::ClaudeOpus4_20250514
183                | Model::ClaudeOpus4_1_20250805
184                | Model::O3_20250416
185                | Model::O3Pro20250610
186                | Model::O4Mini20250416
187                | Model::Gemini2_5FlashPreview0417
188                | Model::Gemini2_5ProPreview0506
189                | Model::Gemini2_5ProPreview0605
190                | Model::Grok3Mini
191                | Model::Grok4_0709
192        )
193    }
194
195    pub fn default_system_prompt_file(&self) -> Option<&'static str> {
196        match self {
197            Model::O3_20250416 => Some("models/o3.md"),
198            Model::O3Pro20250610 => Some("models/o3.md"),
199            Model::O4Mini20250416 => Some("models/o3.md"),
200            _ => None,
201        }
202    }
203
204    /// Get all available models
205    pub fn all() -> Vec<Model> {
206        use strum::IntoEnumIterator;
207        Model::iter().collect()
208    }
209}
210
211#[derive(Clone)]
212pub struct Client {
213    provider_map: Arc<RwLock<HashMap<Model, Arc<dyn Provider>>>>,
214    config_provider: LlmConfigProvider,
215}
216
217impl Client {
218    pub fn new_with_provider(provider: LlmConfigProvider) -> Self {
219        Self {
220            provider_map: Arc::new(RwLock::new(HashMap::new())),
221            config_provider: provider,
222        }
223    }
224
225    async fn get_or_create_provider(&self, model: Model) -> Result<Arc<dyn Provider>> {
226        // First check without holding the lock across await
227        {
228            let map = self.provider_map.read().unwrap();
229            if let Some(provider) = map.get(&model) {
230                return Ok(provider.clone());
231            }
232        }
233
234        // Get provider kind and auth before acquiring write lock
235        let provider_kind = model.provider();
236        let auth = self
237            .config_provider
238            .get_auth_for_provider(provider_kind)
239            .await?;
240
241        // Now acquire write lock and create provider
242        let mut map = self.provider_map.write().unwrap();
243
244        // Check again in case another thread added it
245        if let Some(provider) = map.get(&model) {
246            return Ok(provider.clone());
247        }
248
249        // Create and insert the provider
250        let provider_instance: Arc<dyn Provider> = match auth {
251            Some(ApiAuth::OAuth) => {
252                if provider_kind == ProviderKind::Anthropic {
253                    let storage = self.config_provider.auth_storage();
254                    Arc::new(AnthropicClient::with_oauth(storage.clone()))
255                } else {
256                    return Err(crate::error::Error::Api(ApiError::Configuration(format!(
257                        "OAuth is not supported for {provider_kind:?} provider"
258                    ))));
259                }
260            }
261            Some(ApiAuth::Key(key)) => match provider_kind {
262                ProviderKind::Anthropic => Arc::new(AnthropicClient::with_api_key(&key)),
263                ProviderKind::OpenAI => Arc::new(OpenAIClient::new(key)),
264                ProviderKind::Google => Arc::new(GeminiClient::new(&key)),
265                ProviderKind::XAI => Arc::new(XAIClient::new(key)),
266            },
267
268            None => {
269                return Err(crate::error::Error::Api(ApiError::Configuration(format!(
270                    "No authentication configured for {provider_kind:?} needed by model {model:?}"
271                ))));
272            }
273        };
274        map.insert(model, provider_instance.clone());
275        Ok(provider_instance)
276    }
277
278    pub async fn complete(
279        &self,
280        model: Model,
281        messages: Vec<Message>,
282        system: Option<String>,
283        tools: Option<Vec<ToolSchema>>,
284        token: CancellationToken,
285    ) -> std::result::Result<CompletionResponse, ApiError> {
286        let provider = self
287            .get_or_create_provider(model)
288            .await
289            .map_err(ApiError::from)?;
290
291        if token.is_cancelled() {
292            return Err(ApiError::Cancelled {
293                provider: provider.name().to_string(),
294            });
295        }
296
297        provider
298            .complete(model, messages, system, tools, token)
299            .await
300    }
301
302    pub async fn complete_with_retry(
303        &self,
304        model: Model,
305        messages: &[Message],
306        system_prompt: &Option<String>,
307        tools: &Option<Vec<ToolSchema>>,
308        token: CancellationToken,
309        max_attempts: usize,
310    ) -> std::result::Result<CompletionResponse, ApiError> {
311        let mut attempts = 0;
312        debug!(
313            target: "api::complete",
314            model =% model,
315            "system: {:?}",
316            system_prompt
317        );
318        debug!(
319            target: "api::complete",
320            model =% model,
321            "messages: {:?}",
322            messages
323        );
324        loop {
325            match self
326                .complete(
327                    model,
328                    messages.to_vec(),
329                    system_prompt.clone(),
330                    tools.clone(),
331                    token.clone(),
332                )
333                .await
334            {
335                Ok(response) => {
336                    return Ok(response);
337                }
338                Err(error) => {
339                    attempts += 1;
340                    warn!(
341                        "API completion attempt {}/{} failed for model {}: {:?}",
342                        attempts,
343                        max_attempts,
344                        model.as_ref(),
345                        error
346                    );
347
348                    if attempts >= max_attempts {
349                        return Err(error);
350                    }
351
352                    match error {
353                        ApiError::RateLimited { provider, details } => {
354                            let sleep_duration =
355                                std::time::Duration::from_secs(1 << (attempts - 1));
356                            warn!(
357                                "Rate limited by API: {} {} (retrying in {} seconds)",
358                                provider,
359                                details,
360                                sleep_duration.as_secs()
361                            );
362                            tokio::time::sleep(sleep_duration).await;
363                        }
364                        ApiError::NoChoices { provider } => {
365                            warn!("No choices returned from API: {}", provider);
366                        }
367                        ApiError::ServerError {
368                            provider,
369                            status_code,
370                            details,
371                        } => {
372                            warn!(
373                                "Server error for API: {} {} {}",
374                                provider, status_code, details
375                            );
376                        }
377                        _ => {
378                            // Not retryable
379                            return Err(error);
380                        }
381                    }
382                }
383            }
384        }
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use std::str::FromStr;
392
393    #[test]
394    fn test_model_from_str() {
395        let model = Model::from_str("claude-3-7-sonnet-20250219").unwrap();
396        assert_eq!(model, Model::Claude3_7Sonnet20250219);
397    }
398
399    #[test]
400    fn test_model_aliases() {
401        // Test short aliases
402        assert_eq!(
403            Model::from_str("sonnet").unwrap(),
404            Model::ClaudeSonnet4_20250514
405        );
406        assert_eq!(
407            Model::from_str("opus").unwrap(),
408            Model::ClaudeOpus4_1_20250805
409        );
410        assert_eq!(Model::from_str("o3").unwrap(), Model::O3_20250416);
411        assert_eq!(Model::from_str("o3-pro").unwrap(), Model::O3Pro20250610);
412        assert_eq!(
413            Model::from_str("gemini").unwrap(),
414            Model::Gemini2_5ProPreview0605
415        );
416        assert_eq!(Model::from_str("grok").unwrap(), Model::Grok4_0709);
417        assert_eq!(Model::from_str("grok-mini").unwrap(), Model::Grok3Mini);
418
419        // Also test the full names work
420        assert_eq!(
421            Model::from_str("claude-sonnet-4-20250514").unwrap(),
422            Model::ClaudeSonnet4_20250514
423        );
424        assert_eq!(
425            Model::from_str("o3-2025-04-16").unwrap(),
426            Model::O3_20250416
427        );
428
429        assert_eq!(
430            Model::from_str("o4-mini-2025-04-16").unwrap(),
431            Model::O4Mini20250416
432        );
433        assert_eq!(Model::from_str("grok-3").unwrap(), Model::Grok3);
434        assert_eq!(Model::from_str("grok").unwrap(), Model::Grok4_0709);
435        assert_eq!(Model::from_str("grok-4-0709").unwrap(), Model::Grok4_0709);
436        assert_eq!(Model::from_str("grok-3-mini").unwrap(), Model::Grok3Mini);
437        assert_eq!(Model::from_str("grok-mini").unwrap(), Model::Grok3Mini);
438    }
439}