rusty_commit/providers/
nvidia.rs1use anyhow::{Context, Result};
14use async_trait::async_trait;
15use reqwest::Client;
16use serde::{Deserialize, Serialize};
17
18use super::{build_prompt, AIProvider};
19use crate::config::Config;
20use crate::utils::retry::retry_async;
21
22pub struct NvidiaProvider {
23 client: Client,
24 api_url: String,
25 api_key: String,
26 model: String,
27}
28
29#[derive(Serialize)]
30struct NvidiaRequest {
31 model: String,
32 messages: Vec<NvidiaMessage>,
33 max_tokens: i32,
34 temperature: f32,
35 top_p: f32,
36 stream: bool,
37}
38
39#[derive(Serialize, Deserialize, Clone)]
40struct NvidiaMessage {
41 role: String,
42 content: String,
43}
44
45#[derive(Deserialize)]
46struct NvidiaResponse {
47 choices: Vec<NvidiaChoice>,
48}
49
50#[derive(Deserialize)]
51struct NvidiaChoice {
52 message: NvidiaMessage,
53}
54
55impl NvidiaProvider {
56 pub fn new(config: &Config) -> Result<Self> {
57 let client = Client::new();
58 let api_key = config
59 .api_key
60 .as_ref()
61 .context("NVIDIA API key not configured.\nRun: rco config set RCO_API_KEY=<your_key>\nGet your API key from: https://build.nvidia.com")?;
62
63 let api_url = config
64 .api_url
65 .as_deref()
66 .unwrap_or("https://integrate.api.nvidia.com/v1")
67 .to_string();
68
69 let model = config
70 .model
71 .as_deref()
72 .unwrap_or("meta/llama-3.1-8b-instruct")
73 .to_string();
74
75 Ok(Self {
76 client,
77 api_url,
78 api_key: api_key.clone(),
79 model,
80 })
81 }
82
83 #[allow(dead_code)]
85 pub fn from_account(
86 account: &crate::config::accounts::AccountConfig,
87 api_key: &str,
88 config: &Config,
89 ) -> Result<Self> {
90 let client = Client::new();
91 let api_url = account
92 .api_url
93 .as_deref()
94 .or(config.api_url.as_deref())
95 .unwrap_or("https://integrate.api.nvidia.com/v1")
96 .to_string();
97
98 let model = account
99 .model
100 .as_deref()
101 .or(config.model.as_deref())
102 .unwrap_or("meta/llama-3.1-8b-instruct")
103 .to_string();
104
105 Ok(Self {
106 client,
107 api_url,
108 api_key: api_key.to_string(),
109 model,
110 })
111 }
112}
113
114#[async_trait]
115impl AIProvider for NvidiaProvider {
116 async fn generate_commit_message(
117 &self,
118 diff: &str,
119 context: Option<&str>,
120 full_gitmoji: bool,
121 config: &Config,
122 ) -> Result<String> {
123 let prompt = build_prompt(diff, context, config, full_gitmoji);
124
125 let messages = vec![
126 NvidiaMessage {
127 role: "system".to_string(),
128 content: "You are an expert at writing clear, concise git commit messages."
129 .to_string(),
130 },
131 NvidiaMessage {
132 role: "user".to_string(),
133 content: prompt,
134 },
135 ];
136
137 let request = NvidiaRequest {
138 model: self.model.clone(),
139 messages,
140 max_tokens: config.tokens_max_output.unwrap_or(500) as i32,
141 temperature: 0.7,
142 top_p: 0.7,
143 stream: false,
144 };
145
146 let nvidia_response: NvidiaResponse = retry_async(|| async {
147 let url = format!("{}/chat/completions", self.api_url);
148 let response = self
149 .client
150 .post(&url)
151 .header("Authorization", format!("Bearer {}", self.api_key))
152 .json(&request)
153 .send()
154 .await
155 .context("Failed to connect to NVIDIA NIM API")?;
156
157 if !response.status().is_success() {
158 let error_text = response.text().await?;
159 if error_text.contains("401") || error_text.contains("Unauthorized") {
160 return Err(anyhow::anyhow!(
161 "Invalid NVIDIA API key. Please check your API key configuration."
162 ));
163 }
164 return Err(anyhow::anyhow!("NVIDIA NIM API error: {}", error_text));
165 }
166
167 let nvidia_response: NvidiaResponse = response
168 .json()
169 .await
170 .context("Failed to parse NVIDIA NIM response")?;
171
172 Ok(nvidia_response)
173 })
174 .await
175 .context("Failed to generate commit message from NVIDIA NIM after retries")?;
176
177 let message = nvidia_response
178 .choices
179 .first()
180 .map(|choice| choice.message.content.trim().to_string())
181 .context("NVIDIA NIM returned an empty response")?;
182
183 Ok(message)
184 }
185}
186
187pub struct NvidiaProviderBuilder;
189
190impl super::registry::ProviderBuilder for NvidiaProviderBuilder {
191 fn name(&self) -> &'static str {
192 "nvidia"
193 }
194
195 fn aliases(&self) -> Vec<&'static str> {
196 vec!["nvidia-nim", "nim", "nvidia-ai"]
197 }
198
199 fn category(&self) -> super::registry::ProviderCategory {
200 super::registry::ProviderCategory::Cloud
201 }
202
203 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
204 Ok(Box::new(NvidiaProvider::new(config)?))
205 }
206
207 fn requires_api_key(&self) -> bool {
208 true
209 }
210
211 fn default_model(&self) -> Option<&'static str> {
212 Some("meta/llama-3.1-8b-instruct")
213 }
214}