steer_core/api/
factory.rs1use crate::api::error::ApiError;
2use crate::api::provider::Provider;
3use crate::api::{
4 claude::AnthropicClient, gemini::GeminiClient, openai::OpenAIClient, xai::XAIClient,
5};
6use crate::auth::storage::Credential;
7use crate::config::provider::{ApiFormat, ProviderConfig};
8use std::sync::Arc;
9
10pub fn create_provider(
16 provider_cfg: &ProviderConfig,
17 credential: &Credential,
18) -> Result<Arc<dyn Provider>, ApiError> {
19 match credential {
20 Credential::ApiKey { value } => match &provider_cfg.api_format {
21 ApiFormat::OpenaiResponses => {
22 let client = if let Some(base_url) = &provider_cfg.base_url {
23 OpenAIClient::with_base_url_mode(
24 value.clone(),
25 Some(base_url.to_string()),
26 crate::api::openai::OpenAIMode::Responses,
27 )
28 } else {
29 OpenAIClient::with_mode(
30 value.clone(),
31 crate::api::openai::OpenAIMode::Responses,
32 )
33 };
34 Ok(Arc::new(client))
35 }
36 ApiFormat::OpenaiChat => {
37 let client = if let Some(base_url) = &provider_cfg.base_url {
38 OpenAIClient::with_base_url_mode(
39 value.clone(),
40 Some(base_url.to_string()),
41 crate::api::openai::OpenAIMode::Chat,
42 )
43 } else {
44 OpenAIClient::with_mode(value.clone(), crate::api::openai::OpenAIMode::Chat)
45 };
46 Ok(Arc::new(client))
47 }
48 ApiFormat::Anthropic => {
49 if provider_cfg.base_url.is_some() {
51 return Err(ApiError::Configuration(
52 "Base URL override not yet supported for Anthropic API format".to_string(),
53 ));
54 }
55 Ok(Arc::new(AnthropicClient::with_api_key(value)))
56 }
57 ApiFormat::Google => {
58 if provider_cfg.base_url.is_some() {
60 return Err(ApiError::Configuration(
61 "Base URL override not yet supported for Gemini API format".to_string(),
62 ));
63 }
64 Ok(Arc::new(GeminiClient::new(value)))
65 }
66 ApiFormat::Xai => {
67 let client = if let Some(base_url) = &provider_cfg.base_url {
68 XAIClient::with_base_url(value.clone(), Some(base_url.to_string()))
69 } else {
70 XAIClient::new(value.clone())
71 };
72 Ok(Arc::new(client))
73 }
74 },
75 Credential::OAuth2(_) => {
76 match &provider_cfg.api_format {
78 ApiFormat::Anthropic => {
79 Err(ApiError::Configuration(
82 "OAuth support requires auth storage context".to_string(),
83 ))
84 }
85 _ => Err(ApiError::Configuration(format!(
86 "OAuth is not supported for {:?} API format",
87 provider_cfg.api_format
88 ))),
89 }
90 }
91 }
92}
93
94pub fn create_provider_with_storage(
99 provider_cfg: &ProviderConfig,
100 credential: &Credential,
101 storage: Arc<dyn crate::auth::AuthStorage>,
102) -> Result<Arc<dyn Provider>, ApiError> {
103 match credential {
104 Credential::ApiKey { .. } => create_provider(provider_cfg, credential),
105 Credential::OAuth2(_) => match &provider_cfg.api_format {
106 ApiFormat::Anthropic => {
107 if provider_cfg.base_url.is_some() {
108 return Err(ApiError::Configuration(
109 "Base URL override not supported with OAuth authentication".to_string(),
110 ));
111 }
112 Ok(Arc::new(AnthropicClient::with_oauth(storage)))
113 }
114 _ => Err(ApiError::Configuration(format!(
115 "OAuth is not supported for {:?} API format",
116 provider_cfg.api_format
117 ))),
118 },
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::config::provider::{self, AuthScheme, ProviderId};
126
127 #[test]
128 fn test_create_openai_provider() {
129 let config = ProviderConfig {
130 id: provider::openai(),
131 name: "OpenAI".to_string(),
132 api_format: ApiFormat::OpenaiResponses,
133 auth_schemes: vec![AuthScheme::ApiKey],
134 base_url: None,
135 };
136
137 let credential = Credential::ApiKey {
138 value: "test-key".to_string(),
139 };
140
141 let provider = create_provider(&config, &credential).unwrap();
142 assert_eq!(provider.name(), "openai");
143 }
144
145 #[test]
146 fn test_create_custom_openai_provider() {
147 let config = ProviderConfig {
148 id: ProviderId("my-provider".to_string()),
149 name: "My Provider".to_string(),
150 api_format: ApiFormat::OpenaiResponses,
151 auth_schemes: vec![AuthScheme::ApiKey],
152 base_url: Some("https://my-api.example.com".parse().unwrap()),
153 };
154
155 let credential = Credential::ApiKey {
156 value: "test-key".to_string(),
157 };
158
159 let provider = create_provider(&config, &credential).unwrap();
160 assert_eq!(provider.name(), "openai"); }
162
163 #[test]
164 fn test_oauth_requires_storage() {
165 let config = ProviderConfig {
166 id: provider::anthropic(),
167 name: "Anthropic".to_string(),
168 api_format: ApiFormat::Anthropic,
169 auth_schemes: vec![AuthScheme::Oauth2],
170 base_url: None,
171 };
172
173 let credential = Credential::OAuth2(crate::auth::storage::OAuth2Token {
174 access_token: "test-token".to_string(),
175 refresh_token: "test-refresh".to_string(),
176 expires_at: std::time::SystemTime::now(),
177 });
178
179 let result = create_provider(&config, &credential);
180 assert!(result.is_err());
181 let err_msg = match result {
182 Err(e) => e.to_string(),
183 Ok(_) => panic!("Expected error"),
184 };
185 assert!(err_msg.contains("OAuth support requires auth storage"));
186 }
187}