steer_core/api/
mod.rs

1pub mod claude;
2pub mod error;
3pub mod gemini;
4pub mod openai;
5pub mod provider;
6pub mod xai;
7
8use crate::config::{ApiAuth, LlmConfigProvider};
9use crate::error::Result;
10pub use claude::AnthropicClient;
11pub use error::ApiError;
12pub use gemini::GeminiClient;
13pub use openai::OpenAIClient;
14pub use provider::{CompletionResponse, Provider};
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::RwLock;
18pub use steer_tools::{InputSchema, ToolCall, ToolSchema};
19use strum::Display;
20use strum::EnumIter;
21use strum::IntoStaticStr;
22use strum_macros::{AsRefStr, EnumString};
23use tokio_util::sync::CancellationToken;
24use tracing::debug;
25use tracing::warn;
26pub use xai::XAIClient;
27
28use crate::app::conversation::Message;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Display, IntoStaticStr)]
31#[strum(serialize_all = "lowercase")]
32pub enum ProviderKind {
33    Anthropic,
34    OpenAI,
35    Google,
36    #[strum(serialize = "xai")]
37    XAI,
38}
39
40impl ProviderKind {
41    pub fn display_name(&self) -> String {
42        match self {
43            ProviderKind::Anthropic => "Anthropic".to_string(),
44            ProviderKind::OpenAI => "OpenAI".to_string(),
45            ProviderKind::Google => "Google".to_string(),
46            ProviderKind::XAI => "xAI".to_string(),
47        }
48    }
49}
50
51#[derive(
52    Debug,
53    Clone,
54    Copy,
55    PartialEq,
56    Eq,
57    Hash,
58    EnumIter,
59    EnumString,
60    AsRefStr,
61    Display,
62    IntoStaticStr,
63    serde::Serialize,
64    serde::Deserialize,
65    Default,
66)]
67pub enum Model {
68    #[strum(serialize = "claude-3-5-sonnet-20240620")]
69    Claude3_5Sonnet20240620,
70    #[strum(serialize = "claude-3-5-sonnet-20241022")]
71    Claude3_5Sonnet20241022,
72    #[strum(serialize = "claude-3-7-sonnet-20250219")]
73    Claude3_7Sonnet20250219,
74    #[strum(serialize = "claude-3-5-haiku-20241022")]
75    Claude3_5Haiku20241022,
76    #[strum(serialize = "claude-sonnet-4-20250514", serialize = "sonnet")]
77    ClaudeSonnet4_20250514,
78    #[strum(serialize = "claude-opus-4-20250514", serialize = "opus-4")]
79    ClaudeOpus4_20250514,
80    #[strum(
81        serialize = "claude-opus-4-1-20250805",
82        serialize = "opus",
83        serialize = "opus-4-1"
84    )]
85    #[default]
86    ClaudeOpus4_1_20250805,
87    #[strum(serialize = "gpt-4.1-2025-04-14")]
88    Gpt4_1_20250414,
89    #[strum(serialize = "gpt-4.1-mini-2025-04-14")]
90    Gpt4_1Mini20250414,
91    #[strum(serialize = "gpt-4.1-nano-2025-04-14")]
92    Gpt4_1Nano20250414,
93    #[strum(serialize = "gpt-5-2025-08-07", serialize = "gpt-5")]
94    Gpt5_20250807,
95    #[strum(serialize = "o3-2025-04-16", serialize = "o3")]
96    O3_20250416,
97    #[strum(serialize = "o3-pro-2025-06-10", serialize = "o3-pro")]
98    O3Pro20250610,
99    #[strum(serialize = "o4-mini-2025-04-16", serialize = "o4-mini")]
100    O4Mini20250416,
101    #[strum(serialize = "gemini-2.5-flash-preview-04-17")]
102    Gemini2_5FlashPreview0417,
103    #[strum(serialize = "gemini-2.5-pro-preview-05-06")]
104    Gemini2_5ProPreview0506,
105    #[strum(serialize = "gemini-2.5-pro-preview-06-05", serialize = "gemini")]
106    Gemini2_5ProPreview0605,
107    #[strum(serialize = "grok-3")]
108    Grok3,
109    #[strum(serialize = "grok-3-mini", serialize = "grok-mini")]
110    Grok3Mini,
111    #[strum(serialize = "grok-4-0709", serialize = "grok")]
112    Grok4_0709,
113}
114
115impl Model {
116    /// Returns true if this model should be shown in the model picker UI
117    pub fn should_show(&self) -> bool {
118        matches!(
119            self,
120            Model::ClaudeOpus4_20250514
121                | Model::ClaudeOpus4_1_20250805
122                | Model::ClaudeSonnet4_20250514
123                | Model::O3_20250416
124                | Model::O3Pro20250610
125                | Model::Gemini2_5ProPreview0605
126                | Model::Grok4_0709
127                | Model::Grok3
128                | Model::Gpt4_1_20250414
129                | Model::Gpt5_20250807
130                | Model::O4Mini20250416
131        )
132    }
133
134    pub fn iter_recommended() -> impl Iterator<Item = Model> {
135        use strum::IntoEnumIterator;
136        Model::iter().filter(|m| m.should_show())
137    }
138
139    pub fn provider(&self) -> ProviderKind {
140        match self {
141            Model::Claude3_7Sonnet20250219
142            | Model::Claude3_5Sonnet20240620
143            | Model::Claude3_5Sonnet20241022
144            | Model::Claude3_5Haiku20241022
145            | Model::ClaudeSonnet4_20250514
146            | Model::ClaudeOpus4_20250514
147            | Model::ClaudeOpus4_1_20250805 => ProviderKind::Anthropic,
148
149            Model::Gpt4_1_20250414
150            | Model::Gpt4_1Mini20250414
151            | Model::Gpt4_1Nano20250414
152            | Model::Gpt5_20250807
153            | Model::O3_20250416
154            | Model::O3Pro20250610
155            | Model::O4Mini20250416 => ProviderKind::OpenAI,
156
157            Model::Gemini2_5FlashPreview0417
158            | Model::Gemini2_5ProPreview0506
159            | Model::Gemini2_5ProPreview0605 => ProviderKind::Google,
160
161            Model::Grok3 | Model::Grok3Mini | Model::Grok4_0709 => ProviderKind::XAI,
162        }
163    }
164
165    pub fn aliases(&self) -> Vec<&'static str> {
166        match self {
167            Model::ClaudeSonnet4_20250514 => vec!["sonnet"],
168            Model::ClaudeOpus4_20250514 => vec!["opus-4-0"],
169            Model::ClaudeOpus4_1_20250805 => vec!["opus-4-1", "opus"],
170            Model::O3_20250416 => vec!["o3"],
171            Model::O3Pro20250610 => vec!["o3-pro"],
172            Model::O4Mini20250416 => vec!["o4-mini"],
173            Model::Gemini2_5ProPreview0605 => vec!["gemini"],
174            Model::Grok3 => vec![],
175            Model::Grok3Mini => vec!["grok-mini"],
176            Model::Grok4_0709 => vec!["grok"],
177            Model::Gpt5_20250807 => vec!["gpt-5"],
178            _ => vec![],
179        }
180    }
181
182    pub fn supports_thinking(&self) -> bool {
183        matches!(
184            self,
185            Model::Claude3_7Sonnet20250219
186                | Model::ClaudeSonnet4_20250514
187                | Model::ClaudeOpus4_20250514
188                | Model::ClaudeOpus4_1_20250805
189                | Model::Gpt5_20250807
190                | Model::O3_20250416
191                | Model::O3Pro20250610
192                | Model::O4Mini20250416
193                | Model::Gemini2_5FlashPreview0417
194                | Model::Gemini2_5ProPreview0506
195                | Model::Gemini2_5ProPreview0605
196                | Model::Grok3Mini
197                | Model::Grok4_0709
198        )
199    }
200
201    pub fn default_system_prompt_file(&self) -> Option<&'static str> {
202        match self {
203            Model::O3_20250416 => Some("models/o3.md"),
204            Model::O3Pro20250610 => Some("models/o3.md"),
205            Model::O4Mini20250416 => Some("models/o3.md"),
206            _ => None,
207        }
208    }
209
210    /// Get all available models
211    pub fn all() -> Vec<Model> {
212        use strum::IntoEnumIterator;
213        Model::iter().collect()
214    }
215}
216
217#[derive(Clone)]
218pub struct Client {
219    provider_map: Arc<RwLock<HashMap<Model, Arc<dyn Provider>>>>,
220    config_provider: LlmConfigProvider,
221}
222
223impl Client {
224    pub fn new_with_provider(provider: LlmConfigProvider) -> Self {
225        Self {
226            provider_map: Arc::new(RwLock::new(HashMap::new())),
227            config_provider: provider,
228        }
229    }
230
231    async fn get_or_create_provider(&self, model: Model) -> Result<Arc<dyn Provider>> {
232        // First check without holding the lock across await
233        {
234            let map = self.provider_map.read().unwrap();
235            if let Some(provider) = map.get(&model) {
236                return Ok(provider.clone());
237            }
238        }
239
240        // Get provider kind and auth before acquiring write lock
241        let provider_kind = model.provider();
242        let auth = self
243            .config_provider
244            .get_auth_for_provider(provider_kind)
245            .await?;
246
247        // Now acquire write lock and create provider
248        let mut map = self.provider_map.write().unwrap();
249
250        // Check again in case another thread added it
251        if let Some(provider) = map.get(&model) {
252            return Ok(provider.clone());
253        }
254
255        // Create and insert the provider
256        let provider_instance: Arc<dyn Provider> = match auth {
257            Some(ApiAuth::OAuth) => {
258                if provider_kind == ProviderKind::Anthropic {
259                    let storage = self.config_provider.auth_storage();
260                    Arc::new(AnthropicClient::with_oauth(storage.clone()))
261                } else {
262                    return Err(crate::error::Error::Api(ApiError::Configuration(format!(
263                        "OAuth is not supported for {provider_kind:?} provider"
264                    ))));
265                }
266            }
267            Some(ApiAuth::Key(key)) => match provider_kind {
268                ProviderKind::Anthropic => Arc::new(AnthropicClient::with_api_key(&key)),
269                ProviderKind::OpenAI => Arc::new(OpenAIClient::new(key)),
270                ProviderKind::Google => Arc::new(GeminiClient::new(&key)),
271                ProviderKind::XAI => Arc::new(XAIClient::new(key)),
272            },
273
274            None => {
275                return Err(crate::error::Error::Api(ApiError::Configuration(format!(
276                    "No authentication configured for {provider_kind:?} needed by model {model:?}"
277                ))));
278            }
279        };
280        map.insert(model, provider_instance.clone());
281        Ok(provider_instance)
282    }
283
284    pub async fn complete(
285        &self,
286        model: Model,
287        messages: Vec<Message>,
288        system: Option<String>,
289        tools: Option<Vec<ToolSchema>>,
290        token: CancellationToken,
291    ) -> std::result::Result<CompletionResponse, ApiError> {
292        let provider = self
293            .get_or_create_provider(model)
294            .await
295            .map_err(ApiError::from)?;
296
297        if token.is_cancelled() {
298            return Err(ApiError::Cancelled {
299                provider: provider.name().to_string(),
300            });
301        }
302
303        provider
304            .complete(model, messages, system, tools, token)
305            .await
306    }
307
308    pub async fn complete_with_retry(
309        &self,
310        model: Model,
311        messages: &[Message],
312        system_prompt: &Option<String>,
313        tools: &Option<Vec<ToolSchema>>,
314        token: CancellationToken,
315        max_attempts: usize,
316    ) -> std::result::Result<CompletionResponse, ApiError> {
317        let mut attempts = 0;
318        debug!(
319            target: "api::complete",
320            model =% model,
321            "system: {:?}",
322            system_prompt
323        );
324        debug!(
325            target: "api::complete",
326            model =% model,
327            "messages: {:?}",
328            messages
329        );
330        loop {
331            match self
332                .complete(
333                    model,
334                    messages.to_vec(),
335                    system_prompt.clone(),
336                    tools.clone(),
337                    token.clone(),
338                )
339                .await
340            {
341                Ok(response) => {
342                    return Ok(response);
343                }
344                Err(error) => {
345                    attempts += 1;
346                    warn!(
347                        "API completion attempt {}/{} failed for model {}: {:?}",
348                        attempts,
349                        max_attempts,
350                        model.as_ref(),
351                        error
352                    );
353
354                    if attempts >= max_attempts {
355                        return Err(error);
356                    }
357
358                    match error {
359                        ApiError::RateLimited { provider, details } => {
360                            let sleep_duration =
361                                std::time::Duration::from_secs(1 << (attempts - 1));
362                            warn!(
363                                "Rate limited by API: {} {} (retrying in {} seconds)",
364                                provider,
365                                details,
366                                sleep_duration.as_secs()
367                            );
368                            tokio::time::sleep(sleep_duration).await;
369                        }
370                        ApiError::NoChoices { provider } => {
371                            warn!("No choices returned from API: {}", provider);
372                        }
373                        ApiError::ServerError {
374                            provider,
375                            status_code,
376                            details,
377                        } => {
378                            warn!(
379                                "Server error for API: {} {} {}",
380                                provider, status_code, details
381                            );
382                        }
383                        _ => {
384                            // Not retryable
385                            return Err(error);
386                        }
387                    }
388                }
389            }
390        }
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use std::str::FromStr;
398
399    #[test]
400    fn test_model_from_str() {
401        let model = Model::from_str("claude-3-7-sonnet-20250219").unwrap();
402        assert_eq!(model, Model::Claude3_7Sonnet20250219);
403    }
404
405    #[test]
406    fn test_model_aliases() {
407        // Test short aliases
408        assert_eq!(
409            Model::from_str("sonnet").unwrap(),
410            Model::ClaudeSonnet4_20250514
411        );
412        assert_eq!(
413            Model::from_str("opus").unwrap(),
414            Model::ClaudeOpus4_1_20250805
415        );
416        assert_eq!(Model::from_str("o3").unwrap(), Model::O3_20250416);
417        assert_eq!(Model::from_str("o3-pro").unwrap(), Model::O3Pro20250610);
418        assert_eq!(
419            Model::from_str("gemini").unwrap(),
420            Model::Gemini2_5ProPreview0605
421        );
422        assert_eq!(Model::from_str("grok").unwrap(), Model::Grok4_0709);
423        assert_eq!(Model::from_str("grok-mini").unwrap(), Model::Grok3Mini);
424
425        // Also test the full names work
426        assert_eq!(
427            Model::from_str("claude-sonnet-4-20250514").unwrap(),
428            Model::ClaudeSonnet4_20250514
429        );
430        assert_eq!(
431            Model::from_str("o3-2025-04-16").unwrap(),
432            Model::O3_20250416
433        );
434
435        assert_eq!(
436            Model::from_str("o4-mini-2025-04-16").unwrap(),
437            Model::O4Mini20250416
438        );
439        assert_eq!(Model::from_str("grok-3").unwrap(), Model::Grok3);
440        assert_eq!(Model::from_str("grok").unwrap(), Model::Grok4_0709);
441        assert_eq!(Model::from_str("grok-4-0709").unwrap(), Model::Grok4_0709);
442        assert_eq!(Model::from_str("grok-3-mini").unwrap(), Model::Grok3Mini);
443        assert_eq!(Model::from_str("grok-mini").unwrap(), Model::Grok3Mini);
444        assert_eq!(
445            Model::from_str("gpt-5-2025-08-07").unwrap(),
446            Model::Gpt5_20250807
447        );
448        assert_eq!(Model::from_str("gpt-5").unwrap(), Model::Gpt5_20250807);
449    }
450}