1#[cfg(feature = "anthropic")]
3pub mod anthropic;
4#[cfg(feature = "azure")]
5pub mod azure;
6#[cfg(feature = "bedrock")]
7pub mod bedrock;
8#[cfg(feature = "flowise")]
9pub mod flowise;
10#[cfg(feature = "gemini")]
11pub mod gemini;
12#[cfg(feature = "huggingface")]
13pub mod huggingface;
14#[cfg(feature = "mlx")]
15pub mod mlx;
16#[cfg(feature = "nvidia")]
17pub mod nvidia;
18#[cfg(feature = "ollama")]
19pub mod ollama;
20#[cfg(feature = "openai")]
21pub mod openai;
22#[cfg(feature = "perplexity")]
23pub mod perplexity;
24#[cfg(feature = "vertex")]
25pub mod vertex;
26#[cfg(feature = "xai")]
27pub mod xai;
28
29pub mod registry;
31
32pub mod prompt;
34
35use crate::config::accounts::AccountConfig;
36use crate::config::Config;
37use anyhow::{Context, Result};
38use async_trait::async_trait;
39use once_cell::sync::Lazy;
40
41#[async_trait]
42pub trait AIProvider: Send + Sync {
43 async fn generate_commit_message(
44 &self,
45 diff: &str,
46 context: Option<&str>,
47 full_gitmoji: bool,
48 config: &Config,
49 ) -> Result<String>;
50
51 async fn generate_commit_messages(
53 &self,
54 diff: &str,
55 context: Option<&str>,
56 full_gitmoji: bool,
57 config: &Config,
58 count: u8,
59 ) -> Result<Vec<String>> {
60 use futures::stream::StreamExt;
61
62 if count <= 1 {
63 match self
65 .generate_commit_message(diff, context, full_gitmoji, config)
66 .await
67 {
68 Ok(msg) => Ok(vec![msg]),
69 Err(e) => {
70 tracing::warn!("Failed to generate message: {}", e);
71 Ok(vec![])
72 }
73 }
74 } else {
75 let futures = (0..count)
77 .map(|_| self.generate_commit_message(diff, context, full_gitmoji, config));
78 let mut stream = futures::stream::FuturesUnordered::from_iter(futures);
79
80 let mut messages = Vec::with_capacity(count as usize);
81 while let Some(result) = stream.next().await {
82 match result {
83 Ok(msg) => messages.push(msg),
84 Err(e) => tracing::warn!("Failed to generate message: {}", e),
85 }
86 }
87 Ok(messages)
88 }
89 }
90
91 #[cfg(any(feature = "openai", feature = "xai"))]
93 async fn generate_pr_description(
94 &self,
95 commits: &[String],
96 diff: &str,
97 config: &Config,
98 ) -> Result<String> {
99 let commits_text = commits.join("\n");
100 let prompt = format!(
101 "Generate a professional pull request description based on the following commits:\n\n{}\n\nDiff:\n{}\n\nFormat the output as:\n## Summary\n## Changes\n## Testing\n## Breaking Changes\n\nKeep it concise and informative.",
102 commits_text, diff
103 );
104
105 let messages = vec![
106 async_openai::types::chat::ChatCompletionRequestSystemMessage::from(
107 "You are an expert at writing pull request descriptions.",
108 )
109 .into(),
110 async_openai::types::chat::ChatCompletionRequestUserMessage::from(prompt).into(),
111 ];
112
113 let request = async_openai::types::chat::CreateChatCompletionRequestArgs::default()
114 .model(
115 config
116 .model
117 .clone()
118 .unwrap_or_else(|| "gpt-3.5-turbo".to_string()),
119 )
120 .messages(messages)
121 .temperature(0.7)
122 .max_tokens(config.tokens_max_output.unwrap_or(1000) as u16)
123 .build()?;
124
125 let api_key = config
127 .api_key
128 .as_ref()
129 .context("API key not configured. Run: rco config set RCO_API_KEY=<your_key>")?;
130 let api_url = config
131 .api_url
132 .as_deref()
133 .unwrap_or("https://api.openai.com/v1");
134
135 let openai_config = async_openai::config::OpenAIConfig::new()
136 .with_api_key(api_key)
137 .with_api_base(api_url);
138
139 let client = async_openai::Client::with_config(openai_config);
140
141 let response = client.chat().create(request).await?;
142
143 let message = response
144 .choices
145 .first()
146 .and_then(|choice| choice.message.content.as_ref())
147 .context("AI returned an empty response")?
148 .trim()
149 .to_string();
150
151 Ok(message)
152 }
153
154 #[cfg(not(any(feature = "openai", feature = "xai")))]
156 async fn generate_pr_description(
157 &self,
158 _commits: &[String],
159 _diff: &str,
160 _config: &Config,
161 ) -> Result<String> {
162 anyhow::bail!(
163 "PR description generation requires the 'openai' or 'xai' feature to be enabled"
164 );
165 }
166}
167
168pub static PROVIDER_REGISTRY: Lazy<registry::ProviderRegistry> = Lazy::new(|| {
170 let reg = registry::ProviderRegistry::new();
171
172 #[cfg(feature = "openai")]
174 {
175 let _ = reg.register(Box::new(openai::OpenAICompatibleProvider::new()));
176 }
177
178 #[cfg(feature = "anthropic")]
180 {
181 let _ = reg.register(Box::new(anthropic::AnthropicProviderBuilder));
182 }
183
184 #[cfg(feature = "ollama")]
185 {
186 let _ = reg.register(Box::new(ollama::OllamaProviderBuilder));
187 }
188
189 #[cfg(feature = "gemini")]
190 {
191 let _ = reg.register(Box::new(gemini::GeminiProviderBuilder));
192 }
193
194 #[cfg(feature = "azure")]
195 {
196 let _ = reg.register(Box::new(azure::AzureProviderBuilder));
197 }
198
199 #[cfg(feature = "perplexity")]
200 {
201 let _ = reg.register(Box::new(perplexity::PerplexityProviderBuilder));
202 }
203
204 #[cfg(feature = "xai")]
205 {
206 let _ = reg.register(Box::new(xai::XAIProviderBuilder));
207 }
208
209 #[cfg(feature = "huggingface")]
210 {
211 let _ = reg.register(Box::new(huggingface::HuggingFaceProviderBuilder));
212 }
213
214 #[cfg(feature = "bedrock")]
215 {
216 let _ = reg.register(Box::new(bedrock::BedrockProviderBuilder));
217 }
218
219 #[cfg(feature = "vertex")]
220 {
221 let _ = reg.register(Box::new(vertex::VertexProviderBuilder));
222 }
223
224 #[cfg(feature = "mlx")]
225 {
226 let _ = reg.register(Box::new(mlx::MlxProviderBuilder));
227 }
228
229 #[cfg(feature = "nvidia")]
230 {
231 let _ = reg.register(Box::new(nvidia::NvidiaProviderBuilder));
232 }
233
234 #[cfg(feature = "flowise")]
235 {
236 let _ = reg.register(Box::new(flowise::FlowiseProviderBuilder));
237 }
238
239 reg
240});
241
242pub fn create_provider(config: &Config) -> Result<Box<dyn AIProvider>> {
244 let provider_name = config.ai_provider.as_deref().unwrap_or("openai");
245
246 if let Some(provider) = PROVIDER_REGISTRY.create(provider_name, config)? {
248 return Ok(provider);
249 }
250
251 let available: Vec<String> = PROVIDER_REGISTRY
253 .all()
254 .unwrap_or_default()
255 .iter()
256 .map(|e| {
257 let aliases = if e.aliases.is_empty() {
258 String::new()
259 } else {
260 format!(" ({})", e.aliases.join(", "))
261 };
262 format!("- {}{}", e.name, aliases)
263 })
264 .chain(std::iter::once(format!(
265 "- {} OpenAI-compatible providers (deepseek, groq, openrouter, etc.)",
266 PROVIDER_REGISTRY
267 .by_category(registry::ProviderCategory::OpenAICompatible)
268 .map_or(0, |v| v.len())
269 )))
270 .filter(|s| !s.contains("0 OpenAI-compatible"))
271 .collect();
272
273 if available.is_empty() {
274 anyhow::bail!(
275 "No AI provider features enabled. Please enable at least one provider feature:\n\
276 --features openai,anthropic,ollama,gemini,azure,perplexity,xai,huggingface,bedrock,vertex"
277 );
278 }
279
280 anyhow::bail!(
281 "Unsupported or disabled AI provider: {}\n\n\
282 Available providers (based on enabled features):\n{}\n\n\
283 Set with: rco config set RCO_AI_PROVIDER=<provider_name>",
284 provider_name,
285 available.join("\n")
286 )
287}
288
289#[allow(dead_code)]
290pub fn available_providers() -> Vec<&'static str> {
292 let mut providers = PROVIDER_REGISTRY
293 .all()
294 .unwrap_or_default()
295 .iter()
296 .flat_map(|e| std::iter::once(e.name).chain(e.aliases.iter().copied()))
297 .collect::<Vec<_>>();
298
299 #[cfg(feature = "openai")]
300 {
301 providers.extend_from_slice(&[
302 "deepseek",
306 "groq",
307 "openrouter",
308 "together",
309 "deepinfra",
310 "mistral",
311 "github-models",
312 "fireworks",
313 "moonshot",
314 "dashscope",
315 "perplexity",
316 "cohere",
320 "cohere-ai",
321 "ai21",
322 "ai21-labs",
323 "upstage",
324 "upstage-ai",
325 "solar",
326 "solar-pro",
327 "nebius",
331 "nebius-ai",
332 "nebius-studio",
333 "ovh",
334 "ovhcloud",
335 "ovh-ai",
336 "scaleway",
337 "scaleway-ai",
338 "friendli",
339 "friendli-ai",
340 "baseten",
341 "baseten-ai",
342 "chutes",
343 "chutes-ai",
344 "ionet",
345 "io-net",
346 "modelscope",
347 "requesty",
348 "morph",
349 "morph-labs",
350 "synthetic",
351 "nano-gpt",
352 "nanogpt",
353 "zenmux",
354 "v0",
355 "v0-vercel",
356 "iflowcn",
357 "venice",
358 "venice-ai",
359 "cortecs",
360 "cortecs-ai",
361 "kimi-coding",
362 "abacus",
363 "abacus-ai",
364 "bailing",
365 "fastrouter",
366 "inference",
367 "inference-net",
368 "submodel",
369 "zai",
370 "zai-coding",
371 "zhipu-coding",
372 "poe",
373 "poe-ai",
374 "cerebras",
375 "cerebras-ai",
376 "sambanova",
377 "sambanova-ai",
378 "novita",
379 "novita-ai",
380 "predibase",
381 "tensorops",
382 "hyperbolic",
383 "hyperbolic-ai",
384 "kluster",
385 "kluster-ai",
386 "lambda",
387 "lambda-labs",
388 "replicate",
389 "targon",
390 "corcel",
391 "cybernative",
392 "cybernative-ai",
393 "edgen",
394 "gigachat",
395 "gigachat-ai",
396 "hydra",
397 "hydra-ai",
398 "jina",
399 "jina-ai",
400 "lingyi",
401 "lingyiwanwu",
402 "monica",
403 "monica-ai",
404 "pollinations",
405 "pollinations-ai",
406 "rawechat",
407 "shuttleai",
408 "shuttle-ai",
409 "teknium",
410 "theb",
411 "theb-ai",
412 "tryleap",
413 "leap-ai",
414 "lmstudio",
418 "lm-studio",
419 "llamacpp",
420 "llama-cpp",
421 "kobold",
422 "koboldcpp",
423 "textgen",
424 "text-generation",
425 "tabby",
426 "siliconflow",
430 "silicon-flow",
431 "zhipu",
432 "zhipu-ai",
433 "bigmodel",
434 "minimax",
435 "minimax-ai",
436 "glm",
437 "chatglm",
438 "baichuan",
439 "01-ai",
440 "yi",
441 "helicone",
445 "helicone-ai",
446 "workers-ai",
447 "cloudflare-ai",
448 "cloudflare-gateway",
449 "vercel-ai",
450 "vercel-gateway",
451 "302ai",
455 "302-ai",
456 "sap-ai",
457 "sap-ai-core",
458 "aimlapi",
462 "ai-ml-api",
463 ]);
464 }
465
466 providers
467}
468
469#[allow(dead_code)]
471pub fn provider_info(provider: &str) -> Option<String> {
472 PROVIDER_REGISTRY.get(provider).map(|e| {
473 let aliases = if e.aliases.is_empty() {
474 String::new()
475 } else {
476 format!(" (aliases: {})", e.aliases.join(", "))
477 };
478 let model = e
479 .default_model
480 .map(|m| format!(", default model: {}", m))
481 .unwrap_or_default();
482 format!("{}{}{}", e.name, aliases, model)
483 })
484}
485
486#[allow(dead_code)]
488pub fn create_provider_for_account(
489 account: &AccountConfig,
490 config: &Config,
491) -> Result<Box<dyn AIProvider>> {
492 use crate::auth::token_storage;
493 use crate::config::secure_storage;
494
495 let provider = account.provider.to_lowercase();
496
497 let credentials = match &account.auth {
499 crate::config::accounts::AuthMethod::ApiKey { key_id } => {
500 token_storage::get_api_key_for_account(key_id)?
502 .or_else(|| secure_storage::get_secret(key_id).ok().flatten())
503 }
504 crate::config::accounts::AuthMethod::OAuth {
505 provider: _oauth_provider,
506 account_id,
507 } => {
508 token_storage::get_tokens_for_account(account_id)?.map(|t| t.access_token)
510 }
511 crate::config::accounts::AuthMethod::EnvVar { name } => std::env::var(name).ok(),
512 crate::config::accounts::AuthMethod::Bearer { token_id } => {
513 token_storage::get_bearer_token_for_account(token_id)?
515 .or_else(|| secure_storage::get_secret(token_id).ok().flatten())
516 }
517 };
518
519 match provider.as_str() {
520 #[cfg(feature = "openai")]
521 "openai" | "codex" => {
522 if let Some(key) = credentials.as_ref() {
523 Ok(Box::new(openai::OpenAIProvider::from_account(
524 account, key, config,
525 )?))
526 } else {
527 Ok(Box::new(openai::OpenAIProvider::new(config)?))
528 }
529 }
530 #[cfg(feature = "anthropic")]
531 "anthropic" | "claude" | "claude-code" => {
532 if let Some(key) = credentials.as_ref() {
533 Ok(Box::new(anthropic::AnthropicProvider::from_account(
534 account, key, config,
535 )?))
536 } else {
537 Ok(Box::new(anthropic::AnthropicProvider::new(config)?))
538 }
539 }
540 #[cfg(feature = "ollama")]
541 "ollama" => {
542 if let Some(key) = credentials.as_ref() {
543 Ok(Box::new(ollama::OllamaProvider::from_account(
544 account, key, config,
545 )?))
546 } else {
547 Ok(Box::new(ollama::OllamaProvider::new(config)?))
548 }
549 }
550 #[cfg(feature = "gemini")]
551 "gemini" => {
552 if let Some(key) = credentials.as_ref() {
553 Ok(Box::new(gemini::GeminiProvider::from_account(
554 account, key, config,
555 )?))
556 } else {
557 Ok(Box::new(gemini::GeminiProvider::new(config)?))
558 }
559 }
560 #[cfg(feature = "azure")]
561 "azure" | "azure-openai" => {
562 if let Some(key) = credentials.as_ref() {
563 Ok(Box::new(azure::AzureProvider::from_account(
564 account, key, config,
565 )?))
566 } else {
567 Ok(Box::new(azure::AzureProvider::new(config)?))
568 }
569 }
570 #[cfg(feature = "perplexity")]
571 "perplexity" => {
572 if let Some(key) = credentials.as_ref() {
573 Ok(Box::new(perplexity::PerplexityProvider::from_account(
574 account, key, config,
575 )?))
576 } else {
577 Ok(Box::new(perplexity::PerplexityProvider::new(config)?))
578 }
579 }
580 #[cfg(feature = "xai")]
581 "xai" | "grok" | "x-ai" => {
582 if let Some(key) = credentials.as_ref() {
583 Ok(Box::new(xai::XAIProvider::from_account(
584 account, key, config,
585 )?))
586 } else {
587 Ok(Box::new(xai::XAIProvider::new(config)?))
588 }
589 }
590 #[cfg(feature = "huggingface")]
591 "huggingface" | "hf" => {
592 if let Some(key) = credentials.as_ref() {
593 Ok(Box::new(huggingface::HuggingFaceProvider::from_account(
594 account, key, config,
595 )?))
596 } else {
597 Ok(Box::new(huggingface::HuggingFaceProvider::new(config)?))
598 }
599 }
600 #[cfg(feature = "bedrock")]
601 "bedrock" | "aws-bedrock" | "amazon-bedrock" => Ok(Box::new(
602 bedrock::BedrockProvider::from_account(account, "", config)?,
603 )),
604 #[cfg(feature = "vertex")]
605 "vertex" | "vertex-ai" | "google-vertex" | "gcp-vertex" => Ok(Box::new(
606 vertex::VertexProvider::from_account(account, "", config)?,
607 )),
608 #[cfg(feature = "mlx")]
609 "mlx" | "mlx-lm" | "apple-mlx" => {
610 if let Some(_key) = credentials.as_ref() {
611 Ok(Box::new(mlx::MlxProvider::from_account(
612 account, "", config,
613 )?))
614 } else {
615 Ok(Box::new(mlx::MlxProvider::new(config)?))
616 }
617 }
618 #[cfg(feature = "nvidia")]
619 "nvidia" | "nvidia-nim" | "nim" | "nvidia-ai" => {
620 if let Some(key) = credentials.as_ref() {
621 Ok(Box::new(nvidia::NvidiaProvider::from_account(
622 account, key, config,
623 )?))
624 } else {
625 Ok(Box::new(nvidia::NvidiaProvider::new(config)?))
626 }
627 }
628 #[cfg(feature = "flowise")]
629 "flowise" | "flowise-ai" => {
630 if let Some(_key) = credentials.as_ref() {
631 Ok(Box::new(flowise::FlowiseProvider::from_account(
632 account, "", config,
633 )?))
634 } else {
635 Ok(Box::new(flowise::FlowiseProvider::new(config)?))
636 }
637 }
638 _ => {
639 anyhow::bail!(
640 "Unsupported AI provider for account: {}\n\n\
641 Account provider: {}\n\
642 Supported providers: openai, anthropic, ollama, gemini, azure, perplexity, xai, huggingface, bedrock, vertex",
643 account.alias,
644 provider
645 );
646 }
647 }
648}