spec_ai_core/agent/
factory.rs

1//! Provider Factory
2//!
3//! Creates model provider instances based on configuration.
4
5use crate::agent::model::{ModelProvider, ProviderKind};
6#[cfg(feature = "anthropic")]
7use crate::agent::providers::AnthropicProvider;
8#[cfg(feature = "lmstudio")]
9use crate::agent::providers::LMStudioProvider;
10#[cfg(feature = "mlx")]
11use crate::agent::providers::MLXProvider;
12use crate::agent::providers::MockProvider;
13#[cfg(feature = "ollama")]
14use crate::agent::providers::OllamaProvider;
15#[cfg(feature = "openai")]
16use crate::agent::providers::OpenAIProvider;
17use crate::config::ModelConfig;
18use anyhow::{anyhow, Context, Result};
19use std::sync::Arc;
20
21/// Create a model provider from configuration
22pub fn create_provider(config: &ModelConfig) -> Result<Arc<dyn ModelProvider>> {
23    let provider_kind = ProviderKind::from_str(&config.provider)
24        .ok_or_else(|| anyhow!("Unknown provider: {}", config.provider))?;
25
26    match provider_kind {
27        ProviderKind::Mock => {
28            // Create mock provider with optional custom responses
29            let provider = if let Some(model_name) = &config.model_name {
30                MockProvider::default().with_model_name(model_name.clone())
31            } else {
32                MockProvider::default()
33            };
34            Ok(Arc::new(provider))
35        }
36
37        #[cfg(feature = "openai")]
38        ProviderKind::OpenAI => {
39            // Get API key from config
40            let api_key = if let Some(source) = &config.api_key_source {
41                resolve_api_key(source)?
42            } else {
43                // Default to OPENAI_API_KEY environment variable
44                load_api_key_from_env("OPENAI_API_KEY")?
45            };
46
47            // Create OpenAI provider
48            let mut provider = OpenAIProvider::with_api_key(api_key);
49
50            // Set model if specified in config
51            if let Some(model_name) = &config.model_name {
52                provider = provider.with_model(model_name.clone());
53            }
54
55            Ok(Arc::new(provider))
56        }
57
58        #[cfg(feature = "anthropic")]
59        ProviderKind::Anthropic => {
60            // Get API key from config
61            let api_key = if let Some(source) = &config.api_key_source {
62                resolve_api_key(source)?
63            } else {
64                // Default to ANTHROPIC_API_KEY environment variable
65                load_api_key_from_env("ANTHROPIC_API_KEY")?
66            };
67
68            // Create Anthropic provider
69            let mut provider = AnthropicProvider::with_api_key(api_key);
70
71            // Set model if specified in config
72            if let Some(model_name) = &config.model_name {
73                provider = provider.with_model(model_name.clone());
74            }
75
76            Ok(Arc::new(provider))
77        }
78
79        #[cfg(feature = "ollama")]
80        ProviderKind::Ollama => {
81            // Create Ollama provider with optional custom base URL
82            let mut provider = if let Ok(base_url) = std::env::var("OLLAMA_BASE_URL") {
83                OllamaProvider::with_base_url(base_url)
84            } else {
85                OllamaProvider::new()
86            };
87
88            // Set model if specified in config
89            if let Some(model_name) = &config.model_name {
90                provider = provider.with_model(model_name.clone());
91            }
92
93            Ok(Arc::new(provider))
94        }
95
96        #[cfg(feature = "mlx")]
97        ProviderKind::MLX => {
98            // MLX requires a model name
99            let model_name = config
100                .model_name
101                .as_ref()
102                .ok_or_else(|| anyhow!("MLX provider requires a model_name to be specified"))?;
103
104            // Create MLX provider with default endpoint (localhost:10240)
105            // Users can customize this by setting MLX_ENDPOINT environment variable
106            let provider = if let Ok(endpoint) = std::env::var("MLX_ENDPOINT") {
107                MLXProvider::with_endpoint(endpoint, model_name)
108            } else {
109                MLXProvider::new(model_name)
110            };
111
112            Ok(Arc::new(provider))
113        }
114
115        #[cfg(feature = "lmstudio")]
116        ProviderKind::LMStudio => {
117            let model_name = config.model_name.as_ref().ok_or_else(|| {
118                anyhow!("LM Studio provider requires a model_name to be specified")
119            })?;
120
121            let provider = if let Ok(endpoint) = std::env::var("LMSTUDIO_ENDPOINT") {
122                LMStudioProvider::with_endpoint(endpoint, model_name)
123            } else {
124                LMStudioProvider::new(model_name)
125            };
126
127            Ok(Arc::new(provider))
128        }
129    }
130}
131
132/// Resolve API key from a source string
133///
134/// Supports the following formats:
135/// - `env:VAR_NAME` - Load from environment variable
136/// - `file:PATH` - Load from file
137/// - Any other string - Use as-is (direct API key)
138pub fn resolve_api_key(source: &str) -> Result<String> {
139    if let Some(env_var) = source.strip_prefix("env:") {
140        load_api_key_from_env(env_var)
141    } else if let Some(path) = source.strip_prefix("file:") {
142        load_api_key_from_file(path)
143    } else {
144        // Treat as direct API key
145        Ok(source.to_string())
146    }
147}
148
149/// Load API key from environment variable
150pub fn load_api_key_from_env(env_var: &str) -> Result<String> {
151    std::env::var(env_var).context(format!("Environment variable {} not set", env_var))
152}
153
154/// Load API key from file
155pub fn load_api_key_from_file(path: &str) -> Result<String> {
156    // Handle tilde expansion manually
157    let expanded_path = if let Some(stripped) = path.strip_prefix("~/") {
158        if let Some(home) = std::env::var_os("HOME") {
159            std::path::PathBuf::from(home).join(stripped)
160        } else {
161            std::path::PathBuf::from(path)
162        }
163    } else {
164        std::path::PathBuf::from(path)
165    };
166
167    std::fs::read_to_string(&expanded_path)
168        .context(format!("Failed to read API key from file: {}", path))
169        .map(|s| s.trim().to_string())
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use crate::config::ModelConfig;
176
177    #[test]
178    fn test_create_mock_provider() {
179        let config = ModelConfig {
180            provider: "mock".to_string(),
181            model_name: Some("test-model".to_string()),
182            embeddings_model: None,
183            api_key_source: None,
184            temperature: 0.8,
185        };
186
187        let provider = create_provider(&config).unwrap();
188        assert_eq!(provider.kind(), ProviderKind::Mock);
189    }
190
191    #[test]
192    fn test_create_unknown_provider() {
193        let config = ModelConfig {
194            provider: "unknown-provider".to_string(),
195            model_name: None,
196            embeddings_model: None,
197            api_key_source: None,
198            temperature: 0.7,
199        };
200
201        let result = create_provider(&config);
202        assert!(result.is_err());
203    }
204
205    #[test]
206    fn test_load_api_key_from_env() {
207        unsafe {
208            std::env::set_var("TEST_API_KEY", "env-key-value");
209        }
210        let key = load_api_key_from_env("TEST_API_KEY").unwrap();
211        assert_eq!(key, "env-key-value");
212        unsafe {
213            std::env::remove_var("TEST_API_KEY");
214        }
215    }
216
217    #[test]
218    fn test_load_api_key_env_var_missing() {
219        let result = load_api_key_from_env("NONEXISTENT_VAR");
220        assert!(result.is_err());
221    }
222
223    #[test]
224    fn test_resolve_api_key_direct() {
225        let key = resolve_api_key("sk-direct-api-key").unwrap();
226        assert_eq!(key, "sk-direct-api-key");
227    }
228
229    #[test]
230    fn test_resolve_api_key_from_env() {
231        unsafe {
232            std::env::set_var("TEST_RESOLVE_KEY", "env-resolved-value");
233        }
234        let key = resolve_api_key("env:TEST_RESOLVE_KEY").unwrap();
235        assert_eq!(key, "env-resolved-value");
236        unsafe {
237            std::env::remove_var("TEST_RESOLVE_KEY");
238        }
239    }
240
241    #[test]
242    fn test_resolve_api_key_from_file() {
243        use std::io::Write;
244        let temp_dir = tempfile::tempdir().unwrap();
245        let file_path = temp_dir.path().join("test_api_key.txt");
246        let mut file = std::fs::File::create(&file_path).unwrap();
247        writeln!(file, "file-api-key-value").unwrap();
248
249        let key = resolve_api_key(&format!("file:{}", file_path.display())).unwrap();
250        assert_eq!(key, "file-api-key-value");
251    }
252
253    #[test]
254    fn test_load_api_key_from_file_with_whitespace() {
255        use std::io::Write;
256        let temp_dir = tempfile::tempdir().unwrap();
257        let file_path = temp_dir.path().join("test_key_whitespace.txt");
258        let mut file = std::fs::File::create(&file_path).unwrap();
259        writeln!(file, "  api-key-with-spaces  ").unwrap();
260
261        let key = load_api_key_from_file(file_path.to_str().unwrap()).unwrap();
262        assert_eq!(key, "api-key-with-spaces");
263    }
264}