1use super::providers::{
2 AnthropicProvider, DeepSeekProvider, GeminiProvider, OpenAIProvider, OpenRouterProvider,
3 XAIProvider,
4};
5use crate::config::core::PromptCachingConfig;
6use crate::llm::provider::{LLMError, LLMProvider};
7use std::collections::HashMap;
8
9pub struct LLMFactory {
11 providers: HashMap<String, Box<dyn Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync>>,
12}
13
14#[derive(Debug, Clone)]
15pub struct ProviderConfig {
16 pub api_key: Option<String>,
17 pub base_url: Option<String>,
18 pub model: Option<String>,
19 pub prompt_cache: Option<PromptCachingConfig>,
20}
21
22impl LLMFactory {
23 pub fn new() -> Self {
24 let mut factory = Self {
25 providers: HashMap::new(),
26 };
27
28 factory.register_provider(
30 "gemini",
31 Box::new(|config: ProviderConfig| {
32 let ProviderConfig {
33 api_key,
34 base_url,
35 model,
36 prompt_cache,
37 } = config;
38 Box::new(GeminiProvider::from_config(
39 api_key,
40 model,
41 base_url,
42 prompt_cache,
43 )) as Box<dyn LLMProvider>
44 }),
45 );
46
47 factory.register_provider(
48 "openai",
49 Box::new(|config: ProviderConfig| {
50 let ProviderConfig {
51 api_key,
52 base_url,
53 model,
54 prompt_cache,
55 } = config;
56 Box::new(OpenAIProvider::from_config(
57 api_key,
58 model,
59 base_url,
60 prompt_cache,
61 )) as Box<dyn LLMProvider>
62 }),
63 );
64
65 factory.register_provider(
66 "anthropic",
67 Box::new(|config: ProviderConfig| {
68 let ProviderConfig {
69 api_key,
70 base_url,
71 model,
72 prompt_cache,
73 } = config;
74 Box::new(AnthropicProvider::from_config(
75 api_key,
76 model,
77 base_url,
78 prompt_cache,
79 )) as Box<dyn LLMProvider>
80 }),
81 );
82
83 factory.register_provider(
84 "deepseek",
85 Box::new(|config: ProviderConfig| {
86 let ProviderConfig {
87 api_key,
88 base_url,
89 model,
90 prompt_cache,
91 } = config;
92 Box::new(DeepSeekProvider::from_config(
93 api_key,
94 model,
95 base_url,
96 prompt_cache,
97 )) as Box<dyn LLMProvider>
98 }),
99 );
100
101 factory.register_provider(
102 "openrouter",
103 Box::new(|config: ProviderConfig| {
104 let ProviderConfig {
105 api_key,
106 base_url,
107 model,
108 prompt_cache,
109 } = config;
110 Box::new(OpenRouterProvider::from_config(
111 api_key,
112 model,
113 base_url,
114 prompt_cache,
115 )) as Box<dyn LLMProvider>
116 }),
117 );
118
119 factory.register_provider(
120 "xai",
121 Box::new(|config: ProviderConfig| {
122 let ProviderConfig {
123 api_key,
124 base_url,
125 model,
126 prompt_cache,
127 } = config;
128 Box::new(XAIProvider::from_config(
129 api_key,
130 model,
131 base_url,
132 prompt_cache,
133 )) as Box<dyn LLMProvider>
134 }),
135 );
136
137 factory
138 }
139
140 pub fn register_provider<F>(&mut self, name: &str, factory_fn: F)
142 where
143 F: Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync + 'static,
144 {
145 self.providers
146 .insert(name.to_string(), Box::new(factory_fn));
147 }
148
149 pub fn create_provider(
151 &self,
152 provider_name: &str,
153 config: ProviderConfig,
154 ) -> Result<Box<dyn LLMProvider>, LLMError> {
155 let factory_fn = self.providers.get(provider_name).ok_or_else(|| {
156 LLMError::InvalidRequest(format!("Unknown provider: {}", provider_name))
157 })?;
158
159 Ok(factory_fn(config))
160 }
161
162 pub fn list_providers(&self) -> Vec<String> {
164 self.providers.keys().cloned().collect()
165 }
166
167 pub fn provider_from_model(&self, model: &str) -> Option<String> {
169 let m = model.to_lowercase();
170 if m.starts_with("gpt-") || m.starts_with("o3") || m.starts_with("o1") {
171 Some("openai".to_string())
172 } else if m.starts_with("claude-") {
173 Some("anthropic".to_string())
174 } else if m.starts_with("deepseek-") {
175 Some("deepseek".to_string())
176 } else if m.contains("gemini") || m.starts_with("palm") {
177 Some("gemini".to_string())
178 } else if m.starts_with("grok-") || m.starts_with("xai-") {
179 Some("xai".to_string())
180 } else if m.contains('/') || m.contains('@') {
181 Some("openrouter".to_string())
182 } else {
183 None
184 }
185 }
186}
187
188impl Default for LLMFactory {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194use std::sync::{LazyLock, Mutex};
196
197static FACTORY: LazyLock<Mutex<LLMFactory>> = LazyLock::new(|| Mutex::new(LLMFactory::new()));
198
199pub fn get_factory() -> &'static Mutex<LLMFactory> {
201 &FACTORY
202}
203
204pub fn create_provider_for_model(
206 model: &str,
207 api_key: String,
208 prompt_cache: Option<PromptCachingConfig>,
209) -> Result<Box<dyn LLMProvider>, LLMError> {
210 let factory = get_factory().lock().unwrap();
211 let provider_name = factory.provider_from_model(model).ok_or_else(|| {
212 LLMError::InvalidRequest(format!("Cannot determine provider for model: {}", model))
213 })?;
214 drop(factory);
215
216 create_provider_with_config(
217 &provider_name,
218 Some(api_key),
219 None,
220 Some(model.to_string()),
221 prompt_cache,
222 )
223}
224
225pub fn create_provider_with_config(
227 provider_name: &str,
228 api_key: Option<String>,
229 base_url: Option<String>,
230 model: Option<String>,
231 prompt_cache: Option<PromptCachingConfig>,
232) -> Result<Box<dyn LLMProvider>, LLMError> {
233 let factory = get_factory().lock().unwrap();
234 let config = ProviderConfig {
235 api_key,
236 base_url,
237 model,
238 prompt_cache,
239 };
240
241 factory.create_provider(provider_name, config)
242}