1pub mod claude;
2pub mod error;
3pub mod factory;
4pub mod gemini;
5pub mod openai;
6pub mod provider;
7pub mod sse;
8pub mod util;
9pub mod xai;
10
11use crate::auth::storage::Credential;
12use crate::auth::{AuthSource, ProviderRegistry};
13use crate::config::model::ModelId;
14use crate::config::provider::ProviderId;
15use crate::config::{LlmConfigProvider, ResolvedAuth};
16use crate::error::Result;
17use crate::model_registry::ModelRegistry;
18pub use error::{ApiError, SseParseError, StreamError};
19pub use factory::{create_provider, create_provider_with_directive};
20use futures::StreamExt;
21pub use provider::{CompletionResponse, CompletionStream, Provider, StreamChunk, TokenUsage};
22use std::collections::HashMap;
23use std::sync::Arc;
24use std::sync::RwLock;
25use steer_tools::ToolSchema;
26use tokio_util::sync::CancellationToken;
27use tracing::debug;
28use tracing::warn;
29
30use crate::app::SystemContext;
31use crate::app::conversation::Message;
32
33const STREAM_TRANSPORT_RETRY_MAX_ATTEMPTS: usize = 2;
34
35#[derive(Clone)]
36pub struct Client {
37 provider_map: Arc<RwLock<HashMap<ProviderId, ProviderEntry>>>,
38 config_provider: LlmConfigProvider,
39 provider_registry: Arc<ProviderRegistry>,
40 model_registry: Arc<ModelRegistry>,
41}
42
43#[derive(Clone)]
44struct ProviderEntry {
45 provider: Arc<dyn Provider>,
46 auth_source: AuthSource,
47}
48
49impl Client {
50 fn invalidate_provider(&self, provider_id: &ProviderId) {
52 let Ok(mut map) = self.provider_map.write() else {
53 warn!(
54 target: "api::client",
55 "Provider cache lock poisoned while invalidating provider"
56 );
57 return;
58 };
59 map.remove(provider_id);
60 }
61
62 fn should_invalidate_provider(error: &ApiError) -> bool {
64 matches!(
65 error,
66 ApiError::AuthenticationFailed { .. } | ApiError::AuthError(_)
67 ) || matches!(
68 error,
69 ApiError::ServerError { status_code, .. } if matches!(status_code, 401 | 403)
70 )
71 }
72
73 pub fn new_with_deps(
76 config_provider: LlmConfigProvider,
77 provider_registry: Arc<ProviderRegistry>,
78 model_registry: Arc<ModelRegistry>,
79 ) -> Self {
80 Self {
81 provider_map: Arc::new(RwLock::new(HashMap::new())),
82 config_provider,
83 provider_registry,
84 model_registry,
85 }
86 }
87
88 pub fn model_context_window_tokens(&self, model_id: &ModelId) -> Option<u32> {
89 self.model_registry
90 .get(model_id)
91 .and_then(|model| model.context_window_tokens)
92 }
93
94 #[cfg(any(test, feature = "test-utils"))]
95 pub fn insert_test_provider(&self, provider_id: ProviderId, provider: Arc<dyn Provider>) {
96 match self.provider_map.write() {
97 Ok(mut map) => {
98 map.insert(
99 provider_id,
100 ProviderEntry {
101 provider,
102 auth_source: AuthSource::None,
103 },
104 );
105 }
106 Err(_) => {
107 warn!(
108 target: "api::client",
109 "Provider cache lock poisoned while inserting test provider"
110 );
111 }
112 }
113 }
114
115 async fn get_or_create_provider_entry(&self, provider_id: ProviderId) -> Result<ProviderEntry> {
116 {
118 let map = self.provider_map.read().map_err(|_| {
119 crate::error::Error::Api(ApiError::Configuration(
120 "Provider cache lock poisoned".to_string(),
121 ))
122 })?;
123 if let Some(entry) = map.get(&provider_id) {
124 return Ok(entry.clone());
125 }
126 }
127
128 let provider_config = self.provider_registry.get(&provider_id).ok_or_else(|| {
130 crate::error::Error::Api(ApiError::Configuration(format!(
131 "No provider configuration found for {provider_id:?}"
132 )))
133 })?;
134
135 let resolved = self
136 .config_provider
137 .resolve_auth_for_provider(&provider_id)
138 .await?;
139
140 let mut map = self.provider_map.write().map_err(|_| {
142 crate::error::Error::Api(ApiError::Configuration(
143 "Provider cache lock poisoned".to_string(),
144 ))
145 })?;
146
147 if let Some(entry) = map.get(&provider_id) {
149 return Ok(entry.clone());
150 }
151
152 let entry = Self::build_provider_entry(provider_config, &resolved)?;
153
154 map.insert(provider_id, entry.clone());
155 Ok(entry)
156 }
157
158 fn build_provider_entry(
159 provider_config: &crate::config::provider::ProviderConfig,
160 resolved: &ResolvedAuth,
161 ) -> std::result::Result<ProviderEntry, ApiError> {
162 let provider = match resolved {
163 ResolvedAuth::Plugin { directive, .. } => {
164 factory::create_provider_with_directive(provider_config, directive)?
165 }
166 ResolvedAuth::ApiKey { credential, .. } => {
167 factory::create_provider(provider_config, credential)?
168 }
169 ResolvedAuth::None => {
170 return Err(ApiError::Configuration(format!(
171 "No authentication configured for {:?}",
172 provider_config.id
173 )));
174 }
175 };
176
177 Ok(ProviderEntry {
178 provider,
179 auth_source: resolved.source(),
180 })
181 }
182
183 async fn fallback_api_key_entry(
184 &self,
185 provider_id: &ProviderId,
186 ) -> std::result::Result<Option<ProviderEntry>, ApiError> {
187 let Some((key, origin)) = self
188 .config_provider
189 .resolve_api_key_for_provider(provider_id)
190 .await?
191 else {
192 return Ok(None);
193 };
194
195 let provider_config = self.provider_registry.get(provider_id).ok_or_else(|| {
196 ApiError::Configuration(format!(
197 "No provider configuration found for {provider_id:?}"
198 ))
199 })?;
200
201 let credential = Credential::ApiKey { value: key };
202 let provider = factory::create_provider(provider_config, &credential)?;
203
204 Ok(Some(ProviderEntry {
205 provider,
206 auth_source: AuthSource::ApiKey { origin },
207 }))
208 }
209
210 pub async fn complete(
212 &self,
213 model_id: &ModelId,
214 messages: Vec<Message>,
215 system: Option<SystemContext>,
216 tools: Option<Vec<ToolSchema>>,
217 call_options: Option<crate::config::model::ModelParameters>,
218 token: CancellationToken,
219 ) -> std::result::Result<CompletionResponse, ApiError> {
220 let provider_id = model_id.provider.clone();
222 let entry = self
223 .get_or_create_provider_entry(provider_id.clone())
224 .await
225 .map_err(ApiError::from)?;
226 let provider = entry.provider.clone();
227
228 if token.is_cancelled() {
229 return Err(ApiError::Cancelled {
230 provider: provider.name().to_string(),
231 });
232 }
233
234 let model_config = self.model_registry.get(model_id);
236 let effective_params = match (model_config, &call_options) {
237 (Some(config), Some(opts)) => config.effective_parameters(Some(opts)),
238 (Some(config), None) => config.effective_parameters(None),
239 (None, Some(opts)) => Some(*opts),
240 (None, None) => None,
241 };
242
243 debug!(
244 target: "api::complete",
245 ?model_id,
246 ?call_options,
247 ?effective_params,
248 "Final parameters for model"
249 );
250
251 let result = provider
252 .complete(
253 model_id,
254 messages.clone(),
255 system.clone(),
256 tools.clone(),
257 effective_params,
258 token.clone(),
259 )
260 .await;
261
262 if let Err(ref err) = result
263 && Self::should_invalidate_provider(err)
264 {
265 self.invalidate_provider(&provider_id);
266
267 if matches!(entry.auth_source, AuthSource::Plugin { .. })
268 && let Some(fallback) = self.fallback_api_key_entry(&provider_id).await?
269 {
270 let fallback_result = fallback
271 .provider
272 .complete(model_id, messages, system, tools, effective_params, token)
273 .await;
274 if fallback_result.is_ok() {
275 let mut map = self.provider_map.write().map_err(|_| {
276 ApiError::Configuration("Provider cache lock poisoned".to_string())
277 })?;
278 map.insert(provider_id, fallback);
279 }
280 }
281 }
282
283 result
284 }
285
286 pub async fn stream_complete(
287 &self,
288 model_id: &ModelId,
289 messages: Vec<Message>,
290 system: Option<SystemContext>,
291 tools: Option<Vec<ToolSchema>>,
292 call_options: Option<crate::config::model::ModelParameters>,
293 token: CancellationToken,
294 ) -> std::result::Result<CompletionStream, ApiError> {
295 let provider_id = model_id.provider.clone();
296 let entry = self
297 .get_or_create_provider_entry(provider_id.clone())
298 .await
299 .map_err(ApiError::from)?;
300 let provider = entry.provider.clone();
301
302 if token.is_cancelled() {
303 return Err(ApiError::Cancelled {
304 provider: provider.name().to_string(),
305 });
306 }
307
308 let model_config = self.model_registry.get(model_id);
309 let effective_params = match (model_config, &call_options) {
310 (Some(config), Some(opts)) => config.effective_parameters(Some(opts)),
311 (Some(config), None) => config.effective_parameters(None),
312 (None, Some(opts)) => Some(*opts),
313 (None, None) => None,
314 };
315
316 debug!(
317 target: "api::stream_complete",
318 ?model_id,
319 ?call_options,
320 ?effective_params,
321 "Streaming with parameters"
322 );
323
324 let (initial_stream, provider_for_retry) = match provider
325 .stream_complete(
326 model_id,
327 messages.clone(),
328 system.clone(),
329 tools.clone(),
330 effective_params,
331 token.clone(),
332 )
333 .await
334 {
335 Ok(stream) => (stream, provider),
336 Err(err) => {
337 if Self::should_invalidate_provider(&err) {
338 self.invalidate_provider(&provider_id);
339
340 if matches!(entry.auth_source, AuthSource::Plugin { .. }) {
341 if let Some(fallback) = self.fallback_api_key_entry(&provider_id).await? {
342 let fallback_provider = fallback.provider.clone();
343 let fallback_stream = fallback_provider
344 .stream_complete(
345 model_id,
346 messages.clone(),
347 system.clone(),
348 tools.clone(),
349 effective_params,
350 token.clone(),
351 )
352 .await?;
353 let mut map = self.provider_map.write().map_err(|_| {
354 ApiError::Configuration("Provider cache lock poisoned".to_string())
355 })?;
356 map.insert(provider_id, fallback);
357 (fallback_stream, fallback_provider)
358 } else {
359 return Err(err);
360 }
361 } else {
362 return Err(err);
363 }
364 } else {
365 return Err(err);
366 }
367 }
368 };
369
370 let model_id = model_id.clone();
371 let stream = async_stream::stream! {
372 let mut attempt = 1usize;
373 let mut current_stream = Some(initial_stream);
374
375 'outer: loop {
376 let mut saw_output = false;
377 let mut stream = if let Some(stream) = current_stream.take() { stream } else {
378 if token.is_cancelled() {
379 yield StreamChunk::Error(StreamError::Cancelled);
380 break;
381 }
382
383 let stream_result = provider_for_retry
384 .stream_complete(
385 &model_id,
386 messages.clone(),
387 system.clone(),
388 tools.clone(),
389 effective_params,
390 token.clone(),
391 )
392 .await;
393 match stream_result {
394 Ok(stream) => stream,
395 Err(err) => {
396 yield StreamChunk::Error(StreamError::Provider {
397 provider: provider_for_retry.name().to_string(),
398 error_type: "stream_retry".to_string(),
399 message: err.to_string(),
400 });
401 break;
402 }
403 }
404 };
405
406 while let Some(chunk) = stream.next().await {
407 if matches!(
408 &chunk,
409 StreamChunk::Error(StreamError::SseParse(
410 SseParseError::Transport { .. }
411 )) if !saw_output && attempt < STREAM_TRANSPORT_RETRY_MAX_ATTEMPTS
412 ) {
413 attempt += 1;
414 warn!(
415 target: "api::stream_complete",
416 ?model_id,
417 attempt,
418 max_attempts = STREAM_TRANSPORT_RETRY_MAX_ATTEMPTS,
419 "Retrying stream after transport error before any output"
420 );
421 current_stream = None;
422 continue 'outer;
423 }
424
425 if !matches!(chunk, StreamChunk::Error(_)) {
426 saw_output = true;
427 }
428
429 yield chunk;
430 }
431
432 break;
433 }
434 };
435
436 Ok(Box::pin(stream))
437 }
438
439 pub async fn complete_with_retry(
440 &self,
441 model_id: &ModelId,
442 messages: &[Message],
443 system_prompt: &Option<SystemContext>,
444 tools: &Option<Vec<ToolSchema>>,
445 token: CancellationToken,
446 max_attempts: usize,
447 ) -> std::result::Result<CompletionResponse, ApiError> {
448 let mut attempts = 0;
449
450 let provider_id = model_id.provider.clone();
452 let entry = self
453 .get_or_create_provider_entry(provider_id.clone())
454 .await
455 .map_err(ApiError::from)?;
456 let provider = entry.provider.clone();
457
458 let model_config = self.model_registry.get(model_id);
459 debug!(
460 target: "api::complete_with_retry",
461 ?model_id,
462 ?model_config,
463 "Model config"
464 );
465 let effective_params = model_config.and_then(|cfg| cfg.effective_parameters(None));
466
467 debug!(
468 target: "api::complete_with_retry",
469 ?model_id,
470 ?effective_params,
471 "system: {:?}",
472 system_prompt
473 );
474 debug!(
475 target: "api::complete_with_retry",
476 ?model_id,
477 "messages: {:?}",
478 messages
479 );
480
481 loop {
482 if token.is_cancelled() {
483 return Err(ApiError::Cancelled {
484 provider: provider.name().to_string(),
485 });
486 }
487
488 match provider
489 .complete(
490 model_id,
491 messages.to_vec(),
492 system_prompt.clone(),
493 tools.clone(),
494 effective_params,
495 token.clone(),
496 )
497 .await
498 {
499 Ok(response) => {
500 return Ok(response);
501 }
502 Err(error) => {
503 attempts += 1;
504 warn!(
505 "API completion attempt {}/{} failed for model {:?}: {:?}",
506 attempts, max_attempts, model_id, error
507 );
508
509 if Self::should_invalidate_provider(&error) {
510 self.invalidate_provider(&provider_id);
511 if matches!(entry.auth_source, AuthSource::Plugin { .. })
512 && let Some(fallback) =
513 self.fallback_api_key_entry(&provider_id).await?
514 {
515 let fallback_result = fallback
516 .provider
517 .complete(
518 model_id,
519 messages.to_vec(),
520 system_prompt.clone(),
521 tools.clone(),
522 effective_params,
523 token.clone(),
524 )
525 .await;
526 if fallback_result.is_ok() {
527 let mut map = self.provider_map.write().map_err(|_| {
528 ApiError::Configuration(
529 "Provider cache lock poisoned".to_string(),
530 )
531 })?;
532 map.insert(provider_id.clone(), fallback);
533 }
534 }
535 return Err(error);
536 }
537
538 if attempts >= max_attempts {
539 return Err(error);
540 }
541
542 match error {
543 ApiError::RateLimited { provider, details } => {
544 let sleep_duration =
545 std::time::Duration::from_secs(1 << (attempts - 1));
546 warn!(
547 "Rate limited by API: {} {} (retrying in {} seconds)",
548 provider,
549 details,
550 sleep_duration.as_secs()
551 );
552 tokio::time::sleep(sleep_duration).await;
553 }
554 ApiError::NoChoices { provider } => {
555 warn!("No choices returned from API: {}", provider);
556 }
557 ApiError::ServerError {
558 provider,
559 status_code,
560 details,
561 } => {
562 warn!(
563 "Server error for API: {} {} {}",
564 provider, status_code, details
565 );
566 }
567 _ => {
568 return Err(error);
570 }
571 }
572 }
573 }
574 }
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use crate::auth::ApiKeyOrigin;
582 use crate::config::provider::ProviderId;
583 use async_trait::async_trait;
584 use tokio_util::sync::CancellationToken;
585
586 #[derive(Clone, Copy)]
587 enum StubErrorKind {
588 Auth,
589 Server401,
590 }
591
592 #[derive(Clone)]
593 struct StubProvider {
594 error_kind: StubErrorKind,
595 }
596
597 impl StubProvider {
598 fn new(error_kind: StubErrorKind) -> Self {
599 Self { error_kind }
600 }
601 }
602
603 #[async_trait]
604 impl Provider for StubProvider {
605 fn name(&self) -> &'static str {
606 "stub"
607 }
608
609 async fn complete(
610 &self,
611 _model_id: &ModelId,
612 _messages: Vec<Message>,
613 _system: Option<SystemContext>,
614 _tools: Option<Vec<ToolSchema>>,
615 _call_options: Option<crate::config::model::ModelParameters>,
616 _token: CancellationToken,
617 ) -> std::result::Result<CompletionResponse, ApiError> {
618 let err = match self.error_kind {
619 StubErrorKind::Auth => ApiError::AuthenticationFailed {
620 provider: "stub".to_string(),
621 details: "bad key".to_string(),
622 },
623 StubErrorKind::Server401 => ApiError::ServerError {
624 provider: "stub".to_string(),
625 status_code: 401,
626 details: "unauthorized".to_string(),
627 },
628 };
629 Err(err)
630 }
631 }
632
633 fn test_client() -> Client {
634 let auth_storage = Arc::new(crate::test_utils::InMemoryAuthStorage::new());
635 let config_provider = LlmConfigProvider::new(auth_storage).unwrap();
636 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
637 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
638
639 Client::new_with_deps(config_provider, provider_registry, model_registry)
640 }
641
642 fn insert_stub_provider(client: &Client, provider_id: ProviderId, error: StubErrorKind) {
643 client.provider_map.write().unwrap().insert(
644 provider_id,
645 ProviderEntry {
646 provider: Arc::new(StubProvider::new(error)),
647 auth_source: AuthSource::ApiKey {
648 origin: ApiKeyOrigin::Stored,
649 },
650 },
651 );
652 }
653
654 #[tokio::test]
655 async fn invalidates_cached_provider_on_auth_failure() {
656 let client = test_client();
657 let provider_id = ProviderId("stub-auth".to_string());
658 let model_id = ModelId::new(provider_id.clone(), "stub-model");
659
660 insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Auth);
661
662 let err = client
663 .complete(
664 &model_id,
665 vec![],
666 None,
667 None,
668 None,
669 CancellationToken::new(),
670 )
671 .await
672 .unwrap_err();
673
674 assert!(matches!(err, ApiError::AuthenticationFailed { .. }));
675 assert!(
676 !client
677 .provider_map
678 .read()
679 .unwrap()
680 .contains_key(&provider_id)
681 );
682 }
683
684 #[tokio::test]
685 async fn invalidates_cached_provider_on_unauthorized_status_code() {
686 let client = test_client();
687 let provider_id = ProviderId("stub-unauthorized".to_string());
688 let model_id = ModelId::new(provider_id.clone(), "stub-model");
689
690 insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Server401);
691
692 let err = client
693 .complete(
694 &model_id,
695 vec![],
696 None,
697 None,
698 None,
699 CancellationToken::new(),
700 )
701 .await
702 .unwrap_err();
703
704 assert!(matches!(
705 err,
706 ApiError::ServerError {
707 status_code: 401,
708 ..
709 }
710 ));
711 assert!(
712 !client
713 .provider_map
714 .read()
715 .unwrap()
716 .contains_key(&provider_id)
717 );
718 }
719}