Skip to main content

rusty_commit/providers/
nvidia.rs

1//! NVIDIA NIM Provider - Enterprise GPU Inference
2//!
3//! NVIDIA NIM (NVIDIA Inference Microservices) provides optimized inference
4//! for LLMs on NVIDIA GPUs. Supports both self-hosted and cloud deployments.
5//!
6//! Setup:
7//! 1. Get API key from: https://build.nvidia.com
8//! 2. Configure rco:
9//!    `rco config set RCO_AI_PROVIDER=nvidia RCO_API_KEY=<key> RCO_MODEL=meta/llama-3.1-8b-instruct`
10//!
11//! Docs: https://docs.nvidia.com/nim/
12
13use anyhow::{Context, Result};
14use async_trait::async_trait;
15use reqwest::Client;
16use serde::{Deserialize, Serialize};
17
18use super::{build_prompt, AIProvider};
19use crate::config::Config;
20use crate::utils::retry::retry_async;
21
22pub struct NvidiaProvider {
23    client: Client,
24    api_url: String,
25    api_key: String,
26    model: String,
27}
28
29#[derive(Serialize)]
30struct NvidiaRequest {
31    model: String,
32    messages: Vec<NvidiaMessage>,
33    max_tokens: i32,
34    temperature: f32,
35    top_p: f32,
36    stream: bool,
37}
38
39#[derive(Serialize, Deserialize, Clone)]
40struct NvidiaMessage {
41    role: String,
42    content: String,
43}
44
45#[derive(Deserialize)]
46struct NvidiaResponse {
47    choices: Vec<NvidiaChoice>,
48}
49
50#[derive(Deserialize)]
51struct NvidiaChoice {
52    message: NvidiaMessage,
53}
54
55impl NvidiaProvider {
56    pub fn new(config: &Config) -> Result<Self> {
57        let client = Client::new();
58        let api_key = config
59            .api_key
60            .as_ref()
61            .context("NVIDIA API key not configured.\nRun: rco config set RCO_API_KEY=<your_key>\nGet your API key from: https://build.nvidia.com")?;
62
63        let api_url = config
64            .api_url
65            .as_deref()
66            .unwrap_or("https://integrate.api.nvidia.com/v1")
67            .to_string();
68
69        let model = config
70            .model
71            .as_deref()
72            .unwrap_or("meta/llama-3.1-8b-instruct")
73            .to_string();
74
75        Ok(Self {
76            client,
77            api_url,
78            api_key: api_key.clone(),
79            model,
80        })
81    }
82
83    /// Create provider from account configuration
84    #[allow(dead_code)]
85    pub fn from_account(
86        account: &crate::config::accounts::AccountConfig,
87        api_key: &str,
88        config: &Config,
89    ) -> Result<Self> {
90        let client = Client::new();
91        let api_url = account
92            .api_url
93            .as_deref()
94            .or(config.api_url.as_deref())
95            .unwrap_or("https://integrate.api.nvidia.com/v1")
96            .to_string();
97
98        let model = account
99            .model
100            .as_deref()
101            .or(config.model.as_deref())
102            .unwrap_or("meta/llama-3.1-8b-instruct")
103            .to_string();
104
105        Ok(Self {
106            client,
107            api_url,
108            api_key: api_key.to_string(),
109            model,
110        })
111    }
112}
113
114#[async_trait]
115impl AIProvider for NvidiaProvider {
116    async fn generate_commit_message(
117        &self,
118        diff: &str,
119        context: Option<&str>,
120        full_gitmoji: bool,
121        config: &Config,
122    ) -> Result<String> {
123        let prompt = build_prompt(diff, context, config, full_gitmoji);
124
125        let messages = vec![
126            NvidiaMessage {
127                role: "system".to_string(),
128                content: "You are an expert at writing clear, concise git commit messages."
129                    .to_string(),
130            },
131            NvidiaMessage {
132                role: "user".to_string(),
133                content: prompt,
134            },
135        ];
136
137        let request = NvidiaRequest {
138            model: self.model.clone(),
139            messages,
140            max_tokens: config.tokens_max_output.unwrap_or(500) as i32,
141            temperature: 0.7,
142            top_p: 0.7,
143            stream: false,
144        };
145
146        let nvidia_response: NvidiaResponse = retry_async(|| async {
147            let url = format!("{}/chat/completions", self.api_url);
148            let response = self
149                .client
150                .post(&url)
151                .header("Authorization", format!("Bearer {}", self.api_key))
152                .json(&request)
153                .send()
154                .await
155                .context("Failed to connect to NVIDIA NIM API")?;
156
157            if !response.status().is_success() {
158                let error_text = response.text().await?;
159                if error_text.contains("401") || error_text.contains("Unauthorized") {
160                    return Err(anyhow::anyhow!(
161                        "Invalid NVIDIA API key. Please check your API key configuration."
162                    ));
163                }
164                return Err(anyhow::anyhow!("NVIDIA NIM API error: {}", error_text));
165            }
166
167            let nvidia_response: NvidiaResponse = response
168                .json()
169                .await
170                .context("Failed to parse NVIDIA NIM response")?;
171
172            Ok(nvidia_response)
173        })
174        .await
175        .context("Failed to generate commit message from NVIDIA NIM after retries")?;
176
177        let message = nvidia_response
178            .choices
179            .first()
180            .map(|choice| choice.message.content.trim().to_string())
181            .context("NVIDIA NIM returned an empty response")?;
182
183        Ok(message)
184    }
185}
186
187/// ProviderBuilder for NVIDIA NIM
188pub struct NvidiaProviderBuilder;
189
190impl super::registry::ProviderBuilder for NvidiaProviderBuilder {
191    fn name(&self) -> &'static str {
192        "nvidia"
193    }
194
195    fn aliases(&self) -> Vec<&'static str> {
196        vec!["nvidia-nim", "nim", "nvidia-ai"]
197    }
198
199    fn category(&self) -> super::registry::ProviderCategory {
200        super::registry::ProviderCategory::Cloud
201    }
202
203    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
204        Ok(Box::new(NvidiaProvider::new(config)?))
205    }
206
207    fn requires_api_key(&self) -> bool {
208        true
209    }
210
211    fn default_model(&self) -> Option<&'static str> {
212        Some("meta/llama-3.1-8b-instruct")
213    }
214}