1use super::providers::{
2 AnthropicProvider, DeepSeekProvider, GeminiProvider, MoonshotProvider, OpenAIProvider,
3 OpenRouterProvider, XAIProvider, ZAIProvider,
4};
5use crate::config::core::PromptCachingConfig;
6use crate::llm::provider::{LLMError, LLMProvider};
7use std::collections::HashMap;
8
9pub struct LLMFactory {
11 providers: HashMap<String, Box<dyn Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync>>,
12}
13
14#[derive(Debug, Clone)]
15pub struct ProviderConfig {
16 pub api_key: Option<String>,
17 pub base_url: Option<String>,
18 pub model: Option<String>,
19 pub prompt_cache: Option<PromptCachingConfig>,
20}
21
22impl LLMFactory {
23 pub fn new() -> Self {
24 let mut factory = Self {
25 providers: HashMap::new(),
26 };
27
28 factory.register_provider(
30 "gemini",
31 Box::new(|config: ProviderConfig| {
32 let ProviderConfig {
33 api_key,
34 base_url,
35 model,
36 prompt_cache,
37 } = config;
38 Box::new(GeminiProvider::from_config(
39 api_key,
40 model,
41 base_url,
42 prompt_cache,
43 )) as Box<dyn LLMProvider>
44 }),
45 );
46
47 factory.register_provider(
48 "openai",
49 Box::new(|config: ProviderConfig| {
50 let ProviderConfig {
51 api_key,
52 base_url,
53 model,
54 prompt_cache,
55 } = config;
56 Box::new(OpenAIProvider::from_config(
57 api_key,
58 model,
59 base_url,
60 prompt_cache,
61 )) as Box<dyn LLMProvider>
62 }),
63 );
64
65 factory.register_provider(
66 "anthropic",
67 Box::new(|config: ProviderConfig| {
68 let ProviderConfig {
69 api_key,
70 base_url,
71 model,
72 prompt_cache,
73 } = config;
74 Box::new(AnthropicProvider::from_config(
75 api_key,
76 model,
77 base_url,
78 prompt_cache,
79 )) as Box<dyn LLMProvider>
80 }),
81 );
82
83 factory.register_provider(
84 "deepseek",
85 Box::new(|config: ProviderConfig| {
86 let ProviderConfig {
87 api_key,
88 base_url,
89 model,
90 prompt_cache,
91 } = config;
92 Box::new(DeepSeekProvider::from_config(
93 api_key,
94 model,
95 base_url,
96 prompt_cache,
97 )) as Box<dyn LLMProvider>
98 }),
99 );
100
101 factory.register_provider(
102 "openrouter",
103 Box::new(|config: ProviderConfig| {
104 let ProviderConfig {
105 api_key,
106 base_url,
107 model,
108 prompt_cache,
109 } = config;
110 Box::new(OpenRouterProvider::from_config(
111 api_key,
112 model,
113 base_url,
114 prompt_cache,
115 )) as Box<dyn LLMProvider>
116 }),
117 );
118
119 factory.register_provider(
120 "moonshot",
121 Box::new(|config: ProviderConfig| {
122 let ProviderConfig {
123 api_key,
124 base_url,
125 model,
126 prompt_cache,
127 } = config;
128 Box::new(MoonshotProvider::from_config(
129 api_key,
130 model,
131 base_url,
132 prompt_cache,
133 )) as Box<dyn LLMProvider>
134 }),
135 );
136
137 factory.register_provider(
138 "xai",
139 Box::new(|config: ProviderConfig| {
140 let ProviderConfig {
141 api_key,
142 base_url,
143 model,
144 prompt_cache,
145 } = config;
146 Box::new(XAIProvider::from_config(
147 api_key,
148 model,
149 base_url,
150 prompt_cache,
151 )) as Box<dyn LLMProvider>
152 }),
153 );
154
155 factory.register_provider(
156 "zai",
157 Box::new(|config: ProviderConfig| {
158 let ProviderConfig {
159 api_key,
160 base_url,
161 model,
162 prompt_cache,
163 } = config;
164 Box::new(ZAIProvider::from_config(
165 api_key,
166 model,
167 base_url,
168 prompt_cache,
169 )) as Box<dyn LLMProvider>
170 }),
171 );
172
173 factory
174 }
175
176 pub fn register_provider<F>(&mut self, name: &str, factory_fn: F)
178 where
179 F: Fn(ProviderConfig) -> Box<dyn LLMProvider> + Send + Sync + 'static,
180 {
181 self.providers
182 .insert(name.to_string(), Box::new(factory_fn));
183 }
184
185 pub fn create_provider(
187 &self,
188 provider_name: &str,
189 config: ProviderConfig,
190 ) -> Result<Box<dyn LLMProvider>, LLMError> {
191 let factory_fn = self.providers.get(provider_name).ok_or_else(|| {
192 LLMError::InvalidRequest(format!("Unknown provider: {}", provider_name))
193 })?;
194
195 Ok(factory_fn(config))
196 }
197
198 pub fn list_providers(&self) -> Vec<String> {
200 self.providers.keys().cloned().collect()
201 }
202
203 pub fn provider_from_model(&self, model: &str) -> Option<String> {
205 let m = model.to_lowercase();
206 if m.starts_with("gpt-") || m.starts_with("o3") || m.starts_with("o1") {
207 Some("openai".to_string())
208 } else if m.starts_with("claude-") {
209 Some("anthropic".to_string())
210 } else if m.starts_with("deepseek-") {
211 Some("deepseek".to_string())
212 } else if m.contains("gemini") || m.starts_with("palm") {
213 Some("gemini".to_string())
214 } else if m.starts_with("grok-") || m.starts_with("xai-") {
215 Some("xai".to_string())
216 } else if m.starts_with("glm-") {
217 Some("zai".to_string())
218 } else if m.starts_with("moonshot-") {
219 Some("moonshot".to_string())
220 } else if m.contains('/') || m.contains('@') {
221 Some("openrouter".to_string())
222 } else {
223 None
224 }
225 }
226}
227
228impl Default for LLMFactory {
229 fn default() -> Self {
230 Self::new()
231 }
232}
233
234use std::sync::{LazyLock, Mutex};
236
237static FACTORY: LazyLock<Mutex<LLMFactory>> = LazyLock::new(|| Mutex::new(LLMFactory::new()));
238
239pub fn get_factory() -> &'static Mutex<LLMFactory> {
241 &FACTORY
242}
243
244pub fn create_provider_for_model(
246 model: &str,
247 api_key: String,
248 prompt_cache: Option<PromptCachingConfig>,
249) -> Result<Box<dyn LLMProvider>, LLMError> {
250 let factory = get_factory().lock().unwrap();
251 let provider_name = factory.provider_from_model(model).ok_or_else(|| {
252 LLMError::InvalidRequest(format!("Cannot determine provider for model: {}", model))
253 })?;
254 drop(factory);
255
256 create_provider_with_config(
257 &provider_name,
258 Some(api_key),
259 None,
260 Some(model.to_string()),
261 prompt_cache,
262 )
263}
264
265pub fn create_provider_with_config(
267 provider_name: &str,
268 api_key: Option<String>,
269 base_url: Option<String>,
270 model: Option<String>,
271 prompt_cache: Option<PromptCachingConfig>,
272) -> Result<Box<dyn LLMProvider>, LLMError> {
273 let factory = get_factory().lock().unwrap();
274 let config = ProviderConfig {
275 api_key,
276 base_url,
277 model,
278 prompt_cache,
279 };
280
281 factory.create_provider(provider_name, config)
282}