rusty_commit/providers/
nvidia.rs1use anyhow::{Context, Result};
14use async_trait::async_trait;
15use reqwest::Client;
16use serde::{Deserialize, Serialize};
17
18use super::prompt::build_prompt;
19use super::AIProvider;
20use crate::config::Config;
21use crate::utils::retry::retry_async;
22
23pub struct NvidiaProvider {
24 client: Client,
25 api_url: String,
26 api_key: String,
27 model: String,
28}
29
30#[derive(Serialize)]
31struct NvidiaRequest {
32 model: String,
33 messages: Vec<NvidiaMessage>,
34 max_tokens: i32,
35 temperature: f32,
36 top_p: f32,
37 stream: bool,
38}
39
40#[derive(Serialize, Deserialize, Clone)]
41struct NvidiaMessage {
42 role: String,
43 content: String,
44}
45
46#[derive(Deserialize)]
47struct NvidiaResponse {
48 choices: Vec<NvidiaChoice>,
49}
50
51#[derive(Deserialize)]
52struct NvidiaChoice {
53 message: NvidiaMessage,
54}
55
56impl NvidiaProvider {
57 pub fn new(config: &Config) -> Result<Self> {
58 let client = Client::new();
59 let api_key = config
60 .api_key
61 .as_ref()
62 .context("NVIDIA API key not configured.\nRun: rco config set RCO_API_KEY=<your_key>\nGet your API key from: https://build.nvidia.com")?;
63
64 let api_url = config
65 .api_url
66 .as_deref()
67 .unwrap_or("https://integrate.api.nvidia.com/v1")
68 .to_string();
69
70 let model = config
71 .model
72 .as_deref()
73 .unwrap_or("meta/llama-3.1-8b-instruct")
74 .to_string();
75
76 Ok(Self {
77 client,
78 api_url,
79 api_key: api_key.clone(),
80 model,
81 })
82 }
83
84 #[allow(dead_code)]
86 pub fn from_account(
87 account: &crate::config::accounts::AccountConfig,
88 api_key: &str,
89 config: &Config,
90 ) -> Result<Self> {
91 let client = Client::new();
92 let api_url = account
93 .api_url
94 .as_deref()
95 .or(config.api_url.as_deref())
96 .unwrap_or("https://integrate.api.nvidia.com/v1")
97 .to_string();
98
99 let model = account
100 .model
101 .as_deref()
102 .or(config.model.as_deref())
103 .unwrap_or("meta/llama-3.1-8b-instruct")
104 .to_string();
105
106 Ok(Self {
107 client,
108 api_url,
109 api_key: api_key.to_string(),
110 model,
111 })
112 }
113}
114
115#[async_trait]
116impl AIProvider for NvidiaProvider {
117 async fn generate_commit_message(
118 &self,
119 diff: &str,
120 context: Option<&str>,
121 full_gitmoji: bool,
122 config: &Config,
123 ) -> Result<String> {
124 let prompt = build_prompt(diff, context, config, full_gitmoji);
125
126 let messages = vec![
127 NvidiaMessage {
128 role: "system".to_string(),
129 content: "You are an expert at writing clear, concise git commit messages."
130 .to_string(),
131 },
132 NvidiaMessage {
133 role: "user".to_string(),
134 content: prompt,
135 },
136 ];
137
138 let request = NvidiaRequest {
139 model: self.model.clone(),
140 messages,
141 max_tokens: config.tokens_max_output.unwrap_or(500) as i32,
142 temperature: 0.7,
143 top_p: 0.7,
144 stream: false,
145 };
146
147 let nvidia_response: NvidiaResponse = retry_async(|| async {
148 let url = format!("{}/chat/completions", self.api_url);
149 let response = self
150 .client
151 .post(&url)
152 .header("Authorization", format!("Bearer {}", self.api_key))
153 .json(&request)
154 .send()
155 .await
156 .context("Failed to connect to NVIDIA NIM API")?;
157
158 if !response.status().is_success() {
159 let error_text = response.text().await?;
160 if error_text.contains("401") || error_text.contains("Unauthorized") {
161 return Err(anyhow::anyhow!(
162 "Invalid NVIDIA API key. Please check your API key configuration."
163 ));
164 }
165 return Err(anyhow::anyhow!("NVIDIA NIM API error: {}", error_text));
166 }
167
168 let nvidia_response: NvidiaResponse = response
169 .json()
170 .await
171 .context("Failed to parse NVIDIA NIM response")?;
172
173 Ok(nvidia_response)
174 })
175 .await
176 .context("Failed to generate commit message from NVIDIA NIM after retries")?;
177
178 let message = nvidia_response
179 .choices
180 .first()
181 .map(|choice| choice.message.content.trim().to_string())
182 .context("NVIDIA NIM returned an empty response")?;
183
184 Ok(message)
185 }
186}
187
188pub struct NvidiaProviderBuilder;
190
191impl super::registry::ProviderBuilder for NvidiaProviderBuilder {
192 fn name(&self) -> &'static str {
193 "nvidia"
194 }
195
196 fn aliases(&self) -> Vec<&'static str> {
197 vec!["nvidia-nim", "nim", "nvidia-ai"]
198 }
199
200 fn category(&self) -> super::registry::ProviderCategory {
201 super::registry::ProviderCategory::Cloud
202 }
203
204 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
205 Ok(Box::new(NvidiaProvider::new(config)?))
206 }
207
208 fn requires_api_key(&self) -> bool {
209 true
210 }
211
212 fn default_model(&self) -> Option<&'static str> {
213 Some("meta/llama-3.1-8b-instruct")
214 }
215}