Skip to main content

steer_core/api/
mod.rs

1pub mod claude;
2pub mod error;
3pub mod factory;
4pub mod gemini;
5pub mod openai;
6pub mod provider;
7pub mod sse;
8pub mod util;
9pub mod xai;
10
11use crate::auth::storage::Credential;
12use crate::auth::{AuthSource, ProviderRegistry};
13use crate::config::model::ModelId;
14use crate::config::provider::ProviderId;
15use crate::config::{LlmConfigProvider, ResolvedAuth};
16use crate::error::Result;
17use crate::model_registry::ModelRegistry;
18pub use error::{ApiError, SseParseError, StreamError};
19pub use factory::{create_provider, create_provider_with_directive};
20use futures::StreamExt;
21pub use provider::{CompletionResponse, CompletionStream, Provider, StreamChunk};
22use std::collections::HashMap;
23use std::sync::Arc;
24use std::sync::RwLock;
25use steer_tools::ToolSchema;
26use tokio_util::sync::CancellationToken;
27use tracing::debug;
28use tracing::warn;
29
30use crate::app::SystemContext;
31use crate::app::conversation::Message;
32
33const STREAM_TRANSPORT_RETRY_MAX_ATTEMPTS: usize = 2;
34
35#[derive(Clone)]
36pub struct Client {
37    provider_map: Arc<RwLock<HashMap<ProviderId, ProviderEntry>>>,
38    config_provider: LlmConfigProvider,
39    provider_registry: Arc<ProviderRegistry>,
40    model_registry: Arc<ModelRegistry>,
41}
42
43#[derive(Clone)]
44struct ProviderEntry {
45    provider: Arc<dyn Provider>,
46    auth_source: AuthSource,
47}
48
49impl Client {
50    /// Remove a cached provider so that future calls re-create it with fresh credentials.
51    fn invalidate_provider(&self, provider_id: &ProviderId) {
52        let Ok(mut map) = self.provider_map.write() else {
53            warn!(
54                target: "api::client",
55                "Provider cache lock poisoned while invalidating provider"
56            );
57            return;
58        };
59        map.remove(provider_id);
60    }
61
62    /// Determine if an API error should invalidate the cached provider (typically auth failures).
63    fn should_invalidate_provider(error: &ApiError) -> bool {
64        matches!(
65            error,
66            ApiError::AuthenticationFailed { .. } | ApiError::AuthError(_)
67        ) || matches!(
68            error,
69            ApiError::ServerError { status_code, .. } if matches!(status_code, 401 | 403)
70        )
71    }
72
73    /// Create a new Client with all dependencies injected.
74    /// This is the preferred constructor to avoid internal registry loading.
75    pub fn new_with_deps(
76        config_provider: LlmConfigProvider,
77        provider_registry: Arc<ProviderRegistry>,
78        model_registry: Arc<ModelRegistry>,
79    ) -> Self {
80        Self {
81            provider_map: Arc::new(RwLock::new(HashMap::new())),
82            config_provider,
83            provider_registry,
84            model_registry,
85        }
86    }
87
88    #[cfg(any(test, feature = "test-utils"))]
89    pub fn insert_test_provider(&self, provider_id: ProviderId, provider: Arc<dyn Provider>) {
90        match self.provider_map.write() {
91            Ok(mut map) => {
92                map.insert(
93                    provider_id,
94                    ProviderEntry {
95                        provider,
96                        auth_source: AuthSource::None,
97                    },
98                );
99            }
100            Err(_) => {
101                warn!(
102                    target: "api::client",
103                    "Provider cache lock poisoned while inserting test provider"
104                );
105            }
106        }
107    }
108
109    async fn get_or_create_provider_entry(&self, provider_id: ProviderId) -> Result<ProviderEntry> {
110        // First check without holding the lock across await
111        {
112            let map = self.provider_map.read().map_err(|_| {
113                crate::error::Error::Api(ApiError::Configuration(
114                    "Provider cache lock poisoned".to_string(),
115                ))
116            })?;
117            if let Some(entry) = map.get(&provider_id) {
118                return Ok(entry.clone());
119            }
120        }
121
122        // Get the provider config from registry
123        let provider_config = self.provider_registry.get(&provider_id).ok_or_else(|| {
124            crate::error::Error::Api(ApiError::Configuration(format!(
125                "No provider configuration found for {provider_id:?}"
126            )))
127        })?;
128
129        let resolved = self
130            .config_provider
131            .resolve_auth_for_provider(&provider_id)
132            .await?;
133
134        // Now acquire write lock and create provider
135        let mut map = self.provider_map.write().map_err(|_| {
136            crate::error::Error::Api(ApiError::Configuration(
137                "Provider cache lock poisoned".to_string(),
138            ))
139        })?;
140
141        // Check again in case another thread added it
142        if let Some(entry) = map.get(&provider_id) {
143            return Ok(entry.clone());
144        }
145
146        let entry = Self::build_provider_entry(provider_config, &resolved)?;
147
148        map.insert(provider_id, entry.clone());
149        Ok(entry)
150    }
151
152    fn build_provider_entry(
153        provider_config: &crate::config::provider::ProviderConfig,
154        resolved: &ResolvedAuth,
155    ) -> std::result::Result<ProviderEntry, ApiError> {
156        let provider = match resolved {
157            ResolvedAuth::Plugin { directive, .. } => {
158                factory::create_provider_with_directive(provider_config, directive)?
159            }
160            ResolvedAuth::ApiKey { credential, .. } => {
161                factory::create_provider(provider_config, credential)?
162            }
163            ResolvedAuth::None => {
164                return Err(ApiError::Configuration(format!(
165                    "No authentication configured for {:?}",
166                    provider_config.id
167                )));
168            }
169        };
170
171        Ok(ProviderEntry {
172            provider,
173            auth_source: resolved.source(),
174        })
175    }
176
177    async fn fallback_api_key_entry(
178        &self,
179        provider_id: &ProviderId,
180    ) -> std::result::Result<Option<ProviderEntry>, ApiError> {
181        let Some((key, origin)) = self
182            .config_provider
183            .resolve_api_key_for_provider(provider_id)
184            .await?
185        else {
186            return Ok(None);
187        };
188
189        let provider_config = self.provider_registry.get(provider_id).ok_or_else(|| {
190            ApiError::Configuration(format!(
191                "No provider configuration found for {provider_id:?}"
192            ))
193        })?;
194
195        let credential = Credential::ApiKey { value: key };
196        let provider = factory::create_provider(provider_config, &credential)?;
197
198        Ok(Some(ProviderEntry {
199            provider,
200            auth_source: AuthSource::ApiKey { origin },
201        }))
202    }
203
204    /// Complete a prompt with a specific model ID and optional parameters
205    pub async fn complete(
206        &self,
207        model_id: &ModelId,
208        messages: Vec<Message>,
209        system: Option<SystemContext>,
210        tools: Option<Vec<ToolSchema>>,
211        call_options: Option<crate::config::model::ModelParameters>,
212        token: CancellationToken,
213    ) -> std::result::Result<CompletionResponse, ApiError> {
214        // Get provider from model ID
215        let provider_id = model_id.provider.clone();
216        let entry = self
217            .get_or_create_provider_entry(provider_id.clone())
218            .await
219            .map_err(ApiError::from)?;
220        let provider = entry.provider.clone();
221
222        if token.is_cancelled() {
223            return Err(ApiError::Cancelled {
224                provider: provider.name().to_string(),
225            });
226        }
227
228        // Get model config and merge parameters
229        let model_config = self.model_registry.get(model_id);
230        let effective_params = match (model_config, &call_options) {
231            (Some(config), Some(opts)) => config.effective_parameters(Some(opts)),
232            (Some(config), None) => config.effective_parameters(None),
233            (None, Some(opts)) => Some(*opts),
234            (None, None) => None,
235        };
236
237        debug!(
238            target: "api::complete",
239            ?model_id,
240            ?call_options,
241            ?effective_params,
242            "Final parameters for model"
243        );
244
245        let result = provider
246            .complete(
247                model_id,
248                messages.clone(),
249                system.clone(),
250                tools.clone(),
251                effective_params,
252                token.clone(),
253            )
254            .await;
255
256        if let Err(ref err) = result
257            && Self::should_invalidate_provider(err)
258        {
259            self.invalidate_provider(&provider_id);
260
261            if matches!(entry.auth_source, AuthSource::Plugin { .. })
262                && let Some(fallback) = self.fallback_api_key_entry(&provider_id).await?
263            {
264                let fallback_result = fallback
265                    .provider
266                    .complete(model_id, messages, system, tools, effective_params, token)
267                    .await;
268                if fallback_result.is_ok() {
269                    let mut map = self.provider_map.write().map_err(|_| {
270                        ApiError::Configuration("Provider cache lock poisoned".to_string())
271                    })?;
272                    map.insert(provider_id, fallback);
273                }
274            }
275        }
276
277        result
278    }
279
280    pub async fn stream_complete(
281        &self,
282        model_id: &ModelId,
283        messages: Vec<Message>,
284        system: Option<SystemContext>,
285        tools: Option<Vec<ToolSchema>>,
286        call_options: Option<crate::config::model::ModelParameters>,
287        token: CancellationToken,
288    ) -> std::result::Result<CompletionStream, ApiError> {
289        let provider_id = model_id.provider.clone();
290        let entry = self
291            .get_or_create_provider_entry(provider_id.clone())
292            .await
293            .map_err(ApiError::from)?;
294        let provider = entry.provider.clone();
295
296        if token.is_cancelled() {
297            return Err(ApiError::Cancelled {
298                provider: provider.name().to_string(),
299            });
300        }
301
302        let model_config = self.model_registry.get(model_id);
303        let effective_params = match (model_config, &call_options) {
304            (Some(config), Some(opts)) => config.effective_parameters(Some(opts)),
305            (Some(config), None) => config.effective_parameters(None),
306            (None, Some(opts)) => Some(*opts),
307            (None, None) => None,
308        };
309
310        debug!(
311            target: "api::stream_complete",
312            ?model_id,
313            ?call_options,
314            ?effective_params,
315            "Streaming with parameters"
316        );
317
318        let (initial_stream, provider_for_retry) = match provider
319            .stream_complete(
320                model_id,
321                messages.clone(),
322                system.clone(),
323                tools.clone(),
324                effective_params,
325                token.clone(),
326            )
327            .await
328        {
329            Ok(stream) => (stream, provider),
330            Err(err) => {
331                if Self::should_invalidate_provider(&err) {
332                    self.invalidate_provider(&provider_id);
333
334                    if matches!(entry.auth_source, AuthSource::Plugin { .. }) {
335                        if let Some(fallback) = self.fallback_api_key_entry(&provider_id).await? {
336                            let fallback_provider = fallback.provider.clone();
337                            let fallback_stream = fallback_provider
338                                .stream_complete(
339                                    model_id,
340                                    messages.clone(),
341                                    system.clone(),
342                                    tools.clone(),
343                                    effective_params,
344                                    token.clone(),
345                                )
346                                .await?;
347                            let mut map = self.provider_map.write().map_err(|_| {
348                                ApiError::Configuration("Provider cache lock poisoned".to_string())
349                            })?;
350                            map.insert(provider_id, fallback);
351                            (fallback_stream, fallback_provider)
352                        } else {
353                            return Err(err);
354                        }
355                    } else {
356                        return Err(err);
357                    }
358                } else {
359                    return Err(err);
360                }
361            }
362        };
363
364        let model_id = model_id.clone();
365        let stream = async_stream::stream! {
366            let mut attempt = 1usize;
367            let mut current_stream = Some(initial_stream);
368
369            'outer: loop {
370                let mut saw_output = false;
371                let mut stream = if let Some(stream) = current_stream.take() { stream } else {
372                    if token.is_cancelled() {
373                        yield StreamChunk::Error(StreamError::Cancelled);
374                        break;
375                    }
376
377                    let stream_result = provider_for_retry
378                        .stream_complete(
379                            &model_id,
380                            messages.clone(),
381                            system.clone(),
382                            tools.clone(),
383                            effective_params,
384                            token.clone(),
385                        )
386                        .await;
387                    match stream_result {
388                        Ok(stream) => stream,
389                        Err(err) => {
390                            yield StreamChunk::Error(StreamError::Provider {
391                                provider: provider_for_retry.name().to_string(),
392                                error_type: "stream_retry".to_string(),
393                                message: err.to_string(),
394                            });
395                            break;
396                        }
397                    }
398                };
399
400                while let Some(chunk) = stream.next().await {
401                    if matches!(
402                        &chunk,
403                        StreamChunk::Error(StreamError::SseParse(
404                            SseParseError::Transport { .. }
405                        )) if !saw_output && attempt < STREAM_TRANSPORT_RETRY_MAX_ATTEMPTS
406                    ) {
407                        attempt += 1;
408                        warn!(
409                            target: "api::stream_complete",
410                            ?model_id,
411                            attempt,
412                            max_attempts = STREAM_TRANSPORT_RETRY_MAX_ATTEMPTS,
413                            "Retrying stream after transport error before any output"
414                        );
415                        current_stream = None;
416                        continue 'outer;
417                    }
418
419                    if !matches!(chunk, StreamChunk::Error(_)) {
420                        saw_output = true;
421                    }
422
423                    yield chunk;
424                }
425
426                break;
427            }
428        };
429
430        Ok(Box::pin(stream))
431    }
432
433    pub async fn complete_with_retry(
434        &self,
435        model_id: &ModelId,
436        messages: &[Message],
437        system_prompt: &Option<SystemContext>,
438        tools: &Option<Vec<ToolSchema>>,
439        token: CancellationToken,
440        max_attempts: usize,
441    ) -> std::result::Result<CompletionResponse, ApiError> {
442        let mut attempts = 0;
443
444        // Prepare provider and parameters once
445        let provider_id = model_id.provider.clone();
446        let entry = self
447            .get_or_create_provider_entry(provider_id.clone())
448            .await
449            .map_err(ApiError::from)?;
450        let provider = entry.provider.clone();
451
452        let model_config = self.model_registry.get(model_id);
453        debug!(
454            target: "api::complete_with_retry",
455            ?model_id,
456            ?model_config,
457            "Model config"
458        );
459        let effective_params = model_config.and_then(|cfg| cfg.effective_parameters(None));
460
461        debug!(
462            target: "api::complete_with_retry",
463            ?model_id,
464            ?effective_params,
465            "system: {:?}",
466            system_prompt
467        );
468        debug!(
469            target: "api::complete_with_retry",
470            ?model_id,
471            "messages: {:?}",
472            messages
473        );
474
475        loop {
476            if token.is_cancelled() {
477                return Err(ApiError::Cancelled {
478                    provider: provider.name().to_string(),
479                });
480            }
481
482            match provider
483                .complete(
484                    model_id,
485                    messages.to_vec(),
486                    system_prompt.clone(),
487                    tools.clone(),
488                    effective_params,
489                    token.clone(),
490                )
491                .await
492            {
493                Ok(response) => {
494                    return Ok(response);
495                }
496                Err(error) => {
497                    attempts += 1;
498                    warn!(
499                        "API completion attempt {}/{} failed for model {:?}: {:?}",
500                        attempts, max_attempts, model_id, error
501                    );
502
503                    if Self::should_invalidate_provider(&error) {
504                        self.invalidate_provider(&provider_id);
505                        if matches!(entry.auth_source, AuthSource::Plugin { .. })
506                            && let Some(fallback) =
507                                self.fallback_api_key_entry(&provider_id).await?
508                        {
509                            let fallback_result = fallback
510                                .provider
511                                .complete(
512                                    model_id,
513                                    messages.to_vec(),
514                                    system_prompt.clone(),
515                                    tools.clone(),
516                                    effective_params,
517                                    token.clone(),
518                                )
519                                .await;
520                            if fallback_result.is_ok() {
521                                let mut map = self.provider_map.write().map_err(|_| {
522                                    ApiError::Configuration(
523                                        "Provider cache lock poisoned".to_string(),
524                                    )
525                                })?;
526                                map.insert(provider_id.clone(), fallback);
527                            }
528                        }
529                        return Err(error);
530                    }
531
532                    if attempts >= max_attempts {
533                        return Err(error);
534                    }
535
536                    match error {
537                        ApiError::RateLimited { provider, details } => {
538                            let sleep_duration =
539                                std::time::Duration::from_secs(1 << (attempts - 1));
540                            warn!(
541                                "Rate limited by API: {} {} (retrying in {} seconds)",
542                                provider,
543                                details,
544                                sleep_duration.as_secs()
545                            );
546                            tokio::time::sleep(sleep_duration).await;
547                        }
548                        ApiError::NoChoices { provider } => {
549                            warn!("No choices returned from API: {}", provider);
550                        }
551                        ApiError::ServerError {
552                            provider,
553                            status_code,
554                            details,
555                        } => {
556                            warn!(
557                                "Server error for API: {} {} {}",
558                                provider, status_code, details
559                            );
560                        }
561                        _ => {
562                            // Not retryable
563                            return Err(error);
564                        }
565                    }
566                }
567            }
568        }
569    }
570}
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575    use crate::auth::ApiKeyOrigin;
576    use crate::config::provider::ProviderId;
577    use async_trait::async_trait;
578    use tokio_util::sync::CancellationToken;
579
580    #[derive(Clone, Copy)]
581    enum StubErrorKind {
582        Auth,
583        Server401,
584    }
585
586    #[derive(Clone)]
587    struct StubProvider {
588        error_kind: StubErrorKind,
589    }
590
591    impl StubProvider {
592        fn new(error_kind: StubErrorKind) -> Self {
593            Self { error_kind }
594        }
595    }
596
597    #[async_trait]
598    impl Provider for StubProvider {
599        fn name(&self) -> &'static str {
600            "stub"
601        }
602
603        async fn complete(
604            &self,
605            _model_id: &ModelId,
606            _messages: Vec<Message>,
607            _system: Option<SystemContext>,
608            _tools: Option<Vec<ToolSchema>>,
609            _call_options: Option<crate::config::model::ModelParameters>,
610            _token: CancellationToken,
611        ) -> std::result::Result<CompletionResponse, ApiError> {
612            let err = match self.error_kind {
613                StubErrorKind::Auth => ApiError::AuthenticationFailed {
614                    provider: "stub".to_string(),
615                    details: "bad key".to_string(),
616                },
617                StubErrorKind::Server401 => ApiError::ServerError {
618                    provider: "stub".to_string(),
619                    status_code: 401,
620                    details: "unauthorized".to_string(),
621                },
622            };
623            Err(err)
624        }
625    }
626
627    fn test_client() -> Client {
628        let auth_storage = Arc::new(crate::test_utils::InMemoryAuthStorage::new());
629        let config_provider = LlmConfigProvider::new(auth_storage).unwrap();
630        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
631        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
632
633        Client::new_with_deps(config_provider, provider_registry, model_registry)
634    }
635
636    fn insert_stub_provider(client: &Client, provider_id: ProviderId, error: StubErrorKind) {
637        client.provider_map.write().unwrap().insert(
638            provider_id,
639            ProviderEntry {
640                provider: Arc::new(StubProvider::new(error)),
641                auth_source: AuthSource::ApiKey {
642                    origin: ApiKeyOrigin::Stored,
643                },
644            },
645        );
646    }
647
648    #[tokio::test]
649    async fn invalidates_cached_provider_on_auth_failure() {
650        let client = test_client();
651        let provider_id = ProviderId("stub-auth".to_string());
652        let model_id = ModelId::new(provider_id.clone(), "stub-model");
653
654        insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Auth);
655
656        let err = client
657            .complete(
658                &model_id,
659                vec![],
660                None,
661                None,
662                None,
663                CancellationToken::new(),
664            )
665            .await
666            .unwrap_err();
667
668        assert!(matches!(err, ApiError::AuthenticationFailed { .. }));
669        assert!(
670            !client
671                .provider_map
672                .read()
673                .unwrap()
674                .contains_key(&provider_id)
675        );
676    }
677
678    #[tokio::test]
679    async fn invalidates_cached_provider_on_unauthorized_status_code() {
680        let client = test_client();
681        let provider_id = ProviderId("stub-unauthorized".to_string());
682        let model_id = ModelId::new(provider_id.clone(), "stub-model");
683
684        insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Server401);
685
686        let err = client
687            .complete(
688                &model_id,
689                vec![],
690                None,
691                None,
692                None,
693                CancellationToken::new(),
694            )
695            .await
696            .unwrap_err();
697
698        assert!(matches!(
699            err,
700            ApiError::ServerError {
701                status_code: 401,
702                ..
703            }
704        ));
705        assert!(
706            !client
707                .provider_map
708                .read()
709                .unwrap()
710                .contains_key(&provider_id)
711        );
712    }
713}