steer_core/api/
mod.rs

1pub mod claude;
2pub mod error;
3pub mod gemini;
4pub mod openai;
5pub mod provider;
6pub mod xai;
7
8use crate::config::{ApiAuth, LlmConfigProvider};
9use crate::error::Result;
10pub use claude::AnthropicClient;
11pub use error::ApiError;
12pub use gemini::GeminiClient;
13pub use openai::OpenAIClient;
14pub use provider::{CompletionResponse, Provider};
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::RwLock;
18pub use steer_tools::{InputSchema, ToolCall, ToolSchema};
19use strum::Display;
20use strum::EnumIter;
21use strum::IntoStaticStr;
22use strum_macros::{AsRefStr, EnumString};
23use tokio_util::sync::CancellationToken;
24use tracing::debug;
25use tracing::warn;
26pub use xai::XAIClient;
27
28use crate::app::conversation::Message;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Display, IntoStaticStr)]
31#[strum(serialize_all = "lowercase")]
32pub enum ProviderKind {
33    Anthropic,
34    OpenAI,
35    Google,
36    #[strum(serialize = "xai")]
37    XAI,
38}
39
40impl ProviderKind {
41    pub fn display_name(&self) -> String {
42        match self {
43            ProviderKind::Anthropic => "Anthropic".to_string(),
44            ProviderKind::OpenAI => "OpenAI".to_string(),
45            ProviderKind::Google => "Google".to_string(),
46            ProviderKind::XAI => "xAI".to_string(),
47        }
48    }
49}
50
51#[derive(
52    Debug,
53    Clone,
54    Copy,
55    PartialEq,
56    Eq,
57    Hash,
58    EnumIter,
59    EnumString,
60    AsRefStr,
61    Display,
62    IntoStaticStr,
63    serde::Serialize,
64    serde::Deserialize,
65    Default,
66)]
67pub enum Model {
68    #[strum(serialize = "claude-3-5-sonnet-20240620")]
69    Claude3_5Sonnet20240620,
70    #[strum(serialize = "claude-3-5-sonnet-20241022")]
71    Claude3_5Sonnet20241022,
72    #[strum(serialize = "claude-3-7-sonnet-20250219")]
73    Claude3_7Sonnet20250219,
74    #[strum(serialize = "claude-3-5-haiku-20241022")]
75    Claude3_5Haiku20241022,
76    #[strum(serialize = "claude-sonnet-4-20250514", serialize = "sonnet")]
77    ClaudeSonnet4_20250514,
78    #[strum(serialize = "claude-opus-4-20250514", serialize = "opus")]
79    #[default]
80    ClaudeOpus4_20250514,
81    #[strum(serialize = "gpt-4.1-2025-04-14")]
82    Gpt4_1_20250414,
83    #[strum(serialize = "gpt-4.1-mini-2025-04-14")]
84    Gpt4_1Mini20250414,
85    #[strum(serialize = "gpt-4.1-nano-2025-04-14")]
86    Gpt4_1Nano20250414,
87    #[strum(serialize = "o3-2025-04-16", serialize = "o3")]
88    O3_20250416,
89    #[strum(serialize = "o3-pro-2025-06-10", serialize = "o3-pro")]
90    O3Pro20250610,
91    #[strum(serialize = "o4-mini-2025-04-16", serialize = "o4-mini")]
92    O4Mini20250416,
93    #[strum(serialize = "gemini-2.5-flash-preview-04-17")]
94    Gemini2_5FlashPreview0417,
95    #[strum(serialize = "gemini-2.5-pro-preview-05-06")]
96    Gemini2_5ProPreview0506,
97    #[strum(serialize = "gemini-2.5-pro-preview-06-05", serialize = "gemini")]
98    Gemini2_5ProPreview0605,
99    #[strum(serialize = "grok-3")]
100    Grok3,
101    #[strum(serialize = "grok-3-mini", serialize = "grok-mini")]
102    Grok3Mini,
103    #[strum(serialize = "grok-4-0709", serialize = "grok")]
104    Grok4_0709,
105}
106
107impl Model {
108    /// Returns true if this model should be shown in the model picker UI
109    pub fn should_show(&self) -> bool {
110        matches!(
111            self,
112            Model::ClaudeOpus4_20250514
113                | Model::ClaudeSonnet4_20250514
114                | Model::O3_20250416
115                | Model::O3Pro20250610
116                | Model::Gemini2_5ProPreview0605
117                | Model::Grok4_0709
118                | Model::Grok3
119                | Model::Gpt4_1_20250414
120                | Model::O4Mini20250416
121        )
122    }
123
124    pub fn iter_recommended() -> impl Iterator<Item = Model> {
125        use strum::IntoEnumIterator;
126        Model::iter().filter(|m| m.should_show())
127    }
128
129    pub fn provider(&self) -> ProviderKind {
130        match self {
131            Model::Claude3_7Sonnet20250219
132            | Model::Claude3_5Sonnet20240620
133            | Model::Claude3_5Sonnet20241022
134            | Model::Claude3_5Haiku20241022
135            | Model::ClaudeSonnet4_20250514
136            | Model::ClaudeOpus4_20250514 => ProviderKind::Anthropic,
137
138            Model::Gpt4_1_20250414
139            | Model::Gpt4_1Mini20250414
140            | Model::Gpt4_1Nano20250414
141            | Model::O3_20250416
142            | Model::O3Pro20250610
143            | Model::O4Mini20250416 => ProviderKind::OpenAI,
144
145            Model::Gemini2_5FlashPreview0417
146            | Model::Gemini2_5ProPreview0506
147            | Model::Gemini2_5ProPreview0605 => ProviderKind::Google,
148
149            Model::Grok3 | Model::Grok3Mini | Model::Grok4_0709 => ProviderKind::XAI,
150        }
151    }
152
153    pub fn aliases(&self) -> Vec<&'static str> {
154        match self {
155            Model::ClaudeSonnet4_20250514 => vec!["sonnet"],
156            Model::ClaudeOpus4_20250514 => vec!["opus"],
157            Model::O3_20250416 => vec!["o3"],
158            Model::O3Pro20250610 => vec!["o3-pro"],
159            Model::O4Mini20250416 => vec!["o4-mini"],
160            Model::Gemini2_5ProPreview0605 => vec!["gemini"],
161            Model::Grok3 => vec![],
162            Model::Grok3Mini => vec!["grok-mini"],
163            Model::Grok4_0709 => vec!["grok"],
164            _ => vec![],
165        }
166    }
167
168    pub fn supports_thinking(&self) -> bool {
169        matches!(
170            self,
171            Model::Claude3_7Sonnet20250219
172                | Model::ClaudeSonnet4_20250514
173                | Model::ClaudeOpus4_20250514
174                | Model::O3_20250416
175                | Model::O3Pro20250610
176                | Model::O4Mini20250416
177                | Model::Gemini2_5FlashPreview0417
178                | Model::Gemini2_5ProPreview0506
179                | Model::Gemini2_5ProPreview0605
180                | Model::Grok3Mini
181                | Model::Grok4_0709
182        )
183    }
184
185    pub fn default_system_prompt_file(&self) -> Option<&'static str> {
186        match self {
187            Model::O3_20250416 => Some("models/o3.md"),
188            Model::O3Pro20250610 => Some("models/o3.md"),
189            Model::O4Mini20250416 => Some("models/o3.md"),
190            _ => None,
191        }
192    }
193
194    /// Get all available models
195    pub fn all() -> Vec<Model> {
196        use strum::IntoEnumIterator;
197        Model::iter().collect()
198    }
199}
200
201#[derive(Clone)]
202pub struct Client {
203    provider_map: Arc<RwLock<HashMap<Model, Arc<dyn Provider>>>>,
204    config_provider: LlmConfigProvider,
205}
206
207impl Client {
208    pub fn new_with_provider(provider: LlmConfigProvider) -> Self {
209        Self {
210            provider_map: Arc::new(RwLock::new(HashMap::new())),
211            config_provider: provider,
212        }
213    }
214
215    async fn get_or_create_provider(&self, model: Model) -> Result<Arc<dyn Provider>> {
216        // First check without holding the lock across await
217        {
218            let map = self.provider_map.read().unwrap();
219            if let Some(provider) = map.get(&model) {
220                return Ok(provider.clone());
221            }
222        }
223
224        // Get provider kind and auth before acquiring write lock
225        let provider_kind = model.provider();
226        let auth = self
227            .config_provider
228            .get_auth_for_provider(provider_kind)
229            .await?;
230
231        // Now acquire write lock and create provider
232        let mut map = self.provider_map.write().unwrap();
233
234        // Check again in case another thread added it
235        if let Some(provider) = map.get(&model) {
236            return Ok(provider.clone());
237        }
238
239        // Create and insert the provider
240        let provider_instance: Arc<dyn Provider> = match auth {
241            Some(ApiAuth::OAuth) => {
242                if provider_kind == ProviderKind::Anthropic {
243                    let storage = self.config_provider.auth_storage();
244                    Arc::new(AnthropicClient::with_oauth(storage.clone()))
245                } else {
246                    return Err(crate::error::Error::Api(ApiError::Configuration(format!(
247                        "OAuth is not supported for {provider_kind:?} provider"
248                    ))));
249                }
250            }
251            Some(ApiAuth::Key(key)) => match provider_kind {
252                ProviderKind::Anthropic => Arc::new(AnthropicClient::with_api_key(&key)),
253                ProviderKind::OpenAI => Arc::new(OpenAIClient::new(key)),
254                ProviderKind::Google => Arc::new(GeminiClient::new(&key)),
255                ProviderKind::XAI => Arc::new(XAIClient::new(key)),
256            },
257
258            None => {
259                return Err(crate::error::Error::Api(ApiError::Configuration(format!(
260                    "No authentication configured for {provider_kind:?} needed by model {model:?}"
261                ))));
262            }
263        };
264        map.insert(model, provider_instance.clone());
265        Ok(provider_instance)
266    }
267
268    pub async fn complete(
269        &self,
270        model: Model,
271        messages: Vec<Message>,
272        system: Option<String>,
273        tools: Option<Vec<ToolSchema>>,
274        token: CancellationToken,
275    ) -> std::result::Result<CompletionResponse, ApiError> {
276        let provider = self
277            .get_or_create_provider(model)
278            .await
279            .map_err(ApiError::from)?;
280
281        if token.is_cancelled() {
282            return Err(ApiError::Cancelled {
283                provider: provider.name().to_string(),
284            });
285        }
286
287        provider
288            .complete(model, messages, system, tools, token)
289            .await
290    }
291
292    pub async fn complete_with_retry(
293        &self,
294        model: Model,
295        messages: &[Message],
296        system_prompt: &Option<String>,
297        tools: &Option<Vec<ToolSchema>>,
298        token: CancellationToken,
299        max_attempts: usize,
300    ) -> std::result::Result<CompletionResponse, ApiError> {
301        let mut attempts = 0;
302        debug!(
303            target: "api::complete",
304            model =% model,
305            "system: {:?}",
306            system_prompt
307        );
308        debug!(
309            target: "api::complete",
310            model =% model,
311            "messages: {:?}",
312            messages
313        );
314        loop {
315            match self
316                .complete(
317                    model,
318                    messages.to_vec(),
319                    system_prompt.clone(),
320                    tools.clone(),
321                    token.clone(),
322                )
323                .await
324            {
325                Ok(response) => {
326                    return Ok(response);
327                }
328                Err(error) => {
329                    attempts += 1;
330                    warn!(
331                        "API completion attempt {}/{} failed for model {}: {:?}",
332                        attempts,
333                        max_attempts,
334                        model.as_ref(),
335                        error
336                    );
337
338                    if attempts >= max_attempts {
339                        return Err(error);
340                    }
341
342                    match error {
343                        ApiError::RateLimited { provider, details } => {
344                            let sleep_duration =
345                                std::time::Duration::from_secs(1 << (attempts - 1));
346                            warn!(
347                                "Rate limited by API: {} {} (retrying in {} seconds)",
348                                provider,
349                                details,
350                                sleep_duration.as_secs()
351                            );
352                            tokio::time::sleep(sleep_duration).await;
353                        }
354                        ApiError::NoChoices { provider } => {
355                            warn!("No choices returned from API: {}", provider);
356                        }
357                        ApiError::ServerError {
358                            provider,
359                            status_code,
360                            details,
361                        } => {
362                            warn!(
363                                "Server error for API: {} {} {}",
364                                provider, status_code, details
365                            );
366                        }
367                        _ => {
368                            // Not retryable
369                            return Err(error);
370                        }
371                    }
372                }
373            }
374        }
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use std::str::FromStr;
382
383    #[test]
384    fn test_model_from_str() {
385        let model = Model::from_str("claude-3-7-sonnet-20250219").unwrap();
386        assert_eq!(model, Model::Claude3_7Sonnet20250219);
387    }
388
389    #[test]
390    fn test_model_aliases() {
391        // Test short aliases
392        assert_eq!(
393            Model::from_str("sonnet").unwrap(),
394            Model::ClaudeSonnet4_20250514
395        );
396        assert_eq!(
397            Model::from_str("opus").unwrap(),
398            Model::ClaudeOpus4_20250514
399        );
400        assert_eq!(Model::from_str("o3").unwrap(), Model::O3_20250416);
401        assert_eq!(Model::from_str("o3-pro").unwrap(), Model::O3Pro20250610);
402        assert_eq!(
403            Model::from_str("gemini").unwrap(),
404            Model::Gemini2_5ProPreview0605
405        );
406        assert_eq!(Model::from_str("grok").unwrap(), Model::Grok4_0709);
407        assert_eq!(Model::from_str("grok-mini").unwrap(), Model::Grok3Mini);
408
409        // Also test the full names work
410        assert_eq!(
411            Model::from_str("claude-sonnet-4-20250514").unwrap(),
412            Model::ClaudeSonnet4_20250514
413        );
414        assert_eq!(
415            Model::from_str("o3-2025-04-16").unwrap(),
416            Model::O3_20250416
417        );
418
419        assert_eq!(
420            Model::from_str("o4-mini-2025-04-16").unwrap(),
421            Model::O4Mini20250416
422        );
423        assert_eq!(Model::from_str("grok-3").unwrap(), Model::Grok3);
424        assert_eq!(Model::from_str("grok").unwrap(), Model::Grok4_0709);
425        assert_eq!(Model::from_str("grok-4-0709").unwrap(), Model::Grok4_0709);
426        assert_eq!(Model::from_str("grok-3-mini").unwrap(), Model::Grok3Mini);
427        assert_eq!(Model::from_str("grok-mini").unwrap(), Model::Grok3Mini);
428    }
429}