rusty_commit/providers/
mlx.rs1use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15
16use super::{build_prompt, AIProvider};
17use crate::config::Config;
18use crate::utils::retry::retry_async;
19
20pub struct MlxProvider {
21 client: Client,
22 api_url: String,
23 model: String,
24}
25
26#[derive(Serialize)]
27struct MlxRequest {
28 model: String,
29 messages: Vec<MlxMessage>,
30 max_tokens: i32,
31 temperature: f32,
32 stream: bool,
33}
34
35#[derive(Serialize, Deserialize, Clone)]
36struct MlxMessage {
37 role: String,
38 content: String,
39}
40
41#[derive(Deserialize)]
42struct MlxResponse {
43 choices: Vec<MlxChoice>,
44}
45
46#[derive(Deserialize)]
47struct MlxChoice {
48 message: MlxMessage,
49}
50
51impl MlxProvider {
52 pub fn new(config: &Config) -> Result<Self> {
53 let client = Client::new();
54 let api_url = config
55 .api_url
56 .as_deref()
57 .unwrap_or("http://localhost:8080")
58 .to_string();
59 let model = config
60 .model
61 .as_deref()
62 .unwrap_or("default")
63 .to_string();
64
65 Ok(Self {
66 client,
67 api_url,
68 model,
69 })
70 }
71
72 #[allow(dead_code)]
74 pub fn from_account(
75 account: &crate::config::accounts::AccountConfig,
76 _api_key: &str,
77 config: &Config,
78 ) -> Result<Self> {
79 let client = Client::new();
80 let api_url = account
81 .api_url
82 .as_deref()
83 .or(config.api_url.as_deref())
84 .unwrap_or("http://localhost:8080")
85 .to_string();
86 let model = account
87 .model
88 .as_deref()
89 .or(config.model.as_deref())
90 .unwrap_or("default")
91 .to_string();
92
93 Ok(Self {
94 client,
95 api_url,
96 model,
97 })
98 }
99}
100
101#[async_trait]
102impl AIProvider for MlxProvider {
103 async fn generate_commit_message(
104 &self,
105 diff: &str,
106 context: Option<&str>,
107 full_gitmoji: bool,
108 config: &Config,
109 ) -> Result<String> {
110 let prompt = build_prompt(diff, context, config, full_gitmoji);
111
112 let messages = vec![
114 MlxMessage {
115 role: "system".to_string(),
116 content: "You are an expert at writing clear, concise git commit messages.".to_string(),
117 },
118 MlxMessage {
119 role: "user".to_string(),
120 content: prompt,
121 },
122 ];
123
124 let request = MlxRequest {
125 model: self.model.clone(),
126 messages,
127 max_tokens: config.tokens_max_output.unwrap_or(500) as i32,
128 temperature: 0.7,
129 stream: false,
130 };
131
132 let mlx_response: MlxResponse = retry_async(|| async {
133 let url = format!("{}/v1/chat/completions", self.api_url);
134 let response = self
135 .client
136 .post(&url)
137 .json(&request)
138 .send()
139 .await
140 .context("Failed to connect to MLX server")?;
141
142 if !response.status().is_success() {
143 let error_text = response.text().await?;
144 return Err(anyhow::anyhow!("MLX API error: {}", error_text));
145 }
146
147 let mlx_response: MlxResponse = response
148 .json()
149 .await
150 .context("Failed to parse MLX response")?;
151
152 Ok(mlx_response)
153 })
154 .await
155 .context("Failed to generate commit message from MLX after retries")?;
156
157 let message = mlx_response
158 .choices
159 .first()
160 .map(|choice| choice.message.content.trim().to_string())
161 .context("MLX returned an empty response")?;
162
163 Ok(message)
164 }
165}
166
167pub struct MlxProviderBuilder;
169
170impl super::registry::ProviderBuilder for MlxProviderBuilder {
171 fn name(&self) -> &'static str {
172 "mlx"
173 }
174
175 fn aliases(&self) -> Vec<&'static str> {
176 vec!["mlx-lm", "apple-mlx"]
177 }
178
179 fn category(&self) -> super::registry::ProviderCategory {
180 super::registry::ProviderCategory::Local
181 }
182
183 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
184 Ok(Box::new(MlxProvider::new(config)?))
185 }
186
187 fn requires_api_key(&self) -> bool {
188 false
189 }
190
191 fn default_model(&self) -> Option<&'static str> {
192 Some("mlx-community/Llama-3.2-3B-Instruct-4bit")
193 }
194}