1#[cfg(feature = "anthropic")]
3pub mod anthropic;
4#[cfg(feature = "azure")]
5pub mod azure;
6#[cfg(feature = "bedrock")]
7pub mod bedrock;
8#[cfg(feature = "flowise")]
9pub mod flowise;
10#[cfg(feature = "gemini")]
11pub mod gemini;
12#[cfg(feature = "huggingface")]
13pub mod huggingface;
14#[cfg(feature = "mlx")]
15pub mod mlx;
16#[cfg(feature = "nvidia")]
17pub mod nvidia;
18#[cfg(feature = "ollama")]
19pub mod ollama;
20#[cfg(feature = "openai")]
21pub mod openai;
22#[cfg(feature = "perplexity")]
23pub mod perplexity;
24#[cfg(feature = "vertex")]
25pub mod vertex;
26#[cfg(feature = "xai")]
27pub mod xai;
28
29pub mod registry;
31
32pub mod prompt;
34
35use crate::config::accounts::AccountConfig;
36use crate::config::Config;
37use anyhow::{Context, Result};
38use async_trait::async_trait;
39use once_cell::sync::Lazy;
40
41#[async_trait]
42pub trait AIProvider: Send + Sync {
43 async fn generate_commit_message(
44 &self,
45 diff: &str,
46 context: Option<&str>,
47 full_gitmoji: bool,
48 config: &Config,
49 ) -> Result<String>;
50
51 async fn generate_commit_messages(
53 &self,
54 diff: &str,
55 context: Option<&str>,
56 full_gitmoji: bool,
57 config: &Config,
58 count: u8,
59 ) -> Result<Vec<String>> {
60 use futures::stream::StreamExt;
61
62 if count <= 1 {
63 match self
65 .generate_commit_message(diff, context, full_gitmoji, config)
66 .await
67 {
68 Ok(msg) => Ok(vec![msg]),
69 Err(e) => {
70 tracing::warn!("Failed to generate message: {}", e);
71 Ok(vec![])
72 }
73 }
74 } else {
75 let futures = (0..count)
77 .map(|_| self.generate_commit_message(diff, context, full_gitmoji, config));
78 let mut stream = futures::stream::FuturesUnordered::from_iter(futures);
79
80 let mut messages = Vec::with_capacity(count as usize);
81 while let Some(result) = stream.next().await {
82 match result {
83 Ok(msg) => messages.push(msg),
84 Err(e) => tracing::warn!("Failed to generate message: {}", e),
85 }
86 }
87 Ok(messages)
88 }
89 }
90
91 #[cfg(any(feature = "openai", feature = "xai"))]
93 async fn generate_pr_description(
94 &self,
95 commits: &[String],
96 diff: &str,
97 config: &Config,
98 ) -> Result<String> {
99 let commits_text = commits.join("\n");
100 let prompt = format!(
101 "Generate a professional pull request description based on the following commits:\n\n{}\n\nDiff:\n{}\n\nFormat the output as:\n## Summary\n## Changes\n## Testing\n## Breaking Changes\n\nKeep it concise and informative.",
102 commits_text, diff
103 );
104
105 let messages = vec![
106 async_openai::types::chat::ChatCompletionRequestSystemMessage::from(
107 "You are an expert at writing pull request descriptions.",
108 )
109 .into(),
110 async_openai::types::chat::ChatCompletionRequestUserMessage::from(prompt).into(),
111 ];
112
113 let request = async_openai::types::chat::CreateChatCompletionRequestArgs::default()
114 .model(
115 config
116 .model
117 .clone()
118 .unwrap_or_else(|| "gpt-3.5-turbo".to_string()),
119 )
120 .messages(messages)
121 .temperature(0.7)
122 .max_tokens(config.tokens_max_output.unwrap_or(1000) as u16)
123 .build()?;
124
125 let api_key = config
127 .api_key
128 .as_ref()
129 .context("API key not configured. Run: rco config set RCO_API_KEY=<your_key>")?;
130 let api_url = config
131 .api_url
132 .as_deref()
133 .unwrap_or("https://api.openai.com/v1");
134
135 let openai_config = async_openai::config::OpenAIConfig::new()
136 .with_api_key(api_key)
137 .with_api_base(api_url);
138
139 let client = async_openai::Client::with_config(openai_config);
140
141 let response = client.chat().create(request).await?;
142
143 let message = response
144 .choices
145 .first()
146 .and_then(|choice| choice.message.content.as_ref())
147 .context("AI returned an empty response")?
148 .trim()
149 .to_string();
150
151 Ok(message)
152 }
153
154 #[cfg(not(any(feature = "openai", feature = "xai")))]
156 async fn generate_pr_description(
157 &self,
158 _commits: &[String],
159 _diff: &str,
160 _config: &Config,
161 ) -> Result<String> {
162 anyhow::bail!(
163 "PR description generation requires the 'openai' or 'xai' feature to be enabled"
164 );
165 }
166}
167
168pub static PROVIDER_REGISTRY: Lazy<registry::ProviderRegistry> = Lazy::new(|| {
170 let reg = registry::ProviderRegistry::new();
171
172 #[cfg(feature = "openai")]
174 {
175 let _ = reg.register(Box::new(openai::OpenAICompatibleProvider::new()));
176 }
177
178 #[cfg(feature = "anthropic")]
180 {
181 let _ = reg.register(Box::new(anthropic::AnthropicProviderBuilder));
182 }
183
184 #[cfg(feature = "ollama")]
185 {
186 let _ = reg.register(Box::new(ollama::OllamaProviderBuilder));
187 }
188
189 #[cfg(feature = "gemini")]
190 {
191 let _ = reg.register(Box::new(gemini::GeminiProviderBuilder));
192 }
193
194 #[cfg(feature = "azure")]
195 {
196 let _ = reg.register(Box::new(azure::AzureProviderBuilder));
197 }
198
199 #[cfg(feature = "perplexity")]
200 {
201 let _ = reg.register(Box::new(perplexity::PerplexityProviderBuilder));
202 }
203
204 #[cfg(feature = "xai")]
205 {
206 let _ = reg.register(Box::new(xai::XAIProviderBuilder));
207 }
208
209 #[cfg(feature = "huggingface")]
210 {
211 let _ = reg.register(Box::new(huggingface::HuggingFaceProviderBuilder));
212 }
213
214 #[cfg(feature = "bedrock")]
215 {
216 let _ = reg.register(Box::new(bedrock::BedrockProviderBuilder));
217 }
218
219 #[cfg(feature = "vertex")]
220 {
221 let _ = reg.register(Box::new(vertex::VertexProviderBuilder));
222 }
223
224 #[cfg(feature = "mlx")]
225 {
226 let _ = reg.register(Box::new(mlx::MlxProviderBuilder));
227 }
228
229 #[cfg(feature = "nvidia")]
230 {
231 let _ = reg.register(Box::new(nvidia::NvidiaProviderBuilder));
232 }
233
234 #[cfg(feature = "flowise")]
235 {
236 let _ = reg.register(Box::new(flowise::FlowiseProviderBuilder));
237 }
238
239 reg
240});
241
242pub fn create_provider(config: &Config) -> Result<Box<dyn AIProvider>> {
244 let provider_name = config.ai_provider.as_deref().unwrap_or("openai");
245
246 if let Some(provider) = PROVIDER_REGISTRY.create(provider_name, config)? {
248 return Ok(provider);
249 }
250
251 let available: Vec<String> = PROVIDER_REGISTRY
253 .all()
254 .unwrap_or_default()
255 .iter()
256 .map(|e| {
257 let aliases = if e.aliases.is_empty() {
258 String::new()
259 } else {
260 format!(" ({})", e.aliases.join(", "))
261 };
262 format!("- {}{}", e.name, aliases)
263 })
264 .chain(std::iter::once(format!(
265 "- {} OpenAI-compatible providers (deepseek, groq, openrouter, etc.)",
266 PROVIDER_REGISTRY
267 .by_category(registry::ProviderCategory::OpenAICompatible)
268 .map_or(0, |v| v.len())
269 )))
270 .filter(|s| !s.contains("0 OpenAI-compatible"))
271 .collect();
272
273 if available.is_empty() {
274 anyhow::bail!(
275 "No AI provider features enabled. Please enable at least one provider feature:\n\
276 --features openai,anthropic,ollama,gemini,azure,perplexity,xai,huggingface,bedrock,vertex"
277 );
278 }
279
280 anyhow::bail!(
281 "Unsupported or disabled AI provider: {}\n\n\
282 Available providers (based on enabled features):\n{}\n\n\
283 Set with: rco config set RCO_AI_PROVIDER=<provider_name>",
284 provider_name,
285 available.join("\n")
286 )
287}
288
289#[allow(dead_code)]
290pub fn available_providers() -> Vec<&'static str> {
292 let mut providers = PROVIDER_REGISTRY
293 .all()
294 .unwrap_or_default()
295 .iter()
296 .flat_map(|e| std::iter::once(e.name).chain(e.aliases.iter().copied()))
297 .collect::<Vec<_>>();
298
299 #[cfg(feature = "openai")]
300 {
301 providers.extend_from_slice(&[
302 "deepseek",
306 "groq",
307 "openrouter",
308 "together",
309 "deepinfra",
310 "mistral",
311 "github-models",
312 "fireworks",
313 "moonshot",
314 "dashscope",
315 "perplexity",
316 "cohere",
320 "cohere-ai",
321 "ai21",
322 "ai21-labs",
323 "upstage",
324 "upstage-ai",
325 "solar",
326 "solar-pro",
327 "nebius",
331 "nebius-ai",
332 "nebius-studio",
333 "ovh",
334 "ovhcloud",
335 "ovh-ai",
336 "scaleway",
337 "scaleway-ai",
338 "friendli",
339 "friendli-ai",
340 "baseten",
341 "baseten-ai",
342 "chutes",
343 "chutes-ai",
344 "ionet",
345 "io-net",
346 "modelscope",
347 "requesty",
348 "morph",
349 "morph-labs",
350 "synthetic",
351 "nano-gpt",
352 "nanogpt",
353 "zenmux",
354 "v0",
355 "v0-vercel",
356 "iflowcn",
357 "venice",
358 "venice-ai",
359 "cortecs",
360 "cortecs-ai",
361 "kimi-coding",
362 "abacus",
363 "abacus-ai",
364 "bailing",
365 "fastrouter",
366 "inference",
367 "inference-net",
368 "submodel",
369 "zai",
370 "zai-coding",
371 "zhipu-coding",
372 "poe",
373 "poe-ai",
374 "cerebras",
375 "cerebras-ai",
376 "sambanova",
377 "sambanova-ai",
378 "novita",
379 "novita-ai",
380 "predibase",
381 "tensorops",
382 "hyperbolic",
383 "hyperbolic-ai",
384 "kluster",
385 "kluster-ai",
386 "lambda",
387 "lambda-labs",
388 "replicate",
389 "targon",
390 "corcel",
391 "cybernative",
392 "cybernative-ai",
393 "edgen",
394 "gigachat",
395 "gigachat-ai",
396 "hydra",
397 "hydra-ai",
398 "jina",
399 "jina-ai",
400 "lingyi",
401 "lingyiwanwu",
402 "monica",
403 "monica-ai",
404 "pollinations",
405 "pollinations-ai",
406 "rawechat",
407 "shuttleai",
408 "shuttle-ai",
409 "teknium",
410 "theb",
411 "theb-ai",
412 "tryleap",
413 "leap-ai",
414 "lmstudio",
418 "lm-studio",
419 "llamacpp",
420 "llama-cpp",
421 "kobold",
422 "koboldcpp",
423 "textgen",
424 "text-generation",
425 "tabby",
426 "siliconflow",
430 "silicon-flow",
431 "zhipu",
432 "zhipu-ai",
433 "bigmodel",
434 "minimax",
435 "minimax-ai",
436 "glm",
437 "chatglm",
438 "baichuan",
439 "01-ai",
440 "yi",
441 "helicone",
445 "helicone-ai",
446 "workers-ai",
447 "cloudflare-ai",
448 "cloudflare-gateway",
449 "vercel-ai",
450 "vercel-gateway",
451 "302ai",
455 "302-ai",
456 "sap-ai",
457 "sap-ai-core",
458 "aimlapi",
462 "ai-ml-api",
463 ]);
464 }
465
466 providers
467}
468
469#[allow(dead_code)]
471pub fn provider_info(provider: &str) -> Option<String> {
472 PROVIDER_REGISTRY.get(provider).map(|e| {
473 let aliases = if e.aliases.is_empty() {
474 String::new()
475 } else {
476 format!(" (aliases: {})", e.aliases.join(", "))
477 };
478 let model = e
479 .default_model
480 .map(|m| format!(", default model: {}", m))
481 .unwrap_or_default();
482 format!("{}{}{}", e.name, aliases, model)
483 })
484}
485
486
487
488#[allow(dead_code)]
490pub fn create_provider_for_account(
491 account: &AccountConfig,
492 config: &Config,
493) -> Result<Box<dyn AIProvider>> {
494 use crate::auth::token_storage;
495 use crate::config::secure_storage;
496
497 let provider = account.provider.to_lowercase();
498
499 let credentials = match &account.auth {
501 crate::config::accounts::AuthMethod::ApiKey { key_id } => {
502 token_storage::get_api_key_for_account(key_id)?
504 .or_else(|| secure_storage::get_secret(key_id).ok().flatten())
505 }
506 crate::config::accounts::AuthMethod::OAuth {
507 provider: _oauth_provider,
508 account_id,
509 } => {
510 token_storage::get_tokens_for_account(account_id)?.map(|t| t.access_token)
512 }
513 crate::config::accounts::AuthMethod::EnvVar { name } => std::env::var(name).ok(),
514 crate::config::accounts::AuthMethod::Bearer { token_id } => {
515 token_storage::get_bearer_token_for_account(token_id)?
517 .or_else(|| secure_storage::get_secret(token_id).ok().flatten())
518 }
519 };
520
521 match provider.as_str() {
522 #[cfg(feature = "openai")]
523 "openai" | "codex" => {
524 if let Some(key) = credentials.as_ref() {
525 Ok(Box::new(openai::OpenAIProvider::from_account(
526 account, key, config,
527 )?))
528 } else {
529 Ok(Box::new(openai::OpenAIProvider::new(config)?))
530 }
531 }
532 #[cfg(feature = "anthropic")]
533 "anthropic" | "claude" | "claude-code" => {
534 if let Some(key) = credentials.as_ref() {
535 Ok(Box::new(anthropic::AnthropicProvider::from_account(
536 account, key, config,
537 )?))
538 } else {
539 Ok(Box::new(anthropic::AnthropicProvider::new(config)?))
540 }
541 }
542 #[cfg(feature = "ollama")]
543 "ollama" => {
544 if let Some(key) = credentials.as_ref() {
545 Ok(Box::new(ollama::OllamaProvider::from_account(
546 account, key, config,
547 )?))
548 } else {
549 Ok(Box::new(ollama::OllamaProvider::new(config)?))
550 }
551 }
552 #[cfg(feature = "gemini")]
553 "gemini" => {
554 if let Some(key) = credentials.as_ref() {
555 Ok(Box::new(gemini::GeminiProvider::from_account(
556 account, key, config,
557 )?))
558 } else {
559 Ok(Box::new(gemini::GeminiProvider::new(config)?))
560 }
561 }
562 #[cfg(feature = "azure")]
563 "azure" | "azure-openai" => {
564 if let Some(key) = credentials.as_ref() {
565 Ok(Box::new(azure::AzureProvider::from_account(
566 account, key, config,
567 )?))
568 } else {
569 Ok(Box::new(azure::AzureProvider::new(config)?))
570 }
571 }
572 #[cfg(feature = "perplexity")]
573 "perplexity" => {
574 if let Some(key) = credentials.as_ref() {
575 Ok(Box::new(perplexity::PerplexityProvider::from_account(
576 account, key, config,
577 )?))
578 } else {
579 Ok(Box::new(perplexity::PerplexityProvider::new(config)?))
580 }
581 }
582 #[cfg(feature = "xai")]
583 "xai" | "grok" | "x-ai" => {
584 if let Some(key) = credentials.as_ref() {
585 Ok(Box::new(xai::XAIProvider::from_account(
586 account, key, config,
587 )?))
588 } else {
589 Ok(Box::new(xai::XAIProvider::new(config)?))
590 }
591 }
592 #[cfg(feature = "huggingface")]
593 "huggingface" | "hf" => {
594 if let Some(key) = credentials.as_ref() {
595 Ok(Box::new(huggingface::HuggingFaceProvider::from_account(
596 account, key, config,
597 )?))
598 } else {
599 Ok(Box::new(huggingface::HuggingFaceProvider::new(config)?))
600 }
601 }
602 #[cfg(feature = "bedrock")]
603 "bedrock" | "aws-bedrock" | "amazon-bedrock" => Ok(Box::new(
604 bedrock::BedrockProvider::from_account(account, "", config)?,
605 )),
606 #[cfg(feature = "vertex")]
607 "vertex" | "vertex-ai" | "google-vertex" | "gcp-vertex" => Ok(Box::new(
608 vertex::VertexProvider::from_account(account, "", config)?,
609 )),
610 #[cfg(feature = "mlx")]
611 "mlx" | "mlx-lm" | "apple-mlx" => {
612 if let Some(_key) = credentials.as_ref() {
613 Ok(Box::new(mlx::MlxProvider::from_account(
614 account, "", config,
615 )?))
616 } else {
617 Ok(Box::new(mlx::MlxProvider::new(config)?))
618 }
619 }
620 #[cfg(feature = "nvidia")]
621 "nvidia" | "nvidia-nim" | "nim" | "nvidia-ai" => {
622 if let Some(key) = credentials.as_ref() {
623 Ok(Box::new(nvidia::NvidiaProvider::from_account(
624 account, key, config,
625 )?))
626 } else {
627 Ok(Box::new(nvidia::NvidiaProvider::new(config)?))
628 }
629 }
630 #[cfg(feature = "flowise")]
631 "flowise" | "flowise-ai" => {
632 if let Some(_key) = credentials.as_ref() {
633 Ok(Box::new(flowise::FlowiseProvider::from_account(
634 account, "", config,
635 )?))
636 } else {
637 Ok(Box::new(flowise::FlowiseProvider::new(config)?))
638 }
639 }
640 _ => {
641 anyhow::bail!(
642 "Unsupported AI provider for account: {}\n\n\
643 Account provider: {}\n\
644 Supported providers: openai, anthropic, ollama, gemini, azure, perplexity, xai, huggingface, bedrock, vertex",
645 account.alias,
646 provider
647 );
648 }
649 }
650}