1pub mod claude;
2pub mod error;
3pub mod factory;
4pub mod gemini;
5pub mod openai;
6pub mod provider;
7pub mod util;
8pub mod xai;
9
10use crate::auth::ProviderRegistry;
11use crate::auth::storage::{Credential, CredentialType};
12use crate::config::model::ModelId;
13use crate::config::provider::ProviderId;
14use crate::config::{ApiAuth, LlmConfigProvider};
15use crate::error::Result;
16use crate::model_registry::ModelRegistry;
17pub use error::ApiError;
18pub use factory::{create_provider, create_provider_with_storage};
19pub use provider::{CompletionResponse, Provider};
20use std::collections::HashMap;
21use std::sync::Arc;
22use std::sync::RwLock;
23use steer_tools::ToolSchema;
24use tokio_util::sync::CancellationToken;
25use tracing::debug;
26use tracing::warn;
27
28use crate::app::conversation::Message;
29
30#[derive(Clone)]
31pub struct Client {
32 provider_map: Arc<RwLock<HashMap<ProviderId, Arc<dyn Provider>>>>,
33 config_provider: LlmConfigProvider,
34 provider_registry: Arc<ProviderRegistry>,
35 model_registry: Arc<ModelRegistry>,
36}
37
38impl Client {
39 fn invalidate_provider(&self, provider_id: &ProviderId) {
41 let mut map = self.provider_map.write().unwrap();
42 map.remove(provider_id);
43 }
44
45 fn should_invalidate_provider(error: &ApiError) -> bool {
47 matches!(
48 error,
49 ApiError::AuthenticationFailed { .. } | ApiError::AuthError(_)
50 ) || matches!(
51 error,
52 ApiError::ServerError { status_code, .. } if matches!(status_code, 401 | 403)
53 )
54 }
55
56 pub fn new_with_deps(
59 config_provider: LlmConfigProvider,
60 provider_registry: Arc<ProviderRegistry>,
61 model_registry: Arc<ModelRegistry>,
62 ) -> Self {
63 Self {
64 provider_map: Arc::new(RwLock::new(HashMap::new())),
65 config_provider,
66 provider_registry,
67 model_registry,
68 }
69 }
70
71 async fn get_or_create_provider(&self, provider_id: ProviderId) -> Result<Arc<dyn Provider>> {
72 {
74 let map = self.provider_map.read().unwrap();
75 if let Some(provider) = map.get(&provider_id) {
76 return Ok(provider.clone());
77 }
78 }
79
80 let provider_config = self.provider_registry.get(&provider_id).ok_or_else(|| {
82 crate::error::Error::Api(ApiError::Configuration(format!(
83 "No provider configuration found for {provider_id:?}"
84 )))
85 })?;
86
87 let credential = match self
89 .config_provider
90 .get_auth_for_provider(&provider_id)
91 .await?
92 {
93 Some(ApiAuth::OAuth) => {
94 self.config_provider
96 .auth_storage()
97 .get_credential(&provider_id.storage_key(), CredentialType::OAuth2)
98 .await
99 .map_err(|e| {
100 crate::error::Error::Api(ApiError::Configuration(format!(
101 "Failed to get OAuth credential: {e}"
102 )))
103 })?
104 .ok_or_else(|| {
105 crate::error::Error::Api(ApiError::Configuration(
106 "OAuth credential not found in storage".to_string(),
107 ))
108 })?
109 }
110 Some(ApiAuth::Key(key)) => Credential::ApiKey { value: key },
111 None => {
112 return Err(crate::error::Error::Api(ApiError::Configuration(format!(
113 "No authentication configured for {provider_id:?}"
114 ))));
115 }
116 };
117
118 let mut map = self.provider_map.write().unwrap();
120
121 if let Some(provider) = map.get(&provider_id) {
123 return Ok(provider.clone());
124 }
125
126 let provider_instance = if matches!(&credential, Credential::OAuth2(_)) {
128 factory::create_provider_with_storage(
129 provider_config,
130 &credential,
131 self.config_provider.auth_storage().clone(),
132 )
133 .map_err(crate::error::Error::Api)?
134 } else {
135 factory::create_provider(provider_config, &credential)
136 .map_err(crate::error::Error::Api)?
137 };
138
139 map.insert(provider_id, provider_instance.clone());
140 Ok(provider_instance)
141 }
142
143 pub async fn complete(
145 &self,
146 model_id: &ModelId,
147 messages: Vec<Message>,
148 system: Option<String>,
149 tools: Option<Vec<ToolSchema>>,
150 call_options: Option<crate::config::model::ModelParameters>,
151 token: CancellationToken,
152 ) -> std::result::Result<CompletionResponse, ApiError> {
153 let provider_id = model_id.0.clone();
155 let provider = self
156 .get_or_create_provider(provider_id.clone())
157 .await
158 .map_err(ApiError::from)?;
159
160 if token.is_cancelled() {
161 return Err(ApiError::Cancelled {
162 provider: provider.name().to_string(),
163 });
164 }
165
166 let model_config = self.model_registry.get(model_id);
168 let effective_params = match (model_config, &call_options) {
169 (Some(config), Some(opts)) => config.effective_parameters(Some(opts)),
170 (Some(config), None) => config.effective_parameters(None),
171 (None, Some(opts)) => Some(*opts),
172 (None, None) => None,
173 };
174
175 debug!(
176 target: "api::complete",
177 ?model_id,
178 ?call_options,
179 ?effective_params,
180 "Final parameters for model"
181 );
182
183 let result = provider
184 .complete(model_id, messages, system, tools, effective_params, token)
185 .await;
186
187 if let Err(ref err) = result {
188 if Self::should_invalidate_provider(err) {
189 self.invalidate_provider(&provider_id);
190 }
191 }
192
193 result
194 }
195
196 pub async fn complete_with_retry(
197 &self,
198 model_id: &ModelId,
199 messages: &[Message],
200 system_prompt: &Option<String>,
201 tools: &Option<Vec<ToolSchema>>,
202 token: CancellationToken,
203 max_attempts: usize,
204 ) -> std::result::Result<CompletionResponse, ApiError> {
205 let mut attempts = 0;
206
207 let provider_id = model_id.0.clone();
209 let provider = self
210 .get_or_create_provider(provider_id.clone())
211 .await
212 .map_err(ApiError::from)?;
213
214 let model_config = self.model_registry.get(model_id);
215 debug!(
216 target: "api::complete_with_retry",
217 ?model_id,
218 ?model_config,
219 "Model config"
220 );
221 let effective_params = model_config.and_then(|cfg| cfg.effective_parameters(None));
222
223 debug!(
224 target: "api::complete_with_retry",
225 ?model_id,
226 ?effective_params,
227 "system: {:?}",
228 system_prompt
229 );
230 debug!(
231 target: "api::complete_with_retry",
232 ?model_id,
233 "messages: {:?}",
234 messages
235 );
236
237 loop {
238 if token.is_cancelled() {
239 return Err(ApiError::Cancelled {
240 provider: provider.name().to_string(),
241 });
242 }
243
244 match provider
245 .complete(
246 model_id,
247 messages.to_vec(),
248 system_prompt.clone(),
249 tools.clone(),
250 effective_params,
251 token.clone(),
252 )
253 .await
254 {
255 Ok(response) => {
256 return Ok(response);
257 }
258 Err(error) => {
259 attempts += 1;
260 warn!(
261 "API completion attempt {}/{} failed for model {:?}: {:?}",
262 attempts, max_attempts, model_id, error
263 );
264
265 if Self::should_invalidate_provider(&error) {
266 self.invalidate_provider(&provider_id);
267 return Err(error);
268 }
269
270 if attempts >= max_attempts {
271 return Err(error);
272 }
273
274 match error {
275 ApiError::RateLimited { provider, details } => {
276 let sleep_duration =
277 std::time::Duration::from_secs(1 << (attempts - 1));
278 warn!(
279 "Rate limited by API: {} {} (retrying in {} seconds)",
280 provider,
281 details,
282 sleep_duration.as_secs()
283 );
284 tokio::time::sleep(sleep_duration).await;
285 }
286 ApiError::NoChoices { provider } => {
287 warn!("No choices returned from API: {}", provider);
288 }
289 ApiError::ServerError {
290 provider,
291 status_code,
292 details,
293 } => {
294 warn!(
295 "Server error for API: {} {} {}",
296 provider, status_code, details
297 );
298 }
299 _ => {
300 return Err(error);
302 }
303 }
304 }
305 }
306 }
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use crate::config::provider::ProviderId;
314 use async_trait::async_trait;
315 use tokio_util::sync::CancellationToken;
316
317 #[derive(Clone, Copy)]
318 enum StubErrorKind {
319 Auth,
320 Server401,
321 }
322
323 #[derive(Clone)]
324 struct StubProvider {
325 error_kind: StubErrorKind,
326 }
327
328 impl StubProvider {
329 fn new(error_kind: StubErrorKind) -> Self {
330 Self { error_kind }
331 }
332 }
333
334 #[async_trait]
335 impl Provider for StubProvider {
336 fn name(&self) -> &'static str {
337 "stub"
338 }
339
340 async fn complete(
341 &self,
342 _model_id: &ModelId,
343 _messages: Vec<Message>,
344 _system: Option<String>,
345 _tools: Option<Vec<ToolSchema>>,
346 _call_options: Option<crate::config::model::ModelParameters>,
347 _token: CancellationToken,
348 ) -> std::result::Result<CompletionResponse, ApiError> {
349 let err = match self.error_kind {
350 StubErrorKind::Auth => ApiError::AuthenticationFailed {
351 provider: "stub".to_string(),
352 details: "bad key".to_string(),
353 },
354 StubErrorKind::Server401 => ApiError::ServerError {
355 provider: "stub".to_string(),
356 status_code: 401,
357 details: "unauthorized".to_string(),
358 },
359 };
360 Err(err)
361 }
362 }
363
364 fn test_client() -> Client {
365 let auth_storage = Arc::new(crate::test_utils::InMemoryAuthStorage::new());
366 let config_provider = LlmConfigProvider::new(auth_storage);
367 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
368 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
369
370 Client::new_with_deps(config_provider, provider_registry, model_registry)
371 }
372
373 fn insert_stub_provider(client: &Client, provider_id: ProviderId, error: StubErrorKind) {
374 client
375 .provider_map
376 .write()
377 .unwrap()
378 .insert(provider_id, Arc::new(StubProvider::new(error)));
379 }
380
381 #[tokio::test]
382 async fn invalidates_cached_provider_on_auth_failure() {
383 let client = test_client();
384 let provider_id = ProviderId("stub-auth".to_string());
385 let model_id = (provider_id.clone(), "stub-model".to_string());
386
387 insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Auth);
388
389 let err = client
390 .complete(
391 &model_id,
392 vec![],
393 None,
394 None,
395 None,
396 CancellationToken::new(),
397 )
398 .await
399 .unwrap_err();
400
401 assert!(matches!(err, ApiError::AuthenticationFailed { .. }));
402 assert!(
403 !client
404 .provider_map
405 .read()
406 .unwrap()
407 .contains_key(&provider_id)
408 );
409 }
410
411 #[tokio::test]
412 async fn invalidates_cached_provider_on_unauthorized_status_code() {
413 let client = test_client();
414 let provider_id = ProviderId("stub-unauthorized".to_string());
415 let model_id = (provider_id.clone(), "stub-model".to_string());
416
417 insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Server401);
418
419 let err = client
420 .complete(
421 &model_id,
422 vec![],
423 None,
424 None,
425 None,
426 CancellationToken::new(),
427 )
428 .await
429 .unwrap_err();
430
431 assert!(matches!(
432 err,
433 ApiError::ServerError {
434 status_code: 401,
435 ..
436 }
437 ));
438 assert!(
439 !client
440 .provider_map
441 .read()
442 .unwrap()
443 .contains_key(&provider_id)
444 );
445 }
446}