1pub mod claude;
2pub mod error;
3pub mod factory;
4pub mod gemini;
5pub mod openai;
6pub mod provider;
7pub mod sse;
8pub mod util;
9pub mod xai;
10
11use crate::auth::storage::Credential;
12use crate::auth::{AuthSource, ProviderRegistry};
13use crate::config::model::ModelId;
14use crate::config::provider::ProviderId;
15use crate::config::{LlmConfigProvider, ResolvedAuth};
16use crate::error::Result;
17use crate::model_registry::ModelRegistry;
18pub use error::{ApiError, SseParseError, StreamError};
19pub use factory::{create_provider, create_provider_with_directive};
20use futures::StreamExt;
21pub use provider::{CompletionResponse, CompletionStream, Provider, StreamChunk};
22use std::collections::HashMap;
23use std::sync::Arc;
24use std::sync::RwLock;
25use steer_tools::ToolSchema;
26use tokio_util::sync::CancellationToken;
27use tracing::debug;
28use tracing::warn;
29
30use crate::app::SystemContext;
31use crate::app::conversation::Message;
32
33const STREAM_TRANSPORT_RETRY_MAX_ATTEMPTS: usize = 2;
34
35#[derive(Clone)]
36pub struct Client {
37 provider_map: Arc<RwLock<HashMap<ProviderId, ProviderEntry>>>,
38 config_provider: LlmConfigProvider,
39 provider_registry: Arc<ProviderRegistry>,
40 model_registry: Arc<ModelRegistry>,
41}
42
43#[derive(Clone)]
44struct ProviderEntry {
45 provider: Arc<dyn Provider>,
46 auth_source: AuthSource,
47}
48
49impl Client {
50 fn invalidate_provider(&self, provider_id: &ProviderId) {
52 let Ok(mut map) = self.provider_map.write() else {
53 warn!(
54 target: "api::client",
55 "Provider cache lock poisoned while invalidating provider"
56 );
57 return;
58 };
59 map.remove(provider_id);
60 }
61
62 fn should_invalidate_provider(error: &ApiError) -> bool {
64 matches!(
65 error,
66 ApiError::AuthenticationFailed { .. } | ApiError::AuthError(_)
67 ) || matches!(
68 error,
69 ApiError::ServerError { status_code, .. } if matches!(status_code, 401 | 403)
70 )
71 }
72
73 pub fn new_with_deps(
76 config_provider: LlmConfigProvider,
77 provider_registry: Arc<ProviderRegistry>,
78 model_registry: Arc<ModelRegistry>,
79 ) -> Self {
80 Self {
81 provider_map: Arc::new(RwLock::new(HashMap::new())),
82 config_provider,
83 provider_registry,
84 model_registry,
85 }
86 }
87
88 #[cfg(any(test, feature = "test-utils"))]
89 pub fn insert_test_provider(&self, provider_id: ProviderId, provider: Arc<dyn Provider>) {
90 match self.provider_map.write() {
91 Ok(mut map) => {
92 map.insert(
93 provider_id,
94 ProviderEntry {
95 provider,
96 auth_source: AuthSource::None,
97 },
98 );
99 }
100 Err(_) => {
101 warn!(
102 target: "api::client",
103 "Provider cache lock poisoned while inserting test provider"
104 );
105 }
106 }
107 }
108
109 async fn get_or_create_provider_entry(&self, provider_id: ProviderId) -> Result<ProviderEntry> {
110 {
112 let map = self.provider_map.read().map_err(|_| {
113 crate::error::Error::Api(ApiError::Configuration(
114 "Provider cache lock poisoned".to_string(),
115 ))
116 })?;
117 if let Some(entry) = map.get(&provider_id) {
118 return Ok(entry.clone());
119 }
120 }
121
122 let provider_config = self.provider_registry.get(&provider_id).ok_or_else(|| {
124 crate::error::Error::Api(ApiError::Configuration(format!(
125 "No provider configuration found for {provider_id:?}"
126 )))
127 })?;
128
129 let resolved = self
130 .config_provider
131 .resolve_auth_for_provider(&provider_id)
132 .await?;
133
134 let mut map = self.provider_map.write().map_err(|_| {
136 crate::error::Error::Api(ApiError::Configuration(
137 "Provider cache lock poisoned".to_string(),
138 ))
139 })?;
140
141 if let Some(entry) = map.get(&provider_id) {
143 return Ok(entry.clone());
144 }
145
146 let entry = Self::build_provider_entry(provider_config, &resolved)?;
147
148 map.insert(provider_id, entry.clone());
149 Ok(entry)
150 }
151
152 fn build_provider_entry(
153 provider_config: &crate::config::provider::ProviderConfig,
154 resolved: &ResolvedAuth,
155 ) -> std::result::Result<ProviderEntry, ApiError> {
156 let provider = match resolved {
157 ResolvedAuth::Plugin { directive, .. } => {
158 factory::create_provider_with_directive(provider_config, directive)?
159 }
160 ResolvedAuth::ApiKey { credential, .. } => {
161 factory::create_provider(provider_config, credential)?
162 }
163 ResolvedAuth::None => {
164 return Err(ApiError::Configuration(format!(
165 "No authentication configured for {:?}",
166 provider_config.id
167 )));
168 }
169 };
170
171 Ok(ProviderEntry {
172 provider,
173 auth_source: resolved.source(),
174 })
175 }
176
177 async fn fallback_api_key_entry(
178 &self,
179 provider_id: &ProviderId,
180 ) -> std::result::Result<Option<ProviderEntry>, ApiError> {
181 let Some((key, origin)) = self
182 .config_provider
183 .resolve_api_key_for_provider(provider_id)
184 .await?
185 else {
186 return Ok(None);
187 };
188
189 let provider_config = self.provider_registry.get(provider_id).ok_or_else(|| {
190 ApiError::Configuration(format!(
191 "No provider configuration found for {provider_id:?}"
192 ))
193 })?;
194
195 let credential = Credential::ApiKey { value: key };
196 let provider = factory::create_provider(provider_config, &credential)?;
197
198 Ok(Some(ProviderEntry {
199 provider,
200 auth_source: AuthSource::ApiKey { origin },
201 }))
202 }
203
204 pub async fn complete(
206 &self,
207 model_id: &ModelId,
208 messages: Vec<Message>,
209 system: Option<SystemContext>,
210 tools: Option<Vec<ToolSchema>>,
211 call_options: Option<crate::config::model::ModelParameters>,
212 token: CancellationToken,
213 ) -> std::result::Result<CompletionResponse, ApiError> {
214 let provider_id = model_id.provider.clone();
216 let entry = self
217 .get_or_create_provider_entry(provider_id.clone())
218 .await
219 .map_err(ApiError::from)?;
220 let provider = entry.provider.clone();
221
222 if token.is_cancelled() {
223 return Err(ApiError::Cancelled {
224 provider: provider.name().to_string(),
225 });
226 }
227
228 let model_config = self.model_registry.get(model_id);
230 let effective_params = match (model_config, &call_options) {
231 (Some(config), Some(opts)) => config.effective_parameters(Some(opts)),
232 (Some(config), None) => config.effective_parameters(None),
233 (None, Some(opts)) => Some(*opts),
234 (None, None) => None,
235 };
236
237 debug!(
238 target: "api::complete",
239 ?model_id,
240 ?call_options,
241 ?effective_params,
242 "Final parameters for model"
243 );
244
245 let result = provider
246 .complete(
247 model_id,
248 messages.clone(),
249 system.clone(),
250 tools.clone(),
251 effective_params,
252 token.clone(),
253 )
254 .await;
255
256 if let Err(ref err) = result
257 && Self::should_invalidate_provider(err)
258 {
259 self.invalidate_provider(&provider_id);
260
261 if matches!(entry.auth_source, AuthSource::Plugin { .. })
262 && let Some(fallback) = self.fallback_api_key_entry(&provider_id).await?
263 {
264 let fallback_result = fallback
265 .provider
266 .complete(model_id, messages, system, tools, effective_params, token)
267 .await;
268 if fallback_result.is_ok() {
269 let mut map = self.provider_map.write().map_err(|_| {
270 ApiError::Configuration("Provider cache lock poisoned".to_string())
271 })?;
272 map.insert(provider_id, fallback);
273 }
274 }
275 }
276
277 result
278 }
279
280 pub async fn stream_complete(
281 &self,
282 model_id: &ModelId,
283 messages: Vec<Message>,
284 system: Option<SystemContext>,
285 tools: Option<Vec<ToolSchema>>,
286 call_options: Option<crate::config::model::ModelParameters>,
287 token: CancellationToken,
288 ) -> std::result::Result<CompletionStream, ApiError> {
289 let provider_id = model_id.provider.clone();
290 let entry = self
291 .get_or_create_provider_entry(provider_id.clone())
292 .await
293 .map_err(ApiError::from)?;
294 let provider = entry.provider.clone();
295
296 if token.is_cancelled() {
297 return Err(ApiError::Cancelled {
298 provider: provider.name().to_string(),
299 });
300 }
301
302 let model_config = self.model_registry.get(model_id);
303 let effective_params = match (model_config, &call_options) {
304 (Some(config), Some(opts)) => config.effective_parameters(Some(opts)),
305 (Some(config), None) => config.effective_parameters(None),
306 (None, Some(opts)) => Some(*opts),
307 (None, None) => None,
308 };
309
310 debug!(
311 target: "api::stream_complete",
312 ?model_id,
313 ?call_options,
314 ?effective_params,
315 "Streaming with parameters"
316 );
317
318 let (initial_stream, provider_for_retry) = match provider
319 .stream_complete(
320 model_id,
321 messages.clone(),
322 system.clone(),
323 tools.clone(),
324 effective_params,
325 token.clone(),
326 )
327 .await
328 {
329 Ok(stream) => (stream, provider),
330 Err(err) => {
331 if Self::should_invalidate_provider(&err) {
332 self.invalidate_provider(&provider_id);
333
334 if matches!(entry.auth_source, AuthSource::Plugin { .. }) {
335 if let Some(fallback) = self.fallback_api_key_entry(&provider_id).await? {
336 let fallback_provider = fallback.provider.clone();
337 let fallback_stream = fallback_provider
338 .stream_complete(
339 model_id,
340 messages.clone(),
341 system.clone(),
342 tools.clone(),
343 effective_params,
344 token.clone(),
345 )
346 .await?;
347 let mut map = self.provider_map.write().map_err(|_| {
348 ApiError::Configuration("Provider cache lock poisoned".to_string())
349 })?;
350 map.insert(provider_id, fallback);
351 (fallback_stream, fallback_provider)
352 } else {
353 return Err(err);
354 }
355 } else {
356 return Err(err);
357 }
358 } else {
359 return Err(err);
360 }
361 }
362 };
363
364 let model_id = model_id.clone();
365 let stream = async_stream::stream! {
366 let mut attempt = 1usize;
367 let mut current_stream = Some(initial_stream);
368
369 'outer: loop {
370 let mut saw_output = false;
371 let mut stream = if let Some(stream) = current_stream.take() { stream } else {
372 if token.is_cancelled() {
373 yield StreamChunk::Error(StreamError::Cancelled);
374 break;
375 }
376
377 let stream_result = provider_for_retry
378 .stream_complete(
379 &model_id,
380 messages.clone(),
381 system.clone(),
382 tools.clone(),
383 effective_params,
384 token.clone(),
385 )
386 .await;
387 match stream_result {
388 Ok(stream) => stream,
389 Err(err) => {
390 yield StreamChunk::Error(StreamError::Provider {
391 provider: provider_for_retry.name().to_string(),
392 error_type: "stream_retry".to_string(),
393 message: err.to_string(),
394 });
395 break;
396 }
397 }
398 };
399
400 while let Some(chunk) = stream.next().await {
401 if matches!(
402 &chunk,
403 StreamChunk::Error(StreamError::SseParse(
404 SseParseError::Transport { .. }
405 )) if !saw_output && attempt < STREAM_TRANSPORT_RETRY_MAX_ATTEMPTS
406 ) {
407 attempt += 1;
408 warn!(
409 target: "api::stream_complete",
410 ?model_id,
411 attempt,
412 max_attempts = STREAM_TRANSPORT_RETRY_MAX_ATTEMPTS,
413 "Retrying stream after transport error before any output"
414 );
415 current_stream = None;
416 continue 'outer;
417 }
418
419 if !matches!(chunk, StreamChunk::Error(_)) {
420 saw_output = true;
421 }
422
423 yield chunk;
424 }
425
426 break;
427 }
428 };
429
430 Ok(Box::pin(stream))
431 }
432
433 pub async fn complete_with_retry(
434 &self,
435 model_id: &ModelId,
436 messages: &[Message],
437 system_prompt: &Option<SystemContext>,
438 tools: &Option<Vec<ToolSchema>>,
439 token: CancellationToken,
440 max_attempts: usize,
441 ) -> std::result::Result<CompletionResponse, ApiError> {
442 let mut attempts = 0;
443
444 let provider_id = model_id.provider.clone();
446 let entry = self
447 .get_or_create_provider_entry(provider_id.clone())
448 .await
449 .map_err(ApiError::from)?;
450 let provider = entry.provider.clone();
451
452 let model_config = self.model_registry.get(model_id);
453 debug!(
454 target: "api::complete_with_retry",
455 ?model_id,
456 ?model_config,
457 "Model config"
458 );
459 let effective_params = model_config.and_then(|cfg| cfg.effective_parameters(None));
460
461 debug!(
462 target: "api::complete_with_retry",
463 ?model_id,
464 ?effective_params,
465 "system: {:?}",
466 system_prompt
467 );
468 debug!(
469 target: "api::complete_with_retry",
470 ?model_id,
471 "messages: {:?}",
472 messages
473 );
474
475 loop {
476 if token.is_cancelled() {
477 return Err(ApiError::Cancelled {
478 provider: provider.name().to_string(),
479 });
480 }
481
482 match provider
483 .complete(
484 model_id,
485 messages.to_vec(),
486 system_prompt.clone(),
487 tools.clone(),
488 effective_params,
489 token.clone(),
490 )
491 .await
492 {
493 Ok(response) => {
494 return Ok(response);
495 }
496 Err(error) => {
497 attempts += 1;
498 warn!(
499 "API completion attempt {}/{} failed for model {:?}: {:?}",
500 attempts, max_attempts, model_id, error
501 );
502
503 if Self::should_invalidate_provider(&error) {
504 self.invalidate_provider(&provider_id);
505 if matches!(entry.auth_source, AuthSource::Plugin { .. })
506 && let Some(fallback) =
507 self.fallback_api_key_entry(&provider_id).await?
508 {
509 let fallback_result = fallback
510 .provider
511 .complete(
512 model_id,
513 messages.to_vec(),
514 system_prompt.clone(),
515 tools.clone(),
516 effective_params,
517 token.clone(),
518 )
519 .await;
520 if fallback_result.is_ok() {
521 let mut map = self.provider_map.write().map_err(|_| {
522 ApiError::Configuration(
523 "Provider cache lock poisoned".to_string(),
524 )
525 })?;
526 map.insert(provider_id.clone(), fallback);
527 }
528 }
529 return Err(error);
530 }
531
532 if attempts >= max_attempts {
533 return Err(error);
534 }
535
536 match error {
537 ApiError::RateLimited { provider, details } => {
538 let sleep_duration =
539 std::time::Duration::from_secs(1 << (attempts - 1));
540 warn!(
541 "Rate limited by API: {} {} (retrying in {} seconds)",
542 provider,
543 details,
544 sleep_duration.as_secs()
545 );
546 tokio::time::sleep(sleep_duration).await;
547 }
548 ApiError::NoChoices { provider } => {
549 warn!("No choices returned from API: {}", provider);
550 }
551 ApiError::ServerError {
552 provider,
553 status_code,
554 details,
555 } => {
556 warn!(
557 "Server error for API: {} {} {}",
558 provider, status_code, details
559 );
560 }
561 _ => {
562 return Err(error);
564 }
565 }
566 }
567 }
568 }
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use crate::auth::ApiKeyOrigin;
576 use crate::config::provider::ProviderId;
577 use async_trait::async_trait;
578 use tokio_util::sync::CancellationToken;
579
580 #[derive(Clone, Copy)]
581 enum StubErrorKind {
582 Auth,
583 Server401,
584 }
585
586 #[derive(Clone)]
587 struct StubProvider {
588 error_kind: StubErrorKind,
589 }
590
591 impl StubProvider {
592 fn new(error_kind: StubErrorKind) -> Self {
593 Self { error_kind }
594 }
595 }
596
597 #[async_trait]
598 impl Provider for StubProvider {
599 fn name(&self) -> &'static str {
600 "stub"
601 }
602
603 async fn complete(
604 &self,
605 _model_id: &ModelId,
606 _messages: Vec<Message>,
607 _system: Option<SystemContext>,
608 _tools: Option<Vec<ToolSchema>>,
609 _call_options: Option<crate::config::model::ModelParameters>,
610 _token: CancellationToken,
611 ) -> std::result::Result<CompletionResponse, ApiError> {
612 let err = match self.error_kind {
613 StubErrorKind::Auth => ApiError::AuthenticationFailed {
614 provider: "stub".to_string(),
615 details: "bad key".to_string(),
616 },
617 StubErrorKind::Server401 => ApiError::ServerError {
618 provider: "stub".to_string(),
619 status_code: 401,
620 details: "unauthorized".to_string(),
621 },
622 };
623 Err(err)
624 }
625 }
626
627 fn test_client() -> Client {
628 let auth_storage = Arc::new(crate::test_utils::InMemoryAuthStorage::new());
629 let config_provider = LlmConfigProvider::new(auth_storage).unwrap();
630 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
631 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
632
633 Client::new_with_deps(config_provider, provider_registry, model_registry)
634 }
635
636 fn insert_stub_provider(client: &Client, provider_id: ProviderId, error: StubErrorKind) {
637 client.provider_map.write().unwrap().insert(
638 provider_id,
639 ProviderEntry {
640 provider: Arc::new(StubProvider::new(error)),
641 auth_source: AuthSource::ApiKey {
642 origin: ApiKeyOrigin::Stored,
643 },
644 },
645 );
646 }
647
648 #[tokio::test]
649 async fn invalidates_cached_provider_on_auth_failure() {
650 let client = test_client();
651 let provider_id = ProviderId("stub-auth".to_string());
652 let model_id = ModelId::new(provider_id.clone(), "stub-model");
653
654 insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Auth);
655
656 let err = client
657 .complete(
658 &model_id,
659 vec![],
660 None,
661 None,
662 None,
663 CancellationToken::new(),
664 )
665 .await
666 .unwrap_err();
667
668 assert!(matches!(err, ApiError::AuthenticationFailed { .. }));
669 assert!(
670 !client
671 .provider_map
672 .read()
673 .unwrap()
674 .contains_key(&provider_id)
675 );
676 }
677
678 #[tokio::test]
679 async fn invalidates_cached_provider_on_unauthorized_status_code() {
680 let client = test_client();
681 let provider_id = ProviderId("stub-unauthorized".to_string());
682 let model_id = ModelId::new(provider_id.clone(), "stub-model");
683
684 insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Server401);
685
686 let err = client
687 .complete(
688 &model_id,
689 vec![],
690 None,
691 None,
692 None,
693 CancellationToken::new(),
694 )
695 .await
696 .unwrap_err();
697
698 assert!(matches!(
699 err,
700 ApiError::ServerError {
701 status_code: 401,
702 ..
703 }
704 ));
705 assert!(
706 !client
707 .provider_map
708 .read()
709 .unwrap()
710 .contains_key(&provider_id)
711 );
712 }
713}