vtcode_core/llm/
factory.rs1use super::providers::{
2 AnthropicProvider, GeminiProvider, OpenAIProvider, OpenRouterProvider, XAIProvider,
3};
4use crate::llm::provider::{LLMError, LLMProvider};
5use std::collections::HashMap;
6
7pub struct LLMFactory {
9 providers: HashMap<String, Box<dyn Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync>>,
10}
11
12#[derive(Debug, Clone)]
13pub struct ProviderConfig {
14 pub api_key: Option<String>,
15 pub base_url: Option<String>,
16 pub model: Option<String>,
17}
18
19impl LLMFactory {
20 pub fn new() -> Self {
21 let mut factory = Self {
22 providers: HashMap::new(),
23 };
24
25 factory.register_provider(
27 "gemini",
28 Box::new(|config: ProviderConfig| {
29 let ProviderConfig {
30 api_key,
31 base_url,
32 model,
33 } = config;
34 Box::new(GeminiProvider::from_config(api_key, model, base_url))
35 as Box<dyn LLMProvider>
36 }),
37 );
38
39 factory.register_provider(
40 "openai",
41 Box::new(|config: ProviderConfig| {
42 let ProviderConfig {
43 api_key,
44 base_url,
45 model,
46 } = config;
47 Box::new(OpenAIProvider::from_config(api_key, model, base_url))
48 as Box<dyn LLMProvider>
49 }),
50 );
51
52 factory.register_provider(
53 "anthropic",
54 Box::new(|config: ProviderConfig| {
55 let ProviderConfig {
56 api_key,
57 base_url,
58 model,
59 } = config;
60 Box::new(AnthropicProvider::from_config(api_key, model, base_url))
61 as Box<dyn LLMProvider>
62 }),
63 );
64
65 factory.register_provider(
66 "openrouter",
67 Box::new(|config: ProviderConfig| {
68 let ProviderConfig {
69 api_key,
70 base_url,
71 model,
72 } = config;
73 Box::new(OpenRouterProvider::from_config(api_key, model, base_url))
74 as Box<dyn LLMProvider>
75 }),
76 );
77
78 factory.register_provider(
79 "xai",
80 Box::new(|config: ProviderConfig| {
81 let ProviderConfig {
82 api_key,
83 base_url,
84 model,
85 } = config;
86 Box::new(XAIProvider::from_config(api_key, model, base_url)) as Box<dyn LLMProvider>
87 }),
88 );
89
90 factory
91 }
92
93 pub fn register_provider<F>(&mut self, name: &str, factory_fn: F)
95 where
96 F: Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync + 'static,
97 {
98 self.providers
99 .insert(name.to_string(), Box::new(factory_fn));
100 }
101
102 pub fn create_provider(
104 &self,
105 provider_name: &str,
106 config: ProviderConfig,
107 ) -> Result<Box<dyn LLMProvider>, LLMError> {
108 let factory_fn = self.providers.get(provider_name).ok_or_else(|| {
109 LLMError::InvalidRequest(format!("Unknown provider: {}", provider_name))
110 })?;
111
112 Ok(factory_fn(config))
113 }
114
115 pub fn list_providers(&self) -> Vec<String> {
117 self.providers.keys().cloned().collect()
118 }
119
120 pub fn provider_from_model(&self, model: &str) -> Option<String> {
122 let m = model.to_lowercase();
123 if m.starts_with("gpt-") || m.starts_with("o3") || m.starts_with("o1") {
124 Some("openai".to_string())
125 } else if m.starts_with("claude-") {
126 Some("anthropic".to_string())
127 } else if m.contains("gemini") || m.starts_with("palm") {
128 Some("gemini".to_string())
129 } else if m.starts_with("grok-") || m.starts_with("xai-") {
130 Some("xai".to_string())
131 } else if m.contains('/') || m.contains('@') {
132 Some("openrouter".to_string())
133 } else {
134 None
135 }
136 }
137}
138
139impl Default for LLMFactory {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145use std::sync::{LazyLock, Mutex};
147
148static FACTORY: LazyLock<Mutex<LLMFactory>> = LazyLock::new(|| Mutex::new(LLMFactory::new()));
149
150pub fn get_factory() -> &'static Mutex<LLMFactory> {
152 &FACTORY
153}
154
155pub fn create_provider_for_model(
157 model: &str,
158 api_key: String,
159) -> Result<Box<dyn LLMProvider>, LLMError> {
160 let factory = get_factory().lock().unwrap();
161 let provider_name = factory.provider_from_model(model).ok_or_else(|| {
162 LLMError::InvalidRequest(format!("Cannot determine provider for model: {}", model))
163 })?;
164 drop(factory);
165
166 create_provider_with_config(&provider_name, Some(api_key), None, Some(model.to_string()))
167}
168
169pub fn create_provider_with_config(
171 provider_name: &str,
172 api_key: Option<String>,
173 base_url: Option<String>,
174 model: Option<String>,
175) -> Result<Box<dyn LLMProvider>, LLMError> {
176 let factory = get_factory().lock().unwrap();
177 let config = ProviderConfig {
178 api_key,
179 base_url,
180 model,
181 };
182
183 factory.create_provider(provider_name, config)
184}