Skip to main content

rusty_commit/providers/
mod.rs

1// AI Provider modules - conditionally compiled based on features
2#[cfg(feature = "anthropic")]
3pub mod anthropic;
4#[cfg(feature = "azure")]
5pub mod azure;
6#[cfg(feature = "bedrock")]
7pub mod bedrock;
8#[cfg(feature = "gemini")]
9pub mod gemini;
10#[cfg(feature = "huggingface")]
11pub mod huggingface;
12#[cfg(feature = "ollama")]
13pub mod ollama;
14#[cfg(feature = "openai")]
15pub mod openai;
16#[cfg(feature = "perplexity")]
17pub mod perplexity;
18#[cfg(feature = "vertex")]
19pub mod vertex;
20#[cfg(feature = "xai")]
21pub mod xai;
22
23// Provider registry for extensible provider management
24pub mod registry;
25
26use crate::config::accounts::AccountConfig;
27use crate::config::Config;
28use anyhow::{Context, Result};
29use async_trait::async_trait;
30use once_cell::sync::Lazy;
31
32#[async_trait]
33pub trait AIProvider: Send + Sync {
34    async fn generate_commit_message(
35        &self,
36        diff: &str,
37        context: Option<&str>,
38        full_gitmoji: bool,
39        config: &Config,
40    ) -> Result<String>;
41
42    /// Generate multiple commit message variations
43    async fn generate_commit_messages(
44        &self,
45        diff: &str,
46        context: Option<&str>,
47        full_gitmoji: bool,
48        config: &Config,
49        count: u8,
50    ) -> Result<Vec<String>> {
51        use futures::stream::StreamExt;
52
53        if count <= 1 {
54            // For single message, no parallelism needed
55            match self
56                .generate_commit_message(diff, context, full_gitmoji, config)
57                .await
58            {
59                Ok(msg) => Ok(vec![msg]),
60                Err(e) => {
61                    tracing::warn!("Failed to generate message: {}", e);
62                    Ok(vec![])
63                }
64            }
65        } else {
66            // Generate messages in parallel using FuturesUnordered
67            let futures = (0..count)
68                .map(|_| self.generate_commit_message(diff, context, full_gitmoji, config));
69            let mut stream = futures::stream::FuturesUnordered::from_iter(futures);
70
71            let mut messages = Vec::with_capacity(count as usize);
72            while let Some(result) = stream.next().await {
73                match result {
74                    Ok(msg) => messages.push(msg),
75                    Err(e) => tracing::warn!("Failed to generate message: {}", e),
76                }
77            }
78            Ok(messages)
79        }
80    }
81
82    /// Generate a PR description from commits
83    #[cfg(any(feature = "openai", feature = "xai"))]
84    async fn generate_pr_description(
85        &self,
86        commits: &[String],
87        diff: &str,
88        config: &Config,
89    ) -> Result<String> {
90        let commits_text = commits.join("\n");
91        let prompt = format!(
92            "Generate a professional pull request description based on the following commits:\n\n{}\n\nDiff:\n{}\n\nFormat the output as:\n## Summary\n## Changes\n## Testing\n## Breaking Changes\n\nKeep it concise and informative.",
93            commits_text, diff
94        );
95
96        let messages = vec![
97            async_openai::types::chat::ChatCompletionRequestSystemMessage::from(
98                "You are an expert at writing pull request descriptions.",
99            )
100            .into(),
101            async_openai::types::chat::ChatCompletionRequestUserMessage::from(prompt).into(),
102        ];
103
104        let request = async_openai::types::chat::CreateChatCompletionRequestArgs::default()
105            .model(
106                config
107                    .model
108                    .clone()
109                    .unwrap_or_else(|| "gpt-3.5-turbo".to_string()),
110            )
111            .messages(messages)
112            .temperature(0.7)
113            .max_tokens(config.tokens_max_output.unwrap_or(1000) as u16)
114            .build()?;
115
116        // Create a new client for this request
117        let api_key = config
118            .api_key
119            .as_ref()
120            .context("API key not configured. Run: rco config set RCO_API_KEY=<your_key>")?;
121        let api_url = config
122            .api_url
123            .as_deref()
124            .unwrap_or("https://api.openai.com/v1");
125
126        let openai_config = async_openai::config::OpenAIConfig::new()
127            .with_api_key(api_key)
128            .with_api_base(api_url);
129
130        let client = async_openai::Client::with_config(openai_config);
131
132        let response = client.chat().create(request).await?;
133
134        let message = response
135            .choices
136            .first()
137            .and_then(|choice| choice.message.content.as_ref())
138            .context("AI returned an empty response")?
139            .trim()
140            .to_string();
141
142        Ok(message)
143    }
144
145    /// Generate a PR description - stub when OpenAI/xAI features are disabled
146    #[cfg(not(any(feature = "openai", feature = "xai")))]
147    async fn generate_pr_description(
148        &self,
149        _commits: &[String],
150        _diff: &str,
151        _config: &Config,
152    ) -> Result<String> {
153        anyhow::bail!(
154            "PR description generation requires the 'openai' or 'xai' feature to be enabled"
155        );
156    }
157}
158
159/// Global provider registry - automatically populated based on enabled features
160pub static PROVIDER_REGISTRY: Lazy<registry::ProviderRegistry> = Lazy::new(|| {
161    let reg = registry::ProviderRegistry::new();
162
163    // Register OpenAI-compatible providers (require openai feature)
164    #[cfg(feature = "openai")]
165    {
166        let _ = reg.register(Box::new(openai::OpenAICompatibleProvider::new()));
167    }
168
169    // Register dedicated providers
170    #[cfg(feature = "anthropic")]
171    {
172        let _ = reg.register(Box::new(anthropic::AnthropicProviderBuilder));
173    }
174
175    #[cfg(feature = "ollama")]
176    {
177        let _ = reg.register(Box::new(ollama::OllamaProviderBuilder));
178    }
179
180    #[cfg(feature = "gemini")]
181    {
182        let _ = reg.register(Box::new(gemini::GeminiProviderBuilder));
183    }
184
185    #[cfg(feature = "azure")]
186    {
187        let _ = reg.register(Box::new(azure::AzureProviderBuilder));
188    }
189
190    #[cfg(feature = "perplexity")]
191    {
192        let _ = reg.register(Box::new(perplexity::PerplexityProviderBuilder));
193    }
194
195    #[cfg(feature = "xai")]
196    {
197        let _ = reg.register(Box::new(xai::XAIProviderBuilder));
198    }
199
200    #[cfg(feature = "huggingface")]
201    {
202        let _ = reg.register(Box::new(huggingface::HuggingFaceProviderBuilder));
203    }
204
205    #[cfg(feature = "bedrock")]
206    {
207        let _ = reg.register(Box::new(bedrock::BedrockProviderBuilder));
208    }
209
210    #[cfg(feature = "vertex")]
211    {
212        let _ = reg.register(Box::new(vertex::VertexProviderBuilder));
213    }
214
215    reg
216});
217
218/// Create an AI provider instance from configuration
219pub fn create_provider(config: &Config) -> Result<Box<dyn AIProvider>> {
220    let provider_name = config.ai_provider.as_deref().unwrap_or("openai");
221
222    // Try to create from registry
223    if let Some(provider) = PROVIDER_REGISTRY.create(provider_name, config)? {
224        return Ok(provider);
225    }
226
227    // Provider not found - build error message with available providers
228    let available: Vec<String> = PROVIDER_REGISTRY
229        .all()
230        .unwrap_or_default()
231        .iter()
232        .map(|e| {
233            let aliases = if e.aliases.is_empty() {
234                String::new()
235            } else {
236                format!(" ({})", e.aliases.join(", "))
237            };
238            format!("- {}{}", e.name, aliases)
239        })
240        .chain(std::iter::once(format!(
241            "- {} OpenAI-compatible providers (deepseek, groq, openrouter, etc.)",
242            PROVIDER_REGISTRY
243                .by_category(registry::ProviderCategory::OpenAICompatible)
244                .map_or(0, |v| v.len())
245        )))
246        .filter(|s| !s.contains("0 OpenAI-compatible"))
247        .collect();
248
249    if available.is_empty() {
250        anyhow::bail!(
251            "No AI provider features enabled. Please enable at least one provider feature:\n\
252             --features openai,anthropic,ollama,gemini,azure,perplexity,xai,huggingface,bedrock,vertex"
253        );
254    }
255
256    anyhow::bail!(
257        "Unsupported or disabled AI provider: {}\n\n\
258         Available providers (based on enabled features):\n{}\n\n\
259         Set with: rco config set RCO_AI_PROVIDER=<provider_name>",
260        provider_name,
261        available.join("\n")
262    )
263}
264
265#[allow(dead_code)]
266/// Get list of all available provider names
267pub fn available_providers() -> Vec<&'static str> {
268    let mut providers = PROVIDER_REGISTRY
269        .all()
270        .unwrap_or_default()
271        .iter()
272        .flat_map(|e| std::iter::once(e.name).chain(e.aliases.iter().copied()))
273        .collect::<Vec<_>>();
274
275    #[cfg(feature = "openai")]
276    {
277        providers.extend_from_slice(&[
278            "deepseek",
279            "groq",
280            "openrouter",
281            "together",
282            "deepinfra",
283            "mistral",
284            "github-models",
285            "fireworks",
286            "moonshot",
287            "dashscope",
288        ]);
289    }
290
291    providers
292}
293
294/// Get provider info for display
295#[allow(dead_code)]
296pub fn provider_info(provider: &str) -> Option<String> {
297    PROVIDER_REGISTRY.get(provider).map(|e| {
298        let aliases = if e.aliases.is_empty() {
299            String::new()
300        } else {
301            format!(" (aliases: {})", e.aliases.join(", "))
302        };
303        let model = e
304            .default_model
305            .map(|m| format!(", default model: {}", m))
306            .unwrap_or_default();
307        format!("{}{}{}", e.name, aliases, model)
308    })
309}
310
311/// Split the prompt into system and user parts for providers that support it
312pub fn split_prompt(
313    diff: &str,
314    context: Option<&str>,
315    config: &Config,
316    full_gitmoji: bool,
317) -> (String, String) {
318    let system_prompt = build_system_prompt(config, full_gitmoji);
319    let user_prompt = build_user_prompt(diff, context, full_gitmoji);
320    (system_prompt, user_prompt)
321}
322
323/// Build the system prompt part (role definition, rules)
324fn build_system_prompt(config: &Config, full_gitmoji: bool) -> String {
325    let mut prompt = String::new();
326
327    prompt.push_str("You are an expert at writing clear, concise git commit messages.\n\n");
328
329    // Core constraints
330    prompt.push_str("CONSTRAINTS:\n");
331    prompt.push_str("- Return ONLY the commit message, with no additional explanation, markdown formatting, or code blocks\n");
332    prompt.push_str(
333        "- Do not include any reasoning, thinking, analysis, or <thinking> tags in your response\n",
334    );
335    prompt.push_str(
336        "- If you cannot generate a meaningful commit message, return \"chore: update\"\n\n",
337    );
338
339    // Add style guidance from history if enabled
340    if config.learn_from_history.unwrap_or(false) {
341        if let Some(style_guidance) = get_style_guidance(config) {
342            prompt.push_str("REPO STYLE (learned from commit history):\n");
343            prompt.push_str(&style_guidance);
344            prompt.push('\n');
345        }
346    }
347
348    // Add locale if specified
349    if let Some(locale) = &config.language {
350        prompt.push_str(&format!(
351            "- Generate the commit message in {} language\n",
352            locale
353        ));
354    }
355
356    // Add commit type preference
357    let commit_type = config.commit_type.as_deref().unwrap_or("conventional");
358    match commit_type {
359        "conventional" => {
360            prompt.push_str("- Use conventional commit format: <type>(<scope>): <description>\n");
361            prompt.push_str(
362                "- Types: feat, fix, docs, style, refactor, perf, test, build, ci, chore\n",
363            );
364            if config.omit_scope.unwrap_or(false) {
365                prompt.push_str("- Omit the scope, use format: <type>: <description>\n");
366            }
367        }
368        "gitmoji" => {
369            if full_gitmoji {
370                prompt.push_str("- Use GitMoji format with full emoji specification from https://gitmoji.dev/\n");
371            } else {
372                prompt.push_str("- Use GitMoji format: <emoji> <type>: <description>\n");
373                prompt.push_str(
374                    "- Emojis: 🐛(fix), ✨(feat), 📝(docs), 🚀(deploy), ✅(test), ♻️(refactor)\n",
375                );
376            }
377        }
378        _ => {}
379    }
380
381    // Description requirements
382    let max_length = config.description_max_length.unwrap_or(100);
383    prompt.push_str(&format!(
384        "- Keep the description under {} characters\n",
385        max_length
386    ));
387
388    if config.description_capitalize.unwrap_or(true) {
389        prompt.push_str("- Capitalize the first letter of the description\n");
390    }
391
392    if !config.description_add_period.unwrap_or(false) {
393        prompt.push_str("- Do not end the description with a period\n");
394    }
395
396    prompt
397}
398
399/// Get style guidance from commit history analysis
400fn get_style_guidance(config: &Config) -> Option<String> {
401    use crate::git;
402    use crate::utils::commit_style::CommitStyleProfile;
403
404    // Get cached style profile or analyze fresh
405    if let Some(cached) = &config.style_profile {
406        // Use cached profile if available
407        return Some(cached.clone());
408    }
409
410    // Analyze recent commits
411    let count = config.history_commits_count.unwrap_or(10);
412
413    match git::get_recent_commit_messages(count) {
414        Ok(commits) => {
415            if commits.is_empty() {
416                return None;
417            }
418
419            let profile = CommitStyleProfile::analyze_from_commits(&commits);
420
421            if profile.is_empty() {
422                return None;
423            }
424
425            Some(profile.to_prompt_guidance())
426        }
427        Err(e) => {
428            tracing::warn!("Failed to get commit history for style analysis: {}", e);
429            None
430        }
431    }
432}
433
434/// Build the user prompt part (actual task + diff)
435fn build_user_prompt(diff: &str, context: Option<&str>, _full_gitmoji: bool) -> String {
436    let mut prompt = String::new();
437
438    // Add context if provided
439    if let Some(ctx) = context {
440        prompt.push_str(&format!("Additional context: {}\n\n", ctx));
441    }
442
443    prompt.push_str("Generate a commit message for the following git diff:\n");
444    prompt.push_str("```diff\n");
445    prompt.push_str(diff);
446    prompt.push_str("\n```\n");
447
448    prompt
449}
450
451/// Build the combined prompt for providers without system message support
452pub fn build_prompt(
453    diff: &str,
454    context: Option<&str>,
455    config: &Config,
456    full_gitmoji: bool,
457) -> String {
458    let (system, user) = split_prompt(diff, context, config, full_gitmoji);
459    format!("{}\\n\\n---\\n\\n{}", system, user)
460}
461
462/// Create an AI provider from an account configuration
463#[allow(dead_code)]
464pub fn create_provider_for_account(
465    account: &AccountConfig,
466    config: &Config,
467) -> Result<Box<dyn AIProvider>> {
468    use crate::auth::token_storage;
469    use crate::config::secure_storage;
470
471    let provider = account.provider.to_lowercase();
472
473    // Extract credentials from the account's auth method
474    let credentials = match &account.auth {
475        crate::config::accounts::AuthMethod::ApiKey { key_id } => {
476            // Get API key from secure storage using the account's key_id
477            token_storage::get_api_key_for_account(key_id)?
478                .or_else(|| secure_storage::get_secret(key_id).ok().flatten())
479        }
480        crate::config::accounts::AuthMethod::OAuth {
481            provider: _oauth_provider,
482            account_id,
483        } => {
484            // Get OAuth access token from secure storage
485            token_storage::get_tokens_for_account(account_id)?.map(|t| t.access_token)
486        }
487        crate::config::accounts::AuthMethod::EnvVar { name } => std::env::var(name).ok(),
488        crate::config::accounts::AuthMethod::Bearer { token_id } => {
489            // Get bearer token from secure storage
490            token_storage::get_bearer_token_for_account(token_id)?
491                .or_else(|| secure_storage::get_secret(token_id).ok().flatten())
492        }
493    };
494
495    match provider.as_str() {
496        #[cfg(feature = "openai")]
497        "openai" | "codex" => {
498            if let Some(key) = credentials.as_ref() {
499                Ok(Box::new(openai::OpenAIProvider::from_account(
500                    account, key, config,
501                )?))
502            } else {
503                Ok(Box::new(openai::OpenAIProvider::new(config)?))
504            }
505        }
506        #[cfg(feature = "anthropic")]
507        "anthropic" | "claude" | "claude-code" => {
508            if let Some(key) = credentials.as_ref() {
509                Ok(Box::new(anthropic::AnthropicProvider::from_account(
510                    account, key, config,
511                )?))
512            } else {
513                Ok(Box::new(anthropic::AnthropicProvider::new(config)?))
514            }
515        }
516        #[cfg(feature = "ollama")]
517        "ollama" => {
518            if let Some(key) = credentials.as_ref() {
519                Ok(Box::new(ollama::OllamaProvider::from_account(
520                    account, key, config,
521                )?))
522            } else {
523                Ok(Box::new(ollama::OllamaProvider::new(config)?))
524            }
525        }
526        #[cfg(feature = "gemini")]
527        "gemini" => {
528            if let Some(key) = credentials.as_ref() {
529                Ok(Box::new(gemini::GeminiProvider::from_account(
530                    account, key, config,
531                )?))
532            } else {
533                Ok(Box::new(gemini::GeminiProvider::new(config)?))
534            }
535        }
536        #[cfg(feature = "azure")]
537        "azure" | "azure-openai" => {
538            if let Some(key) = credentials.as_ref() {
539                Ok(Box::new(azure::AzureProvider::from_account(
540                    account, key, config,
541                )?))
542            } else {
543                Ok(Box::new(azure::AzureProvider::new(config)?))
544            }
545        }
546        #[cfg(feature = "perplexity")]
547        "perplexity" => {
548            if let Some(key) = credentials.as_ref() {
549                Ok(Box::new(perplexity::PerplexityProvider::from_account(
550                    account, key, config,
551                )?))
552            } else {
553                Ok(Box::new(perplexity::PerplexityProvider::new(config)?))
554            }
555        }
556        #[cfg(feature = "xai")]
557        "xai" | "grok" | "x-ai" => {
558            if let Some(key) = credentials.as_ref() {
559                Ok(Box::new(xai::XAIProvider::from_account(
560                    account, key, config,
561                )?))
562            } else {
563                Ok(Box::new(xai::XAIProvider::new(config)?))
564            }
565        }
566        #[cfg(feature = "huggingface")]
567        "huggingface" | "hf" => {
568            if let Some(key) = credentials.as_ref() {
569                Ok(Box::new(huggingface::HuggingFaceProvider::from_account(
570                    account, key, config,
571                )?))
572            } else {
573                Ok(Box::new(huggingface::HuggingFaceProvider::new(config)?))
574            }
575        }
576        #[cfg(feature = "bedrock")]
577        "bedrock" | "aws-bedrock" | "amazon-bedrock" => Ok(Box::new(
578            bedrock::BedrockProvider::from_account(account, "", config)?,
579        )),
580        #[cfg(feature = "vertex")]
581        "vertex" | "vertex-ai" | "google-vertex" | "gcp-vertex" => Ok(Box::new(
582            vertex::VertexProvider::from_account(account, "", config)?,
583        )),
584        _ => {
585            anyhow::bail!(
586                "Unsupported AI provider for account: {}\n\n\
587                 Account provider: {}\n\
588                 Supported providers: openai, anthropic, ollama, gemini, azure, perplexity, xai, huggingface, bedrock, vertex",
589                account.alias,
590                provider
591            );
592        }
593    }
594}