rusty_commit/providers/
huggingface.rs1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::{build_prompt, AIProvider};
7use crate::config::accounts::AccountConfig;
8use crate::config::Config;
9use crate::utils::retry::retry_async;
10
11pub struct HuggingFaceProvider {
12 client: Client,
13 api_key: String,
14 model: String,
15 api_url: String,
16}
17
18#[derive(Serialize)]
19struct HFRequest {
20 model: String,
21 inputs: String,
22 parameters: HFParameters,
23 options: HFOptions,
24}
25
26#[derive(Serialize)]
27struct HFParameters {
28 temperature: Option<f32>,
29 max_new_tokens: Option<u32>,
30 return_full_text: bool,
31}
32
33#[derive(Serialize)]
34struct HFOptions {
35 use_cache: bool,
36}
37
38#[derive(Deserialize)]
39struct HFResponse {
40 generated_text: Option<String>,
41 error: Option<String>,
42}
43
44impl HuggingFaceProvider {
45 pub fn new(config: &Config) -> Result<Self> {
46 let api_key = config
47 .api_key
48 .as_ref()
49 .context("HuggingFace API key not configured.\nRun: rco config set RCO_API_KEY=<your_key>\nGet your token from: https://huggingface.co/settings/tokens")?
50 .clone();
51
52 let client = Client::new();
53 let model = config
54 .model
55 .as_deref()
56 .unwrap_or("mistralai/Mistral-7B-Instruct-v0.2")
57 .to_string();
58
59 let api_url = config
61 .api_url
62 .as_deref()
63 .unwrap_or("https://api-inference.huggingface.co");
64
65 Ok(Self {
66 client,
67 api_key,
68 model,
69 api_url: api_url.to_string(),
70 })
71 }
72
73 #[allow(dead_code)]
75 pub fn from_account(account: &AccountConfig, api_key: &str, config: &Config) -> Result<Self> {
76 let client = Client::new();
77 let model = account
78 .model
79 .as_deref()
80 .or(config.model.as_deref())
81 .unwrap_or("mistralai/Mistral-7B-Instruct-v0.2")
82 .to_string();
83
84 let api_url = account
85 .api_url
86 .as_deref()
87 .or(config.api_url.as_deref())
88 .unwrap_or("https://api-inference.huggingface.co")
89 .to_string();
90
91 Ok(Self {
92 client,
93 api_key: api_key.to_string(),
94 model,
95 api_url,
96 })
97 }
98}
99
100#[async_trait]
101impl AIProvider for HuggingFaceProvider {
102 async fn generate_commit_message(
103 &self,
104 diff: &str,
105 context: Option<&str>,
106 full_gitmoji: bool,
107 config: &Config,
108 ) -> Result<String> {
109 let prompt = build_prompt(diff, context, config, full_gitmoji);
111
112 let request = HFRequest {
113 model: self.model.clone(),
114 inputs: prompt,
115 parameters: HFParameters {
116 temperature: Some(0.7),
117 max_new_tokens: Some(config.tokens_max_output.unwrap_or(500)),
118 return_full_text: false,
119 },
120 options: HFOptions { use_cache: true },
121 };
122
123 let url = format!("{}/models/{}", self.api_url, self.model);
125
126 let hf_response: HFResponse = retry_async(|| async {
127 let response = self
128 .client
129 .post(&url)
130 .header("Authorization", format!("Bearer {}", self.api_key))
131 .header("Content-Type", "application/json")
132 .json(&request)
133 .send()
134 .await
135 .context("Failed to connect to HuggingFace")?;
136
137 if !response.status().is_success() {
138 let error_text = response.text().await?;
139 return Err(anyhow::anyhow!("HuggingFace API error: {}", error_text));
140 }
141
142 let hf_response: HFResponse = response
143 .json()
144 .await
145 .context("Failed to parse HuggingFace response")?;
146
147 Ok(hf_response)
148 })
149 .await
150 .context("Failed to generate commit message from HuggingFace after retries")?;
151
152 if let Some(error) = hf_response.error {
154 anyhow::bail!("HuggingFace inference error: {}", error);
155 }
156
157 let message = hf_response
158 .generated_text
159 .context("HuggingFace returned an empty response")?
160 .trim()
161 .to_string();
162
163 Ok(message)
166 }
167}
168
169pub struct HuggingFaceProviderBuilder;
171
172impl super::registry::ProviderBuilder for HuggingFaceProviderBuilder {
173 fn name(&self) -> &'static str {
174 "huggingface"
175 }
176
177 fn aliases(&self) -> Vec<&'static str> {
178 vec!["hf"]
179 }
180
181 fn category(&self) -> super::registry::ProviderCategory {
182 super::registry::ProviderCategory::Standard
183 }
184
185 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
186 Ok(Box::new(HuggingFaceProvider::new(config)?))
187 }
188
189 fn requires_api_key(&self) -> bool {
190 true
191 }
192
193 fn default_model(&self) -> Option<&'static str> {
194 Some("mistralai/Mistral-7B-Instruct-v0.2")
195 }
196}