1#[cfg(feature = "anthropic")]
3pub mod anthropic;
4#[cfg(feature = "azure")]
5pub mod azure;
6#[cfg(feature = "bedrock")]
7pub mod bedrock;
8#[cfg(feature = "gemini")]
9pub mod gemini;
10#[cfg(feature = "huggingface")]
11pub mod huggingface;
12#[cfg(feature = "flowise")]
13pub mod flowise;
14#[cfg(feature = "mlx")]
15pub mod mlx;
16#[cfg(feature = "nvidia")]
17pub mod nvidia;
18#[cfg(feature = "ollama")]
19pub mod ollama;
20#[cfg(feature = "openai")]
21pub mod openai;
22#[cfg(feature = "perplexity")]
23pub mod perplexity;
24#[cfg(feature = "vertex")]
25pub mod vertex;
26#[cfg(feature = "xai")]
27pub mod xai;
28
29pub mod registry;
31
32use crate::config::accounts::AccountConfig;
33use crate::config::Config;
34use anyhow::{Context, Result};
35use async_trait::async_trait;
36use once_cell::sync::Lazy;
37
38#[async_trait]
39pub trait AIProvider: Send + Sync {
40 async fn generate_commit_message(
41 &self,
42 diff: &str,
43 context: Option<&str>,
44 full_gitmoji: bool,
45 config: &Config,
46 ) -> Result<String>;
47
48 async fn generate_commit_messages(
50 &self,
51 diff: &str,
52 context: Option<&str>,
53 full_gitmoji: bool,
54 config: &Config,
55 count: u8,
56 ) -> Result<Vec<String>> {
57 use futures::stream::StreamExt;
58
59 if count <= 1 {
60 match self
62 .generate_commit_message(diff, context, full_gitmoji, config)
63 .await
64 {
65 Ok(msg) => Ok(vec![msg]),
66 Err(e) => {
67 tracing::warn!("Failed to generate message: {}", e);
68 Ok(vec![])
69 }
70 }
71 } else {
72 let futures = (0..count)
74 .map(|_| self.generate_commit_message(diff, context, full_gitmoji, config));
75 let mut stream = futures::stream::FuturesUnordered::from_iter(futures);
76
77 let mut messages = Vec::with_capacity(count as usize);
78 while let Some(result) = stream.next().await {
79 match result {
80 Ok(msg) => messages.push(msg),
81 Err(e) => tracing::warn!("Failed to generate message: {}", e),
82 }
83 }
84 Ok(messages)
85 }
86 }
87
88 #[cfg(any(feature = "openai", feature = "xai"))]
90 async fn generate_pr_description(
91 &self,
92 commits: &[String],
93 diff: &str,
94 config: &Config,
95 ) -> Result<String> {
96 let commits_text = commits.join("\n");
97 let prompt = format!(
98 "Generate a professional pull request description based on the following commits:\n\n{}\n\nDiff:\n{}\n\nFormat the output as:\n## Summary\n## Changes\n## Testing\n## Breaking Changes\n\nKeep it concise and informative.",
99 commits_text, diff
100 );
101
102 let messages = vec![
103 async_openai::types::chat::ChatCompletionRequestSystemMessage::from(
104 "You are an expert at writing pull request descriptions.",
105 )
106 .into(),
107 async_openai::types::chat::ChatCompletionRequestUserMessage::from(prompt).into(),
108 ];
109
110 let request = async_openai::types::chat::CreateChatCompletionRequestArgs::default()
111 .model(
112 config
113 .model
114 .clone()
115 .unwrap_or_else(|| "gpt-3.5-turbo".to_string()),
116 )
117 .messages(messages)
118 .temperature(0.7)
119 .max_tokens(config.tokens_max_output.unwrap_or(1000) as u16)
120 .build()?;
121
122 let api_key = config
124 .api_key
125 .as_ref()
126 .context("API key not configured. Run: rco config set RCO_API_KEY=<your_key>")?;
127 let api_url = config
128 .api_url
129 .as_deref()
130 .unwrap_or("https://api.openai.com/v1");
131
132 let openai_config = async_openai::config::OpenAIConfig::new()
133 .with_api_key(api_key)
134 .with_api_base(api_url);
135
136 let client = async_openai::Client::with_config(openai_config);
137
138 let response = client.chat().create(request).await?;
139
140 let message = response
141 .choices
142 .first()
143 .and_then(|choice| choice.message.content.as_ref())
144 .context("AI returned an empty response")?
145 .trim()
146 .to_string();
147
148 Ok(message)
149 }
150
151 #[cfg(not(any(feature = "openai", feature = "xai")))]
153 async fn generate_pr_description(
154 &self,
155 _commits: &[String],
156 _diff: &str,
157 _config: &Config,
158 ) -> Result<String> {
159 anyhow::bail!(
160 "PR description generation requires the 'openai' or 'xai' feature to be enabled"
161 );
162 }
163}
164
165pub static PROVIDER_REGISTRY: Lazy<registry::ProviderRegistry> = Lazy::new(|| {
167 let reg = registry::ProviderRegistry::new();
168
169 #[cfg(feature = "openai")]
171 {
172 let _ = reg.register(Box::new(openai::OpenAICompatibleProvider::new()));
173 }
174
175 #[cfg(feature = "anthropic")]
177 {
178 let _ = reg.register(Box::new(anthropic::AnthropicProviderBuilder));
179 }
180
181 #[cfg(feature = "ollama")]
182 {
183 let _ = reg.register(Box::new(ollama::OllamaProviderBuilder));
184 }
185
186 #[cfg(feature = "gemini")]
187 {
188 let _ = reg.register(Box::new(gemini::GeminiProviderBuilder));
189 }
190
191 #[cfg(feature = "azure")]
192 {
193 let _ = reg.register(Box::new(azure::AzureProviderBuilder));
194 }
195
196 #[cfg(feature = "perplexity")]
197 {
198 let _ = reg.register(Box::new(perplexity::PerplexityProviderBuilder));
199 }
200
201 #[cfg(feature = "xai")]
202 {
203 let _ = reg.register(Box::new(xai::XAIProviderBuilder));
204 }
205
206 #[cfg(feature = "huggingface")]
207 {
208 let _ = reg.register(Box::new(huggingface::HuggingFaceProviderBuilder));
209 }
210
211 #[cfg(feature = "bedrock")]
212 {
213 let _ = reg.register(Box::new(bedrock::BedrockProviderBuilder));
214 }
215
216 #[cfg(feature = "vertex")]
217 {
218 let _ = reg.register(Box::new(vertex::VertexProviderBuilder));
219 }
220
221 #[cfg(feature = "mlx")]
222 {
223 let _ = reg.register(Box::new(mlx::MlxProviderBuilder));
224 }
225
226 #[cfg(feature = "nvidia")]
227 {
228 let _ = reg.register(Box::new(nvidia::NvidiaProviderBuilder));
229 }
230
231 #[cfg(feature = "flowise")]
232 {
233 let _ = reg.register(Box::new(flowise::FlowiseProviderBuilder));
234 }
235
236 reg
237});
238
239pub fn create_provider(config: &Config) -> Result<Box<dyn AIProvider>> {
241 let provider_name = config.ai_provider.as_deref().unwrap_or("openai");
242
243 if let Some(provider) = PROVIDER_REGISTRY.create(provider_name, config)? {
245 return Ok(provider);
246 }
247
248 let available: Vec<String> = PROVIDER_REGISTRY
250 .all()
251 .unwrap_or_default()
252 .iter()
253 .map(|e| {
254 let aliases = if e.aliases.is_empty() {
255 String::new()
256 } else {
257 format!(" ({})", e.aliases.join(", "))
258 };
259 format!("- {}{}", e.name, aliases)
260 })
261 .chain(std::iter::once(format!(
262 "- {} OpenAI-compatible providers (deepseek, groq, openrouter, etc.)",
263 PROVIDER_REGISTRY
264 .by_category(registry::ProviderCategory::OpenAICompatible)
265 .map_or(0, |v| v.len())
266 )))
267 .filter(|s| !s.contains("0 OpenAI-compatible"))
268 .collect();
269
270 if available.is_empty() {
271 anyhow::bail!(
272 "No AI provider features enabled. Please enable at least one provider feature:\n\
273 --features openai,anthropic,ollama,gemini,azure,perplexity,xai,huggingface,bedrock,vertex"
274 );
275 }
276
277 anyhow::bail!(
278 "Unsupported or disabled AI provider: {}\n\n\
279 Available providers (based on enabled features):\n{}\n\n\
280 Set with: rco config set RCO_AI_PROVIDER=<provider_name>",
281 provider_name,
282 available.join("\n")
283 )
284}
285
286#[allow(dead_code)]
287pub fn available_providers() -> Vec<&'static str> {
289 let mut providers = PROVIDER_REGISTRY
290 .all()
291 .unwrap_or_default()
292 .iter()
293 .flat_map(|e| std::iter::once(e.name).chain(e.aliases.iter().copied()))
294 .collect::<Vec<_>>();
295
296 #[cfg(feature = "openai")]
297 {
298 providers.extend_from_slice(&[
299 "deepseek",
301 "groq",
302 "openrouter",
303 "together",
304 "deepinfra",
305 "mistral",
306 "github-models",
307 "fireworks",
308 "moonshot",
309 "dashscope",
310 "aimlapi",
312 "cohere",
314 "ai21",
315 "cloudflare",
316 "siliconflow",
317 "zhipu",
318 "minimax",
319 "upstage",
320 "nebius",
321 "ovh",
322 "scaleway",
323 "friendli",
324 "baseten",
325 "chutes",
326 "ionet",
327 "modelscope",
328 "requesty",
329 "morph",
330 "synthetic",
331 "nano-gpt",
332 "zenmux",
333 "v0",
334 "iflowcn",
335 "venice",
336 "cortecs",
337 "kimi-coding",
338 "abacus",
339 "bailing",
340 "fastrouter",
341 "inference",
342 "submodel",
343 "zai",
344 "zai-coding",
345 "zhipu-coding",
346 "poe",
347 "cerebras",
348 "lmstudio",
349 "sambanova",
350 "novita",
351 "predibase",
352 "tensorops",
353 "hyperbolic",
354 "kluster",
355 "lambda",
356 "replicate",
357 "targon",
358 "corcel",
359 "cybernative",
360 "edgen",
361 "gigachat",
362 "hydra",
363 "jina",
364 "lingyi",
365 "monica",
366 "pollinations",
367 "rawechat",
368 "shuttleai",
369 "teknium",
370 "theb",
371 "tryleap",
372 "workers-ai",
373 ]);
374 }
375
376 providers
377}
378
379#[allow(dead_code)]
381pub fn provider_info(provider: &str) -> Option<String> {
382 PROVIDER_REGISTRY.get(provider).map(|e| {
383 let aliases = if e.aliases.is_empty() {
384 String::new()
385 } else {
386 format!(" (aliases: {})", e.aliases.join(", "))
387 };
388 let model = e
389 .default_model
390 .map(|m| format!(", default model: {}", m))
391 .unwrap_or_default();
392 format!("{}{}{}", e.name, aliases, model)
393 })
394}
395
396pub fn split_prompt(
398 diff: &str,
399 context: Option<&str>,
400 config: &Config,
401 full_gitmoji: bool,
402) -> (String, String) {
403 let system_prompt = build_system_prompt(config, full_gitmoji);
404 let user_prompt = build_user_prompt(diff, context, full_gitmoji, config);
405 (system_prompt, user_prompt)
406}
407
408fn build_system_prompt(config: &Config, full_gitmoji: bool) -> String {
410 let mut prompt = String::new();
411
412 prompt.push_str("You are an expert at writing clear, concise git commit messages.\n\n");
413
414 prompt.push_str("OUTPUT RULES:\n");
416 prompt.push_str("- Return ONLY the commit message, with no additional explanation, markdown formatting, or code blocks\n");
417 prompt.push_str("- Do not include any reasoning, thinking, analysis, <thinking> tags, or XML-like tags in your response\n");
418 prompt.push_str("- Never explain your choices or provide commentary\n");
419 prompt.push_str("- If you cannot generate a meaningful commit message, return \"chore: update\"\n\n");
420
421 if config.learn_from_history.unwrap_or(false) {
423 if let Some(style_guidance) = get_style_guidance(config) {
424 prompt.push_str("REPO STYLE (learned from commit history):\n");
425 prompt.push_str(&style_guidance);
426 prompt.push('\n');
427 }
428 }
429
430 if let Some(locale) = &config.language {
432 prompt.push_str(&format!(
433 "- Generate the commit message in {} language\n",
434 locale
435 ));
436 }
437
438 let commit_type = config.commit_type.as_deref().unwrap_or("conventional");
440 match commit_type {
441 "conventional" => {
442 prompt.push_str("- Use conventional commit format: <type>(<scope>): <description>\n");
443 prompt.push_str(
444 "- Types: feat, fix, docs, style, refactor, perf, test, build, ci, chore\n",
445 );
446 if config.omit_scope.unwrap_or(false) {
447 prompt.push_str("- Omit the scope, use format: <type>: <description>\n");
448 }
449 }
450 "gitmoji" => {
451 if full_gitmoji {
452 prompt.push_str("- Use GitMoji format with full emoji specification from https://gitmoji.dev/\n");
453 prompt.push_str("- Common emojis: ✨(feat), 🐛(fix), 📝(docs), 🚀(deploy), ♻️(refactor), ✅(test), 🔧(chore), ⚡(perf), 🎨(style), 📦(build), 👷(ci)\n");
454 prompt.push_str("- For breaking changes, add 💥 after the type\n");
455 } else {
456 prompt.push_str("- Use GitMoji format: <emoji> <type>: <description>\n");
457 prompt.push_str("- Common emojis: 🐛(fix), ✨(feat), 📝(docs), 🚀(deploy), ✅(test), ♻️(refactor), 🔧(chore), ⚡(perf), 🎨(style), 📦(build), 👷(ci)\n");
458 }
459 }
460 _ => {}
461 }
462
463 let max_length = config.description_max_length.unwrap_or(100);
465 prompt.push_str(&format!(
466 "- Keep the description under {} characters\n",
467 max_length
468 ));
469
470 if config.description_capitalize.unwrap_or(true) {
471 prompt.push_str("- Capitalize the first letter of the description\n");
472 }
473
474 if !config.description_add_period.unwrap_or(false) {
475 prompt.push_str("- Do not end the description with a period\n");
476 }
477
478 if config.enable_commit_body.unwrap_or(false) {
480 prompt.push_str("\nCOMMIT BODY (optional):\n");
481 prompt.push_str("- Add a blank line after the description, then explain WHY the change was made\n");
482 prompt.push_str("- Use bullet points for multiple changes\n");
483 prompt.push_str("- Wrap body text at 72 characters\n");
484 prompt.push_str("- Focus on motivation and context, not what changed (that's in the diff)\n");
485 }
486
487 prompt
488}
489
490fn get_style_guidance(config: &Config) -> Option<String> {
492 use crate::git;
493 use crate::utils::commit_style::CommitStyleProfile;
494
495 if let Some(cached) = &config.style_profile {
497 return Some(cached.clone());
499 }
500
501 let count = config.history_commits_count.unwrap_or(50);
503
504 match git::get_recent_commit_messages(count) {
505 Ok(commits) => {
506 if commits.is_empty() {
507 return None;
508 }
509
510 let profile = CommitStyleProfile::analyze_from_commits(&commits);
511
512 if profile.is_empty() || commits.len() < 10 {
515 return None;
516 }
517
518 Some(profile.to_prompt_guidance())
519 }
520 Err(e) => {
521 tracing::warn!("Failed to get commit history for style analysis: {}", e);
522 None
523 }
524 }
525}
526
527fn build_user_prompt(diff: &str, context: Option<&str>, _full_gitmoji: bool, _config: &Config) -> String {
529 let mut prompt = String::new();
530
531 if let Some(project_context) = get_project_context() {
533 prompt.push_str(&format!("Project Context: {}\n\n", project_context));
534 }
535
536 let file_summary = extract_file_summary(diff);
538 if !file_summary.is_empty() {
539 prompt.push_str(&format!("Files Changed: {}\n\n", file_summary));
540 }
541
542 if diff.contains("---CHUNK") {
544 let chunk_count = diff.matches("---CHUNK").count();
545 if chunk_count > 1 {
546 prompt.push_str(&format!(
547 "Note: This diff was split into {} chunks due to size. Focus on the overall purpose of the changes across all chunks.\n\n",
548 chunk_count
549 ));
550 } else {
551 prompt.push_str("Note: The diff was split into chunks due to size. Focus on the overall purpose of the changes.\n\n");
552 }
553 }
554
555 if let Some(ctx) = context {
557 prompt.push_str(&format!("Additional context: {}\n\n", ctx));
558 }
559
560 prompt.push_str("Generate a commit message for the following git diff:\n");
561 prompt.push_str("```diff\n");
562 prompt.push_str(diff);
563 prompt.push_str("\n```\n");
564
565 prompt.push_str("\nRemember: Return ONLY the commit message, no explanations or markdown.");
567
568 prompt
569}
570
571fn extract_file_summary(diff: &str) -> String {
573 let mut files: Vec<String> = Vec::new();
574 let mut extensions: std::collections::HashSet<String> = std::collections::HashSet::new();
575 let mut file_types: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
576
577 for line in diff.lines() {
578 if line.starts_with("+++ b/") {
579 let path = line.strip_prefix("+++ b/").unwrap_or(line);
580 if path != "/dev/null" {
581 files.push(path.to_string());
582 if let Some(ext) = std::path::Path::new(path).extension() {
584 if let Some(ext_str) = ext.to_str() {
585 let ext_lower = ext_str.to_lowercase();
586 extensions.insert(ext_lower.clone());
587
588 let category = categorize_file_type(&ext_lower);
590 *file_types.entry(category).or_insert(0) += 1;
591 }
592 } else {
593 if path.contains("Makefile") || path.contains("Dockerfile") || path.contains("LICENSE") {
595 *file_types.entry("config".to_string()).or_insert(0) += 1;
596 }
597 }
598 }
599 }
600 }
601
602 if files.is_empty() {
603 return String::new();
604 }
605
606 let mut summary = format!("{} file(s)", files.len());
608
609 if !file_types.is_empty() {
611 let mut type_list: Vec<_> = file_types.into_iter().collect();
612 type_list.sort_by(|a, b| b.1.cmp(&a.1)); let type_str: Vec<_> = type_list.iter()
615 .map(|(t, c)| format!("{} {}", c, t))
616 .collect();
617 summary.push_str(&format!(" ({})", type_str.join(", ")));
618 }
619
620 if !extensions.is_empty() && extensions.len() <= 5 {
622 let ext_list: Vec<_> = extensions.into_iter().collect();
623 summary.push_str(&format!(" [.{}]", ext_list.join(", .")));
624 }
625
626 if files.len() <= 3 {
628 summary.push_str(&format!(": {}", files.join(", ")));
629 }
630
631 summary
632}
633
634fn categorize_file_type(ext: &str) -> String {
636 match ext {
637 "rs" => "Rust",
639 "py" => "Python",
640 "js" => "JavaScript",
641 "ts" => "TypeScript",
642 "jsx" | "tsx" => "React",
643 "go" => "Go",
644 "java" => "Java",
645 "kt" => "Kotlin",
646 "swift" => "Swift",
647 "c" | "cpp" | "cc" | "h" | "hpp" => "C/C++",
648 "rb" => "Ruby",
649 "php" => "PHP",
650 "cs" => "C#",
651 "scala" => "Scala",
652 "r" => "R",
653 "m" => "Objective-C",
654 "lua" => "Lua",
655 "pl" => "Perl",
656
657 "html" | "htm" => "HTML",
659 "css" | "scss" | "sass" | "less" => "CSS",
660 "vue" => "Vue",
661 "svelte" => "Svelte",
662
663 "json" => "JSON",
665 "yaml" | "yml" => "YAML",
666 "toml" => "TOML",
667 "xml" => "XML",
668 "csv" => "CSV",
669 "sql" => "SQL",
670
671 "md" | "markdown" => "Markdown",
673 "rst" => "reStructuredText",
674 "txt" => "Text",
675
676 "sh" | "bash" | "zsh" | "fish" => "Shell",
678 "ps1" => "PowerShell",
679 "bat" | "cmd" => "Batch",
680 "dockerfile" => "Docker",
681 "makefile" | "mk" => "Make",
682 "cmake" => "CMake",
683
684 _ => "Other",
686 }.to_string()
687}
688
689fn get_project_context() -> Option<String> {
691 use std::path::Path;
692
693 if let Ok(repo_root) = crate::git::get_repo_root() {
695 let context_path = Path::new(&repo_root).join(".rco").join("context.txt");
696 if context_path.exists() {
697 if let Ok(content) = std::fs::read_to_string(&context_path) {
698 let trimmed = content.trim();
699 if !trimmed.is_empty() {
700 return Some(trimmed.to_string());
701 }
702 }
703 }
704
705 let readme_path = Path::new(&repo_root).join("README.md");
707 if readme_path.exists() {
708 if let Ok(content) = std::fs::read_to_string(&readme_path) {
709 for line in content.lines() {
711 let trimmed = line.trim();
712 if !trimmed.is_empty() && !trimmed.starts_with('#') {
713 let context = if let Some(idx) = trimmed.find('.') {
715 trimmed[..idx + 1].to_string()
716 } else {
717 trimmed.chars().take(100).collect()
718 };
719 if !context.is_empty() {
720 return Some(context);
721 }
722 }
723 }
724 }
725 }
726
727 let cargo_path = Path::new(&repo_root).join("Cargo.toml");
729 if cargo_path.exists() {
730 if let Ok(content) = std::fs::read_to_string(&cargo_path) {
731 let mut in_package = false;
733 for line in content.lines() {
734 let trimmed = line.trim();
735 if trimmed == "[package]" {
736 in_package = true;
737 } else if trimmed.starts_with('[') && trimmed != "[package]" {
738 in_package = false;
739 } else if in_package && trimmed.starts_with("description") {
740 if let Some(idx) = trimmed.find('=') {
741 let desc = trimmed[idx+1..].trim().trim_matches('"');
742 if !desc.is_empty() {
743 return Some(format!("Rust project: {}", desc));
744 }
745 }
746 }
747 }
748 }
749 }
750
751 let package_path = Path::new(&repo_root).join("package.json");
753 if package_path.exists() {
754 if let Ok(content) = std::fs::read_to_string(&package_path) {
755 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&content) {
756 if let Some(desc) = json.get("description").and_then(|d| d.as_str()) {
757 if !desc.is_empty() {
758 return Some(format!("Node.js project: {}", desc));
759 }
760 }
761 }
762 }
763 }
764 }
765
766 None
767}
768
769pub fn build_prompt(
771 diff: &str,
772 context: Option<&str>,
773 config: &Config,
774 full_gitmoji: bool,
775) -> String {
776 let (system, user) = split_prompt(diff, context, config, full_gitmoji);
777 format!("{}\n\n---\n\n{}", system, user)
778}
779
780#[allow(dead_code)]
782pub fn create_provider_for_account(
783 account: &AccountConfig,
784 config: &Config,
785) -> Result<Box<dyn AIProvider>> {
786 use crate::auth::token_storage;
787 use crate::config::secure_storage;
788
789 let provider = account.provider.to_lowercase();
790
791 let credentials = match &account.auth {
793 crate::config::accounts::AuthMethod::ApiKey { key_id } => {
794 token_storage::get_api_key_for_account(key_id)?
796 .or_else(|| secure_storage::get_secret(key_id).ok().flatten())
797 }
798 crate::config::accounts::AuthMethod::OAuth {
799 provider: _oauth_provider,
800 account_id,
801 } => {
802 token_storage::get_tokens_for_account(account_id)?.map(|t| t.access_token)
804 }
805 crate::config::accounts::AuthMethod::EnvVar { name } => std::env::var(name).ok(),
806 crate::config::accounts::AuthMethod::Bearer { token_id } => {
807 token_storage::get_bearer_token_for_account(token_id)?
809 .or_else(|| secure_storage::get_secret(token_id).ok().flatten())
810 }
811 };
812
813 match provider.as_str() {
814 #[cfg(feature = "openai")]
815 "openai" | "codex" => {
816 if let Some(key) = credentials.as_ref() {
817 Ok(Box::new(openai::OpenAIProvider::from_account(
818 account, key, config,
819 )?))
820 } else {
821 Ok(Box::new(openai::OpenAIProvider::new(config)?))
822 }
823 }
824 #[cfg(feature = "anthropic")]
825 "anthropic" | "claude" | "claude-code" => {
826 if let Some(key) = credentials.as_ref() {
827 Ok(Box::new(anthropic::AnthropicProvider::from_account(
828 account, key, config,
829 )?))
830 } else {
831 Ok(Box::new(anthropic::AnthropicProvider::new(config)?))
832 }
833 }
834 #[cfg(feature = "ollama")]
835 "ollama" => {
836 if let Some(key) = credentials.as_ref() {
837 Ok(Box::new(ollama::OllamaProvider::from_account(
838 account, key, config,
839 )?))
840 } else {
841 Ok(Box::new(ollama::OllamaProvider::new(config)?))
842 }
843 }
844 #[cfg(feature = "gemini")]
845 "gemini" => {
846 if let Some(key) = credentials.as_ref() {
847 Ok(Box::new(gemini::GeminiProvider::from_account(
848 account, key, config,
849 )?))
850 } else {
851 Ok(Box::new(gemini::GeminiProvider::new(config)?))
852 }
853 }
854 #[cfg(feature = "azure")]
855 "azure" | "azure-openai" => {
856 if let Some(key) = credentials.as_ref() {
857 Ok(Box::new(azure::AzureProvider::from_account(
858 account, key, config,
859 )?))
860 } else {
861 Ok(Box::new(azure::AzureProvider::new(config)?))
862 }
863 }
864 #[cfg(feature = "perplexity")]
865 "perplexity" => {
866 if let Some(key) = credentials.as_ref() {
867 Ok(Box::new(perplexity::PerplexityProvider::from_account(
868 account, key, config,
869 )?))
870 } else {
871 Ok(Box::new(perplexity::PerplexityProvider::new(config)?))
872 }
873 }
874 #[cfg(feature = "xai")]
875 "xai" | "grok" | "x-ai" => {
876 if let Some(key) = credentials.as_ref() {
877 Ok(Box::new(xai::XAIProvider::from_account(
878 account, key, config,
879 )?))
880 } else {
881 Ok(Box::new(xai::XAIProvider::new(config)?))
882 }
883 }
884 #[cfg(feature = "huggingface")]
885 "huggingface" | "hf" => {
886 if let Some(key) = credentials.as_ref() {
887 Ok(Box::new(huggingface::HuggingFaceProvider::from_account(
888 account, key, config,
889 )?))
890 } else {
891 Ok(Box::new(huggingface::HuggingFaceProvider::new(config)?))
892 }
893 }
894 #[cfg(feature = "bedrock")]
895 "bedrock" | "aws-bedrock" | "amazon-bedrock" => Ok(Box::new(
896 bedrock::BedrockProvider::from_account(account, "", config)?,
897 )),
898 #[cfg(feature = "vertex")]
899 "vertex" | "vertex-ai" | "google-vertex" | "gcp-vertex" => Ok(Box::new(
900 vertex::VertexProvider::from_account(account, "", config)?,
901 )),
902 #[cfg(feature = "mlx")]
903 "mlx" | "mlx-lm" | "apple-mlx" => {
904 if let Some(_key) = credentials.as_ref() {
905 Ok(Box::new(mlx::MlxProvider::from_account(
906 account, "", config,
907 )?))
908 } else {
909 Ok(Box::new(mlx::MlxProvider::new(config)?))
910 }
911 }
912 #[cfg(feature = "nvidia")]
913 "nvidia" | "nvidia-nim" | "nim" | "nvidia-ai" => {
914 if let Some(key) = credentials.as_ref() {
915 Ok(Box::new(nvidia::NvidiaProvider::from_account(
916 account, key, config,
917 )?))
918 } else {
919 Ok(Box::new(nvidia::NvidiaProvider::new(config)?))
920 }
921 }
922 #[cfg(feature = "flowise")]
923 "flowise" | "flowise-ai" => {
924 if let Some(_key) = credentials.as_ref() {
925 Ok(Box::new(flowise::FlowiseProvider::from_account(
926 account, "", config,
927 )?))
928 } else {
929 Ok(Box::new(flowise::FlowiseProvider::new(config)?))
930 }
931 }
932 _ => {
933 anyhow::bail!(
934 "Unsupported AI provider for account: {}\n\n\
935 Account provider: {}\n\
936 Supported providers: openai, anthropic, ollama, gemini, azure, perplexity, xai, huggingface, bedrock, vertex",
937 account.alias,
938 provider
939 );
940 }
941 }
942}