vtcode_core/llm/
factory.rs

1use super::providers::{
2    AnthropicProvider, DeepSeekProvider, GeminiProvider, OpenAIProvider, OpenRouterProvider,
3    XAIProvider,
4};
5use crate::config::core::PromptCachingConfig;
6use crate::llm::provider::{LLMError, LLMProvider};
7use std::collections::HashMap;
8
9/// LLM provider factory and registry
10pub struct LLMFactory {
11    providers: HashMap<String, Box<dyn Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync>>,
12}
13
14#[derive(Debug, Clone)]
15pub struct ProviderConfig {
16    pub api_key: Option<String>,
17    pub base_url: Option<String>,
18    pub model: Option<String>,
19    pub prompt_cache: Option<PromptCachingConfig>,
20}
21
22impl LLMFactory {
23    pub fn new() -> Self {
24        let mut factory = Self {
25            providers: HashMap::new(),
26        };
27
28        // Register built-in providers
29        factory.register_provider(
30            "gemini",
31            Box::new(|config: ProviderConfig| {
32                let ProviderConfig {
33                    api_key,
34                    base_url,
35                    model,
36                    prompt_cache,
37                } = config;
38                Box::new(GeminiProvider::from_config(
39                    api_key,
40                    model,
41                    base_url,
42                    prompt_cache,
43                )) as Box<dyn LLMProvider>
44            }),
45        );
46
47        factory.register_provider(
48            "openai",
49            Box::new(|config: ProviderConfig| {
50                let ProviderConfig {
51                    api_key,
52                    base_url,
53                    model,
54                    prompt_cache,
55                } = config;
56                Box::new(OpenAIProvider::from_config(
57                    api_key,
58                    model,
59                    base_url,
60                    prompt_cache,
61                )) as Box<dyn LLMProvider>
62            }),
63        );
64
65        factory.register_provider(
66            "anthropic",
67            Box::new(|config: ProviderConfig| {
68                let ProviderConfig {
69                    api_key,
70                    base_url,
71                    model,
72                    prompt_cache,
73                } = config;
74                Box::new(AnthropicProvider::from_config(
75                    api_key,
76                    model,
77                    base_url,
78                    prompt_cache,
79                )) as Box<dyn LLMProvider>
80            }),
81        );
82
83        factory.register_provider(
84            "deepseek",
85            Box::new(|config: ProviderConfig| {
86                let ProviderConfig {
87                    api_key,
88                    base_url,
89                    model,
90                    prompt_cache,
91                } = config;
92                Box::new(DeepSeekProvider::from_config(
93                    api_key,
94                    model,
95                    base_url,
96                    prompt_cache,
97                )) as Box<dyn LLMProvider>
98            }),
99        );
100
101        factory.register_provider(
102            "openrouter",
103            Box::new(|config: ProviderConfig| {
104                let ProviderConfig {
105                    api_key,
106                    base_url,
107                    model,
108                    prompt_cache,
109                } = config;
110                Box::new(OpenRouterProvider::from_config(
111                    api_key,
112                    model,
113                    base_url,
114                    prompt_cache,
115                )) as Box<dyn LLMProvider>
116            }),
117        );
118
119        factory.register_provider(
120            "xai",
121            Box::new(|config: ProviderConfig| {
122                let ProviderConfig {
123                    api_key,
124                    base_url,
125                    model,
126                    prompt_cache,
127                } = config;
128                Box::new(XAIProvider::from_config(
129                    api_key,
130                    model,
131                    base_url,
132                    prompt_cache,
133                )) as Box<dyn LLMProvider>
134            }),
135        );
136
137        factory
138    }
139
140    /// Register a new provider
141    pub fn register_provider<F>(&mut self, name: &str, factory_fn: F)
142    where
143        F: Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync + 'static,
144    {
145        self.providers
146            .insert(name.to_string(), Box::new(factory_fn));
147    }
148
149    /// Create provider instance
150    pub fn create_provider(
151        &self,
152        provider_name: &str,
153        config: ProviderConfig,
154    ) -> Result<Box<dyn LLMProvider>, LLMError> {
155        let factory_fn = self.providers.get(provider_name).ok_or_else(|| {
156            LLMError::InvalidRequest(format!("Unknown provider: {}", provider_name))
157        })?;
158
159        Ok(factory_fn(config))
160    }
161
162    /// List available providers
163    pub fn list_providers(&self) -> Vec<String> {
164        self.providers.keys().cloned().collect()
165    }
166
167    /// Determine provider name from model string
168    pub fn provider_from_model(&self, model: &str) -> Option<String> {
169        let m = model.to_lowercase();
170        if m.starts_with("gpt-") || m.starts_with("o3") || m.starts_with("o1") {
171            Some("openai".to_string())
172        } else if m.starts_with("claude-") {
173            Some("anthropic".to_string())
174        } else if m.starts_with("deepseek-") {
175            Some("deepseek".to_string())
176        } else if m.contains("gemini") || m.starts_with("palm") {
177            Some("gemini".to_string())
178        } else if m.starts_with("grok-") || m.starts_with("xai-") {
179            Some("xai".to_string())
180        } else if m.contains('/') || m.contains('@') {
181            Some("openrouter".to_string())
182        } else {
183            None
184        }
185    }
186}
187
188impl Default for LLMFactory {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194/// Global factory instance
195use std::sync::{LazyLock, Mutex};
196
197static FACTORY: LazyLock<Mutex<LLMFactory>> = LazyLock::new(|| Mutex::new(LLMFactory::new()));
198
199/// Get global factory instance
200pub fn get_factory() -> &'static Mutex<LLMFactory> {
201    &FACTORY
202}
203
204/// Create provider from model name and API key
205pub fn create_provider_for_model(
206    model: &str,
207    api_key: String,
208    prompt_cache: Option<PromptCachingConfig>,
209) -> Result<Box<dyn LLMProvider>, LLMError> {
210    let factory = get_factory().lock().unwrap();
211    let provider_name = factory.provider_from_model(model).ok_or_else(|| {
212        LLMError::InvalidRequest(format!("Cannot determine provider for model: {}", model))
213    })?;
214    drop(factory);
215
216    create_provider_with_config(
217        &provider_name,
218        Some(api_key),
219        None,
220        Some(model.to_string()),
221        prompt_cache,
222    )
223}
224
225/// Create provider with full configuration
226pub fn create_provider_with_config(
227    provider_name: &str,
228    api_key: Option<String>,
229    base_url: Option<String>,
230    model: Option<String>,
231    prompt_cache: Option<PromptCachingConfig>,
232) -> Result<Box<dyn LLMProvider>, LLMError> {
233    let factory = get_factory().lock().unwrap();
234    let config = ProviderConfig {
235        api_key,
236        base_url,
237        model,
238        prompt_cache,
239    };
240
241    factory.create_provider(provider_name, config)
242}