Skip to main content

steer_core/api/
mod.rs

1pub mod claude;
2pub mod error;
3pub mod factory;
4pub mod gemini;
5pub mod openai;
6pub mod provider;
7pub mod sse;
8pub mod util;
9pub mod xai;
10
11use crate::auth::storage::Credential;
12use crate::auth::{AuthSource, ProviderRegistry};
13use crate::config::model::{ModelId, ModelParameters};
14use crate::config::provider::ProviderId;
15use crate::config::{LlmConfigProvider, ResolvedAuth};
16use crate::error::Result;
17use crate::model_registry::ModelRegistry;
18pub use error::{ApiError, ProviderStreamErrorKind, SseParseError, StreamError};
19pub use factory::{create_provider, create_provider_with_directive};
20use futures::StreamExt;
21pub use provider::{CompletionResponse, CompletionStream, Provider, StreamChunk, TokenUsage};
22use rand::Rng;
23use std::collections::HashMap;
24use std::sync::Arc;
25use std::sync::RwLock;
26use std::time::Duration;
27use steer_tools::ToolSchema;
28use tokio_util::sync::CancellationToken;
29use tracing::debug;
30use tracing::warn;
31
32use crate::app::SystemContext;
33use crate::app::conversation::Message;
34
35#[cfg(not(test))]
36const RETRY_BASE_DELAY_MS: u64 = 250;
37#[cfg(test)]
38const RETRY_BASE_DELAY_MS: u64 = 1;
39const RETRY_MAX_ATTEMPTS: usize = 5;
40
41#[derive(Clone)]
42pub struct Client {
43    provider_map: Arc<RwLock<HashMap<ProviderId, ProviderEntry>>>,
44    config_provider: LlmConfigProvider,
45    provider_registry: Arc<ProviderRegistry>,
46    model_registry: Arc<ModelRegistry>,
47}
48
49#[derive(Clone)]
50struct ProviderEntry {
51    provider: Arc<dyn Provider>,
52    auth_source: AuthSource,
53}
54
55impl Client {
56    /// Remove a cached provider so that future calls re-create it with fresh credentials.
57    fn invalidate_provider(&self, provider_id: &ProviderId) {
58        let Ok(mut map) = self.provider_map.write() else {
59            warn!(
60                target: "api::client",
61                "Provider cache lock poisoned while invalidating provider"
62            );
63            return;
64        };
65        map.remove(provider_id);
66    }
67
68    /// Determine if an API error should invalidate the cached provider (typically auth failures).
69    fn should_invalidate_provider(error: &ApiError) -> bool {
70        matches!(
71            error,
72            ApiError::AuthenticationFailed { .. } | ApiError::AuthError(_)
73        ) || matches!(
74            error,
75            ApiError::ServerError { status_code, .. } if matches!(status_code, 401 | 403)
76        )
77    }
78
79    /// Determine if an API error should trigger an automatic retry.
80    fn should_retry_error(error: &ApiError) -> bool {
81        match error {
82            ApiError::Network(_) => true,
83            ApiError::Timeout { .. } => true,
84            ApiError::RateLimited { .. } => true,
85            ApiError::ServerError { status_code, .. } => {
86                matches!(status_code, 408 | 409 | 429 | 500 | 502 | 503 | 504)
87            }
88            _ => false,
89        }
90    }
91
92    fn retry_delay(attempt: usize) -> Duration {
93        let base_ms = RETRY_BASE_DELAY_MS * (1u64 << attempt.min(4));
94        let jitter_percent = rand::thread_rng().gen_range(80_u64..=120_u64);
95        let jittered_ms = base_ms
96            .saturating_mul(jitter_percent)
97            .saturating_div(100)
98            .max(1);
99        Duration::from_millis(jittered_ms)
100    }
101
102    fn should_retry_stream_error(error: &StreamError) -> bool {
103        match error {
104            StreamError::SseParse(SseParseError::Transport { .. }) => true,
105            StreamError::Provider { kind, .. } => kind.is_retryable(),
106            StreamError::Cancelled | StreamError::SseParse(_) => false,
107        }
108    }
109
110    #[expect(
111        clippy::too_many_arguments,
112        reason = "Retry helper mirrors provider API inputs plus retry controls"
113    )]
114    async fn run_complete_with_retry(
115        provider: &Arc<dyn Provider>,
116        model_id: &ModelId,
117        messages: &[Message],
118        system: &Option<SystemContext>,
119        tools: &Option<Vec<ToolSchema>>,
120        call_options: Option<ModelParameters>,
121        token: &CancellationToken,
122        max_attempts: usize,
123    ) -> std::result::Result<CompletionResponse, ApiError> {
124        let mut attempt = 0usize;
125
126        loop {
127            if token.is_cancelled() {
128                return Err(ApiError::Cancelled {
129                    provider: provider.name().to_string(),
130                });
131            }
132
133            match provider
134                .complete(
135                    model_id,
136                    messages.to_vec(),
137                    system.clone(),
138                    tools.clone(),
139                    call_options,
140                    token.clone(),
141                )
142                .await
143            {
144                Ok(response) => return Ok(response),
145                Err(error)
146                    if Self::should_retry_error(&error)
147                        && attempt + 1 < max_attempts
148                        && !token.is_cancelled() =>
149                {
150                    attempt += 1;
151                    let delay = Self::retry_delay(attempt - 1);
152                    warn!(
153                        target: "api::complete",
154                        provider = provider.name(),
155                        ?model_id,
156                        attempt,
157                        max_attempts,
158                        ?delay,
159                        error = %error,
160                        "Retrying API completion after transient error"
161                    );
162                    tokio::time::sleep(delay).await;
163                }
164                Err(error) => return Err(error),
165            }
166        }
167    }
168
169    #[expect(
170        clippy::too_many_arguments,
171        reason = "Retry helper mirrors provider stream API inputs plus retry controls"
172    )]
173    async fn run_stream_start_with_retry(
174        provider: &Arc<dyn Provider>,
175        model_id: &ModelId,
176        messages: &[Message],
177        system: &Option<SystemContext>,
178        tools: &Option<Vec<ToolSchema>>,
179        call_options: Option<ModelParameters>,
180        token: &CancellationToken,
181        max_attempts: usize,
182    ) -> std::result::Result<CompletionStream, ApiError> {
183        let mut attempt = 0usize;
184
185        loop {
186            if token.is_cancelled() {
187                return Err(ApiError::Cancelled {
188                    provider: provider.name().to_string(),
189                });
190            }
191
192            match provider
193                .stream_complete(
194                    model_id,
195                    messages.to_vec(),
196                    system.clone(),
197                    tools.clone(),
198                    call_options,
199                    token.clone(),
200                )
201                .await
202            {
203                Ok(stream) => return Ok(stream),
204                Err(error)
205                    if Self::should_retry_error(&error)
206                        && attempt + 1 < max_attempts
207                        && !token.is_cancelled() =>
208                {
209                    attempt += 1;
210                    let delay = Self::retry_delay(attempt - 1);
211                    warn!(
212                        target: "api::stream_complete",
213                        provider = provider.name(),
214                        ?model_id,
215                        attempt,
216                        max_attempts,
217                        ?delay,
218                        error = %error,
219                        "Retrying API stream initialization after transient error"
220                    );
221                    tokio::time::sleep(delay).await;
222                }
223                Err(error) => return Err(error),
224            }
225        }
226    }
227
228    async fn collect_completion_from_stream(
229        mut stream: CompletionStream,
230        provider: String,
231    ) -> std::result::Result<CompletionResponse, ApiError> {
232        while let Some(chunk) = stream.next().await {
233            match chunk {
234                StreamChunk::MessageComplete(response) => return Ok(response),
235                StreamChunk::Error(StreamError::Cancelled) => {
236                    return Err(ApiError::Cancelled {
237                        provider: provider.clone(),
238                    });
239                }
240                StreamChunk::Error(error) => {
241                    return Err(ApiError::StreamError {
242                        provider: provider.clone(),
243                        details: error.to_string(),
244                    });
245                }
246                StreamChunk::TextDelta(_)
247                | StreamChunk::ThinkingDelta(_)
248                | StreamChunk::ToolUseStart { .. }
249                | StreamChunk::ToolUseInputDelta { .. }
250                | StreamChunk::ContentBlockStop { .. }
251                | StreamChunk::Reset => {}
252            }
253        }
254
255        Err(ApiError::StreamError {
256            provider,
257            details: "Stream ended without completion response".to_string(),
258        })
259    }
260
261    /// Create a new Client with all dependencies injected.
262    /// This is the preferred constructor to avoid internal registry loading.
263    pub fn new_with_deps(
264        config_provider: LlmConfigProvider,
265        provider_registry: Arc<ProviderRegistry>,
266        model_registry: Arc<ModelRegistry>,
267    ) -> Self {
268        Self {
269            provider_map: Arc::new(RwLock::new(HashMap::new())),
270            config_provider,
271            provider_registry,
272            model_registry,
273        }
274    }
275
276    pub fn model_context_window_tokens(&self, model_id: &ModelId) -> Option<u32> {
277        self.model_registry
278            .get(model_id)
279            .and_then(|model| model.context_window_tokens)
280    }
281
282    pub fn model_max_output_tokens(&self, model_id: &ModelId) -> Option<u32> {
283        self.model_registry
284            .get(model_id)
285            .and_then(|model| model.parameters)
286            .and_then(|parameters| parameters.max_output_tokens)
287    }
288
289    #[cfg(any(test, feature = "test-utils"))]
290    pub fn insert_test_provider(&self, provider_id: ProviderId, provider: Arc<dyn Provider>) {
291        match self.provider_map.write() {
292            Ok(mut map) => {
293                map.insert(
294                    provider_id,
295                    ProviderEntry {
296                        provider,
297                        auth_source: AuthSource::None,
298                    },
299                );
300            }
301            Err(_) => {
302                warn!(
303                    target: "api::client",
304                    "Provider cache lock poisoned while inserting test provider"
305                );
306            }
307        }
308    }
309
310    async fn get_or_create_provider_entry(&self, provider_id: ProviderId) -> Result<ProviderEntry> {
311        // First check without holding the lock across await
312        {
313            let map = self.provider_map.read().map_err(|_| {
314                crate::error::Error::Api(ApiError::Configuration(
315                    "Provider cache lock poisoned".to_string(),
316                ))
317            })?;
318            if let Some(entry) = map.get(&provider_id) {
319                return Ok(entry.clone());
320            }
321        }
322
323        // Get the provider config from registry
324        let provider_config = self.provider_registry.get(&provider_id).ok_or_else(|| {
325            crate::error::Error::Api(ApiError::Configuration(format!(
326                "No provider configuration found for {provider_id:?}"
327            )))
328        })?;
329
330        let resolved = self
331            .config_provider
332            .resolve_auth_for_provider(&provider_id)
333            .await?;
334
335        // Now acquire write lock and create provider
336        let mut map = self.provider_map.write().map_err(|_| {
337            crate::error::Error::Api(ApiError::Configuration(
338                "Provider cache lock poisoned".to_string(),
339            ))
340        })?;
341
342        // Check again in case another thread added it
343        if let Some(entry) = map.get(&provider_id) {
344            return Ok(entry.clone());
345        }
346
347        let entry = Self::build_provider_entry(provider_config, &resolved)?;
348
349        map.insert(provider_id, entry.clone());
350        Ok(entry)
351    }
352
353    fn build_provider_entry(
354        provider_config: &crate::config::provider::ProviderConfig,
355        resolved: &ResolvedAuth,
356    ) -> std::result::Result<ProviderEntry, ApiError> {
357        let provider = match resolved {
358            ResolvedAuth::Plugin { directive, .. } => {
359                factory::create_provider_with_directive(provider_config, directive)?
360            }
361            ResolvedAuth::ApiKey { credential, .. } => {
362                factory::create_provider(provider_config, credential)?
363            }
364            ResolvedAuth::None => {
365                return Err(ApiError::Configuration(format!(
366                    "No authentication configured for {:?}",
367                    provider_config.id
368                )));
369            }
370        };
371
372        Ok(ProviderEntry {
373            provider,
374            auth_source: resolved.source(),
375        })
376    }
377
378    async fn fallback_api_key_entry(
379        &self,
380        provider_id: &ProviderId,
381    ) -> std::result::Result<Option<ProviderEntry>, ApiError> {
382        let Some((key, origin)) = self
383            .config_provider
384            .resolve_api_key_for_provider(provider_id)
385            .await?
386        else {
387            return Ok(None);
388        };
389
390        let provider_config = self.provider_registry.get(provider_id).ok_or_else(|| {
391            ApiError::Configuration(format!(
392                "No provider configuration found for {provider_id:?}"
393            ))
394        })?;
395
396        let credential = Credential::ApiKey { value: key };
397        let provider = factory::create_provider(provider_config, &credential)?;
398
399        Ok(Some(ProviderEntry {
400            provider,
401            auth_source: AuthSource::ApiKey { origin },
402        }))
403    }
404
405    /// Complete a prompt with a specific model ID and optional parameters
406    pub async fn complete(
407        &self,
408        model_id: &ModelId,
409        messages: Vec<Message>,
410        system: Option<SystemContext>,
411        tools: Option<Vec<ToolSchema>>,
412        call_options: Option<crate::config::model::ModelParameters>,
413        token: CancellationToken,
414    ) -> std::result::Result<CompletionResponse, ApiError> {
415        debug!(
416            target: "api::complete",
417            ?model_id,
418            "Completing by consuming the streamed endpoint"
419        );
420
421        let stream = self
422            .stream_complete(model_id, messages, system, tools, call_options, token)
423            .await?;
424
425        Self::collect_completion_from_stream(stream, model_id.provider.to_string()).await
426    }
427
428    pub async fn stream_complete(
429        &self,
430        model_id: &ModelId,
431        messages: Vec<Message>,
432        system: Option<SystemContext>,
433        tools: Option<Vec<ToolSchema>>,
434        call_options: Option<crate::config::model::ModelParameters>,
435        token: CancellationToken,
436    ) -> std::result::Result<CompletionStream, ApiError> {
437        let provider_id = model_id.provider.clone();
438        let entry = self
439            .get_or_create_provider_entry(provider_id.clone())
440            .await
441            .map_err(ApiError::from)?;
442        let provider = entry.provider.clone();
443
444        if token.is_cancelled() {
445            return Err(ApiError::Cancelled {
446                provider: provider.name().to_string(),
447            });
448        }
449
450        let model_config = self.model_registry.get(model_id);
451        let effective_params = match (model_config, &call_options) {
452            (Some(config), Some(opts)) => config.effective_parameters(Some(opts)),
453            (Some(config), None) => config.effective_parameters(None),
454            (None, Some(opts)) => Some(*opts),
455            (None, None) => None,
456        };
457
458        debug!(
459            target: "api::stream_complete",
460            ?model_id,
461            ?call_options,
462            ?effective_params,
463            "Streaming with parameters"
464        );
465
466        let (initial_stream, provider_for_retry) = match Self::run_stream_start_with_retry(
467            &provider,
468            model_id,
469            &messages,
470            &system,
471            &tools,
472            effective_params,
473            &token,
474            RETRY_MAX_ATTEMPTS,
475        )
476        .await
477        {
478            Ok(stream) => (stream, provider),
479            Err(err) => {
480                if Self::should_invalidate_provider(&err) {
481                    self.invalidate_provider(&provider_id);
482
483                    if matches!(entry.auth_source, AuthSource::Plugin { .. }) {
484                        if let Some(fallback) = self.fallback_api_key_entry(&provider_id).await? {
485                            let fallback_provider = fallback.provider.clone();
486                            let fallback_stream = Self::run_stream_start_with_retry(
487                                &fallback_provider,
488                                model_id,
489                                &messages,
490                                &system,
491                                &tools,
492                                effective_params,
493                                &token,
494                                RETRY_MAX_ATTEMPTS,
495                            )
496                            .await?;
497                            let mut map = self.provider_map.write().map_err(|_| {
498                                ApiError::Configuration("Provider cache lock poisoned".to_string())
499                            })?;
500                            map.insert(provider_id, fallback);
501                            (fallback_stream, fallback_provider)
502                        } else {
503                            return Err(err);
504                        }
505                    } else {
506                        return Err(err);
507                    }
508                } else {
509                    return Err(err);
510                }
511            }
512        };
513
514        let model_id = model_id.clone();
515        let stream = async_stream::stream! {
516            let mut attempt = 1usize;
517            let mut current_stream = Some(initial_stream);
518
519            'outer: loop {
520                let mut saw_output = false;
521                let mut stream = if let Some(stream) = current_stream.take() { stream } else {
522                    if token.is_cancelled() {
523                        yield StreamChunk::Error(StreamError::Cancelled);
524                        break;
525                    }
526
527                    let stream_result = Self::run_stream_start_with_retry(
528                        &provider_for_retry,
529                        &model_id,
530                        &messages,
531                        &system,
532                        &tools,
533                        effective_params,
534                        &token,
535                        RETRY_MAX_ATTEMPTS,
536                    )
537                    .await;
538                    match stream_result {
539                        Ok(stream) => stream,
540                        Err(err) => {
541                            yield StreamChunk::Error(StreamError::Provider {
542                                provider: provider_for_retry.name().to_string(),
543                                kind: ProviderStreamErrorKind::StreamRetry,
544                                raw_error_type: Some("stream_retry".to_string()),
545                                message: err.to_string(),
546                            });
547                            break;
548                        }
549                    }
550                };
551
552                while let Some(chunk) = stream.next().await {
553                    let retryable_stream_error = match &chunk {
554                        StreamChunk::Error(stream_err) => match stream_err {
555                            StreamError::Cancelled => false,
556                            StreamError::SseParse(
557                                SseParseError::Parser { .. } | SseParseError::Utf8 { .. },
558                            ) => false,
559                            _ => Self::should_retry_stream_error(stream_err),
560                        },
561                        _ => false,
562                    };
563
564                    if retryable_stream_error && attempt < RETRY_MAX_ATTEMPTS {
565                        attempt += 1;
566                        warn!(
567                            target: "api::stream_complete",
568                            ?model_id,
569                            attempt,
570                            max_attempts = RETRY_MAX_ATTEMPTS,
571                            error = ?chunk,
572                            "Retrying stream after transport/provider stream failure"
573                        );
574                        if saw_output {
575                            yield StreamChunk::Reset;
576                        }
577                        current_stream = None;
578                        continue 'outer;
579                    }
580
581                    if !matches!(chunk, StreamChunk::Error(_)) {
582                        saw_output = true;
583                    }
584
585                    yield chunk;
586                }
587
588                break;
589            }
590        };
591
592        Ok(Box::pin(stream))
593    }
594
595    pub async fn complete_with_retry(
596        &self,
597        model_id: &ModelId,
598        messages: &[Message],
599        system_prompt: &Option<SystemContext>,
600        tools: &Option<Vec<ToolSchema>>,
601        token: CancellationToken,
602        max_attempts: usize,
603    ) -> std::result::Result<CompletionResponse, ApiError> {
604        let provider_id = model_id.provider.clone();
605        let entry = self
606            .get_or_create_provider_entry(provider_id.clone())
607            .await
608            .map_err(ApiError::from)?;
609
610        let model_config = self.model_registry.get(model_id);
611        debug!(
612            target: "api::complete_with_retry",
613            ?model_id,
614            ?model_config,
615            "Model config"
616        );
617        let effective_params = model_config.and_then(|cfg| cfg.effective_parameters(None));
618
619        debug!(
620            target: "api::complete_with_retry",
621            ?model_id,
622            ?effective_params,
623            "system: {:?}",
624            system_prompt
625        );
626        debug!(
627            target: "api::complete_with_retry",
628            ?model_id,
629            "messages: {:?}",
630            messages
631        );
632
633        let result = Self::run_complete_with_retry(
634            &entry.provider,
635            model_id,
636            messages,
637            system_prompt,
638            tools,
639            effective_params,
640            &token,
641            max_attempts,
642        )
643        .await;
644
645        if let Err(ref error) = result
646            && Self::should_invalidate_provider(error)
647        {
648            self.invalidate_provider(&provider_id);
649            if matches!(entry.auth_source, AuthSource::Plugin { .. })
650                && let Some(fallback) = self.fallback_api_key_entry(&provider_id).await?
651            {
652                let fallback_result = Self::run_complete_with_retry(
653                    &fallback.provider,
654                    model_id,
655                    messages,
656                    system_prompt,
657                    tools,
658                    effective_params,
659                    &token,
660                    max_attempts,
661                )
662                .await;
663                if fallback_result.is_ok() {
664                    let mut map = self.provider_map.write().map_err(|_| {
665                        ApiError::Configuration("Provider cache lock poisoned".to_string())
666                    })?;
667                    map.insert(provider_id, fallback);
668                }
669                return fallback_result;
670            }
671        }
672
673        result
674    }
675}
676
677#[cfg(test)]
678mod tests {
679    use super::*;
680    use crate::app::conversation::AssistantContent;
681    use crate::auth::ApiKeyOrigin;
682    use crate::config::provider::ProviderId;
683    use async_trait::async_trait;
684    use futures::StreamExt;
685    use std::sync::atomic::{AtomicUsize, Ordering};
686    use tokio_util::sync::CancellationToken;
687
688    #[derive(Clone, Copy)]
689    enum StubErrorKind {
690        Auth,
691        Server401,
692    }
693
694    #[derive(Clone)]
695    struct StubProvider {
696        error_kind: StubErrorKind,
697    }
698
699    impl StubProvider {
700        fn new(error_kind: StubErrorKind) -> Self {
701            Self { error_kind }
702        }
703    }
704
705    #[async_trait]
706    impl Provider for StubProvider {
707        fn name(&self) -> &'static str {
708            "stub"
709        }
710
711        async fn complete(
712            &self,
713            _model_id: &ModelId,
714            _messages: Vec<Message>,
715            _system: Option<SystemContext>,
716            _tools: Option<Vec<ToolSchema>>,
717            _call_options: Option<crate::config::model::ModelParameters>,
718            _token: CancellationToken,
719        ) -> std::result::Result<CompletionResponse, ApiError> {
720            let err = match self.error_kind {
721                StubErrorKind::Auth => ApiError::AuthenticationFailed {
722                    provider: "stub".to_string(),
723                    details: "bad key".to_string(),
724                },
725                StubErrorKind::Server401 => ApiError::ServerError {
726                    provider: "stub".to_string(),
727                    status_code: 401,
728                    details: "unauthorized".to_string(),
729                },
730            };
731            Err(err)
732        }
733    }
734
735    #[derive(Clone)]
736    struct FlakyCompleteProvider {
737        failures_before_success: usize,
738        attempts: Arc<AtomicUsize>,
739    }
740
741    impl FlakyCompleteProvider {
742        fn new(failures_before_success: usize, attempts: Arc<AtomicUsize>) -> Self {
743            Self {
744                failures_before_success,
745                attempts,
746            }
747        }
748    }
749
750    #[async_trait]
751    impl Provider for FlakyCompleteProvider {
752        fn name(&self) -> &'static str {
753            "flaky-complete"
754        }
755
756        async fn complete(
757            &self,
758            _model_id: &ModelId,
759            _messages: Vec<Message>,
760            _system: Option<SystemContext>,
761            _tools: Option<Vec<ToolSchema>>,
762            _call_options: Option<crate::config::model::ModelParameters>,
763            _token: CancellationToken,
764        ) -> std::result::Result<CompletionResponse, ApiError> {
765            Ok(success_response())
766        }
767
768        async fn stream_complete(
769            &self,
770            _model_id: &ModelId,
771            _messages: Vec<Message>,
772            _system: Option<SystemContext>,
773            _tools: Option<Vec<ToolSchema>>,
774            _call_options: Option<crate::config::model::ModelParameters>,
775            _token: CancellationToken,
776        ) -> std::result::Result<CompletionStream, ApiError> {
777            let attempt = self.attempts.fetch_add(1, Ordering::Relaxed) + 1;
778            if attempt <= self.failures_before_success {
779                return Err(network_api_error());
780            }
781            let response = success_response();
782            Ok(Box::pin(futures_util::stream::once(async move {
783                StreamChunk::MessageComplete(response)
784            })))
785        }
786    }
787
788    #[derive(Clone)]
789    struct FlakyStreamStartProvider {
790        failures_before_success: usize,
791        attempts: Arc<AtomicUsize>,
792    }
793
794    impl FlakyStreamStartProvider {
795        fn new(failures_before_success: usize, attempts: Arc<AtomicUsize>) -> Self {
796            Self {
797                failures_before_success,
798                attempts,
799            }
800        }
801    }
802
803    #[async_trait]
804    impl Provider for FlakyStreamStartProvider {
805        fn name(&self) -> &'static str {
806            "flaky-stream-start"
807        }
808
809        async fn complete(
810            &self,
811            _model_id: &ModelId,
812            _messages: Vec<Message>,
813            _system: Option<SystemContext>,
814            _tools: Option<Vec<ToolSchema>>,
815            _call_options: Option<crate::config::model::ModelParameters>,
816            _token: CancellationToken,
817        ) -> std::result::Result<CompletionResponse, ApiError> {
818            Ok(success_response())
819        }
820
821        async fn stream_complete(
822            &self,
823            _model_id: &ModelId,
824            _messages: Vec<Message>,
825            _system: Option<SystemContext>,
826            _tools: Option<Vec<ToolSchema>>,
827            _call_options: Option<crate::config::model::ModelParameters>,
828            _token: CancellationToken,
829        ) -> std::result::Result<CompletionStream, ApiError> {
830            let attempt = self.attempts.fetch_add(1, Ordering::Relaxed) + 1;
831            if attempt <= self.failures_before_success {
832                return Err(network_api_error());
833            }
834
835            let response = success_response();
836            Ok(Box::pin(futures_util::stream::once(async move {
837                StreamChunk::MessageComplete(response)
838            })))
839        }
840    }
841
842    #[derive(Clone)]
843    struct InvalidRequestProvider {
844        attempts: Arc<AtomicUsize>,
845    }
846
847    impl InvalidRequestProvider {
848        fn new(attempts: Arc<AtomicUsize>) -> Self {
849            Self { attempts }
850        }
851    }
852
853    #[async_trait]
854    impl Provider for InvalidRequestProvider {
855        fn name(&self) -> &'static str {
856            "invalid-request"
857        }
858
859        async fn complete(
860            &self,
861            _model_id: &ModelId,
862            _messages: Vec<Message>,
863            _system: Option<SystemContext>,
864            _tools: Option<Vec<ToolSchema>>,
865            _call_options: Option<crate::config::model::ModelParameters>,
866            _token: CancellationToken,
867        ) -> std::result::Result<CompletionResponse, ApiError> {
868            Err(ApiError::InvalidRequest {
869                provider: "stub".to_string(),
870                details: "bad request".to_string(),
871            })
872        }
873
874        async fn stream_complete(
875            &self,
876            _model_id: &ModelId,
877            _messages: Vec<Message>,
878            _system: Option<SystemContext>,
879            _tools: Option<Vec<ToolSchema>>,
880            _call_options: Option<crate::config::model::ModelParameters>,
881            _token: CancellationToken,
882        ) -> std::result::Result<CompletionStream, ApiError> {
883            self.attempts.fetch_add(1, Ordering::Relaxed);
884            Err(ApiError::InvalidRequest {
885                provider: "stub".to_string(),
886                details: "bad request".to_string(),
887            })
888        }
889    }
890
891    fn success_response() -> CompletionResponse {
892        CompletionResponse::new(vec![AssistantContent::Text {
893            text: "ok".to_string(),
894        }])
895    }
896
897    #[derive(Clone)]
898    struct StreamWithoutCompletionProvider;
899
900    #[async_trait]
901    impl Provider for StreamWithoutCompletionProvider {
902        fn name(&self) -> &'static str {
903            "stream-without-completion"
904        }
905
906        async fn complete(
907            &self,
908            _model_id: &ModelId,
909            _messages: Vec<Message>,
910            _system: Option<SystemContext>,
911            _tools: Option<Vec<ToolSchema>>,
912            _call_options: Option<crate::config::model::ModelParameters>,
913            _token: CancellationToken,
914        ) -> std::result::Result<CompletionResponse, ApiError> {
915            Ok(success_response())
916        }
917
918        async fn stream_complete(
919            &self,
920            _model_id: &ModelId,
921            _messages: Vec<Message>,
922            _system: Option<SystemContext>,
923            _tools: Option<Vec<ToolSchema>>,
924            _call_options: Option<crate::config::model::ModelParameters>,
925            _token: CancellationToken,
926        ) -> std::result::Result<CompletionStream, ApiError> {
927            Ok(Box::pin(futures_util::stream::iter(vec![
928                StreamChunk::TextDelta("partial".to_string()),
929            ])))
930        }
931    }
932
933    #[derive(Clone)]
934    struct StreamCancelledProvider;
935
936    #[async_trait]
937    impl Provider for StreamCancelledProvider {
938        fn name(&self) -> &'static str {
939            "stream-cancelled"
940        }
941
942        async fn complete(
943            &self,
944            _model_id: &ModelId,
945            _messages: Vec<Message>,
946            _system: Option<SystemContext>,
947            _tools: Option<Vec<ToolSchema>>,
948            _call_options: Option<crate::config::model::ModelParameters>,
949            _token: CancellationToken,
950        ) -> std::result::Result<CompletionResponse, ApiError> {
951            Ok(success_response())
952        }
953
954        async fn stream_complete(
955            &self,
956            _model_id: &ModelId,
957            _messages: Vec<Message>,
958            _system: Option<SystemContext>,
959            _tools: Option<Vec<ToolSchema>>,
960            _call_options: Option<crate::config::model::ModelParameters>,
961            _token: CancellationToken,
962        ) -> std::result::Result<CompletionStream, ApiError> {
963            Ok(Box::pin(futures_util::stream::iter(vec![
964                StreamChunk::Error(StreamError::Cancelled),
965            ])))
966        }
967    }
968
969    #[derive(Clone)]
970    struct StreamProviderErrorProvider;
971
972    #[async_trait]
973    impl Provider for StreamProviderErrorProvider {
974        fn name(&self) -> &'static str {
975            "stream-provider-error"
976        }
977
978        async fn complete(
979            &self,
980            _model_id: &ModelId,
981            _messages: Vec<Message>,
982            _system: Option<SystemContext>,
983            _tools: Option<Vec<ToolSchema>>,
984            _call_options: Option<crate::config::model::ModelParameters>,
985            _token: CancellationToken,
986        ) -> std::result::Result<CompletionResponse, ApiError> {
987            Ok(success_response())
988        }
989
990        async fn stream_complete(
991            &self,
992            _model_id: &ModelId,
993            _messages: Vec<Message>,
994            _system: Option<SystemContext>,
995            _tools: Option<Vec<ToolSchema>>,
996            _call_options: Option<crate::config::model::ModelParameters>,
997            _token: CancellationToken,
998        ) -> std::result::Result<CompletionStream, ApiError> {
999            Ok(Box::pin(futures_util::stream::iter(vec![
1000                StreamChunk::Error(StreamError::Provider {
1001                    provider: "stub".to_string(),
1002                    kind: ProviderStreamErrorKind::StreamError,
1003                    raw_error_type: Some("stream_error".to_string()),
1004                    message: "upstream failed".to_string(),
1005                }),
1006            ])))
1007        }
1008    }
1009
1010    fn network_api_error() -> ApiError {
1011        let err = reqwest::Client::new()
1012            .get("http://[::1")
1013            .build()
1014            .expect_err("invalid URL should fail");
1015        ApiError::Network(err)
1016    }
1017
1018    fn test_client() -> Client {
1019        let auth_storage = Arc::new(crate::test_utils::InMemoryAuthStorage::new());
1020        let config_provider = LlmConfigProvider::new(auth_storage).unwrap();
1021        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
1022        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
1023
1024        Client::new_with_deps(config_provider, provider_registry, model_registry)
1025    }
1026
1027    fn insert_provider(client: &Client, provider_id: ProviderId, provider: Arc<dyn Provider>) {
1028        client.provider_map.write().unwrap().insert(
1029            provider_id,
1030            ProviderEntry {
1031                provider,
1032                auth_source: AuthSource::ApiKey {
1033                    origin: ApiKeyOrigin::Stored,
1034                },
1035            },
1036        );
1037    }
1038
1039    fn insert_stub_provider(client: &Client, provider_id: ProviderId, error: StubErrorKind) {
1040        insert_provider(client, provider_id, Arc::new(StubProvider::new(error)));
1041    }
1042
1043    #[tokio::test]
1044    async fn invalidates_cached_provider_on_auth_failure() {
1045        let client = test_client();
1046        let provider_id = ProviderId("stub-auth".to_string());
1047        let model_id = ModelId::new(provider_id.clone(), "stub-model");
1048
1049        insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Auth);
1050
1051        let err = client
1052            .complete(
1053                &model_id,
1054                vec![],
1055                None,
1056                None,
1057                None,
1058                CancellationToken::new(),
1059            )
1060            .await
1061            .unwrap_err();
1062
1063        assert!(matches!(err, ApiError::AuthenticationFailed { .. }));
1064        assert!(
1065            !client
1066                .provider_map
1067                .read()
1068                .unwrap()
1069                .contains_key(&provider_id)
1070        );
1071    }
1072
1073    #[tokio::test]
1074    async fn invalidates_cached_provider_on_unauthorized_status_code() {
1075        let client = test_client();
1076        let provider_id = ProviderId("stub-unauthorized".to_string());
1077        let model_id = ModelId::new(provider_id.clone(), "stub-model");
1078
1079        insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Server401);
1080
1081        let err = client
1082            .complete(
1083                &model_id,
1084                vec![],
1085                None,
1086                None,
1087                None,
1088                CancellationToken::new(),
1089            )
1090            .await
1091            .unwrap_err();
1092
1093        assert!(matches!(
1094            err,
1095            ApiError::ServerError {
1096                status_code: 401,
1097                ..
1098            }
1099        ));
1100        assert!(
1101            !client
1102                .provider_map
1103                .read()
1104                .unwrap()
1105                .contains_key(&provider_id)
1106        );
1107    }
1108
1109    #[tokio::test]
1110    async fn retries_network_errors_for_complete() {
1111        let client = test_client();
1112        let provider_id = ProviderId("flaky-complete".to_string());
1113        let model_id = ModelId::new(provider_id.clone(), "stub-model");
1114        let attempts = Arc::new(AtomicUsize::new(0));
1115
1116        insert_provider(
1117            &client,
1118            provider_id,
1119            Arc::new(FlakyCompleteProvider::new(2, attempts.clone())),
1120        );
1121
1122        let response = client
1123            .complete(
1124                &model_id,
1125                vec![],
1126                None,
1127                None,
1128                None,
1129                CancellationToken::new(),
1130            )
1131            .await
1132            .expect("complete should retry transient network failures");
1133
1134        assert_eq!(response.extract_text(), "ok");
1135        assert_eq!(attempts.load(Ordering::Relaxed), 3);
1136    }
1137
1138    #[tokio::test]
1139    async fn complete_errors_when_stream_ends_without_message_complete() {
1140        let client = test_client();
1141        let provider_id = ProviderId("stream-without-completion".to_string());
1142        let model_id = ModelId::new(provider_id.clone(), "stub-model");
1143
1144        insert_provider(
1145            &client,
1146            provider_id,
1147            Arc::new(StreamWithoutCompletionProvider),
1148        );
1149
1150        let err = client
1151            .complete(
1152                &model_id,
1153                vec![],
1154                None,
1155                None,
1156                None,
1157                CancellationToken::new(),
1158            )
1159            .await
1160            .unwrap_err();
1161
1162        assert!(matches!(
1163            err,
1164            ApiError::StreamError {
1165                ref provider,
1166                ref details,
1167            } if provider == "stream-without-completion" && details == "Stream ended without completion response"
1168        ));
1169    }
1170
1171    #[tokio::test]
1172    async fn complete_maps_stream_cancelled_to_cancelled_api_error() {
1173        let client = test_client();
1174        let provider_id = ProviderId("stream-cancelled".to_string());
1175        let model_id = ModelId::new(provider_id.clone(), "stub-model");
1176
1177        insert_provider(&client, provider_id, Arc::new(StreamCancelledProvider));
1178
1179        let err = client
1180            .complete(
1181                &model_id,
1182                vec![],
1183                None,
1184                None,
1185                None,
1186                CancellationToken::new(),
1187            )
1188            .await
1189            .unwrap_err();
1190
1191        assert!(matches!(
1192            err,
1193            ApiError::Cancelled { ref provider } if provider == "stream-cancelled"
1194        ));
1195    }
1196
1197    #[tokio::test]
1198    async fn complete_maps_stream_provider_error_to_stream_api_error() {
1199        let client = test_client();
1200        let provider_id = ProviderId("stream-provider-error".to_string());
1201        let model_id = ModelId::new(provider_id.clone(), "stub-model");
1202
1203        insert_provider(&client, provider_id, Arc::new(StreamProviderErrorProvider));
1204
1205        let err = client
1206            .complete(
1207                &model_id,
1208                vec![],
1209                None,
1210                None,
1211                None,
1212                CancellationToken::new(),
1213            )
1214            .await
1215            .unwrap_err();
1216
1217        assert!(matches!(
1218            err,
1219            ApiError::StreamError {
1220                ref provider,
1221                ref details,
1222            } if provider == "stream-provider-error" && details.contains("upstream failed")
1223        ));
1224    }
1225
1226    #[tokio::test]
1227    async fn does_not_retry_non_retryable_complete_error() {
1228        let client = test_client();
1229        let provider_id = ProviderId("invalid-request".to_string());
1230        let model_id = ModelId::new(provider_id.clone(), "stub-model");
1231        let attempts = Arc::new(AtomicUsize::new(0));
1232
1233        insert_provider(
1234            &client,
1235            provider_id,
1236            Arc::new(InvalidRequestProvider::new(attempts.clone())),
1237        );
1238
1239        let err = client
1240            .complete(
1241                &model_id,
1242                vec![],
1243                None,
1244                None,
1245                None,
1246                CancellationToken::new(),
1247            )
1248            .await
1249            .unwrap_err();
1250
1251        assert!(matches!(err, ApiError::InvalidRequest { .. }));
1252        assert_eq!(attempts.load(Ordering::Relaxed), 1);
1253    }
1254
1255    #[tokio::test]
1256    async fn retries_network_errors_when_starting_stream() {
1257        let client = test_client();
1258        let provider_id = ProviderId("flaky-stream-start".to_string());
1259        let model_id = ModelId::new(provider_id.clone(), "stub-model");
1260        let attempts = Arc::new(AtomicUsize::new(0));
1261
1262        insert_provider(
1263            &client,
1264            provider_id,
1265            Arc::new(FlakyStreamStartProvider::new(2, attempts.clone())),
1266        );
1267
1268        let mut stream = client
1269            .stream_complete(
1270                &model_id,
1271                vec![],
1272                None,
1273                None,
1274                None,
1275                CancellationToken::new(),
1276            )
1277            .await
1278            .expect("stream start should retry transient network failures");
1279
1280        let chunk = stream.next().await.expect("stream should yield completion");
1281        match chunk {
1282            StreamChunk::MessageComplete(response) => assert_eq!(response.extract_text(), "ok"),
1283            other => panic!("unexpected stream chunk: {other:?}"),
1284        }
1285
1286        assert_eq!(attempts.load(Ordering::Relaxed), 3);
1287    }
1288}