1pub mod claude;
2pub mod error;
3pub mod factory;
4pub mod gemini;
5pub mod openai;
6pub mod provider;
7pub mod util;
8pub mod xai;
9
10use crate::auth::ProviderRegistry;
11use crate::auth::storage::{Credential, CredentialType};
12use crate::config::model::ModelId;
13use crate::config::provider::ProviderId;
14use crate::config::{ApiAuth, LlmConfigProvider};
15use crate::error::Result;
16use crate::model_registry::ModelRegistry;
17pub use error::ApiError;
18pub use factory::{create_provider, create_provider_with_storage};
19pub use provider::{CompletionResponse, Provider};
20use std::collections::HashMap;
21use std::sync::Arc;
22use std::sync::RwLock;
23use steer_tools::ToolSchema;
24use tokio_util::sync::CancellationToken;
25use tracing::debug;
26use tracing::warn;
27
28use crate::app::conversation::Message;
29
30#[derive(Clone)]
31pub struct Client {
32 provider_map: Arc<RwLock<HashMap<ProviderId, Arc<dyn Provider>>>>,
33 config_provider: LlmConfigProvider,
34 provider_registry: Arc<ProviderRegistry>,
35 model_registry: Arc<ModelRegistry>,
36}
37
38impl Client {
39 pub fn new_with_deps(
42 config_provider: LlmConfigProvider,
43 provider_registry: Arc<ProviderRegistry>,
44 model_registry: Arc<ModelRegistry>,
45 ) -> Self {
46 Self {
47 provider_map: Arc::new(RwLock::new(HashMap::new())),
48 config_provider,
49 provider_registry,
50 model_registry,
51 }
52 }
53
54 async fn get_or_create_provider(&self, provider_id: ProviderId) -> Result<Arc<dyn Provider>> {
55 {
57 let map = self.provider_map.read().unwrap();
58 if let Some(provider) = map.get(&provider_id) {
59 return Ok(provider.clone());
60 }
61 }
62
63 let provider_config = self.provider_registry.get(&provider_id).ok_or_else(|| {
65 crate::error::Error::Api(ApiError::Configuration(format!(
66 "No provider configuration found for {provider_id:?}"
67 )))
68 })?;
69
70 let credential = match self
72 .config_provider
73 .get_auth_for_provider(&provider_id)
74 .await?
75 {
76 Some(ApiAuth::OAuth) => {
77 self.config_provider
79 .auth_storage()
80 .get_credential(&provider_id.storage_key(), CredentialType::OAuth2)
81 .await
82 .map_err(|e| {
83 crate::error::Error::Api(ApiError::Configuration(format!(
84 "Failed to get OAuth credential: {e}"
85 )))
86 })?
87 .ok_or_else(|| {
88 crate::error::Error::Api(ApiError::Configuration(
89 "OAuth credential not found in storage".to_string(),
90 ))
91 })?
92 }
93 Some(ApiAuth::Key(key)) => Credential::ApiKey { value: key },
94 None => {
95 return Err(crate::error::Error::Api(ApiError::Configuration(format!(
96 "No authentication configured for {provider_id:?}"
97 ))));
98 }
99 };
100
101 let mut map = self.provider_map.write().unwrap();
103
104 if let Some(provider) = map.get(&provider_id) {
106 return Ok(provider.clone());
107 }
108
109 let provider_instance = if matches!(&credential, Credential::OAuth2(_)) {
111 factory::create_provider_with_storage(
112 provider_config,
113 &credential,
114 self.config_provider.auth_storage().clone(),
115 )
116 .map_err(crate::error::Error::Api)?
117 } else {
118 factory::create_provider(provider_config, &credential)
119 .map_err(crate::error::Error::Api)?
120 };
121
122 map.insert(provider_id, provider_instance.clone());
123 Ok(provider_instance)
124 }
125
126 pub async fn complete(
128 &self,
129 model_id: &ModelId,
130 messages: Vec<Message>,
131 system: Option<String>,
132 tools: Option<Vec<ToolSchema>>,
133 call_options: Option<crate::config::model::ModelParameters>,
134 token: CancellationToken,
135 ) -> std::result::Result<CompletionResponse, ApiError> {
136 let provider_id = model_id.0.clone();
138 let provider = self
139 .get_or_create_provider(provider_id)
140 .await
141 .map_err(ApiError::from)?;
142
143 if token.is_cancelled() {
144 return Err(ApiError::Cancelled {
145 provider: provider.name().to_string(),
146 });
147 }
148
149 let model_config = self.model_registry.get(model_id);
151 let effective_params = match (model_config, &call_options) {
152 (Some(config), Some(opts)) => config.effective_parameters(Some(opts)),
153 (Some(config), None) => config.effective_parameters(None),
154 (None, Some(opts)) => Some(*opts),
155 (None, None) => None,
156 };
157
158 debug!(
159 target: "api::complete",
160 ?model_id,
161 ?call_options,
162 ?effective_params,
163 "Final parameters for model"
164 );
165
166 provider
167 .complete(model_id, messages, system, tools, effective_params, token)
168 .await
169 }
170
171 pub async fn complete_with_retry(
172 &self,
173 model_id: &ModelId,
174 messages: &[Message],
175 system_prompt: &Option<String>,
176 tools: &Option<Vec<ToolSchema>>,
177 token: CancellationToken,
178 max_attempts: usize,
179 ) -> std::result::Result<CompletionResponse, ApiError> {
180 let mut attempts = 0;
181
182 let provider_id = model_id.0.clone();
184 let provider = self
185 .get_or_create_provider(provider_id.clone())
186 .await
187 .map_err(ApiError::from)?;
188
189 let model_config = self.model_registry.get(model_id);
190 debug!(
191 target: "api::complete_with_retry",
192 ?model_id,
193 ?model_config,
194 "Model config"
195 );
196 let effective_params = model_config.and_then(|cfg| cfg.effective_parameters(None));
197
198 debug!(
199 target: "api::complete_with_retry",
200 ?model_id,
201 ?effective_params,
202 "system: {:?}",
203 system_prompt
204 );
205 debug!(
206 target: "api::complete_with_retry",
207 ?model_id,
208 "messages: {:?}",
209 messages
210 );
211
212 loop {
213 if token.is_cancelled() {
214 return Err(ApiError::Cancelled {
215 provider: provider.name().to_string(),
216 });
217 }
218
219 match provider
220 .complete(
221 model_id,
222 messages.to_vec(),
223 system_prompt.clone(),
224 tools.clone(),
225 effective_params,
226 token.clone(),
227 )
228 .await
229 {
230 Ok(response) => {
231 return Ok(response);
232 }
233 Err(error) => {
234 attempts += 1;
235 warn!(
236 "API completion attempt {}/{} failed for model {:?}: {:?}",
237 attempts, max_attempts, model_id, error
238 );
239
240 if attempts >= max_attempts {
241 return Err(error);
242 }
243
244 match error {
245 ApiError::RateLimited { provider, details } => {
246 let sleep_duration =
247 std::time::Duration::from_secs(1 << (attempts - 1));
248 warn!(
249 "Rate limited by API: {} {} (retrying in {} seconds)",
250 provider,
251 details,
252 sleep_duration.as_secs()
253 );
254 tokio::time::sleep(sleep_duration).await;
255 }
256 ApiError::NoChoices { provider } => {
257 warn!("No choices returned from API: {}", provider);
258 }
259 ApiError::ServerError {
260 provider,
261 status_code,
262 details,
263 } => {
264 warn!(
265 "Server error for API: {} {} {}",
266 provider, status_code, details
267 );
268 }
269 _ => {
270 return Err(error);
272 }
273 }
274 }
275 }
276 }
277 }
278}