1#[cfg(feature = "anthropic")]
3pub mod anthropic;
4#[cfg(feature = "azure")]
5pub mod azure;
6#[cfg(feature = "bedrock")]
7pub mod bedrock;
8#[cfg(feature = "flowise")]
9pub mod flowise;
10#[cfg(feature = "gemini")]
11pub mod gemini;
12#[cfg(feature = "huggingface")]
13pub mod huggingface;
14#[cfg(feature = "mlx")]
15pub mod mlx;
16#[cfg(feature = "nvidia")]
17pub mod nvidia;
18#[cfg(feature = "ollama")]
19pub mod ollama;
20#[cfg(feature = "openai")]
21pub mod openai;
22#[cfg(feature = "perplexity")]
23pub mod perplexity;
24#[cfg(feature = "vertex")]
25pub mod vertex;
26#[cfg(feature = "xai")]
27pub mod xai;
28
29pub mod registry;
31
32use crate::config::accounts::AccountConfig;
33use crate::config::Config;
34use anyhow::{Context, Result};
35use async_trait::async_trait;
36use once_cell::sync::Lazy;
37
38#[async_trait]
39pub trait AIProvider: Send + Sync {
40 async fn generate_commit_message(
41 &self,
42 diff: &str,
43 context: Option<&str>,
44 full_gitmoji: bool,
45 config: &Config,
46 ) -> Result<String>;
47
48 async fn generate_commit_messages(
50 &self,
51 diff: &str,
52 context: Option<&str>,
53 full_gitmoji: bool,
54 config: &Config,
55 count: u8,
56 ) -> Result<Vec<String>> {
57 use futures::stream::StreamExt;
58
59 if count <= 1 {
60 match self
62 .generate_commit_message(diff, context, full_gitmoji, config)
63 .await
64 {
65 Ok(msg) => Ok(vec![msg]),
66 Err(e) => {
67 tracing::warn!("Failed to generate message: {}", e);
68 Ok(vec![])
69 }
70 }
71 } else {
72 let futures = (0..count)
74 .map(|_| self.generate_commit_message(diff, context, full_gitmoji, config));
75 let mut stream = futures::stream::FuturesUnordered::from_iter(futures);
76
77 let mut messages = Vec::with_capacity(count as usize);
78 while let Some(result) = stream.next().await {
79 match result {
80 Ok(msg) => messages.push(msg),
81 Err(e) => tracing::warn!("Failed to generate message: {}", e),
82 }
83 }
84 Ok(messages)
85 }
86 }
87
88 #[cfg(any(feature = "openai", feature = "xai"))]
90 async fn generate_pr_description(
91 &self,
92 commits: &[String],
93 diff: &str,
94 config: &Config,
95 ) -> Result<String> {
96 let commits_text = commits.join("\n");
97 let prompt = format!(
98 "Generate a professional pull request description based on the following commits:\n\n{}\n\nDiff:\n{}\n\nFormat the output as:\n## Summary\n## Changes\n## Testing\n## Breaking Changes\n\nKeep it concise and informative.",
99 commits_text, diff
100 );
101
102 let messages = vec![
103 async_openai::types::chat::ChatCompletionRequestSystemMessage::from(
104 "You are an expert at writing pull request descriptions.",
105 )
106 .into(),
107 async_openai::types::chat::ChatCompletionRequestUserMessage::from(prompt).into(),
108 ];
109
110 let request = async_openai::types::chat::CreateChatCompletionRequestArgs::default()
111 .model(
112 config
113 .model
114 .clone()
115 .unwrap_or_else(|| "gpt-3.5-turbo".to_string()),
116 )
117 .messages(messages)
118 .temperature(0.7)
119 .max_tokens(config.tokens_max_output.unwrap_or(1000) as u16)
120 .build()?;
121
122 let api_key = config
124 .api_key
125 .as_ref()
126 .context("API key not configured. Run: rco config set RCO_API_KEY=<your_key>")?;
127 let api_url = config
128 .api_url
129 .as_deref()
130 .unwrap_or("https://api.openai.com/v1");
131
132 let openai_config = async_openai::config::OpenAIConfig::new()
133 .with_api_key(api_key)
134 .with_api_base(api_url);
135
136 let client = async_openai::Client::with_config(openai_config);
137
138 let response = client.chat().create(request).await?;
139
140 let message = response
141 .choices
142 .first()
143 .and_then(|choice| choice.message.content.as_ref())
144 .context("AI returned an empty response")?
145 .trim()
146 .to_string();
147
148 Ok(message)
149 }
150
151 #[cfg(not(any(feature = "openai", feature = "xai")))]
153 async fn generate_pr_description(
154 &self,
155 _commits: &[String],
156 _diff: &str,
157 _config: &Config,
158 ) -> Result<String> {
159 anyhow::bail!(
160 "PR description generation requires the 'openai' or 'xai' feature to be enabled"
161 );
162 }
163}
164
165pub static PROVIDER_REGISTRY: Lazy<registry::ProviderRegistry> = Lazy::new(|| {
167 let reg = registry::ProviderRegistry::new();
168
169 #[cfg(feature = "openai")]
171 {
172 let _ = reg.register(Box::new(openai::OpenAICompatibleProvider::new()));
173 }
174
175 #[cfg(feature = "anthropic")]
177 {
178 let _ = reg.register(Box::new(anthropic::AnthropicProviderBuilder));
179 }
180
181 #[cfg(feature = "ollama")]
182 {
183 let _ = reg.register(Box::new(ollama::OllamaProviderBuilder));
184 }
185
186 #[cfg(feature = "gemini")]
187 {
188 let _ = reg.register(Box::new(gemini::GeminiProviderBuilder));
189 }
190
191 #[cfg(feature = "azure")]
192 {
193 let _ = reg.register(Box::new(azure::AzureProviderBuilder));
194 }
195
196 #[cfg(feature = "perplexity")]
197 {
198 let _ = reg.register(Box::new(perplexity::PerplexityProviderBuilder));
199 }
200
201 #[cfg(feature = "xai")]
202 {
203 let _ = reg.register(Box::new(xai::XAIProviderBuilder));
204 }
205
206 #[cfg(feature = "huggingface")]
207 {
208 let _ = reg.register(Box::new(huggingface::HuggingFaceProviderBuilder));
209 }
210
211 #[cfg(feature = "bedrock")]
212 {
213 let _ = reg.register(Box::new(bedrock::BedrockProviderBuilder));
214 }
215
216 #[cfg(feature = "vertex")]
217 {
218 let _ = reg.register(Box::new(vertex::VertexProviderBuilder));
219 }
220
221 #[cfg(feature = "mlx")]
222 {
223 let _ = reg.register(Box::new(mlx::MlxProviderBuilder));
224 }
225
226 #[cfg(feature = "nvidia")]
227 {
228 let _ = reg.register(Box::new(nvidia::NvidiaProviderBuilder));
229 }
230
231 #[cfg(feature = "flowise")]
232 {
233 let _ = reg.register(Box::new(flowise::FlowiseProviderBuilder));
234 }
235
236 reg
237});
238
239pub fn create_provider(config: &Config) -> Result<Box<dyn AIProvider>> {
241 let provider_name = config.ai_provider.as_deref().unwrap_or("openai");
242
243 if let Some(provider) = PROVIDER_REGISTRY.create(provider_name, config)? {
245 return Ok(provider);
246 }
247
248 let available: Vec<String> = PROVIDER_REGISTRY
250 .all()
251 .unwrap_or_default()
252 .iter()
253 .map(|e| {
254 let aliases = if e.aliases.is_empty() {
255 String::new()
256 } else {
257 format!(" ({})", e.aliases.join(", "))
258 };
259 format!("- {}{}", e.name, aliases)
260 })
261 .chain(std::iter::once(format!(
262 "- {} OpenAI-compatible providers (deepseek, groq, openrouter, etc.)",
263 PROVIDER_REGISTRY
264 .by_category(registry::ProviderCategory::OpenAICompatible)
265 .map_or(0, |v| v.len())
266 )))
267 .filter(|s| !s.contains("0 OpenAI-compatible"))
268 .collect();
269
270 if available.is_empty() {
271 anyhow::bail!(
272 "No AI provider features enabled. Please enable at least one provider feature:\n\
273 --features openai,anthropic,ollama,gemini,azure,perplexity,xai,huggingface,bedrock,vertex"
274 );
275 }
276
277 anyhow::bail!(
278 "Unsupported or disabled AI provider: {}\n\n\
279 Available providers (based on enabled features):\n{}\n\n\
280 Set with: rco config set RCO_AI_PROVIDER=<provider_name>",
281 provider_name,
282 available.join("\n")
283 )
284}
285
286#[allow(dead_code)]
287pub fn available_providers() -> Vec<&'static str> {
289 let mut providers = PROVIDER_REGISTRY
290 .all()
291 .unwrap_or_default()
292 .iter()
293 .flat_map(|e| std::iter::once(e.name).chain(e.aliases.iter().copied()))
294 .collect::<Vec<_>>();
295
296 #[cfg(feature = "openai")]
297 {
298 providers.extend_from_slice(&[
299 "deepseek",
303 "groq",
304 "openrouter",
305 "together",
306 "deepinfra",
307 "mistral",
308 "github-models",
309 "fireworks",
310 "moonshot",
311 "dashscope",
312 "perplexity",
313 "cohere",
317 "cohere-ai",
318 "ai21",
319 "ai21-labs",
320 "upstage",
321 "upstage-ai",
322 "solar",
323 "solar-pro",
324 "nebius",
328 "nebius-ai",
329 "nebius-studio",
330 "ovh",
331 "ovhcloud",
332 "ovh-ai",
333 "scaleway",
334 "scaleway-ai",
335 "friendli",
336 "friendli-ai",
337 "baseten",
338 "baseten-ai",
339 "chutes",
340 "chutes-ai",
341 "ionet",
342 "io-net",
343 "modelscope",
344 "requesty",
345 "morph",
346 "morph-labs",
347 "synthetic",
348 "nano-gpt",
349 "nanogpt",
350 "zenmux",
351 "v0",
352 "v0-vercel",
353 "iflowcn",
354 "venice",
355 "venice-ai",
356 "cortecs",
357 "cortecs-ai",
358 "kimi-coding",
359 "abacus",
360 "abacus-ai",
361 "bailing",
362 "fastrouter",
363 "inference",
364 "inference-net",
365 "submodel",
366 "zai",
367 "zai-coding",
368 "zhipu-coding",
369 "poe",
370 "poe-ai",
371 "cerebras",
372 "cerebras-ai",
373 "sambanova",
374 "sambanova-ai",
375 "novita",
376 "novita-ai",
377 "predibase",
378 "tensorops",
379 "hyperbolic",
380 "hyperbolic-ai",
381 "kluster",
382 "kluster-ai",
383 "lambda",
384 "lambda-labs",
385 "replicate",
386 "targon",
387 "corcel",
388 "cybernative",
389 "cybernative-ai",
390 "edgen",
391 "gigachat",
392 "gigachat-ai",
393 "hydra",
394 "hydra-ai",
395 "jina",
396 "jina-ai",
397 "lingyi",
398 "lingyiwanwu",
399 "monica",
400 "monica-ai",
401 "pollinations",
402 "pollinations-ai",
403 "rawechat",
404 "shuttleai",
405 "shuttle-ai",
406 "teknium",
407 "theb",
408 "theb-ai",
409 "tryleap",
410 "leap-ai",
411 "lmstudio",
415 "lm-studio",
416 "llamacpp",
417 "llama-cpp",
418 "kobold",
419 "koboldcpp",
420 "textgen",
421 "text-generation",
422 "tabby",
423 "siliconflow",
427 "silicon-flow",
428 "zhipu",
429 "zhipu-ai",
430 "bigmodel",
431 "minimax",
432 "minimax-ai",
433 "glm",
434 "chatglm",
435 "baichuan",
436 "01-ai",
437 "yi",
438 "helicone",
442 "helicone-ai",
443 "workers-ai",
444 "cloudflare-ai",
445 "cloudflare-gateway",
446 "vercel-ai",
447 "vercel-gateway",
448 "302ai",
452 "302-ai",
453 "sap-ai",
454 "sap-ai-core",
455 "aimlapi",
459 "ai-ml-api",
460 ]);
461 }
462
463 providers
464}
465
466#[allow(dead_code)]
468pub fn provider_info(provider: &str) -> Option<String> {
469 PROVIDER_REGISTRY.get(provider).map(|e| {
470 let aliases = if e.aliases.is_empty() {
471 String::new()
472 } else {
473 format!(" (aliases: {})", e.aliases.join(", "))
474 };
475 let model = e
476 .default_model
477 .map(|m| format!(", default model: {}", m))
478 .unwrap_or_default();
479 format!("{}{}{}", e.name, aliases, model)
480 })
481}
482
483pub fn split_prompt(
485 diff: &str,
486 context: Option<&str>,
487 config: &Config,
488 full_gitmoji: bool,
489) -> (String, String) {
490 let system_prompt = build_system_prompt(config, full_gitmoji);
491 let user_prompt = build_user_prompt(diff, context, full_gitmoji, config);
492 (system_prompt, user_prompt)
493}
494
495fn build_system_prompt(config: &Config, full_gitmoji: bool) -> String {
497 let mut prompt = String::new();
498
499 prompt.push_str("You are an expert at writing clear, concise git commit messages.\n\n");
500
501 prompt.push_str("OUTPUT RULES:\n");
503 prompt.push_str("- Return ONLY the commit message, with no additional explanation, markdown formatting, or code blocks\n");
504 prompt.push_str("- Do not include any reasoning, thinking, analysis, <thinking> tags, or XML-like tags in your response\n");
505 prompt.push_str("- Never explain your choices or provide commentary\n");
506 prompt.push_str(
507 "- If you cannot generate a meaningful commit message, return \"chore: update\"\n\n",
508 );
509
510 if config.learn_from_history.unwrap_or(false) {
512 if let Some(style_guidance) = get_style_guidance(config) {
513 prompt.push_str("REPO STYLE (learned from commit history):\n");
514 prompt.push_str(&style_guidance);
515 prompt.push('\n');
516 }
517 }
518
519 if let Some(locale) = &config.language {
521 prompt.push_str(&format!(
522 "- Generate the commit message in {} language\n",
523 locale
524 ));
525 }
526
527 let commit_type = config.commit_type.as_deref().unwrap_or("conventional");
529 match commit_type {
530 "conventional" => {
531 prompt.push_str("- Use conventional commit format: <type>(<scope>): <description>\n");
532 prompt.push_str(
533 "- Types: feat, fix, docs, style, refactor, perf, test, build, ci, chore\n",
534 );
535 if config.omit_scope.unwrap_or(false) {
536 prompt.push_str("- Omit the scope, use format: <type>: <description>\n");
537 }
538 }
539 "gitmoji" => {
540 if full_gitmoji {
541 prompt.push_str("- Use GitMoji format with full emoji specification from https://gitmoji.dev/\n");
542 prompt.push_str("- Common emojis: ✨(feat), 🐛(fix), 📝(docs), 🚀(deploy), ♻️(refactor), ✅(test), 🔧(chore), ⚡(perf), 🎨(style), 📦(build), 👷(ci)\n");
543 prompt.push_str("- For breaking changes, add 💥 after the type\n");
544 } else {
545 prompt.push_str("- Use GitMoji format: <emoji> <type>: <description>\n");
546 prompt.push_str("- Common emojis: 🐛(fix), ✨(feat), 📝(docs), 🚀(deploy), ✅(test), ♻️(refactor), 🔧(chore), ⚡(perf), 🎨(style), 📦(build), 👷(ci)\n");
547 }
548 }
549 _ => {}
550 }
551
552 let max_length = config.description_max_length.unwrap_or(100);
554 prompt.push_str(&format!(
555 "- Keep the description under {} characters\n",
556 max_length
557 ));
558
559 if config.description_capitalize.unwrap_or(true) {
560 prompt.push_str("- Capitalize the first letter of the description\n");
561 }
562
563 if !config.description_add_period.unwrap_or(false) {
564 prompt.push_str("- Do not end the description with a period\n");
565 }
566
567 if config.enable_commit_body.unwrap_or(false) {
569 prompt.push_str("\nCOMMIT BODY (optional):\n");
570 prompt.push_str(
571 "- Add a blank line after the description, then explain WHY the change was made\n",
572 );
573 prompt.push_str("- Use bullet points for multiple changes\n");
574 prompt.push_str("- Wrap body text at 72 characters\n");
575 prompt
576 .push_str("- Focus on motivation and context, not what changed (that's in the diff)\n");
577 }
578
579 prompt
580}
581
582fn get_style_guidance(config: &Config) -> Option<String> {
584 use crate::git;
585 use crate::utils::commit_style::CommitStyleProfile;
586
587 if let Some(cached) = &config.style_profile {
589 return Some(cached.clone());
591 }
592
593 let count = config.history_commits_count.unwrap_or(50);
595
596 match git::get_recent_commit_messages(count) {
597 Ok(commits) => {
598 if commits.is_empty() {
599 return None;
600 }
601
602 let profile = CommitStyleProfile::analyze_from_commits(&commits);
603
604 if profile.is_empty() || commits.len() < 10 {
607 return None;
608 }
609
610 Some(profile.to_prompt_guidance())
611 }
612 Err(e) => {
613 tracing::warn!("Failed to get commit history for style analysis: {}", e);
614 None
615 }
616 }
617}
618
619fn build_user_prompt(
621 diff: &str,
622 context: Option<&str>,
623 _full_gitmoji: bool,
624 _config: &Config,
625) -> String {
626 let mut prompt = String::new();
627
628 if let Some(project_context) = get_project_context() {
630 prompt.push_str(&format!("Project Context: {}\n\n", project_context));
631 }
632
633 let file_summary = extract_file_summary(diff);
635 if !file_summary.is_empty() {
636 prompt.push_str(&format!("Files Changed: {}\n\n", file_summary));
637 }
638
639 if diff.contains("---CHUNK") {
641 let chunk_count = diff.matches("---CHUNK").count();
642 if chunk_count > 1 {
643 prompt.push_str(&format!(
644 "Note: This diff was split into {} chunks due to size. Focus on the overall purpose of the changes across all chunks.\n\n",
645 chunk_count
646 ));
647 } else {
648 prompt.push_str("Note: The diff was split into chunks due to size. Focus on the overall purpose of the changes.\n\n");
649 }
650 }
651
652 if let Some(ctx) = context {
654 prompt.push_str(&format!("Additional context: {}\n\n", ctx));
655 }
656
657 prompt.push_str("Generate a commit message for the following git diff:\n");
658 prompt.push_str("```diff\n");
659 prompt.push_str(diff);
660 prompt.push_str("\n```\n");
661
662 prompt.push_str("\nRemember: Return ONLY the commit message, no explanations or markdown.");
664
665 prompt
666}
667
668fn extract_file_summary(diff: &str) -> String {
670 let mut files: Vec<String> = Vec::new();
671 let mut extensions: std::collections::HashSet<String> = std::collections::HashSet::new();
672 let mut file_types: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
673
674 for line in diff.lines() {
675 if line.starts_with("+++ b/") {
676 let path = line.strip_prefix("+++ b/").unwrap_or(line);
677 if path != "/dev/null" {
678 files.push(path.to_string());
679 if let Some(ext) = std::path::Path::new(path).extension() {
681 if let Some(ext_str) = ext.to_str() {
682 let ext_lower = ext_str.to_lowercase();
683 extensions.insert(ext_lower.clone());
684
685 let category = categorize_file_type(&ext_lower);
687 *file_types.entry(category).or_insert(0) += 1;
688 }
689 } else {
690 if path.contains("Makefile")
692 || path.contains("Dockerfile")
693 || path.contains("LICENSE")
694 {
695 *file_types.entry("config".to_string()).or_insert(0) += 1;
696 }
697 }
698 }
699 }
700 }
701
702 if files.is_empty() {
703 return String::new();
704 }
705
706 let mut summary = format!("{} file(s)", files.len());
708
709 if !file_types.is_empty() {
711 let mut type_list: Vec<_> = file_types.into_iter().collect();
712 type_list.sort_by(|a, b| b.1.cmp(&a.1)); let type_str: Vec<_> = type_list
715 .iter()
716 .map(|(t, c)| format!("{} {}", c, t))
717 .collect();
718 summary.push_str(&format!(" ({})", type_str.join(", ")));
719 }
720
721 if !extensions.is_empty() && extensions.len() <= 5 {
723 let ext_list: Vec<_> = extensions.into_iter().collect();
724 summary.push_str(&format!(" [.{}]", ext_list.join(", .")));
725 }
726
727 if files.len() <= 3 {
729 summary.push_str(&format!(": {}", files.join(", ")));
730 }
731
732 summary
733}
734
735fn categorize_file_type(ext: &str) -> String {
737 match ext {
738 "rs" => "Rust",
740 "py" => "Python",
741 "js" => "JavaScript",
742 "ts" => "TypeScript",
743 "jsx" | "tsx" => "React",
744 "go" => "Go",
745 "java" => "Java",
746 "kt" => "Kotlin",
747 "swift" => "Swift",
748 "c" | "cpp" | "cc" | "h" | "hpp" => "C/C++",
749 "rb" => "Ruby",
750 "php" => "PHP",
751 "cs" => "C#",
752 "scala" => "Scala",
753 "r" => "R",
754 "m" => "Objective-C",
755 "lua" => "Lua",
756 "pl" => "Perl",
757
758 "html" | "htm" => "HTML",
760 "css" | "scss" | "sass" | "less" => "CSS",
761 "vue" => "Vue",
762 "svelte" => "Svelte",
763
764 "json" => "JSON",
766 "yaml" | "yml" => "YAML",
767 "toml" => "TOML",
768 "xml" => "XML",
769 "csv" => "CSV",
770 "sql" => "SQL",
771
772 "md" | "markdown" => "Markdown",
774 "rst" => "reStructuredText",
775 "txt" => "Text",
776
777 "sh" | "bash" | "zsh" | "fish" => "Shell",
779 "ps1" => "PowerShell",
780 "bat" | "cmd" => "Batch",
781 "dockerfile" => "Docker",
782 "makefile" | "mk" => "Make",
783 "cmake" => "CMake",
784
785 _ => "Other",
787 }
788 .to_string()
789}
790
791fn get_project_context() -> Option<String> {
793 use std::path::Path;
794
795 if let Ok(repo_root) = crate::git::get_repo_root() {
797 let context_path = Path::new(&repo_root).join(".rco").join("context.txt");
798 if context_path.exists() {
799 if let Ok(content) = std::fs::read_to_string(&context_path) {
800 let trimmed = content.trim();
801 if !trimmed.is_empty() {
802 return Some(trimmed.to_string());
803 }
804 }
805 }
806
807 let readme_path = Path::new(&repo_root).join("README.md");
809 if readme_path.exists() {
810 if let Ok(content) = std::fs::read_to_string(&readme_path) {
811 for line in content.lines() {
813 let trimmed = line.trim();
814 if !trimmed.is_empty() && !trimmed.starts_with('#') {
815 let context = if let Some(idx) = trimmed.find('.') {
817 trimmed[..idx + 1].to_string()
818 } else {
819 trimmed.chars().take(100).collect()
820 };
821 if !context.is_empty() {
822 return Some(context);
823 }
824 }
825 }
826 }
827 }
828
829 let cargo_path = Path::new(&repo_root).join("Cargo.toml");
831 if cargo_path.exists() {
832 if let Ok(content) = std::fs::read_to_string(&cargo_path) {
833 let mut in_package = false;
835 for line in content.lines() {
836 let trimmed = line.trim();
837 if trimmed == "[package]" {
838 in_package = true;
839 } else if trimmed.starts_with('[') && trimmed != "[package]" {
840 in_package = false;
841 } else if in_package && trimmed.starts_with("description") {
842 if let Some(idx) = trimmed.find('=') {
843 let desc = trimmed[idx + 1..].trim().trim_matches('"');
844 if !desc.is_empty() {
845 return Some(format!("Rust project: {}", desc));
846 }
847 }
848 }
849 }
850 }
851 }
852
853 let package_path = Path::new(&repo_root).join("package.json");
855 if package_path.exists() {
856 if let Ok(content) = std::fs::read_to_string(&package_path) {
857 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&content) {
858 if let Some(desc) = json.get("description").and_then(|d| d.as_str()) {
859 if !desc.is_empty() {
860 return Some(format!("Node.js project: {}", desc));
861 }
862 }
863 }
864 }
865 }
866 }
867
868 None
869}
870
871pub fn build_prompt(
873 diff: &str,
874 context: Option<&str>,
875 config: &Config,
876 full_gitmoji: bool,
877) -> String {
878 let (system, user) = split_prompt(diff, context, config, full_gitmoji);
879 format!("{}\n\n---\n\n{}", system, user)
880}
881
882#[allow(dead_code)]
884pub fn create_provider_for_account(
885 account: &AccountConfig,
886 config: &Config,
887) -> Result<Box<dyn AIProvider>> {
888 use crate::auth::token_storage;
889 use crate::config::secure_storage;
890
891 let provider = account.provider.to_lowercase();
892
893 let credentials = match &account.auth {
895 crate::config::accounts::AuthMethod::ApiKey { key_id } => {
896 token_storage::get_api_key_for_account(key_id)?
898 .or_else(|| secure_storage::get_secret(key_id).ok().flatten())
899 }
900 crate::config::accounts::AuthMethod::OAuth {
901 provider: _oauth_provider,
902 account_id,
903 } => {
904 token_storage::get_tokens_for_account(account_id)?.map(|t| t.access_token)
906 }
907 crate::config::accounts::AuthMethod::EnvVar { name } => std::env::var(name).ok(),
908 crate::config::accounts::AuthMethod::Bearer { token_id } => {
909 token_storage::get_bearer_token_for_account(token_id)?
911 .or_else(|| secure_storage::get_secret(token_id).ok().flatten())
912 }
913 };
914
915 match provider.as_str() {
916 #[cfg(feature = "openai")]
917 "openai" | "codex" => {
918 if let Some(key) = credentials.as_ref() {
919 Ok(Box::new(openai::OpenAIProvider::from_account(
920 account, key, config,
921 )?))
922 } else {
923 Ok(Box::new(openai::OpenAIProvider::new(config)?))
924 }
925 }
926 #[cfg(feature = "anthropic")]
927 "anthropic" | "claude" | "claude-code" => {
928 if let Some(key) = credentials.as_ref() {
929 Ok(Box::new(anthropic::AnthropicProvider::from_account(
930 account, key, config,
931 )?))
932 } else {
933 Ok(Box::new(anthropic::AnthropicProvider::new(config)?))
934 }
935 }
936 #[cfg(feature = "ollama")]
937 "ollama" => {
938 if let Some(key) = credentials.as_ref() {
939 Ok(Box::new(ollama::OllamaProvider::from_account(
940 account, key, config,
941 )?))
942 } else {
943 Ok(Box::new(ollama::OllamaProvider::new(config)?))
944 }
945 }
946 #[cfg(feature = "gemini")]
947 "gemini" => {
948 if let Some(key) = credentials.as_ref() {
949 Ok(Box::new(gemini::GeminiProvider::from_account(
950 account, key, config,
951 )?))
952 } else {
953 Ok(Box::new(gemini::GeminiProvider::new(config)?))
954 }
955 }
956 #[cfg(feature = "azure")]
957 "azure" | "azure-openai" => {
958 if let Some(key) = credentials.as_ref() {
959 Ok(Box::new(azure::AzureProvider::from_account(
960 account, key, config,
961 )?))
962 } else {
963 Ok(Box::new(azure::AzureProvider::new(config)?))
964 }
965 }
966 #[cfg(feature = "perplexity")]
967 "perplexity" => {
968 if let Some(key) = credentials.as_ref() {
969 Ok(Box::new(perplexity::PerplexityProvider::from_account(
970 account, key, config,
971 )?))
972 } else {
973 Ok(Box::new(perplexity::PerplexityProvider::new(config)?))
974 }
975 }
976 #[cfg(feature = "xai")]
977 "xai" | "grok" | "x-ai" => {
978 if let Some(key) = credentials.as_ref() {
979 Ok(Box::new(xai::XAIProvider::from_account(
980 account, key, config,
981 )?))
982 } else {
983 Ok(Box::new(xai::XAIProvider::new(config)?))
984 }
985 }
986 #[cfg(feature = "huggingface")]
987 "huggingface" | "hf" => {
988 if let Some(key) = credentials.as_ref() {
989 Ok(Box::new(huggingface::HuggingFaceProvider::from_account(
990 account, key, config,
991 )?))
992 } else {
993 Ok(Box::new(huggingface::HuggingFaceProvider::new(config)?))
994 }
995 }
996 #[cfg(feature = "bedrock")]
997 "bedrock" | "aws-bedrock" | "amazon-bedrock" => Ok(Box::new(
998 bedrock::BedrockProvider::from_account(account, "", config)?,
999 )),
1000 #[cfg(feature = "vertex")]
1001 "vertex" | "vertex-ai" | "google-vertex" | "gcp-vertex" => Ok(Box::new(
1002 vertex::VertexProvider::from_account(account, "", config)?,
1003 )),
1004 #[cfg(feature = "mlx")]
1005 "mlx" | "mlx-lm" | "apple-mlx" => {
1006 if let Some(_key) = credentials.as_ref() {
1007 Ok(Box::new(mlx::MlxProvider::from_account(
1008 account, "", config,
1009 )?))
1010 } else {
1011 Ok(Box::new(mlx::MlxProvider::new(config)?))
1012 }
1013 }
1014 #[cfg(feature = "nvidia")]
1015 "nvidia" | "nvidia-nim" | "nim" | "nvidia-ai" => {
1016 if let Some(key) = credentials.as_ref() {
1017 Ok(Box::new(nvidia::NvidiaProvider::from_account(
1018 account, key, config,
1019 )?))
1020 } else {
1021 Ok(Box::new(nvidia::NvidiaProvider::new(config)?))
1022 }
1023 }
1024 #[cfg(feature = "flowise")]
1025 "flowise" | "flowise-ai" => {
1026 if let Some(_key) = credentials.as_ref() {
1027 Ok(Box::new(flowise::FlowiseProvider::from_account(
1028 account, "", config,
1029 )?))
1030 } else {
1031 Ok(Box::new(flowise::FlowiseProvider::new(config)?))
1032 }
1033 }
1034 _ => {
1035 anyhow::bail!(
1036 "Unsupported AI provider for account: {}\n\n\
1037 Account provider: {}\n\
1038 Supported providers: openai, anthropic, ollama, gemini, azure, perplexity, xai, huggingface, bedrock, vertex",
1039 account.alias,
1040 provider
1041 );
1042 }
1043 }
1044}