1pub mod claude;
2pub mod error;
3pub mod factory;
4pub mod gemini;
5pub mod openai;
6pub mod provider;
7pub mod sse;
8pub mod util;
9pub mod xai;
10
11use crate::auth::storage::Credential;
12use crate::auth::{AuthSource, ProviderRegistry};
13use crate::config::model::{ModelId, ModelParameters};
14use crate::config::provider::ProviderId;
15use crate::config::{LlmConfigProvider, ResolvedAuth};
16use crate::error::Result;
17use crate::model_registry::ModelRegistry;
18pub use error::{ApiError, ProviderStreamErrorKind, SseParseError, StreamError};
19pub use factory::{create_provider, create_provider_with_directive};
20use futures::StreamExt;
21pub use provider::{CompletionResponse, CompletionStream, Provider, StreamChunk, TokenUsage};
22use rand::Rng;
23use std::collections::HashMap;
24use std::sync::Arc;
25use std::sync::RwLock;
26use std::time::Duration;
27use steer_tools::ToolSchema;
28use tokio_util::sync::CancellationToken;
29use tracing::debug;
30use tracing::warn;
31
32use crate::app::SystemContext;
33use crate::app::conversation::Message;
34
35#[cfg(not(test))]
36const RETRY_BASE_DELAY_MS: u64 = 250;
37#[cfg(test)]
38const RETRY_BASE_DELAY_MS: u64 = 1;
39const RETRY_MAX_ATTEMPTS: usize = 5;
40
41#[derive(Clone)]
42pub struct Client {
43 provider_map: Arc<RwLock<HashMap<ProviderId, ProviderEntry>>>,
44 config_provider: LlmConfigProvider,
45 provider_registry: Arc<ProviderRegistry>,
46 model_registry: Arc<ModelRegistry>,
47}
48
49#[derive(Clone)]
50struct ProviderEntry {
51 provider: Arc<dyn Provider>,
52 auth_source: AuthSource,
53}
54
55impl Client {
56 fn invalidate_provider(&self, provider_id: &ProviderId) {
58 let Ok(mut map) = self.provider_map.write() else {
59 warn!(
60 target: "api::client",
61 "Provider cache lock poisoned while invalidating provider"
62 );
63 return;
64 };
65 map.remove(provider_id);
66 }
67
68 fn should_invalidate_provider(error: &ApiError) -> bool {
70 matches!(
71 error,
72 ApiError::AuthenticationFailed { .. } | ApiError::AuthError(_)
73 ) || matches!(
74 error,
75 ApiError::ServerError { status_code, .. } if matches!(status_code, 401 | 403)
76 )
77 }
78
79 fn should_retry_error(error: &ApiError) -> bool {
81 match error {
82 ApiError::Network(_) => true,
83 ApiError::Timeout { .. } => true,
84 ApiError::RateLimited { .. } => true,
85 ApiError::ServerError { status_code, .. } => {
86 matches!(status_code, 408 | 409 | 429 | 500 | 502 | 503 | 504)
87 }
88 _ => false,
89 }
90 }
91
92 fn retry_delay(attempt: usize) -> Duration {
93 let base_ms = RETRY_BASE_DELAY_MS * (1u64 << attempt.min(4));
94 let jitter_percent = rand::thread_rng().gen_range(80_u64..=120_u64);
95 let jittered_ms = base_ms
96 .saturating_mul(jitter_percent)
97 .saturating_div(100)
98 .max(1);
99 Duration::from_millis(jittered_ms)
100 }
101
102 fn should_retry_stream_error(error: &StreamError) -> bool {
103 match error {
104 StreamError::SseParse(SseParseError::Transport { .. }) => true,
105 StreamError::Provider { kind, .. } => kind.is_retryable(),
106 StreamError::Cancelled | StreamError::SseParse(_) => false,
107 }
108 }
109
110 #[expect(
111 clippy::too_many_arguments,
112 reason = "Retry helper mirrors provider API inputs plus retry controls"
113 )]
114 async fn run_complete_with_retry(
115 provider: &Arc<dyn Provider>,
116 model_id: &ModelId,
117 messages: &[Message],
118 system: &Option<SystemContext>,
119 tools: &Option<Vec<ToolSchema>>,
120 call_options: Option<ModelParameters>,
121 token: &CancellationToken,
122 max_attempts: usize,
123 ) -> std::result::Result<CompletionResponse, ApiError> {
124 let mut attempt = 0usize;
125
126 loop {
127 if token.is_cancelled() {
128 return Err(ApiError::Cancelled {
129 provider: provider.name().to_string(),
130 });
131 }
132
133 match provider
134 .complete(
135 model_id,
136 messages.to_vec(),
137 system.clone(),
138 tools.clone(),
139 call_options,
140 token.clone(),
141 )
142 .await
143 {
144 Ok(response) => return Ok(response),
145 Err(error)
146 if Self::should_retry_error(&error)
147 && attempt + 1 < max_attempts
148 && !token.is_cancelled() =>
149 {
150 attempt += 1;
151 let delay = Self::retry_delay(attempt - 1);
152 warn!(
153 target: "api::complete",
154 provider = provider.name(),
155 ?model_id,
156 attempt,
157 max_attempts,
158 ?delay,
159 error = %error,
160 "Retrying API completion after transient error"
161 );
162 tokio::time::sleep(delay).await;
163 }
164 Err(error) => return Err(error),
165 }
166 }
167 }
168
169 #[expect(
170 clippy::too_many_arguments,
171 reason = "Retry helper mirrors provider stream API inputs plus retry controls"
172 )]
173 async fn run_stream_start_with_retry(
174 provider: &Arc<dyn Provider>,
175 model_id: &ModelId,
176 messages: &[Message],
177 system: &Option<SystemContext>,
178 tools: &Option<Vec<ToolSchema>>,
179 call_options: Option<ModelParameters>,
180 token: &CancellationToken,
181 max_attempts: usize,
182 ) -> std::result::Result<CompletionStream, ApiError> {
183 let mut attempt = 0usize;
184
185 loop {
186 if token.is_cancelled() {
187 return Err(ApiError::Cancelled {
188 provider: provider.name().to_string(),
189 });
190 }
191
192 match provider
193 .stream_complete(
194 model_id,
195 messages.to_vec(),
196 system.clone(),
197 tools.clone(),
198 call_options,
199 token.clone(),
200 )
201 .await
202 {
203 Ok(stream) => return Ok(stream),
204 Err(error)
205 if Self::should_retry_error(&error)
206 && attempt + 1 < max_attempts
207 && !token.is_cancelled() =>
208 {
209 attempt += 1;
210 let delay = Self::retry_delay(attempt - 1);
211 warn!(
212 target: "api::stream_complete",
213 provider = provider.name(),
214 ?model_id,
215 attempt,
216 max_attempts,
217 ?delay,
218 error = %error,
219 "Retrying API stream initialization after transient error"
220 );
221 tokio::time::sleep(delay).await;
222 }
223 Err(error) => return Err(error),
224 }
225 }
226 }
227
228 async fn collect_completion_from_stream(
229 mut stream: CompletionStream,
230 provider: String,
231 ) -> std::result::Result<CompletionResponse, ApiError> {
232 while let Some(chunk) = stream.next().await {
233 match chunk {
234 StreamChunk::MessageComplete(response) => return Ok(response),
235 StreamChunk::Error(StreamError::Cancelled) => {
236 return Err(ApiError::Cancelled {
237 provider: provider.clone(),
238 });
239 }
240 StreamChunk::Error(error) => {
241 return Err(ApiError::StreamError {
242 provider: provider.clone(),
243 details: error.to_string(),
244 });
245 }
246 StreamChunk::TextDelta(_)
247 | StreamChunk::ThinkingDelta(_)
248 | StreamChunk::ToolUseStart { .. }
249 | StreamChunk::ToolUseInputDelta { .. }
250 | StreamChunk::ContentBlockStop { .. }
251 | StreamChunk::Reset => {}
252 }
253 }
254
255 Err(ApiError::StreamError {
256 provider,
257 details: "Stream ended without completion response".to_string(),
258 })
259 }
260
261 pub fn new_with_deps(
264 config_provider: LlmConfigProvider,
265 provider_registry: Arc<ProviderRegistry>,
266 model_registry: Arc<ModelRegistry>,
267 ) -> Self {
268 Self {
269 provider_map: Arc::new(RwLock::new(HashMap::new())),
270 config_provider,
271 provider_registry,
272 model_registry,
273 }
274 }
275
276 pub fn model_context_window_tokens(&self, model_id: &ModelId) -> Option<u32> {
277 self.model_registry
278 .get(model_id)
279 .and_then(|model| model.context_window_tokens)
280 }
281
282 pub fn model_max_output_tokens(&self, model_id: &ModelId) -> Option<u32> {
283 self.model_registry
284 .get(model_id)
285 .and_then(|model| model.parameters)
286 .and_then(|parameters| parameters.max_output_tokens)
287 }
288
289 #[cfg(any(test, feature = "test-utils"))]
290 pub fn insert_test_provider(&self, provider_id: ProviderId, provider: Arc<dyn Provider>) {
291 match self.provider_map.write() {
292 Ok(mut map) => {
293 map.insert(
294 provider_id,
295 ProviderEntry {
296 provider,
297 auth_source: AuthSource::None,
298 },
299 );
300 }
301 Err(_) => {
302 warn!(
303 target: "api::client",
304 "Provider cache lock poisoned while inserting test provider"
305 );
306 }
307 }
308 }
309
310 async fn get_or_create_provider_entry(&self, provider_id: ProviderId) -> Result<ProviderEntry> {
311 {
313 let map = self.provider_map.read().map_err(|_| {
314 crate::error::Error::Api(ApiError::Configuration(
315 "Provider cache lock poisoned".to_string(),
316 ))
317 })?;
318 if let Some(entry) = map.get(&provider_id) {
319 return Ok(entry.clone());
320 }
321 }
322
323 let provider_config = self.provider_registry.get(&provider_id).ok_or_else(|| {
325 crate::error::Error::Api(ApiError::Configuration(format!(
326 "No provider configuration found for {provider_id:?}"
327 )))
328 })?;
329
330 let resolved = self
331 .config_provider
332 .resolve_auth_for_provider(&provider_id)
333 .await?;
334
335 let mut map = self.provider_map.write().map_err(|_| {
337 crate::error::Error::Api(ApiError::Configuration(
338 "Provider cache lock poisoned".to_string(),
339 ))
340 })?;
341
342 if let Some(entry) = map.get(&provider_id) {
344 return Ok(entry.clone());
345 }
346
347 let entry = Self::build_provider_entry(provider_config, &resolved)?;
348
349 map.insert(provider_id, entry.clone());
350 Ok(entry)
351 }
352
353 fn build_provider_entry(
354 provider_config: &crate::config::provider::ProviderConfig,
355 resolved: &ResolvedAuth,
356 ) -> std::result::Result<ProviderEntry, ApiError> {
357 let provider = match resolved {
358 ResolvedAuth::Plugin { directive, .. } => {
359 factory::create_provider_with_directive(provider_config, directive)?
360 }
361 ResolvedAuth::ApiKey { credential, .. } => {
362 factory::create_provider(provider_config, credential)?
363 }
364 ResolvedAuth::None => {
365 return Err(ApiError::Configuration(format!(
366 "No authentication configured for {:?}",
367 provider_config.id
368 )));
369 }
370 };
371
372 Ok(ProviderEntry {
373 provider,
374 auth_source: resolved.source(),
375 })
376 }
377
378 async fn fallback_api_key_entry(
379 &self,
380 provider_id: &ProviderId,
381 ) -> std::result::Result<Option<ProviderEntry>, ApiError> {
382 let Some((key, origin)) = self
383 .config_provider
384 .resolve_api_key_for_provider(provider_id)
385 .await?
386 else {
387 return Ok(None);
388 };
389
390 let provider_config = self.provider_registry.get(provider_id).ok_or_else(|| {
391 ApiError::Configuration(format!(
392 "No provider configuration found for {provider_id:?}"
393 ))
394 })?;
395
396 let credential = Credential::ApiKey { value: key };
397 let provider = factory::create_provider(provider_config, &credential)?;
398
399 Ok(Some(ProviderEntry {
400 provider,
401 auth_source: AuthSource::ApiKey { origin },
402 }))
403 }
404
405 pub async fn complete(
407 &self,
408 model_id: &ModelId,
409 messages: Vec<Message>,
410 system: Option<SystemContext>,
411 tools: Option<Vec<ToolSchema>>,
412 call_options: Option<crate::config::model::ModelParameters>,
413 token: CancellationToken,
414 ) -> std::result::Result<CompletionResponse, ApiError> {
415 debug!(
416 target: "api::complete",
417 ?model_id,
418 "Completing by consuming the streamed endpoint"
419 );
420
421 let stream = self
422 .stream_complete(model_id, messages, system, tools, call_options, token)
423 .await?;
424
425 Self::collect_completion_from_stream(stream, model_id.provider.to_string()).await
426 }
427
428 pub async fn stream_complete(
429 &self,
430 model_id: &ModelId,
431 messages: Vec<Message>,
432 system: Option<SystemContext>,
433 tools: Option<Vec<ToolSchema>>,
434 call_options: Option<crate::config::model::ModelParameters>,
435 token: CancellationToken,
436 ) -> std::result::Result<CompletionStream, ApiError> {
437 let provider_id = model_id.provider.clone();
438 let entry = self
439 .get_or_create_provider_entry(provider_id.clone())
440 .await
441 .map_err(ApiError::from)?;
442 let provider = entry.provider.clone();
443
444 if token.is_cancelled() {
445 return Err(ApiError::Cancelled {
446 provider: provider.name().to_string(),
447 });
448 }
449
450 let model_config = self.model_registry.get(model_id);
451 let effective_params = match (model_config, &call_options) {
452 (Some(config), Some(opts)) => config.effective_parameters(Some(opts)),
453 (Some(config), None) => config.effective_parameters(None),
454 (None, Some(opts)) => Some(*opts),
455 (None, None) => None,
456 };
457
458 debug!(
459 target: "api::stream_complete",
460 ?model_id,
461 ?call_options,
462 ?effective_params,
463 "Streaming with parameters"
464 );
465
466 let (initial_stream, provider_for_retry) = match Self::run_stream_start_with_retry(
467 &provider,
468 model_id,
469 &messages,
470 &system,
471 &tools,
472 effective_params,
473 &token,
474 RETRY_MAX_ATTEMPTS,
475 )
476 .await
477 {
478 Ok(stream) => (stream, provider),
479 Err(err) => {
480 if Self::should_invalidate_provider(&err) {
481 self.invalidate_provider(&provider_id);
482
483 if matches!(entry.auth_source, AuthSource::Plugin { .. }) {
484 if let Some(fallback) = self.fallback_api_key_entry(&provider_id).await? {
485 let fallback_provider = fallback.provider.clone();
486 let fallback_stream = Self::run_stream_start_with_retry(
487 &fallback_provider,
488 model_id,
489 &messages,
490 &system,
491 &tools,
492 effective_params,
493 &token,
494 RETRY_MAX_ATTEMPTS,
495 )
496 .await?;
497 let mut map = self.provider_map.write().map_err(|_| {
498 ApiError::Configuration("Provider cache lock poisoned".to_string())
499 })?;
500 map.insert(provider_id, fallback);
501 (fallback_stream, fallback_provider)
502 } else {
503 return Err(err);
504 }
505 } else {
506 return Err(err);
507 }
508 } else {
509 return Err(err);
510 }
511 }
512 };
513
514 let model_id = model_id.clone();
515 let stream = async_stream::stream! {
516 let mut attempt = 1usize;
517 let mut current_stream = Some(initial_stream);
518
519 'outer: loop {
520 let mut saw_output = false;
521 let mut stream = if let Some(stream) = current_stream.take() { stream } else {
522 if token.is_cancelled() {
523 yield StreamChunk::Error(StreamError::Cancelled);
524 break;
525 }
526
527 let stream_result = Self::run_stream_start_with_retry(
528 &provider_for_retry,
529 &model_id,
530 &messages,
531 &system,
532 &tools,
533 effective_params,
534 &token,
535 RETRY_MAX_ATTEMPTS,
536 )
537 .await;
538 match stream_result {
539 Ok(stream) => stream,
540 Err(err) => {
541 yield StreamChunk::Error(StreamError::Provider {
542 provider: provider_for_retry.name().to_string(),
543 kind: ProviderStreamErrorKind::StreamRetry,
544 raw_error_type: Some("stream_retry".to_string()),
545 message: err.to_string(),
546 });
547 break;
548 }
549 }
550 };
551
552 while let Some(chunk) = stream.next().await {
553 let retryable_stream_error = match &chunk {
554 StreamChunk::Error(stream_err) => match stream_err {
555 StreamError::Cancelled => false,
556 StreamError::SseParse(
557 SseParseError::Parser { .. } | SseParseError::Utf8 { .. },
558 ) => false,
559 _ => Self::should_retry_stream_error(stream_err),
560 },
561 _ => false,
562 };
563
564 if retryable_stream_error && attempt < RETRY_MAX_ATTEMPTS {
565 attempt += 1;
566 warn!(
567 target: "api::stream_complete",
568 ?model_id,
569 attempt,
570 max_attempts = RETRY_MAX_ATTEMPTS,
571 error = ?chunk,
572 "Retrying stream after transport/provider stream failure"
573 );
574 if saw_output {
575 yield StreamChunk::Reset;
576 }
577 current_stream = None;
578 continue 'outer;
579 }
580
581 if !matches!(chunk, StreamChunk::Error(_)) {
582 saw_output = true;
583 }
584
585 yield chunk;
586 }
587
588 break;
589 }
590 };
591
592 Ok(Box::pin(stream))
593 }
594
595 pub async fn complete_with_retry(
596 &self,
597 model_id: &ModelId,
598 messages: &[Message],
599 system_prompt: &Option<SystemContext>,
600 tools: &Option<Vec<ToolSchema>>,
601 token: CancellationToken,
602 max_attempts: usize,
603 ) -> std::result::Result<CompletionResponse, ApiError> {
604 let provider_id = model_id.provider.clone();
605 let entry = self
606 .get_or_create_provider_entry(provider_id.clone())
607 .await
608 .map_err(ApiError::from)?;
609
610 let model_config = self.model_registry.get(model_id);
611 debug!(
612 target: "api::complete_with_retry",
613 ?model_id,
614 ?model_config,
615 "Model config"
616 );
617 let effective_params = model_config.and_then(|cfg| cfg.effective_parameters(None));
618
619 debug!(
620 target: "api::complete_with_retry",
621 ?model_id,
622 ?effective_params,
623 "system: {:?}",
624 system_prompt
625 );
626 debug!(
627 target: "api::complete_with_retry",
628 ?model_id,
629 "messages: {:?}",
630 messages
631 );
632
633 let result = Self::run_complete_with_retry(
634 &entry.provider,
635 model_id,
636 messages,
637 system_prompt,
638 tools,
639 effective_params,
640 &token,
641 max_attempts,
642 )
643 .await;
644
645 if let Err(ref error) = result
646 && Self::should_invalidate_provider(error)
647 {
648 self.invalidate_provider(&provider_id);
649 if matches!(entry.auth_source, AuthSource::Plugin { .. })
650 && let Some(fallback) = self.fallback_api_key_entry(&provider_id).await?
651 {
652 let fallback_result = Self::run_complete_with_retry(
653 &fallback.provider,
654 model_id,
655 messages,
656 system_prompt,
657 tools,
658 effective_params,
659 &token,
660 max_attempts,
661 )
662 .await;
663 if fallback_result.is_ok() {
664 let mut map = self.provider_map.write().map_err(|_| {
665 ApiError::Configuration("Provider cache lock poisoned".to_string())
666 })?;
667 map.insert(provider_id, fallback);
668 }
669 return fallback_result;
670 }
671 }
672
673 result
674 }
675}
676
677#[cfg(test)]
678mod tests {
679 use super::*;
680 use crate::app::conversation::AssistantContent;
681 use crate::auth::ApiKeyOrigin;
682 use crate::config::provider::ProviderId;
683 use async_trait::async_trait;
684 use futures::StreamExt;
685 use std::sync::atomic::{AtomicUsize, Ordering};
686 use tokio_util::sync::CancellationToken;
687
688 #[derive(Clone, Copy)]
689 enum StubErrorKind {
690 Auth,
691 Server401,
692 }
693
694 #[derive(Clone)]
695 struct StubProvider {
696 error_kind: StubErrorKind,
697 }
698
699 impl StubProvider {
700 fn new(error_kind: StubErrorKind) -> Self {
701 Self { error_kind }
702 }
703 }
704
705 #[async_trait]
706 impl Provider for StubProvider {
707 fn name(&self) -> &'static str {
708 "stub"
709 }
710
711 async fn complete(
712 &self,
713 _model_id: &ModelId,
714 _messages: Vec<Message>,
715 _system: Option<SystemContext>,
716 _tools: Option<Vec<ToolSchema>>,
717 _call_options: Option<crate::config::model::ModelParameters>,
718 _token: CancellationToken,
719 ) -> std::result::Result<CompletionResponse, ApiError> {
720 let err = match self.error_kind {
721 StubErrorKind::Auth => ApiError::AuthenticationFailed {
722 provider: "stub".to_string(),
723 details: "bad key".to_string(),
724 },
725 StubErrorKind::Server401 => ApiError::ServerError {
726 provider: "stub".to_string(),
727 status_code: 401,
728 details: "unauthorized".to_string(),
729 },
730 };
731 Err(err)
732 }
733 }
734
735 #[derive(Clone)]
736 struct FlakyCompleteProvider {
737 failures_before_success: usize,
738 attempts: Arc<AtomicUsize>,
739 }
740
741 impl FlakyCompleteProvider {
742 fn new(failures_before_success: usize, attempts: Arc<AtomicUsize>) -> Self {
743 Self {
744 failures_before_success,
745 attempts,
746 }
747 }
748 }
749
750 #[async_trait]
751 impl Provider for FlakyCompleteProvider {
752 fn name(&self) -> &'static str {
753 "flaky-complete"
754 }
755
756 async fn complete(
757 &self,
758 _model_id: &ModelId,
759 _messages: Vec<Message>,
760 _system: Option<SystemContext>,
761 _tools: Option<Vec<ToolSchema>>,
762 _call_options: Option<crate::config::model::ModelParameters>,
763 _token: CancellationToken,
764 ) -> std::result::Result<CompletionResponse, ApiError> {
765 Ok(success_response())
766 }
767
768 async fn stream_complete(
769 &self,
770 _model_id: &ModelId,
771 _messages: Vec<Message>,
772 _system: Option<SystemContext>,
773 _tools: Option<Vec<ToolSchema>>,
774 _call_options: Option<crate::config::model::ModelParameters>,
775 _token: CancellationToken,
776 ) -> std::result::Result<CompletionStream, ApiError> {
777 let attempt = self.attempts.fetch_add(1, Ordering::Relaxed) + 1;
778 if attempt <= self.failures_before_success {
779 return Err(network_api_error());
780 }
781 let response = success_response();
782 Ok(Box::pin(futures_util::stream::once(async move {
783 StreamChunk::MessageComplete(response)
784 })))
785 }
786 }
787
788 #[derive(Clone)]
789 struct FlakyStreamStartProvider {
790 failures_before_success: usize,
791 attempts: Arc<AtomicUsize>,
792 }
793
794 impl FlakyStreamStartProvider {
795 fn new(failures_before_success: usize, attempts: Arc<AtomicUsize>) -> Self {
796 Self {
797 failures_before_success,
798 attempts,
799 }
800 }
801 }
802
803 #[async_trait]
804 impl Provider for FlakyStreamStartProvider {
805 fn name(&self) -> &'static str {
806 "flaky-stream-start"
807 }
808
809 async fn complete(
810 &self,
811 _model_id: &ModelId,
812 _messages: Vec<Message>,
813 _system: Option<SystemContext>,
814 _tools: Option<Vec<ToolSchema>>,
815 _call_options: Option<crate::config::model::ModelParameters>,
816 _token: CancellationToken,
817 ) -> std::result::Result<CompletionResponse, ApiError> {
818 Ok(success_response())
819 }
820
821 async fn stream_complete(
822 &self,
823 _model_id: &ModelId,
824 _messages: Vec<Message>,
825 _system: Option<SystemContext>,
826 _tools: Option<Vec<ToolSchema>>,
827 _call_options: Option<crate::config::model::ModelParameters>,
828 _token: CancellationToken,
829 ) -> std::result::Result<CompletionStream, ApiError> {
830 let attempt = self.attempts.fetch_add(1, Ordering::Relaxed) + 1;
831 if attempt <= self.failures_before_success {
832 return Err(network_api_error());
833 }
834
835 let response = success_response();
836 Ok(Box::pin(futures_util::stream::once(async move {
837 StreamChunk::MessageComplete(response)
838 })))
839 }
840 }
841
842 #[derive(Clone)]
843 struct InvalidRequestProvider {
844 attempts: Arc<AtomicUsize>,
845 }
846
847 impl InvalidRequestProvider {
848 fn new(attempts: Arc<AtomicUsize>) -> Self {
849 Self { attempts }
850 }
851 }
852
853 #[async_trait]
854 impl Provider for InvalidRequestProvider {
855 fn name(&self) -> &'static str {
856 "invalid-request"
857 }
858
859 async fn complete(
860 &self,
861 _model_id: &ModelId,
862 _messages: Vec<Message>,
863 _system: Option<SystemContext>,
864 _tools: Option<Vec<ToolSchema>>,
865 _call_options: Option<crate::config::model::ModelParameters>,
866 _token: CancellationToken,
867 ) -> std::result::Result<CompletionResponse, ApiError> {
868 Err(ApiError::InvalidRequest {
869 provider: "stub".to_string(),
870 details: "bad request".to_string(),
871 })
872 }
873
874 async fn stream_complete(
875 &self,
876 _model_id: &ModelId,
877 _messages: Vec<Message>,
878 _system: Option<SystemContext>,
879 _tools: Option<Vec<ToolSchema>>,
880 _call_options: Option<crate::config::model::ModelParameters>,
881 _token: CancellationToken,
882 ) -> std::result::Result<CompletionStream, ApiError> {
883 self.attempts.fetch_add(1, Ordering::Relaxed);
884 Err(ApiError::InvalidRequest {
885 provider: "stub".to_string(),
886 details: "bad request".to_string(),
887 })
888 }
889 }
890
891 fn success_response() -> CompletionResponse {
892 CompletionResponse::new(vec![AssistantContent::Text {
893 text: "ok".to_string(),
894 }])
895 }
896
897 #[derive(Clone)]
898 struct StreamWithoutCompletionProvider;
899
900 #[async_trait]
901 impl Provider for StreamWithoutCompletionProvider {
902 fn name(&self) -> &'static str {
903 "stream-without-completion"
904 }
905
906 async fn complete(
907 &self,
908 _model_id: &ModelId,
909 _messages: Vec<Message>,
910 _system: Option<SystemContext>,
911 _tools: Option<Vec<ToolSchema>>,
912 _call_options: Option<crate::config::model::ModelParameters>,
913 _token: CancellationToken,
914 ) -> std::result::Result<CompletionResponse, ApiError> {
915 Ok(success_response())
916 }
917
918 async fn stream_complete(
919 &self,
920 _model_id: &ModelId,
921 _messages: Vec<Message>,
922 _system: Option<SystemContext>,
923 _tools: Option<Vec<ToolSchema>>,
924 _call_options: Option<crate::config::model::ModelParameters>,
925 _token: CancellationToken,
926 ) -> std::result::Result<CompletionStream, ApiError> {
927 Ok(Box::pin(futures_util::stream::iter(vec![
928 StreamChunk::TextDelta("partial".to_string()),
929 ])))
930 }
931 }
932
933 #[derive(Clone)]
934 struct StreamCancelledProvider;
935
936 #[async_trait]
937 impl Provider for StreamCancelledProvider {
938 fn name(&self) -> &'static str {
939 "stream-cancelled"
940 }
941
942 async fn complete(
943 &self,
944 _model_id: &ModelId,
945 _messages: Vec<Message>,
946 _system: Option<SystemContext>,
947 _tools: Option<Vec<ToolSchema>>,
948 _call_options: Option<crate::config::model::ModelParameters>,
949 _token: CancellationToken,
950 ) -> std::result::Result<CompletionResponse, ApiError> {
951 Ok(success_response())
952 }
953
954 async fn stream_complete(
955 &self,
956 _model_id: &ModelId,
957 _messages: Vec<Message>,
958 _system: Option<SystemContext>,
959 _tools: Option<Vec<ToolSchema>>,
960 _call_options: Option<crate::config::model::ModelParameters>,
961 _token: CancellationToken,
962 ) -> std::result::Result<CompletionStream, ApiError> {
963 Ok(Box::pin(futures_util::stream::iter(vec![
964 StreamChunk::Error(StreamError::Cancelled),
965 ])))
966 }
967 }
968
969 #[derive(Clone)]
970 struct StreamProviderErrorProvider;
971
972 #[async_trait]
973 impl Provider for StreamProviderErrorProvider {
974 fn name(&self) -> &'static str {
975 "stream-provider-error"
976 }
977
978 async fn complete(
979 &self,
980 _model_id: &ModelId,
981 _messages: Vec<Message>,
982 _system: Option<SystemContext>,
983 _tools: Option<Vec<ToolSchema>>,
984 _call_options: Option<crate::config::model::ModelParameters>,
985 _token: CancellationToken,
986 ) -> std::result::Result<CompletionResponse, ApiError> {
987 Ok(success_response())
988 }
989
990 async fn stream_complete(
991 &self,
992 _model_id: &ModelId,
993 _messages: Vec<Message>,
994 _system: Option<SystemContext>,
995 _tools: Option<Vec<ToolSchema>>,
996 _call_options: Option<crate::config::model::ModelParameters>,
997 _token: CancellationToken,
998 ) -> std::result::Result<CompletionStream, ApiError> {
999 Ok(Box::pin(futures_util::stream::iter(vec![
1000 StreamChunk::Error(StreamError::Provider {
1001 provider: "stub".to_string(),
1002 kind: ProviderStreamErrorKind::StreamError,
1003 raw_error_type: Some("stream_error".to_string()),
1004 message: "upstream failed".to_string(),
1005 }),
1006 ])))
1007 }
1008 }
1009
1010 fn network_api_error() -> ApiError {
1011 let err = reqwest::Client::new()
1012 .get("http://[::1")
1013 .build()
1014 .expect_err("invalid URL should fail");
1015 ApiError::Network(err)
1016 }
1017
1018 fn test_client() -> Client {
1019 let auth_storage = Arc::new(crate::test_utils::InMemoryAuthStorage::new());
1020 let config_provider = LlmConfigProvider::new(auth_storage).unwrap();
1021 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
1022 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
1023
1024 Client::new_with_deps(config_provider, provider_registry, model_registry)
1025 }
1026
1027 fn insert_provider(client: &Client, provider_id: ProviderId, provider: Arc<dyn Provider>) {
1028 client.provider_map.write().unwrap().insert(
1029 provider_id,
1030 ProviderEntry {
1031 provider,
1032 auth_source: AuthSource::ApiKey {
1033 origin: ApiKeyOrigin::Stored,
1034 },
1035 },
1036 );
1037 }
1038
1039 fn insert_stub_provider(client: &Client, provider_id: ProviderId, error: StubErrorKind) {
1040 insert_provider(client, provider_id, Arc::new(StubProvider::new(error)));
1041 }
1042
1043 #[tokio::test]
1044 async fn invalidates_cached_provider_on_auth_failure() {
1045 let client = test_client();
1046 let provider_id = ProviderId("stub-auth".to_string());
1047 let model_id = ModelId::new(provider_id.clone(), "stub-model");
1048
1049 insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Auth);
1050
1051 let err = client
1052 .complete(
1053 &model_id,
1054 vec![],
1055 None,
1056 None,
1057 None,
1058 CancellationToken::new(),
1059 )
1060 .await
1061 .unwrap_err();
1062
1063 assert!(matches!(err, ApiError::AuthenticationFailed { .. }));
1064 assert!(
1065 !client
1066 .provider_map
1067 .read()
1068 .unwrap()
1069 .contains_key(&provider_id)
1070 );
1071 }
1072
1073 #[tokio::test]
1074 async fn invalidates_cached_provider_on_unauthorized_status_code() {
1075 let client = test_client();
1076 let provider_id = ProviderId("stub-unauthorized".to_string());
1077 let model_id = ModelId::new(provider_id.clone(), "stub-model");
1078
1079 insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Server401);
1080
1081 let err = client
1082 .complete(
1083 &model_id,
1084 vec![],
1085 None,
1086 None,
1087 None,
1088 CancellationToken::new(),
1089 )
1090 .await
1091 .unwrap_err();
1092
1093 assert!(matches!(
1094 err,
1095 ApiError::ServerError {
1096 status_code: 401,
1097 ..
1098 }
1099 ));
1100 assert!(
1101 !client
1102 .provider_map
1103 .read()
1104 .unwrap()
1105 .contains_key(&provider_id)
1106 );
1107 }
1108
1109 #[tokio::test]
1110 async fn retries_network_errors_for_complete() {
1111 let client = test_client();
1112 let provider_id = ProviderId("flaky-complete".to_string());
1113 let model_id = ModelId::new(provider_id.clone(), "stub-model");
1114 let attempts = Arc::new(AtomicUsize::new(0));
1115
1116 insert_provider(
1117 &client,
1118 provider_id,
1119 Arc::new(FlakyCompleteProvider::new(2, attempts.clone())),
1120 );
1121
1122 let response = client
1123 .complete(
1124 &model_id,
1125 vec![],
1126 None,
1127 None,
1128 None,
1129 CancellationToken::new(),
1130 )
1131 .await
1132 .expect("complete should retry transient network failures");
1133
1134 assert_eq!(response.extract_text(), "ok");
1135 assert_eq!(attempts.load(Ordering::Relaxed), 3);
1136 }
1137
1138 #[tokio::test]
1139 async fn complete_errors_when_stream_ends_without_message_complete() {
1140 let client = test_client();
1141 let provider_id = ProviderId("stream-without-completion".to_string());
1142 let model_id = ModelId::new(provider_id.clone(), "stub-model");
1143
1144 insert_provider(
1145 &client,
1146 provider_id,
1147 Arc::new(StreamWithoutCompletionProvider),
1148 );
1149
1150 let err = client
1151 .complete(
1152 &model_id,
1153 vec![],
1154 None,
1155 None,
1156 None,
1157 CancellationToken::new(),
1158 )
1159 .await
1160 .unwrap_err();
1161
1162 assert!(matches!(
1163 err,
1164 ApiError::StreamError {
1165 ref provider,
1166 ref details,
1167 } if provider == "stream-without-completion" && details == "Stream ended without completion response"
1168 ));
1169 }
1170
1171 #[tokio::test]
1172 async fn complete_maps_stream_cancelled_to_cancelled_api_error() {
1173 let client = test_client();
1174 let provider_id = ProviderId("stream-cancelled".to_string());
1175 let model_id = ModelId::new(provider_id.clone(), "stub-model");
1176
1177 insert_provider(&client, provider_id, Arc::new(StreamCancelledProvider));
1178
1179 let err = client
1180 .complete(
1181 &model_id,
1182 vec![],
1183 None,
1184 None,
1185 None,
1186 CancellationToken::new(),
1187 )
1188 .await
1189 .unwrap_err();
1190
1191 assert!(matches!(
1192 err,
1193 ApiError::Cancelled { ref provider } if provider == "stream-cancelled"
1194 ));
1195 }
1196
1197 #[tokio::test]
1198 async fn complete_maps_stream_provider_error_to_stream_api_error() {
1199 let client = test_client();
1200 let provider_id = ProviderId("stream-provider-error".to_string());
1201 let model_id = ModelId::new(provider_id.clone(), "stub-model");
1202
1203 insert_provider(&client, provider_id, Arc::new(StreamProviderErrorProvider));
1204
1205 let err = client
1206 .complete(
1207 &model_id,
1208 vec![],
1209 None,
1210 None,
1211 None,
1212 CancellationToken::new(),
1213 )
1214 .await
1215 .unwrap_err();
1216
1217 assert!(matches!(
1218 err,
1219 ApiError::StreamError {
1220 ref provider,
1221 ref details,
1222 } if provider == "stream-provider-error" && details.contains("upstream failed")
1223 ));
1224 }
1225
1226 #[tokio::test]
1227 async fn does_not_retry_non_retryable_complete_error() {
1228 let client = test_client();
1229 let provider_id = ProviderId("invalid-request".to_string());
1230 let model_id = ModelId::new(provider_id.clone(), "stub-model");
1231 let attempts = Arc::new(AtomicUsize::new(0));
1232
1233 insert_provider(
1234 &client,
1235 provider_id,
1236 Arc::new(InvalidRequestProvider::new(attempts.clone())),
1237 );
1238
1239 let err = client
1240 .complete(
1241 &model_id,
1242 vec![],
1243 None,
1244 None,
1245 None,
1246 CancellationToken::new(),
1247 )
1248 .await
1249 .unwrap_err();
1250
1251 assert!(matches!(err, ApiError::InvalidRequest { .. }));
1252 assert_eq!(attempts.load(Ordering::Relaxed), 1);
1253 }
1254
1255 #[tokio::test]
1256 async fn retries_network_errors_when_starting_stream() {
1257 let client = test_client();
1258 let provider_id = ProviderId("flaky-stream-start".to_string());
1259 let model_id = ModelId::new(provider_id.clone(), "stub-model");
1260 let attempts = Arc::new(AtomicUsize::new(0));
1261
1262 insert_provider(
1263 &client,
1264 provider_id,
1265 Arc::new(FlakyStreamStartProvider::new(2, attempts.clone())),
1266 );
1267
1268 let mut stream = client
1269 .stream_complete(
1270 &model_id,
1271 vec![],
1272 None,
1273 None,
1274 None,
1275 CancellationToken::new(),
1276 )
1277 .await
1278 .expect("stream start should retry transient network failures");
1279
1280 let chunk = stream.next().await.expect("stream should yield completion");
1281 match chunk {
1282 StreamChunk::MessageComplete(response) => assert_eq!(response.extract_text(), "ok"),
1283 other => panic!("unexpected stream chunk: {other:?}"),
1284 }
1285
1286 assert_eq!(attempts.load(Ordering::Relaxed), 3);
1287 }
1288}