vtcode_core/llm/
factory.rs

1use super::providers::{
2    AnthropicProvider, GeminiProvider, OpenAIProvider, OpenRouterProvider, XAIProvider,
3};
4use crate::llm::provider::{LLMError, LLMProvider};
5use std::collections::HashMap;
6
7/// LLM provider factory and registry
8pub struct LLMFactory {
9    providers: HashMap<String, Box<dyn Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync>>,
10}
11
12#[derive(Debug, Clone)]
13pub struct ProviderConfig {
14    pub api_key: Option<String>,
15    pub base_url: Option<String>,
16    pub model: Option<String>,
17}
18
19impl LLMFactory {
20    pub fn new() -> Self {
21        let mut factory = Self {
22            providers: HashMap::new(),
23        };
24
25        // Register built-in providers
26        factory.register_provider(
27            "gemini",
28            Box::new(|config: ProviderConfig| {
29                let ProviderConfig {
30                    api_key,
31                    base_url,
32                    model,
33                } = config;
34                Box::new(GeminiProvider::from_config(api_key, model, base_url))
35                    as Box<dyn LLMProvider>
36            }),
37        );
38
39        factory.register_provider(
40            "openai",
41            Box::new(|config: ProviderConfig| {
42                let ProviderConfig {
43                    api_key,
44                    base_url,
45                    model,
46                } = config;
47                Box::new(OpenAIProvider::from_config(api_key, model, base_url))
48                    as Box<dyn LLMProvider>
49            }),
50        );
51
52        factory.register_provider(
53            "anthropic",
54            Box::new(|config: ProviderConfig| {
55                let ProviderConfig {
56                    api_key,
57                    base_url,
58                    model,
59                } = config;
60                Box::new(AnthropicProvider::from_config(api_key, model, base_url))
61                    as Box<dyn LLMProvider>
62            }),
63        );
64
65        factory.register_provider(
66            "openrouter",
67            Box::new(|config: ProviderConfig| {
68                let ProviderConfig {
69                    api_key,
70                    base_url,
71                    model,
72                } = config;
73                Box::new(OpenRouterProvider::from_config(api_key, model, base_url))
74                    as Box<dyn LLMProvider>
75            }),
76        );
77
78        factory.register_provider(
79            "xai",
80            Box::new(|config: ProviderConfig| {
81                let ProviderConfig {
82                    api_key,
83                    base_url,
84                    model,
85                } = config;
86                Box::new(XAIProvider::from_config(api_key, model, base_url)) as Box<dyn LLMProvider>
87            }),
88        );
89
90        factory
91    }
92
93    /// Register a new provider
94    pub fn register_provider<F>(&mut self, name: &str, factory_fn: F)
95    where
96        F: Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync + 'static,
97    {
98        self.providers
99            .insert(name.to_string(), Box::new(factory_fn));
100    }
101
102    /// Create provider instance
103    pub fn create_provider(
104        &self,
105        provider_name: &str,
106        config: ProviderConfig,
107    ) -> Result<Box<dyn LLMProvider>, LLMError> {
108        let factory_fn = self.providers.get(provider_name).ok_or_else(|| {
109            LLMError::InvalidRequest(format!("Unknown provider: {}", provider_name))
110        })?;
111
112        Ok(factory_fn(config))
113    }
114
115    /// List available providers
116    pub fn list_providers(&self) -> Vec<String> {
117        self.providers.keys().cloned().collect()
118    }
119
120    /// Determine provider name from model string
121    pub fn provider_from_model(&self, model: &str) -> Option<String> {
122        let m = model.to_lowercase();
123        if m.starts_with("gpt-") || m.starts_with("o3") || m.starts_with("o1") {
124            Some("openai".to_string())
125        } else if m.starts_with("claude-") {
126            Some("anthropic".to_string())
127        } else if m.contains("gemini") || m.starts_with("palm") {
128            Some("gemini".to_string())
129        } else if m.starts_with("grok-") || m.starts_with("xai-") {
130            Some("xai".to_string())
131        } else if m.contains('/') || m.contains('@') {
132            Some("openrouter".to_string())
133        } else {
134            None
135        }
136    }
137}
138
139impl Default for LLMFactory {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145/// Global factory instance
146use std::sync::{LazyLock, Mutex};
147
148static FACTORY: LazyLock<Mutex<LLMFactory>> = LazyLock::new(|| Mutex::new(LLMFactory::new()));
149
150/// Get global factory instance
151pub fn get_factory() -> &'static Mutex<LLMFactory> {
152    &FACTORY
153}
154
155/// Create provider from model name and API key
156pub fn create_provider_for_model(
157    model: &str,
158    api_key: String,
159) -> Result<Box<dyn LLMProvider>, LLMError> {
160    let factory = get_factory().lock().unwrap();
161    let provider_name = factory.provider_from_model(model).ok_or_else(|| {
162        LLMError::InvalidRequest(format!("Cannot determine provider for model: {}", model))
163    })?;
164    drop(factory);
165
166    create_provider_with_config(&provider_name, Some(api_key), None, Some(model.to_string()))
167}
168
169/// Create provider with full configuration
170pub fn create_provider_with_config(
171    provider_name: &str,
172    api_key: Option<String>,
173    base_url: Option<String>,
174    model: Option<String>,
175) -> Result<Box<dyn LLMProvider>, LLMError> {
176    let factory = get_factory().lock().unwrap();
177    let config = ProviderConfig {
178        api_key,
179        base_url,
180        model,
181    };
182
183    factory.create_provider(provider_name, config)
184}