1#[cfg(feature = "anthropic")]
3pub mod anthropic;
4#[cfg(feature = "azure")]
5pub mod azure;
6#[cfg(feature = "bedrock")]
7pub mod bedrock;
8#[cfg(feature = "flowise")]
9pub mod flowise;
10#[cfg(feature = "gemini")]
11pub mod gemini;
12#[cfg(feature = "huggingface")]
13pub mod huggingface;
14#[cfg(feature = "mlx")]
15pub mod mlx;
16#[cfg(feature = "nvidia")]
17pub mod nvidia;
18#[cfg(feature = "ollama")]
19pub mod ollama;
20#[cfg(feature = "openai")]
21pub mod openai;
22#[cfg(feature = "perplexity")]
23pub mod perplexity;
24#[cfg(feature = "vertex")]
25pub mod vertex;
26#[cfg(feature = "xai")]
27pub mod xai;
28
29pub mod registry;
31
32use crate::config::accounts::AccountConfig;
33use crate::config::Config;
34use anyhow::{Context, Result};
35use async_trait::async_trait;
36use once_cell::sync::Lazy;
37
38#[async_trait]
39pub trait AIProvider: Send + Sync {
40 async fn generate_commit_message(
41 &self,
42 diff: &str,
43 context: Option<&str>,
44 full_gitmoji: bool,
45 config: &Config,
46 ) -> Result<String>;
47
48 async fn generate_commit_messages(
50 &self,
51 diff: &str,
52 context: Option<&str>,
53 full_gitmoji: bool,
54 config: &Config,
55 count: u8,
56 ) -> Result<Vec<String>> {
57 use futures::stream::StreamExt;
58
59 if count <= 1 {
60 match self
62 .generate_commit_message(diff, context, full_gitmoji, config)
63 .await
64 {
65 Ok(msg) => Ok(vec![msg]),
66 Err(e) => {
67 tracing::warn!("Failed to generate message: {}", e);
68 Ok(vec![])
69 }
70 }
71 } else {
72 let futures = (0..count)
74 .map(|_| self.generate_commit_message(diff, context, full_gitmoji, config));
75 let mut stream = futures::stream::FuturesUnordered::from_iter(futures);
76
77 let mut messages = Vec::with_capacity(count as usize);
78 while let Some(result) = stream.next().await {
79 match result {
80 Ok(msg) => messages.push(msg),
81 Err(e) => tracing::warn!("Failed to generate message: {}", e),
82 }
83 }
84 Ok(messages)
85 }
86 }
87
88 #[cfg(any(feature = "openai", feature = "xai"))]
90 async fn generate_pr_description(
91 &self,
92 commits: &[String],
93 diff: &str,
94 config: &Config,
95 ) -> Result<String> {
96 let commits_text = commits.join("\n");
97 let prompt = format!(
98 "Generate a professional pull request description based on the following commits:\n\n{}\n\nDiff:\n{}\n\nFormat the output as:\n## Summary\n## Changes\n## Testing\n## Breaking Changes\n\nKeep it concise and informative.",
99 commits_text, diff
100 );
101
102 let messages = vec![
103 async_openai::types::chat::ChatCompletionRequestSystemMessage::from(
104 "You are an expert at writing pull request descriptions.",
105 )
106 .into(),
107 async_openai::types::chat::ChatCompletionRequestUserMessage::from(prompt).into(),
108 ];
109
110 let request = async_openai::types::chat::CreateChatCompletionRequestArgs::default()
111 .model(
112 config
113 .model
114 .clone()
115 .unwrap_or_else(|| "gpt-3.5-turbo".to_string()),
116 )
117 .messages(messages)
118 .temperature(0.7)
119 .max_tokens(config.tokens_max_output.unwrap_or(1000) as u16)
120 .build()?;
121
122 let api_key = config
124 .api_key
125 .as_ref()
126 .context("API key not configured. Run: rco config set RCO_API_KEY=<your_key>")?;
127 let api_url = config
128 .api_url
129 .as_deref()
130 .unwrap_or("https://api.openai.com/v1");
131
132 let openai_config = async_openai::config::OpenAIConfig::new()
133 .with_api_key(api_key)
134 .with_api_base(api_url);
135
136 let client = async_openai::Client::with_config(openai_config);
137
138 let response = client.chat().create(request).await?;
139
140 let message = response
141 .choices
142 .first()
143 .and_then(|choice| choice.message.content.as_ref())
144 .context("AI returned an empty response")?
145 .trim()
146 .to_string();
147
148 Ok(message)
149 }
150
151 #[cfg(not(any(feature = "openai", feature = "xai")))]
153 async fn generate_pr_description(
154 &self,
155 _commits: &[String],
156 _diff: &str,
157 _config: &Config,
158 ) -> Result<String> {
159 anyhow::bail!(
160 "PR description generation requires the 'openai' or 'xai' feature to be enabled"
161 );
162 }
163}
164
165pub static PROVIDER_REGISTRY: Lazy<registry::ProviderRegistry> = Lazy::new(|| {
167 let reg = registry::ProviderRegistry::new();
168
169 #[cfg(feature = "openai")]
171 {
172 let _ = reg.register(Box::new(openai::OpenAICompatibleProvider::new()));
173 }
174
175 #[cfg(feature = "anthropic")]
177 {
178 let _ = reg.register(Box::new(anthropic::AnthropicProviderBuilder));
179 }
180
181 #[cfg(feature = "ollama")]
182 {
183 let _ = reg.register(Box::new(ollama::OllamaProviderBuilder));
184 }
185
186 #[cfg(feature = "gemini")]
187 {
188 let _ = reg.register(Box::new(gemini::GeminiProviderBuilder));
189 }
190
191 #[cfg(feature = "azure")]
192 {
193 let _ = reg.register(Box::new(azure::AzureProviderBuilder));
194 }
195
196 #[cfg(feature = "perplexity")]
197 {
198 let _ = reg.register(Box::new(perplexity::PerplexityProviderBuilder));
199 }
200
201 #[cfg(feature = "xai")]
202 {
203 let _ = reg.register(Box::new(xai::XAIProviderBuilder));
204 }
205
206 #[cfg(feature = "huggingface")]
207 {
208 let _ = reg.register(Box::new(huggingface::HuggingFaceProviderBuilder));
209 }
210
211 #[cfg(feature = "bedrock")]
212 {
213 let _ = reg.register(Box::new(bedrock::BedrockProviderBuilder));
214 }
215
216 #[cfg(feature = "vertex")]
217 {
218 let _ = reg.register(Box::new(vertex::VertexProviderBuilder));
219 }
220
221 #[cfg(feature = "mlx")]
222 {
223 let _ = reg.register(Box::new(mlx::MlxProviderBuilder));
224 }
225
226 #[cfg(feature = "nvidia")]
227 {
228 let _ = reg.register(Box::new(nvidia::NvidiaProviderBuilder));
229 }
230
231 #[cfg(feature = "flowise")]
232 {
233 let _ = reg.register(Box::new(flowise::FlowiseProviderBuilder));
234 }
235
236 reg
237});
238
239pub fn create_provider(config: &Config) -> Result<Box<dyn AIProvider>> {
241 let provider_name = config.ai_provider.as_deref().unwrap_or("openai");
242
243 if let Some(provider) = PROVIDER_REGISTRY.create(provider_name, config)? {
245 return Ok(provider);
246 }
247
248 let available: Vec<String> = PROVIDER_REGISTRY
250 .all()
251 .unwrap_or_default()
252 .iter()
253 .map(|e| {
254 let aliases = if e.aliases.is_empty() {
255 String::new()
256 } else {
257 format!(" ({})", e.aliases.join(", "))
258 };
259 format!("- {}{}", e.name, aliases)
260 })
261 .chain(std::iter::once(format!(
262 "- {} OpenAI-compatible providers (deepseek, groq, openrouter, etc.)",
263 PROVIDER_REGISTRY
264 .by_category(registry::ProviderCategory::OpenAICompatible)
265 .map_or(0, |v| v.len())
266 )))
267 .filter(|s| !s.contains("0 OpenAI-compatible"))
268 .collect();
269
270 if available.is_empty() {
271 anyhow::bail!(
272 "No AI provider features enabled. Please enable at least one provider feature:\n\
273 --features openai,anthropic,ollama,gemini,azure,perplexity,xai,huggingface,bedrock,vertex"
274 );
275 }
276
277 anyhow::bail!(
278 "Unsupported or disabled AI provider: {}\n\n\
279 Available providers (based on enabled features):\n{}\n\n\
280 Set with: rco config set RCO_AI_PROVIDER=<provider_name>",
281 provider_name,
282 available.join("\n")
283 )
284}
285
286#[allow(dead_code)]
287pub fn available_providers() -> Vec<&'static str> {
289 let mut providers = PROVIDER_REGISTRY
290 .all()
291 .unwrap_or_default()
292 .iter()
293 .flat_map(|e| std::iter::once(e.name).chain(e.aliases.iter().copied()))
294 .collect::<Vec<_>>();
295
296 #[cfg(feature = "openai")]
297 {
298 providers.extend_from_slice(&[
299 "deepseek",
301 "groq",
302 "openrouter",
303 "together",
304 "deepinfra",
305 "mistral",
306 "github-models",
307 "fireworks",
308 "moonshot",
309 "dashscope",
310 "aimlapi",
312 "cohere",
314 "ai21",
315 "cloudflare",
316 "siliconflow",
317 "zhipu",
318 "minimax",
319 "upstage",
320 "nebius",
321 "ovh",
322 "scaleway",
323 "friendli",
324 "baseten",
325 "chutes",
326 "ionet",
327 "modelscope",
328 "requesty",
329 "morph",
330 "synthetic",
331 "nano-gpt",
332 "zenmux",
333 "v0",
334 "iflowcn",
335 "venice",
336 "cortecs",
337 "kimi-coding",
338 "abacus",
339 "bailing",
340 "fastrouter",
341 "inference",
342 "submodel",
343 "zai",
344 "zai-coding",
345 "zhipu-coding",
346 "poe",
347 "cerebras",
348 "lmstudio",
349 "sambanova",
350 "novita",
351 "predibase",
352 "tensorops",
353 "hyperbolic",
354 "kluster",
355 "lambda",
356 "replicate",
357 "targon",
358 "corcel",
359 "cybernative",
360 "edgen",
361 "gigachat",
362 "hydra",
363 "jina",
364 "lingyi",
365 "monica",
366 "pollinations",
367 "rawechat",
368 "shuttleai",
369 "teknium",
370 "theb",
371 "tryleap",
372 "workers-ai",
373 ]);
374 }
375
376 providers
377}
378
379#[allow(dead_code)]
381pub fn provider_info(provider: &str) -> Option<String> {
382 PROVIDER_REGISTRY.get(provider).map(|e| {
383 let aliases = if e.aliases.is_empty() {
384 String::new()
385 } else {
386 format!(" (aliases: {})", e.aliases.join(", "))
387 };
388 let model = e
389 .default_model
390 .map(|m| format!(", default model: {}", m))
391 .unwrap_or_default();
392 format!("{}{}{}", e.name, aliases, model)
393 })
394}
395
396pub fn split_prompt(
398 diff: &str,
399 context: Option<&str>,
400 config: &Config,
401 full_gitmoji: bool,
402) -> (String, String) {
403 let system_prompt = build_system_prompt(config, full_gitmoji);
404 let user_prompt = build_user_prompt(diff, context, full_gitmoji, config);
405 (system_prompt, user_prompt)
406}
407
408fn build_system_prompt(config: &Config, full_gitmoji: bool) -> String {
410 let mut prompt = String::new();
411
412 prompt.push_str("You are an expert at writing clear, concise git commit messages.\n\n");
413
414 prompt.push_str("OUTPUT RULES:\n");
416 prompt.push_str("- Return ONLY the commit message, with no additional explanation, markdown formatting, or code blocks\n");
417 prompt.push_str("- Do not include any reasoning, thinking, analysis, <thinking> tags, or XML-like tags in your response\n");
418 prompt.push_str("- Never explain your choices or provide commentary\n");
419 prompt.push_str(
420 "- If you cannot generate a meaningful commit message, return \"chore: update\"\n\n",
421 );
422
423 if config.learn_from_history.unwrap_or(false) {
425 if let Some(style_guidance) = get_style_guidance(config) {
426 prompt.push_str("REPO STYLE (learned from commit history):\n");
427 prompt.push_str(&style_guidance);
428 prompt.push('\n');
429 }
430 }
431
432 if let Some(locale) = &config.language {
434 prompt.push_str(&format!(
435 "- Generate the commit message in {} language\n",
436 locale
437 ));
438 }
439
440 let commit_type = config.commit_type.as_deref().unwrap_or("conventional");
442 match commit_type {
443 "conventional" => {
444 prompt.push_str("- Use conventional commit format: <type>(<scope>): <description>\n");
445 prompt.push_str(
446 "- Types: feat, fix, docs, style, refactor, perf, test, build, ci, chore\n",
447 );
448 if config.omit_scope.unwrap_or(false) {
449 prompt.push_str("- Omit the scope, use format: <type>: <description>\n");
450 }
451 }
452 "gitmoji" => {
453 if full_gitmoji {
454 prompt.push_str("- Use GitMoji format with full emoji specification from https://gitmoji.dev/\n");
455 prompt.push_str("- Common emojis: ✨(feat), 🐛(fix), 📝(docs), 🚀(deploy), ♻️(refactor), ✅(test), 🔧(chore), ⚡(perf), 🎨(style), 📦(build), 👷(ci)\n");
456 prompt.push_str("- For breaking changes, add 💥 after the type\n");
457 } else {
458 prompt.push_str("- Use GitMoji format: <emoji> <type>: <description>\n");
459 prompt.push_str("- Common emojis: 🐛(fix), ✨(feat), 📝(docs), 🚀(deploy), ✅(test), ♻️(refactor), 🔧(chore), ⚡(perf), 🎨(style), 📦(build), 👷(ci)\n");
460 }
461 }
462 _ => {}
463 }
464
465 let max_length = config.description_max_length.unwrap_or(100);
467 prompt.push_str(&format!(
468 "- Keep the description under {} characters\n",
469 max_length
470 ));
471
472 if config.description_capitalize.unwrap_or(true) {
473 prompt.push_str("- Capitalize the first letter of the description\n");
474 }
475
476 if !config.description_add_period.unwrap_or(false) {
477 prompt.push_str("- Do not end the description with a period\n");
478 }
479
480 if config.enable_commit_body.unwrap_or(false) {
482 prompt.push_str("\nCOMMIT BODY (optional):\n");
483 prompt.push_str(
484 "- Add a blank line after the description, then explain WHY the change was made\n",
485 );
486 prompt.push_str("- Use bullet points for multiple changes\n");
487 prompt.push_str("- Wrap body text at 72 characters\n");
488 prompt
489 .push_str("- Focus on motivation and context, not what changed (that's in the diff)\n");
490 }
491
492 prompt
493}
494
495fn get_style_guidance(config: &Config) -> Option<String> {
497 use crate::git;
498 use crate::utils::commit_style::CommitStyleProfile;
499
500 if let Some(cached) = &config.style_profile {
502 return Some(cached.clone());
504 }
505
506 let count = config.history_commits_count.unwrap_or(50);
508
509 match git::get_recent_commit_messages(count) {
510 Ok(commits) => {
511 if commits.is_empty() {
512 return None;
513 }
514
515 let profile = CommitStyleProfile::analyze_from_commits(&commits);
516
517 if profile.is_empty() || commits.len() < 10 {
520 return None;
521 }
522
523 Some(profile.to_prompt_guidance())
524 }
525 Err(e) => {
526 tracing::warn!("Failed to get commit history for style analysis: {}", e);
527 None
528 }
529 }
530}
531
532fn build_user_prompt(
534 diff: &str,
535 context: Option<&str>,
536 _full_gitmoji: bool,
537 _config: &Config,
538) -> String {
539 let mut prompt = String::new();
540
541 if let Some(project_context) = get_project_context() {
543 prompt.push_str(&format!("Project Context: {}\n\n", project_context));
544 }
545
546 let file_summary = extract_file_summary(diff);
548 if !file_summary.is_empty() {
549 prompt.push_str(&format!("Files Changed: {}\n\n", file_summary));
550 }
551
552 if diff.contains("---CHUNK") {
554 let chunk_count = diff.matches("---CHUNK").count();
555 if chunk_count > 1 {
556 prompt.push_str(&format!(
557 "Note: This diff was split into {} chunks due to size. Focus on the overall purpose of the changes across all chunks.\n\n",
558 chunk_count
559 ));
560 } else {
561 prompt.push_str("Note: The diff was split into chunks due to size. Focus on the overall purpose of the changes.\n\n");
562 }
563 }
564
565 if let Some(ctx) = context {
567 prompt.push_str(&format!("Additional context: {}\n\n", ctx));
568 }
569
570 prompt.push_str("Generate a commit message for the following git diff:\n");
571 prompt.push_str("```diff\n");
572 prompt.push_str(diff);
573 prompt.push_str("\n```\n");
574
575 prompt.push_str("\nRemember: Return ONLY the commit message, no explanations or markdown.");
577
578 prompt
579}
580
581fn extract_file_summary(diff: &str) -> String {
583 let mut files: Vec<String> = Vec::new();
584 let mut extensions: std::collections::HashSet<String> = std::collections::HashSet::new();
585 let mut file_types: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
586
587 for line in diff.lines() {
588 if line.starts_with("+++ b/") {
589 let path = line.strip_prefix("+++ b/").unwrap_or(line);
590 if path != "/dev/null" {
591 files.push(path.to_string());
592 if let Some(ext) = std::path::Path::new(path).extension() {
594 if let Some(ext_str) = ext.to_str() {
595 let ext_lower = ext_str.to_lowercase();
596 extensions.insert(ext_lower.clone());
597
598 let category = categorize_file_type(&ext_lower);
600 *file_types.entry(category).or_insert(0) += 1;
601 }
602 } else {
603 if path.contains("Makefile")
605 || path.contains("Dockerfile")
606 || path.contains("LICENSE")
607 {
608 *file_types.entry("config".to_string()).or_insert(0) += 1;
609 }
610 }
611 }
612 }
613 }
614
615 if files.is_empty() {
616 return String::new();
617 }
618
619 let mut summary = format!("{} file(s)", files.len());
621
622 if !file_types.is_empty() {
624 let mut type_list: Vec<_> = file_types.into_iter().collect();
625 type_list.sort_by(|a, b| b.1.cmp(&a.1)); let type_str: Vec<_> = type_list
628 .iter()
629 .map(|(t, c)| format!("{} {}", c, t))
630 .collect();
631 summary.push_str(&format!(" ({})", type_str.join(", ")));
632 }
633
634 if !extensions.is_empty() && extensions.len() <= 5 {
636 let ext_list: Vec<_> = extensions.into_iter().collect();
637 summary.push_str(&format!(" [.{}]", ext_list.join(", .")));
638 }
639
640 if files.len() <= 3 {
642 summary.push_str(&format!(": {}", files.join(", ")));
643 }
644
645 summary
646}
647
648fn categorize_file_type(ext: &str) -> String {
650 match ext {
651 "rs" => "Rust",
653 "py" => "Python",
654 "js" => "JavaScript",
655 "ts" => "TypeScript",
656 "jsx" | "tsx" => "React",
657 "go" => "Go",
658 "java" => "Java",
659 "kt" => "Kotlin",
660 "swift" => "Swift",
661 "c" | "cpp" | "cc" | "h" | "hpp" => "C/C++",
662 "rb" => "Ruby",
663 "php" => "PHP",
664 "cs" => "C#",
665 "scala" => "Scala",
666 "r" => "R",
667 "m" => "Objective-C",
668 "lua" => "Lua",
669 "pl" => "Perl",
670
671 "html" | "htm" => "HTML",
673 "css" | "scss" | "sass" | "less" => "CSS",
674 "vue" => "Vue",
675 "svelte" => "Svelte",
676
677 "json" => "JSON",
679 "yaml" | "yml" => "YAML",
680 "toml" => "TOML",
681 "xml" => "XML",
682 "csv" => "CSV",
683 "sql" => "SQL",
684
685 "md" | "markdown" => "Markdown",
687 "rst" => "reStructuredText",
688 "txt" => "Text",
689
690 "sh" | "bash" | "zsh" | "fish" => "Shell",
692 "ps1" => "PowerShell",
693 "bat" | "cmd" => "Batch",
694 "dockerfile" => "Docker",
695 "makefile" | "mk" => "Make",
696 "cmake" => "CMake",
697
698 _ => "Other",
700 }
701 .to_string()
702}
703
704fn get_project_context() -> Option<String> {
706 use std::path::Path;
707
708 if let Ok(repo_root) = crate::git::get_repo_root() {
710 let context_path = Path::new(&repo_root).join(".rco").join("context.txt");
711 if context_path.exists() {
712 if let Ok(content) = std::fs::read_to_string(&context_path) {
713 let trimmed = content.trim();
714 if !trimmed.is_empty() {
715 return Some(trimmed.to_string());
716 }
717 }
718 }
719
720 let readme_path = Path::new(&repo_root).join("README.md");
722 if readme_path.exists() {
723 if let Ok(content) = std::fs::read_to_string(&readme_path) {
724 for line in content.lines() {
726 let trimmed = line.trim();
727 if !trimmed.is_empty() && !trimmed.starts_with('#') {
728 let context = if let Some(idx) = trimmed.find('.') {
730 trimmed[..idx + 1].to_string()
731 } else {
732 trimmed.chars().take(100).collect()
733 };
734 if !context.is_empty() {
735 return Some(context);
736 }
737 }
738 }
739 }
740 }
741
742 let cargo_path = Path::new(&repo_root).join("Cargo.toml");
744 if cargo_path.exists() {
745 if let Ok(content) = std::fs::read_to_string(&cargo_path) {
746 let mut in_package = false;
748 for line in content.lines() {
749 let trimmed = line.trim();
750 if trimmed == "[package]" {
751 in_package = true;
752 } else if trimmed.starts_with('[') && trimmed != "[package]" {
753 in_package = false;
754 } else if in_package && trimmed.starts_with("description") {
755 if let Some(idx) = trimmed.find('=') {
756 let desc = trimmed[idx + 1..].trim().trim_matches('"');
757 if !desc.is_empty() {
758 return Some(format!("Rust project: {}", desc));
759 }
760 }
761 }
762 }
763 }
764 }
765
766 let package_path = Path::new(&repo_root).join("package.json");
768 if package_path.exists() {
769 if let Ok(content) = std::fs::read_to_string(&package_path) {
770 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&content) {
771 if let Some(desc) = json.get("description").and_then(|d| d.as_str()) {
772 if !desc.is_empty() {
773 return Some(format!("Node.js project: {}", desc));
774 }
775 }
776 }
777 }
778 }
779 }
780
781 None
782}
783
784pub fn build_prompt(
786 diff: &str,
787 context: Option<&str>,
788 config: &Config,
789 full_gitmoji: bool,
790) -> String {
791 let (system, user) = split_prompt(diff, context, config, full_gitmoji);
792 format!("{}\n\n---\n\n{}", system, user)
793}
794
795#[allow(dead_code)]
797pub fn create_provider_for_account(
798 account: &AccountConfig,
799 config: &Config,
800) -> Result<Box<dyn AIProvider>> {
801 use crate::auth::token_storage;
802 use crate::config::secure_storage;
803
804 let provider = account.provider.to_lowercase();
805
806 let credentials = match &account.auth {
808 crate::config::accounts::AuthMethod::ApiKey { key_id } => {
809 token_storage::get_api_key_for_account(key_id)?
811 .or_else(|| secure_storage::get_secret(key_id).ok().flatten())
812 }
813 crate::config::accounts::AuthMethod::OAuth {
814 provider: _oauth_provider,
815 account_id,
816 } => {
817 token_storage::get_tokens_for_account(account_id)?.map(|t| t.access_token)
819 }
820 crate::config::accounts::AuthMethod::EnvVar { name } => std::env::var(name).ok(),
821 crate::config::accounts::AuthMethod::Bearer { token_id } => {
822 token_storage::get_bearer_token_for_account(token_id)?
824 .or_else(|| secure_storage::get_secret(token_id).ok().flatten())
825 }
826 };
827
828 match provider.as_str() {
829 #[cfg(feature = "openai")]
830 "openai" | "codex" => {
831 if let Some(key) = credentials.as_ref() {
832 Ok(Box::new(openai::OpenAIProvider::from_account(
833 account, key, config,
834 )?))
835 } else {
836 Ok(Box::new(openai::OpenAIProvider::new(config)?))
837 }
838 }
839 #[cfg(feature = "anthropic")]
840 "anthropic" | "claude" | "claude-code" => {
841 if let Some(key) = credentials.as_ref() {
842 Ok(Box::new(anthropic::AnthropicProvider::from_account(
843 account, key, config,
844 )?))
845 } else {
846 Ok(Box::new(anthropic::AnthropicProvider::new(config)?))
847 }
848 }
849 #[cfg(feature = "ollama")]
850 "ollama" => {
851 if let Some(key) = credentials.as_ref() {
852 Ok(Box::new(ollama::OllamaProvider::from_account(
853 account, key, config,
854 )?))
855 } else {
856 Ok(Box::new(ollama::OllamaProvider::new(config)?))
857 }
858 }
859 #[cfg(feature = "gemini")]
860 "gemini" => {
861 if let Some(key) = credentials.as_ref() {
862 Ok(Box::new(gemini::GeminiProvider::from_account(
863 account, key, config,
864 )?))
865 } else {
866 Ok(Box::new(gemini::GeminiProvider::new(config)?))
867 }
868 }
869 #[cfg(feature = "azure")]
870 "azure" | "azure-openai" => {
871 if let Some(key) = credentials.as_ref() {
872 Ok(Box::new(azure::AzureProvider::from_account(
873 account, key, config,
874 )?))
875 } else {
876 Ok(Box::new(azure::AzureProvider::new(config)?))
877 }
878 }
879 #[cfg(feature = "perplexity")]
880 "perplexity" => {
881 if let Some(key) = credentials.as_ref() {
882 Ok(Box::new(perplexity::PerplexityProvider::from_account(
883 account, key, config,
884 )?))
885 } else {
886 Ok(Box::new(perplexity::PerplexityProvider::new(config)?))
887 }
888 }
889 #[cfg(feature = "xai")]
890 "xai" | "grok" | "x-ai" => {
891 if let Some(key) = credentials.as_ref() {
892 Ok(Box::new(xai::XAIProvider::from_account(
893 account, key, config,
894 )?))
895 } else {
896 Ok(Box::new(xai::XAIProvider::new(config)?))
897 }
898 }
899 #[cfg(feature = "huggingface")]
900 "huggingface" | "hf" => {
901 if let Some(key) = credentials.as_ref() {
902 Ok(Box::new(huggingface::HuggingFaceProvider::from_account(
903 account, key, config,
904 )?))
905 } else {
906 Ok(Box::new(huggingface::HuggingFaceProvider::new(config)?))
907 }
908 }
909 #[cfg(feature = "bedrock")]
910 "bedrock" | "aws-bedrock" | "amazon-bedrock" => Ok(Box::new(
911 bedrock::BedrockProvider::from_account(account, "", config)?,
912 )),
913 #[cfg(feature = "vertex")]
914 "vertex" | "vertex-ai" | "google-vertex" | "gcp-vertex" => Ok(Box::new(
915 vertex::VertexProvider::from_account(account, "", config)?,
916 )),
917 #[cfg(feature = "mlx")]
918 "mlx" | "mlx-lm" | "apple-mlx" => {
919 if let Some(_key) = credentials.as_ref() {
920 Ok(Box::new(mlx::MlxProvider::from_account(
921 account, "", config,
922 )?))
923 } else {
924 Ok(Box::new(mlx::MlxProvider::new(config)?))
925 }
926 }
927 #[cfg(feature = "nvidia")]
928 "nvidia" | "nvidia-nim" | "nim" | "nvidia-ai" => {
929 if let Some(key) = credentials.as_ref() {
930 Ok(Box::new(nvidia::NvidiaProvider::from_account(
931 account, key, config,
932 )?))
933 } else {
934 Ok(Box::new(nvidia::NvidiaProvider::new(config)?))
935 }
936 }
937 #[cfg(feature = "flowise")]
938 "flowise" | "flowise-ai" => {
939 if let Some(_key) = credentials.as_ref() {
940 Ok(Box::new(flowise::FlowiseProvider::from_account(
941 account, "", config,
942 )?))
943 } else {
944 Ok(Box::new(flowise::FlowiseProvider::new(config)?))
945 }
946 }
947 _ => {
948 anyhow::bail!(
949 "Unsupported AI provider for account: {}\n\n\
950 Account provider: {}\n\
951 Supported providers: openai, anthropic, ollama, gemini, azure, perplexity, xai, huggingface, bedrock, vertex",
952 account.alias,
953 provider
954 );
955 }
956 }
957}