vtcode_core/llm/
factory.rs1use super::providers::{AnthropicProvider, GeminiProvider, OpenAIProvider, OpenRouterProvider};
2use crate::llm::provider::{LLMError, LLMProvider};
3use std::collections::HashMap;
4
5pub struct LLMFactory {
7 providers: HashMap<String, Box<dyn Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync>>,
8}
9
10#[derive(Debug, Clone)]
11pub struct ProviderConfig {
12 pub api_key: Option<String>,
13 pub base_url: Option<String>,
14 pub model: Option<String>,
15}
16
17impl LLMFactory {
18 pub fn new() -> Self {
19 let mut factory = Self {
20 providers: HashMap::new(),
21 };
22
23 factory.register_provider(
25 "gemini",
26 Box::new(|config: ProviderConfig| {
27 let ProviderConfig {
28 api_key,
29 base_url,
30 model,
31 } = config;
32 Box::new(GeminiProvider::from_config(api_key, model, base_url))
33 as Box<dyn LLMProvider>
34 }),
35 );
36
37 factory.register_provider(
38 "openai",
39 Box::new(|config: ProviderConfig| {
40 let ProviderConfig {
41 api_key,
42 base_url,
43 model,
44 } = config;
45 Box::new(OpenAIProvider::from_config(api_key, model, base_url))
46 as Box<dyn LLMProvider>
47 }),
48 );
49
50 factory.register_provider(
51 "anthropic",
52 Box::new(|config: ProviderConfig| {
53 let ProviderConfig {
54 api_key,
55 base_url,
56 model,
57 } = config;
58 Box::new(AnthropicProvider::from_config(api_key, model, base_url))
59 as Box<dyn LLMProvider>
60 }),
61 );
62
63 factory.register_provider(
64 "openrouter",
65 Box::new(|config: ProviderConfig| {
66 let ProviderConfig {
67 api_key,
68 base_url,
69 model,
70 } = config;
71 Box::new(OpenRouterProvider::from_config(api_key, model, base_url))
72 as Box<dyn LLMProvider>
73 }),
74 );
75
76 factory
77 }
78
79 pub fn register_provider<F>(&mut self, name: &str, factory_fn: F)
81 where
82 F: Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync + 'static,
83 {
84 self.providers
85 .insert(name.to_string(), Box::new(factory_fn));
86 }
87
88 pub fn create_provider(
90 &self,
91 provider_name: &str,
92 config: ProviderConfig,
93 ) -> Result<Box<dyn LLMProvider>, LLMError> {
94 let factory_fn = self.providers.get(provider_name).ok_or_else(|| {
95 LLMError::InvalidRequest(format!("Unknown provider: {}", provider_name))
96 })?;
97
98 Ok(factory_fn(config))
99 }
100
101 pub fn list_providers(&self) -> Vec<String> {
103 self.providers.keys().cloned().collect()
104 }
105
106 pub fn provider_from_model(&self, model: &str) -> Option<String> {
108 let m = model.to_lowercase();
109 if m.starts_with("gpt-") || m.starts_with("o3") || m.starts_with("o1") {
110 Some("openai".to_string())
111 } else if m.starts_with("claude-") {
112 Some("anthropic".to_string())
113 } else if m.contains("gemini") || m.starts_with("palm") {
114 Some("gemini".to_string())
115 } else if m.contains('/') || m.contains('@') {
116 Some("openrouter".to_string())
117 } else {
118 None
119 }
120 }
121}
122
123impl Default for LLMFactory {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129use std::sync::{LazyLock, Mutex};
131
132static FACTORY: LazyLock<Mutex<LLMFactory>> = LazyLock::new(|| Mutex::new(LLMFactory::new()));
133
134pub fn get_factory() -> &'static Mutex<LLMFactory> {
136 &FACTORY
137}
138
139pub fn create_provider_for_model(
141 model: &str,
142 api_key: String,
143) -> Result<Box<dyn LLMProvider>, LLMError> {
144 let factory = get_factory().lock().unwrap();
145 let provider_name = factory.provider_from_model(model).ok_or_else(|| {
146 LLMError::InvalidRequest(format!("Cannot determine provider for model: {}", model))
147 })?;
148 drop(factory);
149
150 create_provider_with_config(&provider_name, Some(api_key), None, Some(model.to_string()))
151}
152
153pub fn create_provider_with_config(
155 provider_name: &str,
156 api_key: Option<String>,
157 base_url: Option<String>,
158 model: Option<String>,
159) -> Result<Box<dyn LLMProvider>, LLMError> {
160 let factory = get_factory().lock().unwrap();
161 let config = ProviderConfig {
162 api_key,
163 base_url,
164 model,
165 };
166
167 factory.create_provider(provider_name, config)
168}