vtcode_core/llm/
factory.rs

1use super::providers::{
2    AnthropicProvider, DeepSeekProvider, GeminiProvider, MoonshotProvider, OpenAIProvider,
3    OpenRouterProvider, XAIProvider, ZAIProvider,
4};
5use crate::config::core::PromptCachingConfig;
6use crate::llm::provider::{LLMError, LLMProvider};
7use std::collections::HashMap;
8
9/// LLM provider factory and registry
10pub struct LLMFactory {
11    providers: HashMap<String, Box<dyn Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync>>,
12}
13
14#[derive(Debug, Clone)]
15pub struct ProviderConfig {
16    pub api_key: Option<String>,
17    pub base_url: Option<String>,
18    pub model: Option<String>,
19    pub prompt_cache: Option<PromptCachingConfig>,
20}
21
22impl LLMFactory {
23    pub fn new() -> Self {
24        let mut factory = Self {
25            providers: HashMap::new(),
26        };
27
28        // Register built-in providers
29        factory.register_provider(
30            "gemini",
31            Box::new(|config: ProviderConfig| {
32                let ProviderConfig {
33                    api_key,
34                    base_url,
35                    model,
36                    prompt_cache,
37                } = config;
38                Box::new(GeminiProvider::from_config(
39                    api_key,
40                    model,
41                    base_url,
42                    prompt_cache,
43                )) as Box<dyn LLMProvider>
44            }),
45        );
46
47        factory.register_provider(
48            "openai",
49            Box::new(|config: ProviderConfig| {
50                let ProviderConfig {
51                    api_key,
52                    base_url,
53                    model,
54                    prompt_cache,
55                } = config;
56                Box::new(OpenAIProvider::from_config(
57                    api_key,
58                    model,
59                    base_url,
60                    prompt_cache,
61                )) as Box<dyn LLMProvider>
62            }),
63        );
64
65        factory.register_provider(
66            "anthropic",
67            Box::new(|config: ProviderConfig| {
68                let ProviderConfig {
69                    api_key,
70                    base_url,
71                    model,
72                    prompt_cache,
73                } = config;
74                Box::new(AnthropicProvider::from_config(
75                    api_key,
76                    model,
77                    base_url,
78                    prompt_cache,
79                )) as Box<dyn LLMProvider>
80            }),
81        );
82
83        factory.register_provider(
84            "deepseek",
85            Box::new(|config: ProviderConfig| {
86                let ProviderConfig {
87                    api_key,
88                    base_url,
89                    model,
90                    prompt_cache,
91                } = config;
92                Box::new(DeepSeekProvider::from_config(
93                    api_key,
94                    model,
95                    base_url,
96                    prompt_cache,
97                )) as Box<dyn LLMProvider>
98            }),
99        );
100
101        factory.register_provider(
102            "openrouter",
103            Box::new(|config: ProviderConfig| {
104                let ProviderConfig {
105                    api_key,
106                    base_url,
107                    model,
108                    prompt_cache,
109                } = config;
110                Box::new(OpenRouterProvider::from_config(
111                    api_key,
112                    model,
113                    base_url,
114                    prompt_cache,
115                )) as Box<dyn LLMProvider>
116            }),
117        );
118
119        factory.register_provider(
120            "moonshot",
121            Box::new(|config: ProviderConfig| {
122                let ProviderConfig {
123                    api_key,
124                    base_url,
125                    model,
126                    prompt_cache,
127                } = config;
128                Box::new(MoonshotProvider::from_config(
129                    api_key,
130                    model,
131                    base_url,
132                    prompt_cache,
133                )) as Box<dyn LLMProvider>
134            }),
135        );
136
137        factory.register_provider(
138            "xai",
139            Box::new(|config: ProviderConfig| {
140                let ProviderConfig {
141                    api_key,
142                    base_url,
143                    model,
144                    prompt_cache,
145                } = config;
146                Box::new(XAIProvider::from_config(
147                    api_key,
148                    model,
149                    base_url,
150                    prompt_cache,
151                )) as Box<dyn LLMProvider>
152            }),
153        );
154
155        factory.register_provider(
156            "zai",
157            Box::new(|config: ProviderConfig| {
158                let ProviderConfig {
159                    api_key,
160                    base_url,
161                    model,
162                    prompt_cache,
163                } = config;
164                Box::new(ZAIProvider::from_config(
165                    api_key,
166                    model,
167                    base_url,
168                    prompt_cache,
169                )) as Box<dyn LLMProvider>
170            }),
171        );
172
173        factory
174    }
175
176    /// Register a new provider
177    pub fn register_provider<F>(&mut self, name: &str, factory_fn: F)
178    where
179        F: Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync + 'static,
180    {
181        self.providers
182            .insert(name.to_string(), Box::new(factory_fn));
183    }
184
185    /// Create provider instance
186    pub fn create_provider(
187        &self,
188        provider_name: &str,
189        config: ProviderConfig,
190    ) -> Result<Box<dyn LLMProvider>, LLMError> {
191        let factory_fn = self.providers.get(provider_name).ok_or_else(|| {
192            LLMError::InvalidRequest(format!("Unknown provider: {}", provider_name))
193        })?;
194
195        Ok(factory_fn(config))
196    }
197
198    /// List available providers
199    pub fn list_providers(&self) -> Vec<String> {
200        self.providers.keys().cloned().collect()
201    }
202
203    /// Determine provider name from model string
204    pub fn provider_from_model(&self, model: &str) -> Option<String> {
205        let m = model.to_lowercase();
206        if m.starts_with("gpt-") || m.starts_with("o3") || m.starts_with("o1") {
207            Some("openai".to_string())
208        } else if m.starts_with("claude-") {
209            Some("anthropic".to_string())
210        } else if m.starts_with("deepseek-") {
211            Some("deepseek".to_string())
212        } else if m.contains("gemini") || m.starts_with("palm") {
213            Some("gemini".to_string())
214        } else if m.starts_with("grok-") || m.starts_with("xai-") {
215            Some("xai".to_string())
216        } else if m.starts_with("glm-") {
217            Some("zai".to_string())
218        } else if m.starts_with("moonshot-") || m.starts_with("kimi-") {
219            Some("moonshot".to_string())
220        } else if m.contains('/') || m.contains('@') {
221            Some("openrouter".to_string())
222        } else {
223            None
224        }
225    }
226}
227
228impl Default for LLMFactory {
229    fn default() -> Self {
230        Self::new()
231    }
232}
233
234/// Global factory instance
235use std::sync::{LazyLock, Mutex};
236
237static FACTORY: LazyLock<Mutex<LLMFactory>> = LazyLock::new(|| Mutex::new(LLMFactory::new()));
238
239/// Get global factory instance
240pub fn get_factory() -> &'static Mutex<LLMFactory> {
241    &FACTORY
242}
243
244/// Create provider from model name and API key
245pub fn create_provider_for_model(
246    model: &str,
247    api_key: String,
248    prompt_cache: Option<PromptCachingConfig>,
249) -> Result<Box<dyn LLMProvider>, LLMError> {
250    let factory = get_factory().lock().unwrap();
251    let provider_name = factory.provider_from_model(model).ok_or_else(|| {
252        LLMError::InvalidRequest(format!("Cannot determine provider for model: {}", model))
253    })?;
254    drop(factory);
255
256    create_provider_with_config(
257        &provider_name,
258        Some(api_key),
259        None,
260        Some(model.to_string()),
261        prompt_cache,
262    )
263}
264
265/// Create provider with full configuration
266pub fn create_provider_with_config(
267    provider_name: &str,
268    api_key: Option<String>,
269    base_url: Option<String>,
270    model: Option<String>,
271    prompt_cache: Option<PromptCachingConfig>,
272) -> Result<Box<dyn LLMProvider>, LLMError> {
273    let factory = get_factory().lock().unwrap();
274    let config = ProviderConfig {
275        api_key,
276        base_url,
277        model,
278        prompt_cache,
279    };
280
281    factory.create_provider(provider_name, config)
282}