Skip to main content

rusty_commit/providers/
huggingface.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::{build_prompt, AIProvider};
7use crate::config::accounts::AccountConfig;
8use crate::config::Config;
9use crate::utils::retry::retry_async;
10
11pub struct HuggingFaceProvider {
12    client: Client,
13    api_key: String,
14    model: String,
15    api_url: String,
16}
17
18#[derive(Serialize)]
19struct HFRequest {
20    model: String,
21    inputs: String,
22    parameters: HFParameters,
23    options: HFOptions,
24}
25
26#[derive(Serialize)]
27struct HFParameters {
28    temperature: Option<f32>,
29    max_new_tokens: Option<u32>,
30    return_full_text: bool,
31}
32
33#[derive(Serialize)]
34struct HFOptions {
35    use_cache: bool,
36}
37
38#[derive(Deserialize)]
39struct HFResponse {
40    generated_text: Option<String>,
41    error: Option<String>,
42}
43
44impl HuggingFaceProvider {
45    pub fn new(config: &Config) -> Result<Self> {
46        let api_key = config
47            .api_key
48            .as_ref()
49            .context("HuggingFace API key not configured.\nRun: rco config set RCO_API_KEY=<your_key>\nGet your token from: https://huggingface.co/settings/tokens")?
50            .clone();
51
52        let client = Client::new();
53        let model = config
54            .model
55            .as_deref()
56            .unwrap_or("mistralai/Mistral-7B-Instruct-v0.2")
57            .to_string();
58
59        // Determine if this is an inference API call or dedicated endpoint
60        let api_url = config
61            .api_url
62            .as_deref()
63            .unwrap_or("https://api-inference.huggingface.co");
64
65        Ok(Self {
66            client,
67            api_key,
68            model,
69            api_url: api_url.to_string(),
70        })
71    }
72
73    /// Create provider from account configuration
74    #[allow(dead_code)]
75    pub fn from_account(account: &AccountConfig, api_key: &str, config: &Config) -> Result<Self> {
76        let client = Client::new();
77        let model = account
78            .model
79            .as_deref()
80            .or(config.model.as_deref())
81            .unwrap_or("mistralai/Mistral-7B-Instruct-v0.2")
82            .to_string();
83
84        let api_url = account
85            .api_url
86            .as_deref()
87            .or(config.api_url.as_deref())
88            .unwrap_or("https://api-inference.huggingface.co")
89            .to_string();
90
91        Ok(Self {
92            client,
93            api_key: api_key.to_string(),
94            model,
95            api_url,
96        })
97    }
98}
99
100#[async_trait]
101impl AIProvider for HuggingFaceProvider {
102    async fn generate_commit_message(
103        &self,
104        diff: &str,
105        context: Option<&str>,
106        full_gitmoji: bool,
107        config: &Config,
108    ) -> Result<String> {
109        // HuggingFace Inference API uses a single prompt (no system message support)
110        let prompt = build_prompt(diff, context, config, full_gitmoji);
111
112        let request = HFRequest {
113            model: self.model.clone(),
114            inputs: prompt,
115            parameters: HFParameters {
116                temperature: Some(0.7),
117                max_new_tokens: Some(config.tokens_max_output.unwrap_or(500)),
118                return_full_text: false,
119            },
120            options: HFOptions { use_cache: true },
121        };
122
123        // Use Inference API endpoint
124        let url = format!("{}/models/{}", self.api_url, self.model);
125
126        let hf_response: HFResponse = retry_async(|| async {
127            let response = self
128                .client
129                .post(&url)
130                .header("Authorization", format!("Bearer {}", self.api_key))
131                .header("Content-Type", "application/json")
132                .json(&request)
133                .send()
134                .await
135                .context("Failed to connect to HuggingFace")?;
136
137            if !response.status().is_success() {
138                let error_text = response.text().await?;
139                return Err(anyhow::anyhow!("HuggingFace API error: {}", error_text));
140            }
141
142            let hf_response: HFResponse = response
143                .json()
144                .await
145                .context("Failed to parse HuggingFace response")?;
146
147            Ok(hf_response)
148        })
149        .await
150        .context("Failed to generate commit message from HuggingFace after retries")?;
151
152        // Handle error response
153        if let Some(error) = hf_response.error {
154            anyhow::bail!("HuggingFace inference error: {}", error);
155        }
156
157        let message = hf_response
158            .generated_text
159            .context("HuggingFace returned an empty response")?
160            .trim()
161            .to_string();
162
163        // Clean up the response - remove the prompt if it's included
164        // HF models often return the full prompt + completion
165        Ok(message)
166    }
167}
168
169/// ProviderBuilder for HuggingFace
170pub struct HuggingFaceProviderBuilder;
171
172impl super::registry::ProviderBuilder for HuggingFaceProviderBuilder {
173    fn name(&self) -> &'static str {
174        "huggingface"
175    }
176
177    fn aliases(&self) -> Vec<&'static str> {
178        vec!["hf"]
179    }
180
181    fn category(&self) -> super::registry::ProviderCategory {
182        super::registry::ProviderCategory::Standard
183    }
184
185    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
186        Ok(Box::new(HuggingFaceProvider::new(config)?))
187    }
188
189    fn requires_api_key(&self) -> bool {
190        true
191    }
192
193    fn default_model(&self) -> Option<&'static str> {
194        Some("mistralai/Mistral-7B-Instruct-v0.2")
195    }
196}