spec_ai_core/agent/
factory.rs1use crate::agent::model::{ModelProvider, ProviderKind};
6#[cfg(feature = "anthropic")]
7use crate::agent::providers::AnthropicProvider;
8#[cfg(feature = "lmstudio")]
9use crate::agent::providers::LMStudioProvider;
10#[cfg(feature = "mlx")]
11use crate::agent::providers::MLXProvider;
12use crate::agent::providers::MockProvider;
13#[cfg(feature = "ollama")]
14use crate::agent::providers::OllamaProvider;
15#[cfg(feature = "openai")]
16use crate::agent::providers::OpenAIProvider;
17use crate::config::ModelConfig;
18use anyhow::{anyhow, Context, Result};
19use std::sync::Arc;
20
21pub fn create_provider(config: &ModelConfig) -> Result<Arc<dyn ModelProvider>> {
23 let provider_kind = ProviderKind::from_str(&config.provider)
24 .ok_or_else(|| anyhow!("Unknown provider: {}", config.provider))?;
25
26 match provider_kind {
27 ProviderKind::Mock => {
28 let provider = if let Some(model_name) = &config.model_name {
30 MockProvider::default().with_model_name(model_name.clone())
31 } else {
32 MockProvider::default()
33 };
34 Ok(Arc::new(provider))
35 }
36
37 #[cfg(feature = "openai")]
38 ProviderKind::OpenAI => {
39 let api_key = if let Some(source) = &config.api_key_source {
41 resolve_api_key(source)?
42 } else {
43 load_api_key_from_env("OPENAI_API_KEY")?
45 };
46
47 let mut provider = OpenAIProvider::with_api_key(api_key);
49
50 if let Some(model_name) = &config.model_name {
52 provider = provider.with_model(model_name.clone());
53 }
54
55 Ok(Arc::new(provider))
56 }
57
58 #[cfg(feature = "anthropic")]
59 ProviderKind::Anthropic => {
60 let api_key = if let Some(source) = &config.api_key_source {
62 resolve_api_key(source)?
63 } else {
64 load_api_key_from_env("ANTHROPIC_API_KEY")?
66 };
67
68 let mut provider = AnthropicProvider::with_api_key(api_key);
70
71 if let Some(model_name) = &config.model_name {
73 provider = provider.with_model(model_name.clone());
74 }
75
76 Ok(Arc::new(provider))
77 }
78
79 #[cfg(feature = "ollama")]
80 ProviderKind::Ollama => {
81 let mut provider = if let Ok(base_url) = std::env::var("OLLAMA_BASE_URL") {
83 OllamaProvider::with_base_url(base_url)
84 } else {
85 OllamaProvider::new()
86 };
87
88 if let Some(model_name) = &config.model_name {
90 provider = provider.with_model(model_name.clone());
91 }
92
93 Ok(Arc::new(provider))
94 }
95
96 #[cfg(feature = "mlx")]
97 ProviderKind::MLX => {
98 let model_name = config
100 .model_name
101 .as_ref()
102 .ok_or_else(|| anyhow!("MLX provider requires a model_name to be specified"))?;
103
104 let provider = if let Ok(endpoint) = std::env::var("MLX_ENDPOINT") {
107 MLXProvider::with_endpoint(endpoint, model_name)
108 } else {
109 MLXProvider::new(model_name)
110 };
111
112 Ok(Arc::new(provider))
113 }
114
115 #[cfg(feature = "lmstudio")]
116 ProviderKind::LMStudio => {
117 let model_name = config.model_name.as_ref().ok_or_else(|| {
118 anyhow!("LM Studio provider requires a model_name to be specified")
119 })?;
120
121 let provider = if let Ok(endpoint) = std::env::var("LMSTUDIO_ENDPOINT") {
122 LMStudioProvider::with_endpoint(endpoint, model_name)
123 } else {
124 LMStudioProvider::new(model_name)
125 };
126
127 Ok(Arc::new(provider))
128 }
129 }
130}
131
132pub fn resolve_api_key(source: &str) -> Result<String> {
139 if let Some(env_var) = source.strip_prefix("env:") {
140 load_api_key_from_env(env_var)
141 } else if let Some(path) = source.strip_prefix("file:") {
142 load_api_key_from_file(path)
143 } else {
144 Ok(source.to_string())
146 }
147}
148
149pub fn load_api_key_from_env(env_var: &str) -> Result<String> {
151 std::env::var(env_var).context(format!("Environment variable {} not set", env_var))
152}
153
154pub fn load_api_key_from_file(path: &str) -> Result<String> {
156 let expanded_path = if let Some(stripped) = path.strip_prefix("~/") {
158 if let Some(home) = std::env::var_os("HOME") {
159 std::path::PathBuf::from(home).join(stripped)
160 } else {
161 std::path::PathBuf::from(path)
162 }
163 } else {
164 std::path::PathBuf::from(path)
165 };
166
167 std::fs::read_to_string(&expanded_path)
168 .context(format!("Failed to read API key from file: {}", path))
169 .map(|s| s.trim().to_string())
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::config::ModelConfig;
176
177 #[test]
178 fn test_create_mock_provider() {
179 let config = ModelConfig {
180 provider: "mock".to_string(),
181 model_name: Some("test-model".to_string()),
182 embeddings_model: None,
183 api_key_source: None,
184 temperature: 0.8,
185 };
186
187 let provider = create_provider(&config).unwrap();
188 assert_eq!(provider.kind(), ProviderKind::Mock);
189 }
190
191 #[test]
192 fn test_create_unknown_provider() {
193 let config = ModelConfig {
194 provider: "unknown-provider".to_string(),
195 model_name: None,
196 embeddings_model: None,
197 api_key_source: None,
198 temperature: 0.7,
199 };
200
201 let result = create_provider(&config);
202 assert!(result.is_err());
203 }
204
205 #[test]
206 fn test_load_api_key_from_env() {
207 unsafe {
208 std::env::set_var("TEST_API_KEY", "env-key-value");
209 }
210 let key = load_api_key_from_env("TEST_API_KEY").unwrap();
211 assert_eq!(key, "env-key-value");
212 unsafe {
213 std::env::remove_var("TEST_API_KEY");
214 }
215 }
216
217 #[test]
218 fn test_load_api_key_env_var_missing() {
219 let result = load_api_key_from_env("NONEXISTENT_VAR");
220 assert!(result.is_err());
221 }
222
223 #[test]
224 fn test_resolve_api_key_direct() {
225 let key = resolve_api_key("sk-direct-api-key").unwrap();
226 assert_eq!(key, "sk-direct-api-key");
227 }
228
229 #[test]
230 fn test_resolve_api_key_from_env() {
231 unsafe {
232 std::env::set_var("TEST_RESOLVE_KEY", "env-resolved-value");
233 }
234 let key = resolve_api_key("env:TEST_RESOLVE_KEY").unwrap();
235 assert_eq!(key, "env-resolved-value");
236 unsafe {
237 std::env::remove_var("TEST_RESOLVE_KEY");
238 }
239 }
240
241 #[test]
242 fn test_resolve_api_key_from_file() {
243 use std::io::Write;
244 let temp_dir = tempfile::tempdir().unwrap();
245 let file_path = temp_dir.path().join("test_api_key.txt");
246 let mut file = std::fs::File::create(&file_path).unwrap();
247 writeln!(file, "file-api-key-value").unwrap();
248
249 let key = resolve_api_key(&format!("file:{}", file_path.display())).unwrap();
250 assert_eq!(key, "file-api-key-value");
251 }
252
253 #[test]
254 fn test_load_api_key_from_file_with_whitespace() {
255 use std::io::Write;
256 let temp_dir = tempfile::tempdir().unwrap();
257 let file_path = temp_dir.path().join("test_key_whitespace.txt");
258 let mut file = std::fs::File::create(&file_path).unwrap();
259 writeln!(file, " api-key-with-spaces ").unwrap();
260
261 let key = load_api_key_from_file(file_path.to_str().unwrap()).unwrap();
262 assert_eq!(key, "api-key-with-spaces");
263 }
264}