steer_core/api/
factory.rs1use crate::api::error::ApiError;
2use crate::api::provider::Provider;
3use crate::api::{
4 claude::AnthropicClient, gemini::GeminiClient, openai::OpenAIClient, xai::XAIClient,
5};
6use crate::auth::{AuthDirective, Credential};
7use crate::config::provider::{ApiFormat, ProviderConfig};
8use std::sync::Arc;
9
10pub fn create_provider(
16 provider_cfg: &ProviderConfig,
17 credential: &Credential,
18) -> Result<Arc<dyn Provider>, ApiError> {
19 match credential {
20 Credential::ApiKey { value } => match &provider_cfg.api_format {
21 ApiFormat::OpenaiResponses => {
22 let client = if let Some(base_url) = &provider_cfg.base_url {
23 OpenAIClient::with_base_url_mode(
24 value.clone(),
25 Some(base_url.to_string()),
26 crate::api::openai::OpenAIMode::Responses,
27 )?
28 } else {
29 OpenAIClient::with_mode(
30 value.clone(),
31 crate::api::openai::OpenAIMode::Responses,
32 )?
33 };
34 Ok(Arc::new(client))
35 }
36 ApiFormat::OpenaiChat => {
37 let client = if let Some(base_url) = &provider_cfg.base_url {
38 OpenAIClient::with_base_url_mode(
39 value.clone(),
40 Some(base_url.to_string()),
41 crate::api::openai::OpenAIMode::Chat,
42 )?
43 } else {
44 OpenAIClient::with_mode(value.clone(), crate::api::openai::OpenAIMode::Chat)?
45 };
46 Ok(Arc::new(client))
47 }
48 ApiFormat::Anthropic => {
49 if provider_cfg.base_url.is_some() {
51 return Err(ApiError::Configuration(
52 "Base URL override not yet supported for Anthropic API format".to_string(),
53 ));
54 }
55 Ok(Arc::new(AnthropicClient::with_api_key(value)?))
56 }
57 ApiFormat::Google => {
58 if provider_cfg.base_url.is_some() {
60 return Err(ApiError::Configuration(
61 "Base URL override not yet supported for Gemini API format".to_string(),
62 ));
63 }
64 Ok(Arc::new(GeminiClient::new(value)))
65 }
66 ApiFormat::Xai => {
67 let client = if let Some(base_url) = &provider_cfg.base_url {
68 XAIClient::with_base_url(value.clone(), Some(base_url.to_string()))?
69 } else {
70 XAIClient::new(value.clone())?
71 };
72 Ok(Arc::new(client))
73 }
74 },
75 Credential::OAuth2(_) => Err(ApiError::Configuration(
76 "OAuth requires an AuthDirective, not a raw credential".to_string(),
77 )),
78 }
79}
80
81pub fn create_provider_with_directive(
83 provider_cfg: &ProviderConfig,
84 directive: &AuthDirective,
85) -> Result<Arc<dyn Provider>, ApiError> {
86 match directive {
87 AuthDirective::OpenAiResponses(openai) => {
88 if provider_cfg.api_format != ApiFormat::OpenaiResponses {
89 return Err(ApiError::Configuration(
90 "OpenAI OAuth directives require responses API format".to_string(),
91 ));
92 }
93 let base_url = provider_cfg.base_url.as_ref().map(|url| url.to_string());
94 Ok(Arc::new(OpenAIClient::with_directive(
95 openai.clone(),
96 base_url,
97 )?))
98 }
99 AuthDirective::Anthropic(anthropic) => {
100 if provider_cfg.api_format != ApiFormat::Anthropic {
101 return Err(ApiError::Configuration(
102 "Anthropic OAuth directives require Anthropic API format".to_string(),
103 ));
104 }
105 if provider_cfg.base_url.is_some() {
106 return Err(ApiError::Configuration(
107 "Base URL override not yet supported for Anthropic API format".to_string(),
108 ));
109 }
110 Ok(Arc::new(AnthropicClient::with_directive(
111 anthropic.clone(),
112 )?))
113 }
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use crate::config::provider::{self, AuthScheme, ProviderId};
121
122 #[test]
123 fn test_create_openai_provider() {
124 let config = ProviderConfig {
125 id: provider::openai(),
126 name: "OpenAI".to_string(),
127 api_format: ApiFormat::OpenaiResponses,
128 auth_schemes: vec![AuthScheme::ApiKey],
129 base_url: None,
130 };
131
132 let credential = Credential::ApiKey {
133 value: "test-key".to_string(),
134 };
135
136 let provider = create_provider(&config, &credential).unwrap();
137 assert_eq!(provider.name(), "openai");
138 }
139
140 #[test]
141 fn test_create_custom_openai_provider() {
142 let config = ProviderConfig {
143 id: ProviderId("my-provider".to_string()),
144 name: "My Provider".to_string(),
145 api_format: ApiFormat::OpenaiResponses,
146 auth_schemes: vec![AuthScheme::ApiKey],
147 base_url: Some("https://my-api.example.com".parse().unwrap()),
148 };
149
150 let credential = Credential::ApiKey {
151 value: "test-key".to_string(),
152 };
153
154 let provider = create_provider(&config, &credential).unwrap();
155 assert_eq!(provider.name(), "openai"); }
157
158 #[test]
159 fn test_oauth_requires_directive() {
160 let config = ProviderConfig {
161 id: provider::anthropic(),
162 name: "Anthropic".to_string(),
163 api_format: ApiFormat::Anthropic,
164 auth_schemes: vec![AuthScheme::Oauth2],
165 base_url: None,
166 };
167
168 let credential = Credential::OAuth2(crate::auth::storage::OAuth2Token {
169 access_token: "test-token".to_string(),
170 refresh_token: "test-refresh".to_string(),
171 expires_at: std::time::SystemTime::now(),
172 id_token: None,
173 });
174
175 let result = create_provider(&config, &credential);
176 assert!(result.is_err());
177 let err_msg = match result {
178 Err(e) => e.to_string(),
179 Ok(_) => panic!("Expected error"),
180 };
181 assert!(err_msg.contains("OAuth requires an AuthDirective"));
182 }
183}