steer_core/api/
mod.rs

1pub mod claude;
2pub mod error;
3pub mod factory;
4pub mod gemini;
5pub mod openai;
6pub mod provider;
7pub mod util;
8pub mod xai;
9
10use crate::auth::ProviderRegistry;
11use crate::auth::storage::{Credential, CredentialType};
12use crate::config::model::ModelId;
13use crate::config::provider::ProviderId;
14use crate::config::{ApiAuth, LlmConfigProvider};
15use crate::error::Result;
16use crate::model_registry::ModelRegistry;
17pub use error::ApiError;
18pub use factory::{create_provider, create_provider_with_storage};
19pub use provider::{CompletionResponse, Provider};
20use std::collections::HashMap;
21use std::sync::Arc;
22use std::sync::RwLock;
23use steer_tools::ToolSchema;
24use tokio_util::sync::CancellationToken;
25use tracing::debug;
26use tracing::warn;
27
28use crate::app::conversation::Message;
29
30#[derive(Clone)]
31pub struct Client {
32    provider_map: Arc<RwLock<HashMap<ProviderId, Arc<dyn Provider>>>>,
33    config_provider: LlmConfigProvider,
34    provider_registry: Arc<ProviderRegistry>,
35    model_registry: Arc<ModelRegistry>,
36}
37
38impl Client {
39    /// Create a new Client with all dependencies injected.
40    /// This is the preferred constructor to avoid internal registry loading.
41    pub fn new_with_deps(
42        config_provider: LlmConfigProvider,
43        provider_registry: Arc<ProviderRegistry>,
44        model_registry: Arc<ModelRegistry>,
45    ) -> Self {
46        Self {
47            provider_map: Arc::new(RwLock::new(HashMap::new())),
48            config_provider,
49            provider_registry,
50            model_registry,
51        }
52    }
53
54    async fn get_or_create_provider(&self, provider_id: ProviderId) -> Result<Arc<dyn Provider>> {
55        // First check without holding the lock across await
56        {
57            let map = self.provider_map.read().unwrap();
58            if let Some(provider) = map.get(&provider_id) {
59                return Ok(provider.clone());
60            }
61        }
62
63        // Get the provider config from registry
64        let provider_config = self.provider_registry.get(&provider_id).ok_or_else(|| {
65            crate::error::Error::Api(ApiError::Configuration(format!(
66                "No provider configuration found for {provider_id:?}"
67            )))
68        })?;
69
70        // Get credential for the provider
71        let credential = match self
72            .config_provider
73            .get_auth_for_provider(&provider_id)
74            .await?
75        {
76            Some(ApiAuth::OAuth) => {
77                // Get OAuth credential from storage using the centralized storage_key()
78                self.config_provider
79                    .auth_storage()
80                    .get_credential(&provider_id.storage_key(), CredentialType::OAuth2)
81                    .await
82                    .map_err(|e| {
83                        crate::error::Error::Api(ApiError::Configuration(format!(
84                            "Failed to get OAuth credential: {e}"
85                        )))
86                    })?
87                    .ok_or_else(|| {
88                        crate::error::Error::Api(ApiError::Configuration(
89                            "OAuth credential not found in storage".to_string(),
90                        ))
91                    })?
92            }
93            Some(ApiAuth::Key(key)) => Credential::ApiKey { value: key },
94            None => {
95                return Err(crate::error::Error::Api(ApiError::Configuration(format!(
96                    "No authentication configured for {provider_id:?}"
97                ))));
98            }
99        };
100
101        // Now acquire write lock and create provider
102        let mut map = self.provider_map.write().unwrap();
103
104        // Check again in case another thread added it
105        if let Some(provider) = map.get(&provider_id) {
106            return Ok(provider.clone());
107        }
108
109        // Create the provider using factory
110        let provider_instance = if matches!(&credential, Credential::OAuth2(_)) {
111            factory::create_provider_with_storage(
112                provider_config,
113                &credential,
114                self.config_provider.auth_storage().clone(),
115            )
116            .map_err(crate::error::Error::Api)?
117        } else {
118            factory::create_provider(provider_config, &credential)
119                .map_err(crate::error::Error::Api)?
120        };
121
122        map.insert(provider_id, provider_instance.clone());
123        Ok(provider_instance)
124    }
125
126    /// Complete a prompt with a specific model ID and optional parameters
127    pub async fn complete(
128        &self,
129        model_id: &ModelId,
130        messages: Vec<Message>,
131        system: Option<String>,
132        tools: Option<Vec<ToolSchema>>,
133        call_options: Option<crate::config::model::ModelParameters>,
134        token: CancellationToken,
135    ) -> std::result::Result<CompletionResponse, ApiError> {
136        // Get provider from model ID
137        let provider_id = model_id.0.clone();
138        let provider = self
139            .get_or_create_provider(provider_id)
140            .await
141            .map_err(ApiError::from)?;
142
143        if token.is_cancelled() {
144            return Err(ApiError::Cancelled {
145                provider: provider.name().to_string(),
146            });
147        }
148
149        // Get model config and merge parameters
150        let model_config = self.model_registry.get(model_id);
151        let effective_params = match (model_config, &call_options) {
152            (Some(config), Some(opts)) => config.effective_parameters(Some(opts)),
153            (Some(config), None) => config.effective_parameters(None),
154            (None, Some(opts)) => Some(*opts),
155            (None, None) => None,
156        };
157
158        debug!(
159            target: "api::complete",
160            ?model_id,
161            ?call_options,
162            ?effective_params,
163            "Final parameters for model"
164        );
165
166        provider
167            .complete(model_id, messages, system, tools, effective_params, token)
168            .await
169    }
170
171    pub async fn complete_with_retry(
172        &self,
173        model_id: &ModelId,
174        messages: &[Message],
175        system_prompt: &Option<String>,
176        tools: &Option<Vec<ToolSchema>>,
177        token: CancellationToken,
178        max_attempts: usize,
179    ) -> std::result::Result<CompletionResponse, ApiError> {
180        let mut attempts = 0;
181
182        // Prepare provider and parameters once
183        let provider_id = model_id.0.clone();
184        let provider = self
185            .get_or_create_provider(provider_id.clone())
186            .await
187            .map_err(ApiError::from)?;
188
189        let model_config = self.model_registry.get(model_id);
190        debug!(
191            target: "api::complete_with_retry",
192            ?model_id,
193            ?model_config,
194            "Model config"
195        );
196        let effective_params = model_config.and_then(|cfg| cfg.effective_parameters(None));
197
198        debug!(
199            target: "api::complete_with_retry",
200            ?model_id,
201            ?effective_params,
202            "system: {:?}",
203            system_prompt
204        );
205        debug!(
206            target: "api::complete_with_retry",
207            ?model_id,
208            "messages: {:?}",
209            messages
210        );
211
212        loop {
213            if token.is_cancelled() {
214                return Err(ApiError::Cancelled {
215                    provider: provider.name().to_string(),
216                });
217            }
218
219            match provider
220                .complete(
221                    model_id,
222                    messages.to_vec(),
223                    system_prompt.clone(),
224                    tools.clone(),
225                    effective_params,
226                    token.clone(),
227                )
228                .await
229            {
230                Ok(response) => {
231                    return Ok(response);
232                }
233                Err(error) => {
234                    attempts += 1;
235                    warn!(
236                        "API completion attempt {}/{} failed for model {:?}: {:?}",
237                        attempts, max_attempts, model_id, error
238                    );
239
240                    if attempts >= max_attempts {
241                        return Err(error);
242                    }
243
244                    match error {
245                        ApiError::RateLimited { provider, details } => {
246                            let sleep_duration =
247                                std::time::Duration::from_secs(1 << (attempts - 1));
248                            warn!(
249                                "Rate limited by API: {} {} (retrying in {} seconds)",
250                                provider,
251                                details,
252                                sleep_duration.as_secs()
253                            );
254                            tokio::time::sleep(sleep_duration).await;
255                        }
256                        ApiError::NoChoices { provider } => {
257                            warn!("No choices returned from API: {}", provider);
258                        }
259                        ApiError::ServerError {
260                            provider,
261                            status_code,
262                            details,
263                        } => {
264                            warn!(
265                                "Server error for API: {} {} {}",
266                                provider, status_code, details
267                            );
268                        }
269                        _ => {
270                            // Not retryable
271                            return Err(error);
272                        }
273                    }
274                }
275            }
276        }
277    }
278}