rusty_commit/providers/
mlx.rs1use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15
16use super::prompt::build_prompt;
17use super::AIProvider;
18use crate::config::Config;
19use crate::utils::retry::retry_async;
20
21pub struct MlxProvider {
22 client: Client,
23 api_url: String,
24 model: String,
25}
26
27#[derive(Serialize)]
28struct MlxRequest {
29 model: String,
30 messages: Vec<MlxMessage>,
31 max_tokens: i32,
32 temperature: f32,
33 stream: bool,
34}
35
36#[derive(Serialize, Deserialize, Clone)]
37struct MlxMessage {
38 role: String,
39 content: String,
40}
41
42#[derive(Deserialize)]
43struct MlxResponse {
44 choices: Vec<MlxChoice>,
45}
46
47#[derive(Deserialize)]
48struct MlxChoice {
49 message: MlxMessage,
50}
51
52impl MlxProvider {
53 pub fn new(config: &Config) -> Result<Self> {
54 let client = Client::new();
55 let api_url = config
56 .api_url
57 .as_deref()
58 .unwrap_or("http://localhost:8080")
59 .to_string();
60 let model = config.model.as_deref().unwrap_or("default").to_string();
61
62 Ok(Self {
63 client,
64 api_url,
65 model,
66 })
67 }
68
69 #[allow(dead_code)]
71 pub fn from_account(
72 account: &crate::config::accounts::AccountConfig,
73 _api_key: &str,
74 config: &Config,
75 ) -> Result<Self> {
76 let client = Client::new();
77 let api_url = account
78 .api_url
79 .as_deref()
80 .or(config.api_url.as_deref())
81 .unwrap_or("http://localhost:8080")
82 .to_string();
83 let model = account
84 .model
85 .as_deref()
86 .or(config.model.as_deref())
87 .unwrap_or("default")
88 .to_string();
89
90 Ok(Self {
91 client,
92 api_url,
93 model,
94 })
95 }
96}
97
98#[async_trait]
99impl AIProvider for MlxProvider {
100 async fn generate_commit_message(
101 &self,
102 diff: &str,
103 context: Option<&str>,
104 full_gitmoji: bool,
105 config: &Config,
106 ) -> Result<String> {
107 let prompt = build_prompt(diff, context, config, full_gitmoji);
108
109 let messages = vec![
111 MlxMessage {
112 role: "system".to_string(),
113 content: "You are an expert at writing clear, concise git commit messages."
114 .to_string(),
115 },
116 MlxMessage {
117 role: "user".to_string(),
118 content: prompt,
119 },
120 ];
121
122 let request = MlxRequest {
123 model: self.model.clone(),
124 messages,
125 max_tokens: config.tokens_max_output.unwrap_or(500) as i32,
126 temperature: 0.7,
127 stream: false,
128 };
129
130 let mlx_response: MlxResponse = retry_async(|| async {
131 let url = format!("{}/v1/chat/completions", self.api_url);
132 let response = self
133 .client
134 .post(&url)
135 .json(&request)
136 .send()
137 .await
138 .context("Failed to connect to MLX server")?;
139
140 if !response.status().is_success() {
141 let error_text = response.text().await?;
142 return Err(anyhow::anyhow!("MLX API error: {}", error_text));
143 }
144
145 let mlx_response: MlxResponse = response
146 .json()
147 .await
148 .context("Failed to parse MLX response")?;
149
150 Ok(mlx_response)
151 })
152 .await
153 .context("Failed to generate commit message from MLX after retries")?;
154
155 let message = mlx_response
156 .choices
157 .first()
158 .map(|choice| choice.message.content.trim().to_string())
159 .context("MLX returned an empty response")?;
160
161 Ok(message)
162 }
163}
164
165pub struct MlxProviderBuilder;
167
168impl super::registry::ProviderBuilder for MlxProviderBuilder {
169 fn name(&self) -> &'static str {
170 "mlx"
171 }
172
173 fn aliases(&self) -> Vec<&'static str> {
174 vec!["mlx-lm", "apple-mlx"]
175 }
176
177 fn category(&self) -> super::registry::ProviderCategory {
178 super::registry::ProviderCategory::Local
179 }
180
181 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
182 Ok(Box::new(MlxProvider::new(config)?))
183 }
184
185 fn requires_api_key(&self) -> bool {
186 false
187 }
188
189 fn default_model(&self) -> Option<&'static str> {
190 Some("mlx-community/Llama-3.2-3B-Instruct-4bit")
191 }
192}