Skip to main content

steer_core/api/
mod.rs

1pub mod claude;
2pub mod error;
3pub mod factory;
4pub mod gemini;
5pub mod openai;
6pub mod provider;
7pub mod sse;
8pub mod util;
9pub mod xai;
10
11use crate::auth::storage::Credential;
12use crate::auth::{AuthSource, ProviderRegistry};
13use crate::config::model::ModelId;
14use crate::config::provider::ProviderId;
15use crate::config::{LlmConfigProvider, ResolvedAuth};
16use crate::error::Result;
17use crate::model_registry::ModelRegistry;
18pub use error::{ApiError, SseParseError, StreamError};
19pub use factory::{create_provider, create_provider_with_directive};
20use futures::StreamExt;
21pub use provider::{CompletionResponse, CompletionStream, Provider, StreamChunk, TokenUsage};
22use std::collections::HashMap;
23use std::sync::Arc;
24use std::sync::RwLock;
25use steer_tools::ToolSchema;
26use tokio_util::sync::CancellationToken;
27use tracing::debug;
28use tracing::warn;
29
30use crate::app::SystemContext;
31use crate::app::conversation::Message;
32
33const STREAM_TRANSPORT_RETRY_MAX_ATTEMPTS: usize = 2;
34
35#[derive(Clone)]
36pub struct Client {
37    provider_map: Arc<RwLock<HashMap<ProviderId, ProviderEntry>>>,
38    config_provider: LlmConfigProvider,
39    provider_registry: Arc<ProviderRegistry>,
40    model_registry: Arc<ModelRegistry>,
41}
42
43#[derive(Clone)]
44struct ProviderEntry {
45    provider: Arc<dyn Provider>,
46    auth_source: AuthSource,
47}
48
49impl Client {
50    /// Remove a cached provider so that future calls re-create it with fresh credentials.
51    fn invalidate_provider(&self, provider_id: &ProviderId) {
52        let Ok(mut map) = self.provider_map.write() else {
53            warn!(
54                target: "api::client",
55                "Provider cache lock poisoned while invalidating provider"
56            );
57            return;
58        };
59        map.remove(provider_id);
60    }
61
62    /// Determine if an API error should invalidate the cached provider (typically auth failures).
63    fn should_invalidate_provider(error: &ApiError) -> bool {
64        matches!(
65            error,
66            ApiError::AuthenticationFailed { .. } | ApiError::AuthError(_)
67        ) || matches!(
68            error,
69            ApiError::ServerError { status_code, .. } if matches!(status_code, 401 | 403)
70        )
71    }
72
73    /// Create a new Client with all dependencies injected.
74    /// This is the preferred constructor to avoid internal registry loading.
75    pub fn new_with_deps(
76        config_provider: LlmConfigProvider,
77        provider_registry: Arc<ProviderRegistry>,
78        model_registry: Arc<ModelRegistry>,
79    ) -> Self {
80        Self {
81            provider_map: Arc::new(RwLock::new(HashMap::new())),
82            config_provider,
83            provider_registry,
84            model_registry,
85        }
86    }
87
88    pub fn model_context_window_tokens(&self, model_id: &ModelId) -> Option<u32> {
89        self.model_registry
90            .get(model_id)
91            .and_then(|model| model.context_window_tokens)
92    }
93
94    #[cfg(any(test, feature = "test-utils"))]
95    pub fn insert_test_provider(&self, provider_id: ProviderId, provider: Arc<dyn Provider>) {
96        match self.provider_map.write() {
97            Ok(mut map) => {
98                map.insert(
99                    provider_id,
100                    ProviderEntry {
101                        provider,
102                        auth_source: AuthSource::None,
103                    },
104                );
105            }
106            Err(_) => {
107                warn!(
108                    target: "api::client",
109                    "Provider cache lock poisoned while inserting test provider"
110                );
111            }
112        }
113    }
114
115    async fn get_or_create_provider_entry(&self, provider_id: ProviderId) -> Result<ProviderEntry> {
116        // First check without holding the lock across await
117        {
118            let map = self.provider_map.read().map_err(|_| {
119                crate::error::Error::Api(ApiError::Configuration(
120                    "Provider cache lock poisoned".to_string(),
121                ))
122            })?;
123            if let Some(entry) = map.get(&provider_id) {
124                return Ok(entry.clone());
125            }
126        }
127
128        // Get the provider config from registry
129        let provider_config = self.provider_registry.get(&provider_id).ok_or_else(|| {
130            crate::error::Error::Api(ApiError::Configuration(format!(
131                "No provider configuration found for {provider_id:?}"
132            )))
133        })?;
134
135        let resolved = self
136            .config_provider
137            .resolve_auth_for_provider(&provider_id)
138            .await?;
139
140        // Now acquire write lock and create provider
141        let mut map = self.provider_map.write().map_err(|_| {
142            crate::error::Error::Api(ApiError::Configuration(
143                "Provider cache lock poisoned".to_string(),
144            ))
145        })?;
146
147        // Check again in case another thread added it
148        if let Some(entry) = map.get(&provider_id) {
149            return Ok(entry.clone());
150        }
151
152        let entry = Self::build_provider_entry(provider_config, &resolved)?;
153
154        map.insert(provider_id, entry.clone());
155        Ok(entry)
156    }
157
158    fn build_provider_entry(
159        provider_config: &crate::config::provider::ProviderConfig,
160        resolved: &ResolvedAuth,
161    ) -> std::result::Result<ProviderEntry, ApiError> {
162        let provider = match resolved {
163            ResolvedAuth::Plugin { directive, .. } => {
164                factory::create_provider_with_directive(provider_config, directive)?
165            }
166            ResolvedAuth::ApiKey { credential, .. } => {
167                factory::create_provider(provider_config, credential)?
168            }
169            ResolvedAuth::None => {
170                return Err(ApiError::Configuration(format!(
171                    "No authentication configured for {:?}",
172                    provider_config.id
173                )));
174            }
175        };
176
177        Ok(ProviderEntry {
178            provider,
179            auth_source: resolved.source(),
180        })
181    }
182
183    async fn fallback_api_key_entry(
184        &self,
185        provider_id: &ProviderId,
186    ) -> std::result::Result<Option<ProviderEntry>, ApiError> {
187        let Some((key, origin)) = self
188            .config_provider
189            .resolve_api_key_for_provider(provider_id)
190            .await?
191        else {
192            return Ok(None);
193        };
194
195        let provider_config = self.provider_registry.get(provider_id).ok_or_else(|| {
196            ApiError::Configuration(format!(
197                "No provider configuration found for {provider_id:?}"
198            ))
199        })?;
200
201        let credential = Credential::ApiKey { value: key };
202        let provider = factory::create_provider(provider_config, &credential)?;
203
204        Ok(Some(ProviderEntry {
205            provider,
206            auth_source: AuthSource::ApiKey { origin },
207        }))
208    }
209
210    /// Complete a prompt with a specific model ID and optional parameters
211    pub async fn complete(
212        &self,
213        model_id: &ModelId,
214        messages: Vec<Message>,
215        system: Option<SystemContext>,
216        tools: Option<Vec<ToolSchema>>,
217        call_options: Option<crate::config::model::ModelParameters>,
218        token: CancellationToken,
219    ) -> std::result::Result<CompletionResponse, ApiError> {
220        // Get provider from model ID
221        let provider_id = model_id.provider.clone();
222        let entry = self
223            .get_or_create_provider_entry(provider_id.clone())
224            .await
225            .map_err(ApiError::from)?;
226        let provider = entry.provider.clone();
227
228        if token.is_cancelled() {
229            return Err(ApiError::Cancelled {
230                provider: provider.name().to_string(),
231            });
232        }
233
234        // Get model config and merge parameters
235        let model_config = self.model_registry.get(model_id);
236        let effective_params = match (model_config, &call_options) {
237            (Some(config), Some(opts)) => config.effective_parameters(Some(opts)),
238            (Some(config), None) => config.effective_parameters(None),
239            (None, Some(opts)) => Some(*opts),
240            (None, None) => None,
241        };
242
243        debug!(
244            target: "api::complete",
245            ?model_id,
246            ?call_options,
247            ?effective_params,
248            "Final parameters for model"
249        );
250
251        let result = provider
252            .complete(
253                model_id,
254                messages.clone(),
255                system.clone(),
256                tools.clone(),
257                effective_params,
258                token.clone(),
259            )
260            .await;
261
262        if let Err(ref err) = result
263            && Self::should_invalidate_provider(err)
264        {
265            self.invalidate_provider(&provider_id);
266
267            if matches!(entry.auth_source, AuthSource::Plugin { .. })
268                && let Some(fallback) = self.fallback_api_key_entry(&provider_id).await?
269            {
270                let fallback_result = fallback
271                    .provider
272                    .complete(model_id, messages, system, tools, effective_params, token)
273                    .await;
274                if fallback_result.is_ok() {
275                    let mut map = self.provider_map.write().map_err(|_| {
276                        ApiError::Configuration("Provider cache lock poisoned".to_string())
277                    })?;
278                    map.insert(provider_id, fallback);
279                }
280            }
281        }
282
283        result
284    }
285
286    pub async fn stream_complete(
287        &self,
288        model_id: &ModelId,
289        messages: Vec<Message>,
290        system: Option<SystemContext>,
291        tools: Option<Vec<ToolSchema>>,
292        call_options: Option<crate::config::model::ModelParameters>,
293        token: CancellationToken,
294    ) -> std::result::Result<CompletionStream, ApiError> {
295        let provider_id = model_id.provider.clone();
296        let entry = self
297            .get_or_create_provider_entry(provider_id.clone())
298            .await
299            .map_err(ApiError::from)?;
300        let provider = entry.provider.clone();
301
302        if token.is_cancelled() {
303            return Err(ApiError::Cancelled {
304                provider: provider.name().to_string(),
305            });
306        }
307
308        let model_config = self.model_registry.get(model_id);
309        let effective_params = match (model_config, &call_options) {
310            (Some(config), Some(opts)) => config.effective_parameters(Some(opts)),
311            (Some(config), None) => config.effective_parameters(None),
312            (None, Some(opts)) => Some(*opts),
313            (None, None) => None,
314        };
315
316        debug!(
317            target: "api::stream_complete",
318            ?model_id,
319            ?call_options,
320            ?effective_params,
321            "Streaming with parameters"
322        );
323
324        let (initial_stream, provider_for_retry) = match provider
325            .stream_complete(
326                model_id,
327                messages.clone(),
328                system.clone(),
329                tools.clone(),
330                effective_params,
331                token.clone(),
332            )
333            .await
334        {
335            Ok(stream) => (stream, provider),
336            Err(err) => {
337                if Self::should_invalidate_provider(&err) {
338                    self.invalidate_provider(&provider_id);
339
340                    if matches!(entry.auth_source, AuthSource::Plugin { .. }) {
341                        if let Some(fallback) = self.fallback_api_key_entry(&provider_id).await? {
342                            let fallback_provider = fallback.provider.clone();
343                            let fallback_stream = fallback_provider
344                                .stream_complete(
345                                    model_id,
346                                    messages.clone(),
347                                    system.clone(),
348                                    tools.clone(),
349                                    effective_params,
350                                    token.clone(),
351                                )
352                                .await?;
353                            let mut map = self.provider_map.write().map_err(|_| {
354                                ApiError::Configuration("Provider cache lock poisoned".to_string())
355                            })?;
356                            map.insert(provider_id, fallback);
357                            (fallback_stream, fallback_provider)
358                        } else {
359                            return Err(err);
360                        }
361                    } else {
362                        return Err(err);
363                    }
364                } else {
365                    return Err(err);
366                }
367            }
368        };
369
370        let model_id = model_id.clone();
371        let stream = async_stream::stream! {
372            let mut attempt = 1usize;
373            let mut current_stream = Some(initial_stream);
374
375            'outer: loop {
376                let mut saw_output = false;
377                let mut stream = if let Some(stream) = current_stream.take() { stream } else {
378                    if token.is_cancelled() {
379                        yield StreamChunk::Error(StreamError::Cancelled);
380                        break;
381                    }
382
383                    let stream_result = provider_for_retry
384                        .stream_complete(
385                            &model_id,
386                            messages.clone(),
387                            system.clone(),
388                            tools.clone(),
389                            effective_params,
390                            token.clone(),
391                        )
392                        .await;
393                    match stream_result {
394                        Ok(stream) => stream,
395                        Err(err) => {
396                            yield StreamChunk::Error(StreamError::Provider {
397                                provider: provider_for_retry.name().to_string(),
398                                error_type: "stream_retry".to_string(),
399                                message: err.to_string(),
400                            });
401                            break;
402                        }
403                    }
404                };
405
406                while let Some(chunk) = stream.next().await {
407                    if matches!(
408                        &chunk,
409                        StreamChunk::Error(StreamError::SseParse(
410                            SseParseError::Transport { .. }
411                        )) if !saw_output && attempt < STREAM_TRANSPORT_RETRY_MAX_ATTEMPTS
412                    ) {
413                        attempt += 1;
414                        warn!(
415                            target: "api::stream_complete",
416                            ?model_id,
417                            attempt,
418                            max_attempts = STREAM_TRANSPORT_RETRY_MAX_ATTEMPTS,
419                            "Retrying stream after transport error before any output"
420                        );
421                        current_stream = None;
422                        continue 'outer;
423                    }
424
425                    if !matches!(chunk, StreamChunk::Error(_)) {
426                        saw_output = true;
427                    }
428
429                    yield chunk;
430                }
431
432                break;
433            }
434        };
435
436        Ok(Box::pin(stream))
437    }
438
439    pub async fn complete_with_retry(
440        &self,
441        model_id: &ModelId,
442        messages: &[Message],
443        system_prompt: &Option<SystemContext>,
444        tools: &Option<Vec<ToolSchema>>,
445        token: CancellationToken,
446        max_attempts: usize,
447    ) -> std::result::Result<CompletionResponse, ApiError> {
448        let mut attempts = 0;
449
450        // Prepare provider and parameters once
451        let provider_id = model_id.provider.clone();
452        let entry = self
453            .get_or_create_provider_entry(provider_id.clone())
454            .await
455            .map_err(ApiError::from)?;
456        let provider = entry.provider.clone();
457
458        let model_config = self.model_registry.get(model_id);
459        debug!(
460            target: "api::complete_with_retry",
461            ?model_id,
462            ?model_config,
463            "Model config"
464        );
465        let effective_params = model_config.and_then(|cfg| cfg.effective_parameters(None));
466
467        debug!(
468            target: "api::complete_with_retry",
469            ?model_id,
470            ?effective_params,
471            "system: {:?}",
472            system_prompt
473        );
474        debug!(
475            target: "api::complete_with_retry",
476            ?model_id,
477            "messages: {:?}",
478            messages
479        );
480
481        loop {
482            if token.is_cancelled() {
483                return Err(ApiError::Cancelled {
484                    provider: provider.name().to_string(),
485                });
486            }
487
488            match provider
489                .complete(
490                    model_id,
491                    messages.to_vec(),
492                    system_prompt.clone(),
493                    tools.clone(),
494                    effective_params,
495                    token.clone(),
496                )
497                .await
498            {
499                Ok(response) => {
500                    return Ok(response);
501                }
502                Err(error) => {
503                    attempts += 1;
504                    warn!(
505                        "API completion attempt {}/{} failed for model {:?}: {:?}",
506                        attempts, max_attempts, model_id, error
507                    );
508
509                    if Self::should_invalidate_provider(&error) {
510                        self.invalidate_provider(&provider_id);
511                        if matches!(entry.auth_source, AuthSource::Plugin { .. })
512                            && let Some(fallback) =
513                                self.fallback_api_key_entry(&provider_id).await?
514                        {
515                            let fallback_result = fallback
516                                .provider
517                                .complete(
518                                    model_id,
519                                    messages.to_vec(),
520                                    system_prompt.clone(),
521                                    tools.clone(),
522                                    effective_params,
523                                    token.clone(),
524                                )
525                                .await;
526                            if fallback_result.is_ok() {
527                                let mut map = self.provider_map.write().map_err(|_| {
528                                    ApiError::Configuration(
529                                        "Provider cache lock poisoned".to_string(),
530                                    )
531                                })?;
532                                map.insert(provider_id.clone(), fallback);
533                            }
534                        }
535                        return Err(error);
536                    }
537
538                    if attempts >= max_attempts {
539                        return Err(error);
540                    }
541
542                    match error {
543                        ApiError::RateLimited { provider, details } => {
544                            let sleep_duration =
545                                std::time::Duration::from_secs(1 << (attempts - 1));
546                            warn!(
547                                "Rate limited by API: {} {} (retrying in {} seconds)",
548                                provider,
549                                details,
550                                sleep_duration.as_secs()
551                            );
552                            tokio::time::sleep(sleep_duration).await;
553                        }
554                        ApiError::NoChoices { provider } => {
555                            warn!("No choices returned from API: {}", provider);
556                        }
557                        ApiError::ServerError {
558                            provider,
559                            status_code,
560                            details,
561                        } => {
562                            warn!(
563                                "Server error for API: {} {} {}",
564                                provider, status_code, details
565                            );
566                        }
567                        _ => {
568                            // Not retryable
569                            return Err(error);
570                        }
571                    }
572                }
573            }
574        }
575    }
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581    use crate::auth::ApiKeyOrigin;
582    use crate::config::provider::ProviderId;
583    use async_trait::async_trait;
584    use tokio_util::sync::CancellationToken;
585
586    #[derive(Clone, Copy)]
587    enum StubErrorKind {
588        Auth,
589        Server401,
590    }
591
592    #[derive(Clone)]
593    struct StubProvider {
594        error_kind: StubErrorKind,
595    }
596
597    impl StubProvider {
598        fn new(error_kind: StubErrorKind) -> Self {
599            Self { error_kind }
600        }
601    }
602
603    #[async_trait]
604    impl Provider for StubProvider {
605        fn name(&self) -> &'static str {
606            "stub"
607        }
608
609        async fn complete(
610            &self,
611            _model_id: &ModelId,
612            _messages: Vec<Message>,
613            _system: Option<SystemContext>,
614            _tools: Option<Vec<ToolSchema>>,
615            _call_options: Option<crate::config::model::ModelParameters>,
616            _token: CancellationToken,
617        ) -> std::result::Result<CompletionResponse, ApiError> {
618            let err = match self.error_kind {
619                StubErrorKind::Auth => ApiError::AuthenticationFailed {
620                    provider: "stub".to_string(),
621                    details: "bad key".to_string(),
622                },
623                StubErrorKind::Server401 => ApiError::ServerError {
624                    provider: "stub".to_string(),
625                    status_code: 401,
626                    details: "unauthorized".to_string(),
627                },
628            };
629            Err(err)
630        }
631    }
632
633    fn test_client() -> Client {
634        let auth_storage = Arc::new(crate::test_utils::InMemoryAuthStorage::new());
635        let config_provider = LlmConfigProvider::new(auth_storage).unwrap();
636        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
637        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
638
639        Client::new_with_deps(config_provider, provider_registry, model_registry)
640    }
641
642    fn insert_stub_provider(client: &Client, provider_id: ProviderId, error: StubErrorKind) {
643        client.provider_map.write().unwrap().insert(
644            provider_id,
645            ProviderEntry {
646                provider: Arc::new(StubProvider::new(error)),
647                auth_source: AuthSource::ApiKey {
648                    origin: ApiKeyOrigin::Stored,
649                },
650            },
651        );
652    }
653
654    #[tokio::test]
655    async fn invalidates_cached_provider_on_auth_failure() {
656        let client = test_client();
657        let provider_id = ProviderId("stub-auth".to_string());
658        let model_id = ModelId::new(provider_id.clone(), "stub-model");
659
660        insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Auth);
661
662        let err = client
663            .complete(
664                &model_id,
665                vec![],
666                None,
667                None,
668                None,
669                CancellationToken::new(),
670            )
671            .await
672            .unwrap_err();
673
674        assert!(matches!(err, ApiError::AuthenticationFailed { .. }));
675        assert!(
676            !client
677                .provider_map
678                .read()
679                .unwrap()
680                .contains_key(&provider_id)
681        );
682    }
683
684    #[tokio::test]
685    async fn invalidates_cached_provider_on_unauthorized_status_code() {
686        let client = test_client();
687        let provider_id = ProviderId("stub-unauthorized".to_string());
688        let model_id = ModelId::new(provider_id.clone(), "stub-model");
689
690        insert_stub_provider(&client, provider_id.clone(), StubErrorKind::Server401);
691
692        let err = client
693            .complete(
694                &model_id,
695                vec![],
696                None,
697                None,
698                None,
699                CancellationToken::new(),
700            )
701            .await
702            .unwrap_err();
703
704        assert!(matches!(
705            err,
706            ApiError::ServerError {
707                status_code: 401,
708                ..
709            }
710        ));
711        assert!(
712            !client
713                .provider_map
714                .read()
715                .unwrap()
716                .contains_key(&provider_id)
717        );
718    }
719}