steer_core/api/
mod.rs

1pub mod claude;
2pub mod error;
3pub mod factory;
4pub mod gemini;
5pub mod openai;
6pub mod provider;
7pub mod util;
8pub mod xai;
9
10use crate::auth::ProviderRegistry;
11use crate::auth::storage::{Credential, CredentialType};
12use crate::config::model::ModelId;
13use crate::config::provider::ProviderId;
14use crate::config::{ApiAuth, LlmConfigProvider};
15use crate::error::Result;
16use crate::model_registry::ModelRegistry;
17pub use error::ApiError;
18pub use factory::{create_provider, create_provider_with_storage};
19pub use provider::{CompletionResponse, Provider};
20use std::collections::HashMap;
21use std::sync::Arc;
22use std::sync::RwLock;
23use steer_tools::ToolSchema;
24use tokio_util::sync::CancellationToken;
25use tracing::debug;
26use tracing::warn;
27
28use crate::app::conversation::Message;
29
30#[derive(Clone)]
31pub struct Client {
32    provider_map: Arc<RwLock<HashMap<ProviderId, Arc<dyn Provider>>>>,
33    config_provider: LlmConfigProvider,
34    provider_registry: Arc<ProviderRegistry>,
35    model_registry: Arc<ModelRegistry>,
36}
37
38impl Client {
39    /// Remove a cached provider so that future calls re-create it with fresh credentials.
40    fn invalidate_provider(&self, provider_id: &ProviderId) {
41        let mut map = self.provider_map.write().unwrap();
42        map.remove(provider_id);
43    }
44
45    /// Determine if an API error should invalidate the cached provider (typically auth failures).
46    fn should_invalidate_provider(error: &ApiError) -> bool {
47        matches!(
48            error,
49            ApiError::AuthenticationFailed { .. } | ApiError::AuthError(_)
50        ) || matches!(
51            error,
52            ApiError::ServerError { status_code, .. } if matches!(status_code, 401 | 403)
53        )
54    }
55
56    /// Create a new Client with all dependencies injected.
57    /// This is the preferred constructor to avoid internal registry loading.
58    pub fn new_with_deps(
59        config_provider: LlmConfigProvider,
60        provider_registry: Arc<ProviderRegistry>,
61        model_registry: Arc<ModelRegistry>,
62    ) -> Self {
63        Self {
64            provider_map: Arc::new(RwLock::new(HashMap::new())),
65            config_provider,
66            provider_registry,
67            model_registry,
68        }
69    }
70
71    async fn get_or_create_provider(&self, provider_id: ProviderId) -> Result<Arc<dyn Provider>> {
72        // First check without holding the lock across await
73        {
74            let map = self.provider_map.read().unwrap();
75            if let Some(provider) = map.get(&provider_id) {
76                return Ok(provider.clone());
77            }
78        }
79
80        // Get the provider config from registry
81        let provider_config = self.provider_registry.get(&provider_id).ok_or_else(|| {
82            crate::error::Error::Api(ApiError::Configuration(format!(
83                "No provider configuration found for {provider_id:?}"
84            )))
85        })?;
86
87        // Get credential for the provider
88        let credential = match self
89            .config_provider
90            .get_auth_for_provider(&provider_id)
91            .await?
92        {
93            Some(ApiAuth::OAuth) => {
94                // Get OAuth credential from storage using the centralized storage_key()
95                self.config_provider
96                    .auth_storage()
97                    .get_credential(&provider_id.storage_key(), CredentialType::OAuth2)
98                    .await
99                    .map_err(|e| {
100                        crate::error::Error::Api(ApiError::Configuration(format!(
101                            "Failed to get OAuth credential: {e}"
102                        )))
103                    })?
104                    .ok_or_else(|| {
105                        crate::error::Error::Api(ApiError::Configuration(
106                            "OAuth credential not found in storage".to_string(),
107                        ))
108                    })?
109            }
110            Some(ApiAuth::Key(key)) => Credential::ApiKey { value: key },
111            None => {
112                return Err(crate::error::Error::Api(ApiError::Configuration(format!(
113                    "No authentication configured for {provider_id:?}"
114                ))));
115            }
116        };
117
118        // Now acquire write lock and create provider
119        let mut map = self.provider_map.write().unwrap();
120
121        // Check again in case another thread added it
122        if let Some(provider) = map.get(&provider_id) {
123            return Ok(provider.clone());
124        }
125
126        // Create the provider using factory
127        let provider_instance = if matches!(&credential, Credential::OAuth2(_)) {
128            factory::create_provider_with_storage(
129                provider_config,
130                &credential,
131                self.config_provider.auth_storage().clone(),
132            )
133            .map_err(crate::error::Error::Api)?
134        } else {
135            factory::create_provider(provider_config, &credential)
136                .map_err(crate::error::Error::Api)?
137        };
138
139        map.insert(provider_id, provider_instance.clone());
140        Ok(provider_instance)
141    }
142
143    /// Complete a prompt with a specific model ID and optional parameters
144    pub async fn complete(
145        &self,
146        model_id: &ModelId,
147        messages: Vec<Message>,
148        system: Option<String>,
149        tools: Option<Vec<ToolSchema>>,
150        call_options: Option<crate::config::model::ModelParameters>,
151        token: CancellationToken,
152    ) -> std::result::Result<CompletionResponse, ApiError> {
153        // Get provider from model ID
154        let provider_id = model_id.0.clone();
155        let provider = self
156            .get_or_create_provider(provider_id.clone())
157            .await
158            .map_err(ApiError::from)?;
159
160        if token.is_cancelled() {
161            return Err(ApiError::Cancelled {
162                provider: provider.name().to_string(),
163            });
164        }
165
166        // Get model config and merge parameters
167        let model_config = self.model_registry.get(model_id);
168        let effective_params = match (model_config, &call_options) {
169            (Some(config), Some(opts)) => config.effective_parameters(Some(opts)),
170            (Some(config), None) => config.effective_parameters(None),
171            (None, Some(opts)) => Some(*opts),
172            (None, None) => None,
173        };
174
175        debug!(
176            target: "api::complete",
177            ?model_id,
178            ?call_options,
179            ?effective_params,
180            "Final parameters for model"
181        );
182
183        let result = provider
184            .complete(model_id, messages, system, tools, effective_params, token)
185            .await;
186
187        if let Err(ref err) = result {
188            if Self::should_invalidate_provider(err) {
189                self.invalidate_provider(&provider_id);
190            }
191        }
192
193        result
194    }
195
196    pub async fn complete_with_retry(
197        &self,
198        model_id: &ModelId,
199        messages: &[Message],
200        system_prompt: &Option<String>,
201        tools: &Option<Vec<ToolSchema>>,
202        token: CancellationToken,
203        max_attempts: usize,
204    ) -> std::result::Result<CompletionResponse, ApiError> {
205        let mut attempts = 0;
206
207        // Prepare provider and parameters once
208        let provider_id = model_id.0.clone();
209        let provider = self
210            .get_or_create_provider(provider_id.clone())
211            .await
212            .map_err(ApiError::from)?;
213
214        let model_config = self.model_registry.get(model_id);
215        debug!(
216            target: "api::complete_with_retry",
217            ?model_id,
218            ?model_config,
219            "Model config"
220        );
221        let effective_params = model_config.and_then(|cfg| cfg.effective_parameters(None));
222
223        debug!(
224            target: "api::complete_with_retry",
225            ?model_id,
226            ?effective_params,
227            "system: {:?}",
228            system_prompt
229        );
230        debug!(
231            target: "api::complete_with_retry",
232            ?model_id,
233            "messages: {:?}",
234            messages
235        );
236
237        loop {
238            if token.is_cancelled() {
239                return Err(ApiError::Cancelled {
240                    provider: provider.name().to_string(),
241                });
242            }
243
244            match provider
245                .complete(
246                    model_id,
247                    messages.to_vec(),
248                    system_prompt.clone(),
249                    tools.clone(),
250                    effective_params,
251                    token.clone(),
252                )
253                .await
254            {
255                Ok(response) => {
256                    return Ok(response);
257                }
258                Err(error) => {
259                    attempts += 1;
260                    warn!(
261                        "API completion attempt {}/{} failed for model {:?}: {:?}",
262                        attempts, max_attempts, model_id, error
263                    );
264
265                    if Self::should_invalidate_provider(&error) {
266                        self.invalidate_provider(&provider_id);
267                        return Err(error);
268                    }
269
270                    if attempts >= max_attempts {
271                        return Err(error);
272                    }
273
274                    match error {
275                        ApiError::RateLimited { provider, details } => {
276                            let sleep_duration =
277                                std::time::Duration::from_secs(1 << (attempts - 1));
278                            warn!(
279                                "Rate limited by API: {} {} (retrying in {} seconds)",
280                                provider,
281                                details,
282                                sleep_duration.as_secs()
283                            );
284                            tokio::time::sleep(sleep_duration).await;
285                        }
286                        ApiError::NoChoices { provider } => {
287                            warn!("No choices returned from API: {}", provider);
288                        }
289                        ApiError::ServerError {
290                            provider,
291                            status_code,
292                            details,
293                        } => {
294                            warn!(
295                                "Server error for API: {} {} {}",
296                                provider, status_code, details
297                            );
298                        }
299                        _ => {
300                            // Not retryable
301                            return Err(error);
302                        }
303                    }
304                }
305            }
306        }
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use crate::config::provider::ProviderId;
314    use async_trait::async_trait;
315    use tokio_util::sync::CancellationToken;
316
317    #[derive(Clone, Copy)]
318    enum StubErrorKind {
319        Auth,
320        Server401,
321    }
322
323    #[derive(Clone)]
324    struct StubProvider {
325        error_kind: StubErrorKind,
326    }
327
328    impl StubProvider {
329        fn new(error_kind: StubErrorKind) -> Self {
330            Self { error_kind }
331        }
332    }
333
334    #[async_trait]
335    impl Provider for StubProvider {
336        fn name(&self) -> &'static str {
337            "stub"
338        }
339
340        async fn complete(
341            &self,
342            _model_id: &ModelId,
343            _messages: Vec<Message>,
344            _system: Option<String>,
345            _tools: Option<Vec<ToolSchema>>,
346            _call_options: Option<crate::config::model::ModelParameters>,
347            _token: CancellationToken,
348        ) -> std::result::Result<CompletionResponse, ApiError> {
349            let err = match self.error_kind {
350                StubErrorKind::Auth => ApiError::AuthenticationFailed {
351                    provider: "stub".to_string(),
352                    details: "bad key".to_string(),
353                },
354                StubErrorKind::Server401 => ApiError::ServerError {
355                    provider: "stub".to_string(),
356                    status_code: 401,
357                    details: "unauthorized".to_string(),
358                },
359            };
360            Err(err)
361        }
362    }
363
364    fn test_client() -> Client {
365        let auth_storage = Arc::new(crate::test_utils::InMemoryAuthStorage::new());
366        let config_provider = LlmConfigProvider::new(auth_storage);
367        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
368        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
369
370        Client::new_with_deps(config_provider, provider_registry, model_registry)
371    }
372
373    fn insert_stub_provider(client: &Client, provider_id: ProviderId, error: StubErrorKind) {
374        client
375            .provider_map
376            .write()
377            .unwrap()
378            .insert(provider_id, Arc::new(StubProvider::new(error)));
379    }
380
381    #[tokio::test]
382    async fn invalidates_cached_provider_on_auth_failure() {
383        let client = test_client();
384        let provider_id = ProviderId("stub-auth".to_string());
385        let model_id = (provider_id.clone(), "stub-model".to_string());
386
387        insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Auth);
388
389        let err = client
390            .complete(
391                &model_id,
392                vec![],
393                None,
394                None,
395                None,
396                CancellationToken::new(),
397            )
398            .await
399            .unwrap_err();
400
401        assert!(matches!(err, ApiError::AuthenticationFailed { .. }));
402        assert!(
403            !client
404                .provider_map
405                .read()
406                .unwrap()
407                .contains_key(&provider_id)
408        );
409    }
410
411    #[tokio::test]
412    async fn invalidates_cached_provider_on_unauthorized_status_code() {
413        let client = test_client();
414        let provider_id = ProviderId("stub-unauthorized".to_string());
415        let model_id = (provider_id.clone(), "stub-model".to_string());
416
417        insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Server401);
418
419        let err = client
420            .complete(
421                &model_id,
422                vec![],
423                None,
424                None,
425                None,
426                CancellationToken::new(),
427            )
428            .await
429            .unwrap_err();
430
431        assert!(matches!(
432            err,
433            ApiError::ServerError {
434                status_code: 401,
435                ..
436            }
437        ));
438        assert!(
439            !client
440                .provider_map
441                .read()
442                .unwrap()
443                .contains_key(&provider_id)
444        );
445    }
446}