rusty_commit/providers/
mlx.rs1use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15
16use super::{build_prompt, AIProvider};
17use crate::config::Config;
18use crate::utils::retry::retry_async;
19
20pub struct MlxProvider {
21 client: Client,
22 api_url: String,
23 model: String,
24}
25
26#[derive(Serialize)]
27struct MlxRequest {
28 model: String,
29 messages: Vec<MlxMessage>,
30 max_tokens: i32,
31 temperature: f32,
32 stream: bool,
33}
34
35#[derive(Serialize, Deserialize, Clone)]
36struct MlxMessage {
37 role: String,
38 content: String,
39}
40
41#[derive(Deserialize)]
42struct MlxResponse {
43 choices: Vec<MlxChoice>,
44}
45
46#[derive(Deserialize)]
47struct MlxChoice {
48 message: MlxMessage,
49}
50
51impl MlxProvider {
52 pub fn new(config: &Config) -> Result<Self> {
53 let client = Client::new();
54 let api_url = config
55 .api_url
56 .as_deref()
57 .unwrap_or("http://localhost:8080")
58 .to_string();
59 let model = config.model.as_deref().unwrap_or("default").to_string();
60
61 Ok(Self {
62 client,
63 api_url,
64 model,
65 })
66 }
67
68 #[allow(dead_code)]
70 pub fn from_account(
71 account: &crate::config::accounts::AccountConfig,
72 _api_key: &str,
73 config: &Config,
74 ) -> Result<Self> {
75 let client = Client::new();
76 let api_url = account
77 .api_url
78 .as_deref()
79 .or(config.api_url.as_deref())
80 .unwrap_or("http://localhost:8080")
81 .to_string();
82 let model = account
83 .model
84 .as_deref()
85 .or(config.model.as_deref())
86 .unwrap_or("default")
87 .to_string();
88
89 Ok(Self {
90 client,
91 api_url,
92 model,
93 })
94 }
95}
96
97#[async_trait]
98impl AIProvider for MlxProvider {
99 async fn generate_commit_message(
100 &self,
101 diff: &str,
102 context: Option<&str>,
103 full_gitmoji: bool,
104 config: &Config,
105 ) -> Result<String> {
106 let prompt = build_prompt(diff, context, config, full_gitmoji);
107
108 let messages = vec![
110 MlxMessage {
111 role: "system".to_string(),
112 content: "You are an expert at writing clear, concise git commit messages."
113 .to_string(),
114 },
115 MlxMessage {
116 role: "user".to_string(),
117 content: prompt,
118 },
119 ];
120
121 let request = MlxRequest {
122 model: self.model.clone(),
123 messages,
124 max_tokens: config.tokens_max_output.unwrap_or(500) as i32,
125 temperature: 0.7,
126 stream: false,
127 };
128
129 let mlx_response: MlxResponse = retry_async(|| async {
130 let url = format!("{}/v1/chat/completions", self.api_url);
131 let response = self
132 .client
133 .post(&url)
134 .json(&request)
135 .send()
136 .await
137 .context("Failed to connect to MLX server")?;
138
139 if !response.status().is_success() {
140 let error_text = response.text().await?;
141 return Err(anyhow::anyhow!("MLX API error: {}", error_text));
142 }
143
144 let mlx_response: MlxResponse = response
145 .json()
146 .await
147 .context("Failed to parse MLX response")?;
148
149 Ok(mlx_response)
150 })
151 .await
152 .context("Failed to generate commit message from MLX after retries")?;
153
154 let message = mlx_response
155 .choices
156 .first()
157 .map(|choice| choice.message.content.trim().to_string())
158 .context("MLX returned an empty response")?;
159
160 Ok(message)
161 }
162}
163
164pub struct MlxProviderBuilder;
166
167impl super::registry::ProviderBuilder for MlxProviderBuilder {
168 fn name(&self) -> &'static str {
169 "mlx"
170 }
171
172 fn aliases(&self) -> Vec<&'static str> {
173 vec!["mlx-lm", "apple-mlx"]
174 }
175
176 fn category(&self) -> super::registry::ProviderCategory {
177 super::registry::ProviderCategory::Local
178 }
179
180 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
181 Ok(Box::new(MlxProvider::new(config)?))
182 }
183
184 fn requires_api_key(&self) -> bool {
185 false
186 }
187
188 fn default_model(&self) -> Option<&'static str> {
189 Some("mlx-community/Llama-3.2-3B-Instruct-4bit")
190 }
191}