Skip to main content

rusty_commit/providers/
mlx.rs

1//! MLX Provider - Apple's ML framework for Apple Silicon
2//!
3//! MLX is Apple's machine learning framework optimized for Apple Silicon.
4//! This provider connects to an MLX HTTP server running locally.
5//!
6//! Setup:
7//! 1. Install mlx-lm: `pip install mlx-lm`
8//! 2. Start server: `python -m mlx_lm.server --model mlx-community/Llama-3.2-3B-Instruct-4bit`
9//! 3. Configure rco: `rco config set RCO_AI_PROVIDER=mlx RCO_API_URL=http://localhost:8080`
10
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15
16use super::prompt::build_prompt;
17use super::AIProvider;
18use crate::config::Config;
19use crate::utils::retry::retry_async;
20
21pub struct MlxProvider {
22    client: Client,
23    api_url: String,
24    model: String,
25}
26
27#[derive(Serialize)]
28struct MlxRequest {
29    model: String,
30    messages: Vec<MlxMessage>,
31    max_tokens: i32,
32    temperature: f32,
33    stream: bool,
34}
35
36#[derive(Serialize, Deserialize, Clone)]
37struct MlxMessage {
38    role: String,
39    content: String,
40}
41
42#[derive(Deserialize)]
43struct MlxResponse {
44    choices: Vec<MlxChoice>,
45}
46
47#[derive(Deserialize)]
48struct MlxChoice {
49    message: MlxMessage,
50}
51
52impl MlxProvider {
53    pub fn new(config: &Config) -> Result<Self> {
54        let client = Client::new();
55        let api_url = config
56            .api_url
57            .as_deref()
58            .unwrap_or("http://localhost:8080")
59            .to_string();
60        let model = config.model.as_deref().unwrap_or("default").to_string();
61
62        Ok(Self {
63            client,
64            api_url,
65            model,
66        })
67    }
68
69    /// Create provider from account configuration
70    #[allow(dead_code)]
71    pub fn from_account(
72        account: &crate::config::accounts::AccountConfig,
73        _api_key: &str,
74        config: &Config,
75    ) -> Result<Self> {
76        let client = Client::new();
77        let api_url = account
78            .api_url
79            .as_deref()
80            .or(config.api_url.as_deref())
81            .unwrap_or("http://localhost:8080")
82            .to_string();
83        let model = account
84            .model
85            .as_deref()
86            .or(config.model.as_deref())
87            .unwrap_or("default")
88            .to_string();
89
90        Ok(Self {
91            client,
92            api_url,
93            model,
94        })
95    }
96}
97
98#[async_trait]
99impl AIProvider for MlxProvider {
100    async fn generate_commit_message(
101        &self,
102        diff: &str,
103        context: Option<&str>,
104        full_gitmoji: bool,
105        config: &Config,
106    ) -> Result<String> {
107        let prompt = build_prompt(diff, context, config, full_gitmoji);
108
109        // MLX uses OpenAI-compatible chat format
110        let messages = vec![
111            MlxMessage {
112                role: "system".to_string(),
113                content: "You are an expert at writing clear, concise git commit messages."
114                    .to_string(),
115            },
116            MlxMessage {
117                role: "user".to_string(),
118                content: prompt,
119            },
120        ];
121
122        let request = MlxRequest {
123            model: self.model.clone(),
124            messages,
125            max_tokens: config.tokens_max_output.unwrap_or(500) as i32,
126            temperature: 0.7,
127            stream: false,
128        };
129
130        let mlx_response: MlxResponse = retry_async(|| async {
131            let url = format!("{}/v1/chat/completions", self.api_url);
132            let response = self
133                .client
134                .post(&url)
135                .json(&request)
136                .send()
137                .await
138                .context("Failed to connect to MLX server")?;
139
140            if !response.status().is_success() {
141                let error_text = response.text().await?;
142                return Err(anyhow::anyhow!("MLX API error: {}", error_text));
143            }
144
145            let mlx_response: MlxResponse = response
146                .json()
147                .await
148                .context("Failed to parse MLX response")?;
149
150            Ok(mlx_response)
151        })
152        .await
153        .context("Failed to generate commit message from MLX after retries")?;
154
155        let message = mlx_response
156            .choices
157            .first()
158            .map(|choice| choice.message.content.trim().to_string())
159            .context("MLX returned an empty response")?;
160
161        Ok(message)
162    }
163}
164
165/// ProviderBuilder for MLX
166pub struct MlxProviderBuilder;
167
168impl super::registry::ProviderBuilder for MlxProviderBuilder {
169    fn name(&self) -> &'static str {
170        "mlx"
171    }
172
173    fn aliases(&self) -> Vec<&'static str> {
174        vec!["mlx-lm", "apple-mlx"]
175    }
176
177    fn category(&self) -> super::registry::ProviderCategory {
178        super::registry::ProviderCategory::Local
179    }
180
181    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
182        Ok(Box::new(MlxProvider::new(config)?))
183    }
184
185    fn requires_api_key(&self) -> bool {
186        false
187    }
188
189    fn default_model(&self) -> Option<&'static str> {
190        Some("mlx-community/Llama-3.2-3B-Instruct-4bit")
191    }
192}