rusty_commit/providers/
huggingface.rs1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::prompt::build_prompt;
7use super::AIProvider;
8use crate::config::accounts::AccountConfig;
9use crate::config::Config;
10use crate::utils::retry::retry_async;
11
12pub struct HuggingFaceProvider {
13 client: Client,
14 api_key: String,
15 model: String,
16 api_url: String,
17}
18
19#[derive(Serialize)]
20struct HFRequest {
21 model: String,
22 inputs: String,
23 parameters: HFParameters,
24 options: HFOptions,
25}
26
27#[derive(Serialize)]
28struct HFParameters {
29 temperature: Option<f32>,
30 max_new_tokens: Option<u32>,
31 return_full_text: bool,
32}
33
34#[derive(Serialize)]
35struct HFOptions {
36 use_cache: bool,
37}
38
39#[derive(Deserialize)]
40struct HFResponse {
41 generated_text: Option<String>,
42 error: Option<String>,
43}
44
45impl HuggingFaceProvider {
46 pub fn new(config: &Config) -> Result<Self> {
47 let api_key = config
48 .api_key
49 .as_ref()
50 .context("HuggingFace API key not configured.\nRun: rco config set RCO_API_KEY=<your_key>\nGet your token from: https://huggingface.co/settings/tokens")?
51 .clone();
52
53 let client = Client::new();
54 let model = config
55 .model
56 .as_deref()
57 .unwrap_or("mistralai/Mistral-7B-Instruct-v0.2")
58 .to_string();
59
60 let api_url = config
62 .api_url
63 .as_deref()
64 .unwrap_or("https://api-inference.huggingface.co");
65
66 Ok(Self {
67 client,
68 api_key,
69 model,
70 api_url: api_url.to_string(),
71 })
72 }
73
74 #[allow(dead_code)]
76 pub fn from_account(account: &AccountConfig, api_key: &str, config: &Config) -> Result<Self> {
77 let client = Client::new();
78 let model = account
79 .model
80 .as_deref()
81 .or(config.model.as_deref())
82 .unwrap_or("mistralai/Mistral-7B-Instruct-v0.2")
83 .to_string();
84
85 let api_url = account
86 .api_url
87 .as_deref()
88 .or(config.api_url.as_deref())
89 .unwrap_or("https://api-inference.huggingface.co")
90 .to_string();
91
92 Ok(Self {
93 client,
94 api_key: api_key.to_string(),
95 model,
96 api_url,
97 })
98 }
99}
100
101#[async_trait]
102impl AIProvider for HuggingFaceProvider {
103 async fn generate_commit_message(
104 &self,
105 diff: &str,
106 context: Option<&str>,
107 full_gitmoji: bool,
108 config: &Config,
109 ) -> Result<String> {
110 let prompt = build_prompt(diff, context, config, full_gitmoji);
112
113 let request = HFRequest {
114 model: self.model.clone(),
115 inputs: prompt,
116 parameters: HFParameters {
117 temperature: Some(0.7),
118 max_new_tokens: Some(config.tokens_max_output.unwrap_or(500)),
119 return_full_text: false,
120 },
121 options: HFOptions { use_cache: true },
122 };
123
124 let url = format!("{}/models/{}", self.api_url, self.model);
126
127 let hf_response: HFResponse = retry_async(|| async {
128 let response = self
129 .client
130 .post(&url)
131 .header("Authorization", format!("Bearer {}", self.api_key))
132 .header("Content-Type", "application/json")
133 .json(&request)
134 .send()
135 .await
136 .context("Failed to connect to HuggingFace")?;
137
138 if !response.status().is_success() {
139 let error_text = response.text().await?;
140 return Err(anyhow::anyhow!("HuggingFace API error: {}", error_text));
141 }
142
143 let hf_response: HFResponse = response
144 .json()
145 .await
146 .context("Failed to parse HuggingFace response")?;
147
148 Ok(hf_response)
149 })
150 .await
151 .context("Failed to generate commit message from HuggingFace after retries")?;
152
153 if let Some(error) = hf_response.error {
155 anyhow::bail!("HuggingFace inference error: {}", error);
156 }
157
158 let message = hf_response
159 .generated_text
160 .context("HuggingFace returned an empty response")?
161 .trim()
162 .to_string();
163
164 Ok(message)
167 }
168}
169
170pub struct HuggingFaceProviderBuilder;
172
173impl super::registry::ProviderBuilder for HuggingFaceProviderBuilder {
174 fn name(&self) -> &'static str {
175 "huggingface"
176 }
177
178 fn aliases(&self) -> Vec<&'static str> {
179 vec!["hf"]
180 }
181
182 fn category(&self) -> super::registry::ProviderCategory {
183 super::registry::ProviderCategory::Standard
184 }
185
186 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
187 Ok(Box::new(HuggingFaceProvider::new(config)?))
188 }
189
190 fn requires_api_key(&self) -> bool {
191 true
192 }
193
194 fn default_model(&self) -> Option<&'static str> {
195 Some("mistralai/Mistral-7B-Instruct-v0.2")
196 }
197}