Skip to main content

rusty_commit/providers/
mlx.rs

1//! MLX Provider - Apple's ML framework for Apple Silicon
2//!
3//! MLX is Apple's machine learning framework optimized for Apple Silicon.
4//! This provider connects to an MLX HTTP server running locally.
5//!
6//! Setup:
7//! 1. Install mlx-lm: `pip install mlx-lm`
8//! 2. Start server: `python -m mlx_lm.server --model mlx-community/Llama-3.2-3B-Instruct-4bit`
9//! 3. Configure rco: `rco config set RCO_AI_PROVIDER=mlx RCO_API_URL=http://localhost:8080`
10
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15
16use super::{build_prompt, AIProvider};
17use crate::config::Config;
18use crate::utils::retry::retry_async;
19
20pub struct MlxProvider {
21    client: Client,
22    api_url: String,
23    model: String,
24}
25
26#[derive(Serialize)]
27struct MlxRequest {
28    model: String,
29    messages: Vec<MlxMessage>,
30    max_tokens: i32,
31    temperature: f32,
32    stream: bool,
33}
34
35#[derive(Serialize, Deserialize, Clone)]
36struct MlxMessage {
37    role: String,
38    content: String,
39}
40
41#[derive(Deserialize)]
42struct MlxResponse {
43    choices: Vec<MlxChoice>,
44}
45
46#[derive(Deserialize)]
47struct MlxChoice {
48    message: MlxMessage,
49}
50
51impl MlxProvider {
52    pub fn new(config: &Config) -> Result<Self> {
53        let client = Client::new();
54        let api_url = config
55            .api_url
56            .as_deref()
57            .unwrap_or("http://localhost:8080")
58            .to_string();
59        let model = config
60            .model
61            .as_deref()
62            .unwrap_or("default")
63            .to_string();
64
65        Ok(Self {
66            client,
67            api_url,
68            model,
69        })
70    }
71
72    /// Create provider from account configuration
73    #[allow(dead_code)]
74    pub fn from_account(
75        account: &crate::config::accounts::AccountConfig,
76        _api_key: &str,
77        config: &Config,
78    ) -> Result<Self> {
79        let client = Client::new();
80        let api_url = account
81            .api_url
82            .as_deref()
83            .or(config.api_url.as_deref())
84            .unwrap_or("http://localhost:8080")
85            .to_string();
86        let model = account
87            .model
88            .as_deref()
89            .or(config.model.as_deref())
90            .unwrap_or("default")
91            .to_string();
92
93        Ok(Self {
94            client,
95            api_url,
96            model,
97        })
98    }
99}
100
101#[async_trait]
102impl AIProvider for MlxProvider {
103    async fn generate_commit_message(
104        &self,
105        diff: &str,
106        context: Option<&str>,
107        full_gitmoji: bool,
108        config: &Config,
109    ) -> Result<String> {
110        let prompt = build_prompt(diff, context, config, full_gitmoji);
111
112        // MLX uses OpenAI-compatible chat format
113        let messages = vec![
114            MlxMessage {
115                role: "system".to_string(),
116                content: "You are an expert at writing clear, concise git commit messages.".to_string(),
117            },
118            MlxMessage {
119                role: "user".to_string(),
120                content: prompt,
121            },
122        ];
123
124        let request = MlxRequest {
125            model: self.model.clone(),
126            messages,
127            max_tokens: config.tokens_max_output.unwrap_or(500) as i32,
128            temperature: 0.7,
129            stream: false,
130        };
131
132        let mlx_response: MlxResponse = retry_async(|| async {
133            let url = format!("{}/v1/chat/completions", self.api_url);
134            let response = self
135                .client
136                .post(&url)
137                .json(&request)
138                .send()
139                .await
140                .context("Failed to connect to MLX server")?;
141
142            if !response.status().is_success() {
143                let error_text = response.text().await?;
144                return Err(anyhow::anyhow!("MLX API error: {}", error_text));
145            }
146
147            let mlx_response: MlxResponse = response
148                .json()
149                .await
150                .context("Failed to parse MLX response")?;
151
152            Ok(mlx_response)
153        })
154        .await
155        .context("Failed to generate commit message from MLX after retries")?;
156
157        let message = mlx_response
158            .choices
159            .first()
160            .map(|choice| choice.message.content.trim().to_string())
161            .context("MLX returned an empty response")?;
162
163        Ok(message)
164    }
165}
166
167/// ProviderBuilder for MLX
168pub struct MlxProviderBuilder;
169
170impl super::registry::ProviderBuilder for MlxProviderBuilder {
171    fn name(&self) -> &'static str {
172        "mlx"
173    }
174
175    fn aliases(&self) -> Vec<&'static str> {
176        vec!["mlx-lm", "apple-mlx"]
177    }
178
179    fn category(&self) -> super::registry::ProviderCategory {
180        super::registry::ProviderCategory::Local
181    }
182
183    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
184        Ok(Box::new(MlxProvider::new(config)?))
185    }
186
187    fn requires_api_key(&self) -> bool {
188        false
189    }
190
191    fn default_model(&self) -> Option<&'static str> {
192        Some("mlx-community/Llama-3.2-3B-Instruct-4bit")
193    }
194}