Skip to main content

rusty_commit/providers/
mod.rs

1// AI Provider modules - conditionally compiled based on features
2#[cfg(feature = "anthropic")]
3pub mod anthropic;
4#[cfg(feature = "azure")]
5pub mod azure;
6#[cfg(feature = "bedrock")]
7pub mod bedrock;
8#[cfg(feature = "flowise")]
9pub mod flowise;
10#[cfg(feature = "gemini")]
11pub mod gemini;
12#[cfg(feature = "huggingface")]
13pub mod huggingface;
14#[cfg(feature = "mlx")]
15pub mod mlx;
16#[cfg(feature = "nvidia")]
17pub mod nvidia;
18#[cfg(feature = "ollama")]
19pub mod ollama;
20#[cfg(feature = "openai")]
21pub mod openai;
22#[cfg(feature = "perplexity")]
23pub mod perplexity;
24#[cfg(feature = "vertex")]
25pub mod vertex;
26#[cfg(feature = "xai")]
27pub mod xai;
28
29// Provider registry for extensible provider management
30pub mod registry;
31
32// Prompt building utilities
33pub mod prompt;
34
35use crate::config::accounts::AccountConfig;
36use crate::config::Config;
37use anyhow::{Context, Result};
38use async_trait::async_trait;
39use once_cell::sync::Lazy;
40
41#[async_trait]
42pub trait AIProvider: Send + Sync {
43    async fn generate_commit_message(
44        &self,
45        diff: &str,
46        context: Option<&str>,
47        full_gitmoji: bool,
48        config: &Config,
49    ) -> Result<String>;
50
51    /// Generate multiple commit message variations
52    async fn generate_commit_messages(
53        &self,
54        diff: &str,
55        context: Option<&str>,
56        full_gitmoji: bool,
57        config: &Config,
58        count: u8,
59    ) -> Result<Vec<String>> {
60        use futures::stream::StreamExt;
61
62        if count <= 1 {
63            // For single message, no parallelism needed
64            match self
65                .generate_commit_message(diff, context, full_gitmoji, config)
66                .await
67            {
68                Ok(msg) => Ok(vec![msg]),
69                Err(e) => {
70                    tracing::warn!("Failed to generate message: {}", e);
71                    Ok(vec![])
72                }
73            }
74        } else {
75            // Generate messages in parallel using FuturesUnordered
76            let futures = (0..count)
77                .map(|_| self.generate_commit_message(diff, context, full_gitmoji, config));
78            let mut stream = futures::stream::FuturesUnordered::from_iter(futures);
79
80            let mut messages = Vec::with_capacity(count as usize);
81            while let Some(result) = stream.next().await {
82                match result {
83                    Ok(msg) => messages.push(msg),
84                    Err(e) => tracing::warn!("Failed to generate message: {}", e),
85                }
86            }
87            Ok(messages)
88        }
89    }
90
91    /// Generate a PR description from commits
92    #[cfg(any(feature = "openai", feature = "xai"))]
93    async fn generate_pr_description(
94        &self,
95        commits: &[String],
96        diff: &str,
97        config: &Config,
98    ) -> Result<String> {
99        let commits_text = commits.join("\n");
100        let prompt = format!(
101            "Generate a professional pull request description based on the following commits:\n\n{}\n\nDiff:\n{}\n\nFormat the output as:\n## Summary\n## Changes\n## Testing\n## Breaking Changes\n\nKeep it concise and informative.",
102            commits_text, diff
103        );
104
105        let messages = vec![
106            async_openai::types::chat::ChatCompletionRequestSystemMessage::from(
107                "You are an expert at writing pull request descriptions.",
108            )
109            .into(),
110            async_openai::types::chat::ChatCompletionRequestUserMessage::from(prompt).into(),
111        ];
112
113        let request = async_openai::types::chat::CreateChatCompletionRequestArgs::default()
114            .model(
115                config
116                    .model
117                    .clone()
118                    .unwrap_or_else(|| "gpt-3.5-turbo".to_string()),
119            )
120            .messages(messages)
121            .temperature(0.7)
122            .max_tokens(config.tokens_max_output.unwrap_or(1000) as u16)
123            .build()?;
124
125        // Create a new client for this request
126        let api_key = config
127            .api_key
128            .as_ref()
129            .context("API key not configured. Run: rco config set RCO_API_KEY=<your_key>")?;
130        let api_url = config
131            .api_url
132            .as_deref()
133            .unwrap_or("https://api.openai.com/v1");
134
135        let openai_config = async_openai::config::OpenAIConfig::new()
136            .with_api_key(api_key)
137            .with_api_base(api_url);
138
139        let client = async_openai::Client::with_config(openai_config);
140
141        let response = client.chat().create(request).await?;
142
143        let message = response
144            .choices
145            .first()
146            .and_then(|choice| choice.message.content.as_ref())
147            .context("AI returned an empty response")?
148            .trim()
149            .to_string();
150
151        Ok(message)
152    }
153
154    /// Generate a PR description - stub when OpenAI/xAI features are disabled
155    #[cfg(not(any(feature = "openai", feature = "xai")))]
156    async fn generate_pr_description(
157        &self,
158        _commits: &[String],
159        _diff: &str,
160        _config: &Config,
161    ) -> Result<String> {
162        anyhow::bail!(
163            "PR description generation requires the 'openai' or 'xai' feature to be enabled"
164        );
165    }
166}
167
168/// Global provider registry - automatically populated based on enabled features
169pub static PROVIDER_REGISTRY: Lazy<registry::ProviderRegistry> = Lazy::new(|| {
170    let reg = registry::ProviderRegistry::new();
171
172    // Register OpenAI-compatible providers (require openai feature)
173    #[cfg(feature = "openai")]
174    {
175        let _ = reg.register(Box::new(openai::OpenAICompatibleProvider::new()));
176    }
177
178    // Register dedicated providers
179    #[cfg(feature = "anthropic")]
180    {
181        let _ = reg.register(Box::new(anthropic::AnthropicProviderBuilder));
182    }
183
184    #[cfg(feature = "ollama")]
185    {
186        let _ = reg.register(Box::new(ollama::OllamaProviderBuilder));
187    }
188
189    #[cfg(feature = "gemini")]
190    {
191        let _ = reg.register(Box::new(gemini::GeminiProviderBuilder));
192    }
193
194    #[cfg(feature = "azure")]
195    {
196        let _ = reg.register(Box::new(azure::AzureProviderBuilder));
197    }
198
199    #[cfg(feature = "perplexity")]
200    {
201        let _ = reg.register(Box::new(perplexity::PerplexityProviderBuilder));
202    }
203
204    #[cfg(feature = "xai")]
205    {
206        let _ = reg.register(Box::new(xai::XAIProviderBuilder));
207    }
208
209    #[cfg(feature = "huggingface")]
210    {
211        let _ = reg.register(Box::new(huggingface::HuggingFaceProviderBuilder));
212    }
213
214    #[cfg(feature = "bedrock")]
215    {
216        let _ = reg.register(Box::new(bedrock::BedrockProviderBuilder));
217    }
218
219    #[cfg(feature = "vertex")]
220    {
221        let _ = reg.register(Box::new(vertex::VertexProviderBuilder));
222    }
223
224    #[cfg(feature = "mlx")]
225    {
226        let _ = reg.register(Box::new(mlx::MlxProviderBuilder));
227    }
228
229    #[cfg(feature = "nvidia")]
230    {
231        let _ = reg.register(Box::new(nvidia::NvidiaProviderBuilder));
232    }
233
234    #[cfg(feature = "flowise")]
235    {
236        let _ = reg.register(Box::new(flowise::FlowiseProviderBuilder));
237    }
238
239    reg
240});
241
242/// Create an AI provider instance from configuration
243pub fn create_provider(config: &Config) -> Result<Box<dyn AIProvider>> {
244    let provider_name = config.ai_provider.as_deref().unwrap_or("openai");
245
246    // Try to create from registry
247    if let Some(provider) = PROVIDER_REGISTRY.create(provider_name, config)? {
248        return Ok(provider);
249    }
250
251    // Provider not found - build error message with available providers
252    let available: Vec<String> = PROVIDER_REGISTRY
253        .all()
254        .unwrap_or_default()
255        .iter()
256        .map(|e| {
257            let aliases = if e.aliases.is_empty() {
258                String::new()
259            } else {
260                format!(" ({})", e.aliases.join(", "))
261            };
262            format!("- {}{}", e.name, aliases)
263        })
264        .chain(std::iter::once(format!(
265            "- {} OpenAI-compatible providers (deepseek, groq, openrouter, etc.)",
266            PROVIDER_REGISTRY
267                .by_category(registry::ProviderCategory::OpenAICompatible)
268                .map_or(0, |v| v.len())
269        )))
270        .filter(|s| !s.contains("0 OpenAI-compatible"))
271        .collect();
272
273    if available.is_empty() {
274        anyhow::bail!(
275            "No AI provider features enabled. Please enable at least one provider feature:\n\
276             --features openai,anthropic,ollama,gemini,azure,perplexity,xai,huggingface,bedrock,vertex"
277        );
278    }
279
280    anyhow::bail!(
281        "Unsupported or disabled AI provider: {}\n\n\
282         Available providers (based on enabled features):\n{}\n\n\
283         Set with: rco config set RCO_AI_PROVIDER=<provider_name>",
284        provider_name,
285        available.join("\n")
286    )
287}
288
289#[allow(dead_code)]
290/// Get list of all available provider names
291pub fn available_providers() -> Vec<&'static str> {
292    let mut providers = PROVIDER_REGISTRY
293        .all()
294        .unwrap_or_default()
295        .iter()
296        .flat_map(|e| std::iter::once(e.name).chain(e.aliases.iter().copied()))
297        .collect::<Vec<_>>();
298
299    #[cfg(feature = "openai")]
300    {
301        providers.extend_from_slice(&[
302            // ═════════════════════════════════════════════════════════════════
303            // Major Cloud Providers
304            // ═════════════════════════════════════════════════════════════════
305            "deepseek",
306            "groq",
307            "openrouter",
308            "together",
309            "deepinfra",
310            "mistral",
311            "github-models",
312            "fireworks",
313            "moonshot",
314            "dashscope",
315            "perplexity",
316            // ═════════════════════════════════════════════════════════════════
317            // Enterprise & Specialized
318            // ═════════════════════════════════════════════════════════════════
319            "cohere",
320            "cohere-ai",
321            "ai21",
322            "ai21-labs",
323            "upstage",
324            "upstage-ai",
325            "solar",
326            "solar-pro",
327            // ═════════════════════════════════════════════════════════════════
328            // GPU Cloud & Inference Providers
329            // ═════════════════════════════════════════════════════════════════
330            "nebius",
331            "nebius-ai",
332            "nebius-studio",
333            "ovh",
334            "ovhcloud",
335            "ovh-ai",
336            "scaleway",
337            "scaleway-ai",
338            "friendli",
339            "friendli-ai",
340            "baseten",
341            "baseten-ai",
342            "chutes",
343            "chutes-ai",
344            "ionet",
345            "io-net",
346            "modelscope",
347            "requesty",
348            "morph",
349            "morph-labs",
350            "synthetic",
351            "nano-gpt",
352            "nanogpt",
353            "zenmux",
354            "v0",
355            "v0-vercel",
356            "iflowcn",
357            "venice",
358            "venice-ai",
359            "cortecs",
360            "cortecs-ai",
361            "kimi-coding",
362            "abacus",
363            "abacus-ai",
364            "bailing",
365            "fastrouter",
366            "inference",
367            "inference-net",
368            "submodel",
369            "zai",
370            "zai-coding",
371            "zhipu-coding",
372            "poe",
373            "poe-ai",
374            "cerebras",
375            "cerebras-ai",
376            "sambanova",
377            "sambanova-ai",
378            "novita",
379            "novita-ai",
380            "predibase",
381            "tensorops",
382            "hyperbolic",
383            "hyperbolic-ai",
384            "kluster",
385            "kluster-ai",
386            "lambda",
387            "lambda-labs",
388            "replicate",
389            "targon",
390            "corcel",
391            "cybernative",
392            "cybernative-ai",
393            "edgen",
394            "gigachat",
395            "gigachat-ai",
396            "hydra",
397            "hydra-ai",
398            "jina",
399            "jina-ai",
400            "lingyi",
401            "lingyiwanwu",
402            "monica",
403            "monica-ai",
404            "pollinations",
405            "pollinations-ai",
406            "rawechat",
407            "shuttleai",
408            "shuttle-ai",
409            "teknium",
410            "theb",
411            "theb-ai",
412            "tryleap",
413            "leap-ai",
414            // ═════════════════════════════════════════════════════════════════
415            // Local/Self-hosted Providers
416            // ═════════════════════════════════════════════════════════════════
417            "lmstudio",
418            "lm-studio",
419            "llamacpp",
420            "llama-cpp",
421            "kobold",
422            "koboldcpp",
423            "textgen",
424            "text-generation",
425            "tabby",
426            // ═════════════════════════════════════════════════════════════════
427            // China-based Providers
428            // ═════════════════════════════════════════════════════════════════
429            "siliconflow",
430            "silicon-flow",
431            "zhipu",
432            "zhipu-ai",
433            "bigmodel",
434            "minimax",
435            "minimax-ai",
436            "glm",
437            "chatglm",
438            "baichuan",
439            "01-ai",
440            "yi",
441            // ═════════════════════════════════════════════════════════════════
442            // AI Gateway & Proxy Services
443            // ═════════════════════════════════════════════════════════════════
444            "helicone",
445            "helicone-ai",
446            "workers-ai",
447            "cloudflare-ai",
448            "cloudflare-gateway",
449            "vercel-ai",
450            "vercel-gateway",
451            // ═════════════════════════════════════════════════════════════════
452            // Specialized Providers
453            // ═════════════════════════════════════════════════════════════════
454            "302ai",
455            "302-ai",
456            "sap-ai",
457            "sap-ai-core",
458            // ═════════════════════════════════════════════════════════════════
459            // Additional Providers from OpenCommit
460            // ═════════════════════════════════════════════════════════════════
461            "aimlapi",
462            "ai-ml-api",
463        ]);
464    }
465
466    providers
467}
468
469/// Get provider info for display
470#[allow(dead_code)]
471pub fn provider_info(provider: &str) -> Option<String> {
472    PROVIDER_REGISTRY.get(provider).map(|e| {
473        let aliases = if e.aliases.is_empty() {
474            String::new()
475        } else {
476            format!(" (aliases: {})", e.aliases.join(", "))
477        };
478        let model = e
479            .default_model
480            .map(|m| format!(", default model: {}", m))
481            .unwrap_or_default();
482        format!("{}{}{}", e.name, aliases, model)
483    })
484}
485
486
487
488/// Create an AI provider from an account configuration
489#[allow(dead_code)]
490pub fn create_provider_for_account(
491    account: &AccountConfig,
492    config: &Config,
493) -> Result<Box<dyn AIProvider>> {
494    use crate::auth::token_storage;
495    use crate::config::secure_storage;
496
497    let provider = account.provider.to_lowercase();
498
499    // Extract credentials from the account's auth method
500    let credentials = match &account.auth {
501        crate::config::accounts::AuthMethod::ApiKey { key_id } => {
502            // Get API key from secure storage using the account's key_id
503            token_storage::get_api_key_for_account(key_id)?
504                .or_else(|| secure_storage::get_secret(key_id).ok().flatten())
505        }
506        crate::config::accounts::AuthMethod::OAuth {
507            provider: _oauth_provider,
508            account_id,
509        } => {
510            // Get OAuth access token from secure storage
511            token_storage::get_tokens_for_account(account_id)?.map(|t| t.access_token)
512        }
513        crate::config::accounts::AuthMethod::EnvVar { name } => std::env::var(name).ok(),
514        crate::config::accounts::AuthMethod::Bearer { token_id } => {
515            // Get bearer token from secure storage
516            token_storage::get_bearer_token_for_account(token_id)?
517                .or_else(|| secure_storage::get_secret(token_id).ok().flatten())
518        }
519    };
520
521    match provider.as_str() {
522        #[cfg(feature = "openai")]
523        "openai" | "codex" => {
524            if let Some(key) = credentials.as_ref() {
525                Ok(Box::new(openai::OpenAIProvider::from_account(
526                    account, key, config,
527                )?))
528            } else {
529                Ok(Box::new(openai::OpenAIProvider::new(config)?))
530            }
531        }
532        #[cfg(feature = "anthropic")]
533        "anthropic" | "claude" | "claude-code" => {
534            if let Some(key) = credentials.as_ref() {
535                Ok(Box::new(anthropic::AnthropicProvider::from_account(
536                    account, key, config,
537                )?))
538            } else {
539                Ok(Box::new(anthropic::AnthropicProvider::new(config)?))
540            }
541        }
542        #[cfg(feature = "ollama")]
543        "ollama" => {
544            if let Some(key) = credentials.as_ref() {
545                Ok(Box::new(ollama::OllamaProvider::from_account(
546                    account, key, config,
547                )?))
548            } else {
549                Ok(Box::new(ollama::OllamaProvider::new(config)?))
550            }
551        }
552        #[cfg(feature = "gemini")]
553        "gemini" => {
554            if let Some(key) = credentials.as_ref() {
555                Ok(Box::new(gemini::GeminiProvider::from_account(
556                    account, key, config,
557                )?))
558            } else {
559                Ok(Box::new(gemini::GeminiProvider::new(config)?))
560            }
561        }
562        #[cfg(feature = "azure")]
563        "azure" | "azure-openai" => {
564            if let Some(key) = credentials.as_ref() {
565                Ok(Box::new(azure::AzureProvider::from_account(
566                    account, key, config,
567                )?))
568            } else {
569                Ok(Box::new(azure::AzureProvider::new(config)?))
570            }
571        }
572        #[cfg(feature = "perplexity")]
573        "perplexity" => {
574            if let Some(key) = credentials.as_ref() {
575                Ok(Box::new(perplexity::PerplexityProvider::from_account(
576                    account, key, config,
577                )?))
578            } else {
579                Ok(Box::new(perplexity::PerplexityProvider::new(config)?))
580            }
581        }
582        #[cfg(feature = "xai")]
583        "xai" | "grok" | "x-ai" => {
584            if let Some(key) = credentials.as_ref() {
585                Ok(Box::new(xai::XAIProvider::from_account(
586                    account, key, config,
587                )?))
588            } else {
589                Ok(Box::new(xai::XAIProvider::new(config)?))
590            }
591        }
592        #[cfg(feature = "huggingface")]
593        "huggingface" | "hf" => {
594            if let Some(key) = credentials.as_ref() {
595                Ok(Box::new(huggingface::HuggingFaceProvider::from_account(
596                    account, key, config,
597                )?))
598            } else {
599                Ok(Box::new(huggingface::HuggingFaceProvider::new(config)?))
600            }
601        }
602        #[cfg(feature = "bedrock")]
603        "bedrock" | "aws-bedrock" | "amazon-bedrock" => Ok(Box::new(
604            bedrock::BedrockProvider::from_account(account, "", config)?,
605        )),
606        #[cfg(feature = "vertex")]
607        "vertex" | "vertex-ai" | "google-vertex" | "gcp-vertex" => Ok(Box::new(
608            vertex::VertexProvider::from_account(account, "", config)?,
609        )),
610        #[cfg(feature = "mlx")]
611        "mlx" | "mlx-lm" | "apple-mlx" => {
612            if let Some(_key) = credentials.as_ref() {
613                Ok(Box::new(mlx::MlxProvider::from_account(
614                    account, "", config,
615                )?))
616            } else {
617                Ok(Box::new(mlx::MlxProvider::new(config)?))
618            }
619        }
620        #[cfg(feature = "nvidia")]
621        "nvidia" | "nvidia-nim" | "nim" | "nvidia-ai" => {
622            if let Some(key) = credentials.as_ref() {
623                Ok(Box::new(nvidia::NvidiaProvider::from_account(
624                    account, key, config,
625                )?))
626            } else {
627                Ok(Box::new(nvidia::NvidiaProvider::new(config)?))
628            }
629        }
630        #[cfg(feature = "flowise")]
631        "flowise" | "flowise-ai" => {
632            if let Some(_key) = credentials.as_ref() {
633                Ok(Box::new(flowise::FlowiseProvider::from_account(
634                    account, "", config,
635                )?))
636            } else {
637                Ok(Box::new(flowise::FlowiseProvider::new(config)?))
638            }
639        }
640        _ => {
641            anyhow::bail!(
642                "Unsupported AI provider for account: {}\n\n\
643                 Account provider: {}\n\
644                 Supported providers: openai, anthropic, ollama, gemini, azure, perplexity, xai, huggingface, bedrock, vertex",
645                account.alias,
646                provider
647            );
648        }
649    }
650}