vtcode_core/llm/
factory.rs

1use super::providers::{AnthropicProvider, GeminiProvider, OpenAIProvider, OpenRouterProvider};
2use crate::llm::provider::{LLMError, LLMProvider};
3use std::collections::HashMap;
4
5/// LLM provider factory and registry
6pub struct LLMFactory {
7    providers: HashMap<String, Box<dyn Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync>>,
8}
9
10#[derive(Debug, Clone)]
11pub struct ProviderConfig {
12    pub api_key: Option<String>,
13    pub base_url: Option<String>,
14    pub model: Option<String>,
15}
16
17impl LLMFactory {
18    pub fn new() -> Self {
19        let mut factory = Self {
20            providers: HashMap::new(),
21        };
22
23        // Register built-in providers
24        factory.register_provider(
25            "gemini",
26            Box::new(|config: ProviderConfig| {
27                let ProviderConfig {
28                    api_key,
29                    base_url,
30                    model,
31                } = config;
32                Box::new(GeminiProvider::from_config(api_key, model, base_url))
33                    as Box<dyn LLMProvider>
34            }),
35        );
36
37        factory.register_provider(
38            "openai",
39            Box::new(|config: ProviderConfig| {
40                let ProviderConfig {
41                    api_key,
42                    base_url,
43                    model,
44                } = config;
45                Box::new(OpenAIProvider::from_config(api_key, model, base_url))
46                    as Box<dyn LLMProvider>
47            }),
48        );
49
50        factory.register_provider(
51            "anthropic",
52            Box::new(|config: ProviderConfig| {
53                let ProviderConfig {
54                    api_key,
55                    base_url,
56                    model,
57                } = config;
58                Box::new(AnthropicProvider::from_config(api_key, model, base_url))
59                    as Box<dyn LLMProvider>
60            }),
61        );
62
63        factory.register_provider(
64            "openrouter",
65            Box::new(|config: ProviderConfig| {
66                let ProviderConfig {
67                    api_key,
68                    base_url,
69                    model,
70                } = config;
71                Box::new(OpenRouterProvider::from_config(api_key, model, base_url))
72                    as Box<dyn LLMProvider>
73            }),
74        );
75
76        factory
77    }
78
79    /// Register a new provider
80    pub fn register_provider<F>(&mut self, name: &str, factory_fn: F)
81    where
82        F: Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync + 'static,
83    {
84        self.providers
85            .insert(name.to_string(), Box::new(factory_fn));
86    }
87
88    /// Create provider instance
89    pub fn create_provider(
90        &self,
91        provider_name: &str,
92        config: ProviderConfig,
93    ) -> Result<Box<dyn LLMProvider>, LLMError> {
94        let factory_fn = self.providers.get(provider_name).ok_or_else(|| {
95            LLMError::InvalidRequest(format!("Unknown provider: {}", provider_name))
96        })?;
97
98        Ok(factory_fn(config))
99    }
100
101    /// List available providers
102    pub fn list_providers(&self) -> Vec<String> {
103        self.providers.keys().cloned().collect()
104    }
105
106    /// Determine provider name from model string
107    pub fn provider_from_model(&self, model: &str) -> Option<String> {
108        let m = model.to_lowercase();
109        if m.starts_with("gpt-") || m.starts_with("o3") || m.starts_with("o1") {
110            Some("openai".to_string())
111        } else if m.starts_with("claude-") {
112            Some("anthropic".to_string())
113        } else if m.contains("gemini") || m.starts_with("palm") {
114            Some("gemini".to_string())
115        } else if m.contains('/') || m.contains('@') {
116            Some("openrouter".to_string())
117        } else {
118            None
119        }
120    }
121}
122
123impl Default for LLMFactory {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129/// Global factory instance
130use std::sync::{LazyLock, Mutex};
131
132static FACTORY: LazyLock<Mutex<LLMFactory>> = LazyLock::new(|| Mutex::new(LLMFactory::new()));
133
134/// Get global factory instance
135pub fn get_factory() -> &'static Mutex<LLMFactory> {
136    &FACTORY
137}
138
139/// Create provider from model name and API key
140pub fn create_provider_for_model(
141    model: &str,
142    api_key: String,
143) -> Result<Box<dyn LLMProvider>, LLMError> {
144    let factory = get_factory().lock().unwrap();
145    let provider_name = factory.provider_from_model(model).ok_or_else(|| {
146        LLMError::InvalidRequest(format!("Cannot determine provider for model: {}", model))
147    })?;
148    drop(factory);
149
150    create_provider_with_config(&provider_name, Some(api_key), None, Some(model.to_string()))
151}
152
153/// Create provider with full configuration
154pub fn create_provider_with_config(
155    provider_name: &str,
156    api_key: Option<String>,
157    base_url: Option<String>,
158    model: Option<String>,
159) -> Result<Box<dyn LLMProvider>, LLMError> {
160    let factory = get_factory().lock().unwrap();
161    let config = ProviderConfig {
162        api_key,
163        base_url,
164        model,
165    };
166
167    factory.create_provider(provider_name, config)
168}