Skip to main content

rusty_commit/providers/
huggingface.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::prompt::build_prompt;
7use super::AIProvider;
8use crate::config::accounts::AccountConfig;
9use crate::config::Config;
10use crate::utils::retry::retry_async;
11
12pub struct HuggingFaceProvider {
13    client: Client,
14    api_key: String,
15    model: String,
16    api_url: String,
17}
18
19#[derive(Serialize)]
20struct HFRequest {
21    model: String,
22    inputs: String,
23    parameters: HFParameters,
24    options: HFOptions,
25}
26
27#[derive(Serialize)]
28struct HFParameters {
29    temperature: Option<f32>,
30    max_new_tokens: Option<u32>,
31    return_full_text: bool,
32}
33
34#[derive(Serialize)]
35struct HFOptions {
36    use_cache: bool,
37}
38
39#[derive(Deserialize)]
40struct HFResponse {
41    generated_text: Option<String>,
42    error: Option<String>,
43}
44
45impl HuggingFaceProvider {
46    pub fn new(config: &Config) -> Result<Self> {
47        let api_key = config
48            .api_key
49            .as_ref()
50            .context("HuggingFace API key not configured.\nRun: rco config set RCO_API_KEY=<your_key>\nGet your token from: https://huggingface.co/settings/tokens")?
51            .clone();
52
53        let client = Client::new();
54        let model = config
55            .model
56            .as_deref()
57            .unwrap_or("mistralai/Mistral-7B-Instruct-v0.2")
58            .to_string();
59
60        // Determine if this is an inference API call or dedicated endpoint
61        let api_url = config
62            .api_url
63            .as_deref()
64            .unwrap_or("https://api-inference.huggingface.co");
65
66        Ok(Self {
67            client,
68            api_key,
69            model,
70            api_url: api_url.to_string(),
71        })
72    }
73
74    /// Create provider from account configuration
75    #[allow(dead_code)]
76    pub fn from_account(account: &AccountConfig, api_key: &str, config: &Config) -> Result<Self> {
77        let client = Client::new();
78        let model = account
79            .model
80            .as_deref()
81            .or(config.model.as_deref())
82            .unwrap_or("mistralai/Mistral-7B-Instruct-v0.2")
83            .to_string();
84
85        let api_url = account
86            .api_url
87            .as_deref()
88            .or(config.api_url.as_deref())
89            .unwrap_or("https://api-inference.huggingface.co")
90            .to_string();
91
92        Ok(Self {
93            client,
94            api_key: api_key.to_string(),
95            model,
96            api_url,
97        })
98    }
99}
100
101#[async_trait]
102impl AIProvider for HuggingFaceProvider {
103    async fn generate_commit_message(
104        &self,
105        diff: &str,
106        context: Option<&str>,
107        full_gitmoji: bool,
108        config: &Config,
109    ) -> Result<String> {
110        // HuggingFace Inference API uses a single prompt (no system message support)
111        let prompt = build_prompt(diff, context, config, full_gitmoji);
112
113        let request = HFRequest {
114            model: self.model.clone(),
115            inputs: prompt,
116            parameters: HFParameters {
117                temperature: Some(0.7),
118                max_new_tokens: Some(config.tokens_max_output.unwrap_or(500)),
119                return_full_text: false,
120            },
121            options: HFOptions { use_cache: true },
122        };
123
124        // Use Inference API endpoint
125        let url = format!("{}/models/{}", self.api_url, self.model);
126
127        let hf_response: HFResponse = retry_async(|| async {
128            let response = self
129                .client
130                .post(&url)
131                .header("Authorization", format!("Bearer {}", self.api_key))
132                .header("Content-Type", "application/json")
133                .json(&request)
134                .send()
135                .await
136                .context("Failed to connect to HuggingFace")?;
137
138            if !response.status().is_success() {
139                let error_text = response.text().await?;
140                return Err(anyhow::anyhow!("HuggingFace API error: {}", error_text));
141            }
142
143            let hf_response: HFResponse = response
144                .json()
145                .await
146                .context("Failed to parse HuggingFace response")?;
147
148            Ok(hf_response)
149        })
150        .await
151        .context("Failed to generate commit message from HuggingFace after retries")?;
152
153        // Handle error response
154        if let Some(error) = hf_response.error {
155            anyhow::bail!("HuggingFace inference error: {}", error);
156        }
157
158        let message = hf_response
159            .generated_text
160            .context("HuggingFace returned an empty response")?
161            .trim()
162            .to_string();
163
164        // Clean up the response - remove the prompt if it's included
165        // HF models often return the full prompt + completion
166        Ok(message)
167    }
168}
169
170/// ProviderBuilder for HuggingFace
171pub struct HuggingFaceProviderBuilder;
172
173impl super::registry::ProviderBuilder for HuggingFaceProviderBuilder {
174    fn name(&self) -> &'static str {
175        "huggingface"
176    }
177
178    fn aliases(&self) -> Vec<&'static str> {
179        vec!["hf"]
180    }
181
182    fn category(&self) -> super::registry::ProviderCategory {
183        super::registry::ProviderCategory::Standard
184    }
185
186    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
187        Ok(Box::new(HuggingFaceProvider::new(config)?))
188    }
189
190    fn requires_api_key(&self) -> bool {
191        true
192    }
193
194    fn default_model(&self) -> Option<&'static str> {
195        Some("mistralai/Mistral-7B-Instruct-v0.2")
196    }
197}