Skip to main content

sklears_python/linear/
ard_regression.rs

1//! Python bindings for ARD Regression
2//!
3//! This module provides Python bindings for Automatic Relevance Determination (ARD) Regression,
4//! offering scikit-learn compatible interfaces with automatic feature selection
5//! and uncertainty quantification using the sklears-linear crate.
6
7use super::common::*;
8use pyo3::types::PyDict;
9use pyo3::Bound;
10use sklears_core::traits::{Fit, Predict, Score, Trained};
11use sklears_linear::{ARDRegression, ARDRegressionConfig};
12
13/// Python-specific configuration wrapper for ARDRegression
14#[derive(Debug, Clone)]
15pub struct PyARDRegressionConfig {
16    pub max_iter: usize,
17    pub tol: f64,
18    pub alpha_init: Option<f64>,
19    pub lambda_init: Option<f64>,
20    pub threshold_alpha: f64,
21    pub fit_intercept: bool,
22    pub compute_score: bool,
23    pub copy_x: bool,
24}
25
26impl Default for PyARDRegressionConfig {
27    fn default() -> Self {
28        Self {
29            max_iter: 300,
30            tol: 1e-3,
31            alpha_init: Some(1.0),
32            lambda_init: Some(1.0),
33            threshold_alpha: 1e10,
34            fit_intercept: true,
35            compute_score: false,
36            copy_x: true,
37        }
38    }
39}
40
41impl From<PyARDRegressionConfig> for ARDRegressionConfig {
42    fn from(py_config: PyARDRegressionConfig) -> Self {
43        ARDRegressionConfig {
44            max_iter: py_config.max_iter,
45            tol: py_config.tol,
46            alpha_init: py_config
47                .alpha_init
48                .unwrap_or_else(|| ARDRegressionConfig::default().alpha_init),
49            lambda_init: py_config
50                .lambda_init
51                .unwrap_or_else(|| ARDRegressionConfig::default().lambda_init),
52            threshold_alpha: py_config.threshold_alpha,
53            fit_intercept: py_config.fit_intercept,
54            compute_score: py_config.compute_score,
55        }
56    }
57}
58
59/// Bayesian ARD regression.
60///
61/// Fit the weights of a regression model, using an ARD prior. The weights of
62/// the regression model are assumed to be drawn from an isotropic Gaussian
63/// distribution with precision lambda. The shrinkage is data-dependent,
64/// and the parameters of the prior are estimated from the data using empirical
65/// Bayes approach.
66///
67/// Parameters
68/// ----------
69/// max_iter : int, default=300
70///     Maximum number of iterations.
71///
72/// tol : float, default=1e-3
73///     Stop the algorithm if w has converged.
74///
75/// alpha_init : float, default=1.0
76///     Initial value for alpha (per-feature precisions).
77///     If not provided, alpha_init is 1.0.
78///
79/// lambda_init : float, default=1.0
80///     Initial value for lambda (precision of the noise).
81///     If not provided, lambda_init is 1.0.
82///
83/// threshold_alpha : float, default=1e10
84///     Threshold for removing (pruning) weights with high precision from
85///     the computation: features with precision higher than this threshold
86///     are considered to have zero weight.
87///
88/// fit_intercept : bool, default=True
89///     Whether to calculate the intercept for this model. If set
90///     to False, no intercept will be used in calculations
91///     (i.e. data is expected to be centered).
92///
93/// compute_score : bool, default=False
94///     If True, compute the objective function at each step of the model.
95///
96/// copy_X : bool, default=True
97///     If True, X will be copied; else, it may be overwritten.
98///
99/// Attributes
100/// ----------
101/// coef_ : array-like of shape (n_features,)
102///     Coefficients of the regression model (mean of distribution)
103///
104/// alpha_ : array-like of shape (n_features,)
105///     estimated precision of the weights.
106///
107/// lambda_ : float
108///     estimated precision of the noise.
109///
110/// sigma_ : array-like of shape (n_features, n_features)
111///     estimated variance-covariance matrix of the weights
112///
113/// scores_ : array-like of shape (n_iter_+1,)
114///     if computed, value of the objective function (to be maximized)
115///     at each iteration of the optimization.
116///
117/// intercept_ : float
118///     Independent term in decision function. Set to 0.0 if
119///     `fit_intercept = False`.
120///
121/// n_features_in_ : int
122///     Number of features seen during fit.
123///
124/// Examples
125/// --------
126/// >>> from sklears_python import ARDRegression
127/// >>> import numpy as np
128/// ```text
129/// >>> X = np.array([[1], [2], [3], [4], [5]])
130/// >>> y = np.array([1, 2, 3, 4, 5])
131/// >>> reg = ARDRegression()
132/// >>> reg.fit(X, y)
133/// ARDRegression()
134/// >>> reg.predict([[3]])
135/// array([3.])
136/// ```
137///
138/// Notes
139/// -----
140/// ARD performs feature selection by setting the weights of many features
141/// to zero, as they are deemed irrelevant. This is particularly useful when
142/// the number of features is much larger than the number of samples.
143///
144/// For polynomial regression, it is recommended to "center" the data by
145/// subtracting its mean before fitting the ARD model.
146///
147/// References
148/// ----------
149/// D. J. C. MacKay, Bayesian nonlinear modeling for the prediction
150/// competition, ASHRAE Transactions, 1994.
151///
152/// R. Salakhutdinov, Lecture notes on Statistical Machine Learning,
153/// <http://www.cs.toronto.edu/~rsalakhu/sta4273/notes/Lecture2.pdf>
154/// Their beta is our `lambda_`, and their alpha is our `alpha_`.
155/// ARD is a little different: only `lambda_` is inferred; `alpha_`
156/// is fixed by the user.
157#[pyclass(name = "ARDRegression")]
158pub struct PyARDRegression {
159    /// Python-specific configuration
160    py_config: PyARDRegressionConfig,
161    /// Trained model instance using the actual sklears-linear implementation
162    fitted_model: Option<ARDRegression<Trained>>,
163}
164
165#[pymethods]
166impl PyARDRegression {
167    #[new]
168    #[allow(clippy::too_many_arguments)]
169    #[pyo3(signature = (max_iter=300, tol=1e-3, alpha_init=1.0, lambda_init=1.0, threshold_alpha=1e10, fit_intercept=true, compute_score=false, copy_x=true))]
170    fn new(
171        max_iter: usize,
172        tol: f64,
173        alpha_init: f64,
174        lambda_init: f64,
175        threshold_alpha: f64,
176        fit_intercept: bool,
177        compute_score: bool,
178        copy_x: bool,
179    ) -> PyResult<Self> {
180        // Validate parameters
181        if max_iter == 0 {
182            return Err(PyValueError::new_err("max_iter must be greater than 0"));
183        }
184        if tol <= 0.0 {
185            return Err(PyValueError::new_err("tol must be positive"));
186        }
187        if alpha_init <= 0.0 {
188            return Err(PyValueError::new_err("alpha_init must be positive"));
189        }
190        if lambda_init <= 0.0 {
191            return Err(PyValueError::new_err("lambda_init must be positive"));
192        }
193        if threshold_alpha <= 0.0 {
194            return Err(PyValueError::new_err("threshold_alpha must be positive"));
195        }
196
197        let py_config = PyARDRegressionConfig {
198            max_iter,
199            tol,
200            alpha_init: Some(alpha_init),
201            lambda_init: Some(lambda_init),
202            threshold_alpha,
203            fit_intercept,
204            compute_score,
205            copy_x,
206        };
207
208        Ok(Self {
209            py_config,
210            fitted_model: None,
211        })
212    }
213
214    /// Fit the ARD regression model
215    fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
216        let x_array = pyarray_to_core_array2(x)?;
217        let y_array = pyarray_to_core_array1(y)?;
218
219        // Validate input arrays
220        validate_fit_arrays(&x_array, &y_array)?;
221
222        // Create sklears-linear model with ARD configuration
223        let model = ARDRegression::new()
224            .max_iter(self.py_config.max_iter)
225            .tol(self.py_config.tol)
226            .threshold_alpha(self.py_config.threshold_alpha)
227            .fit_intercept(self.py_config.fit_intercept);
228
229        // Fit the model using sklears-linear's implementation
230        match model.fit(&x_array, &y_array) {
231            Ok(fitted_model) => {
232                self.fitted_model = Some(fitted_model);
233                Ok(())
234            }
235            Err(e) => Err(PyValueError::new_err(format!(
236                "Failed to fit ARD regression model: {:?}",
237                e
238            ))),
239        }
240    }
241
242    /// Predict using the fitted model
243    fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
244        let fitted = self
245            .fitted_model
246            .as_ref()
247            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
248
249        let x_array = pyarray_to_core_array2(x)?;
250        validate_predict_array(&x_array)?;
251
252        match fitted.predict(&x_array) {
253            Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
254            Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
255        }
256    }
257
258    /// Get model coefficients
259    #[getter]
260    fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
261        let fitted = self
262            .fitted_model
263            .as_ref()
264            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
265
266        let coef = fitted
267            .coef()
268            .map_err(|e| PyValueError::new_err(format!("Failed to get coefficients: {:?}", e)))?;
269        Ok(core_array1_to_py(py, coef))
270    }
271
272    /// Get model intercept
273    #[getter]
274    fn intercept_(&self) -> PyResult<f64> {
275        let fitted = self
276            .fitted_model
277            .as_ref()
278            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
279
280        Ok(fitted.intercept().unwrap_or(0.0))
281    }
282
283    /// Get estimated per-feature precisions (alpha)
284    #[getter]
285    fn alpha_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
286        let fitted = self
287            .fitted_model
288            .as_ref()
289            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
290
291        let alpha = fitted
292            .alpha()
293            .map_err(|e| PyValueError::new_err(format!("Failed to get alpha: {:?}", e)))?;
294        Ok(core_array1_to_py(py, alpha))
295    }
296
297    /// Get estimated precision of noise (lambda)
298    #[getter]
299    fn lambda_(&self) -> PyResult<f64> {
300        let fitted = self
301            .fitted_model
302            .as_ref()
303            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
304
305        fitted
306            .lambda()
307            .map_err(|e| PyValueError::new_err(format!("Failed to get lambda: {:?}", e)))
308    }
309
310    /// Calculate R² score
311    fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
312        let fitted = self
313            .fitted_model
314            .as_ref()
315            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
316
317        let x_array = pyarray_to_core_array2(x)?;
318        let y_array = pyarray_to_core_array1(y)?;
319
320        match fitted.score(&x_array, &y_array) {
321            Ok(score) => Ok(score),
322            Err(e) => Err(PyValueError::new_err(format!(
323                "Score calculation failed: {:?}",
324                e
325            ))),
326        }
327    }
328
329    /// Get number of features
330    #[getter]
331    fn n_features_in_(&self) -> PyResult<usize> {
332        let fitted = self
333            .fitted_model
334            .as_ref()
335            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
336
337        // Infer number of features from coefficient array length
338        let coef = fitted
339            .coef()
340            .map_err(|e| PyValueError::new_err(format!("Failed to get coefficients: {:?}", e)))?;
341        Ok(coef.len())
342    }
343
344    /// Return parameters for this estimator (sklearn compatibility)
345    fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
346        let _deep = deep.unwrap_or(true);
347
348        let dict = PyDict::new(py);
349
350        dict.set_item("max_iter", self.py_config.max_iter)?;
351        dict.set_item("tol", self.py_config.tol)?;
352        dict.set_item("alpha_init", self.py_config.alpha_init)?;
353        dict.set_item("lambda_init", self.py_config.lambda_init)?;
354        dict.set_item("threshold_alpha", self.py_config.threshold_alpha)?;
355        dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
356        dict.set_item("compute_score", self.py_config.compute_score)?;
357        dict.set_item("copy_X", self.py_config.copy_x)?;
358
359        Ok(dict.into())
360    }
361
362    /// Set parameters for this estimator (sklearn compatibility)
363    fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
364        // Update configuration parameters
365        if let Some(max_iter) = kwargs.get_item("max_iter")? {
366            let max_iter_val: usize = max_iter.extract()?;
367            if max_iter_val == 0 {
368                return Err(PyValueError::new_err("max_iter must be greater than 0"));
369            }
370            self.py_config.max_iter = max_iter_val;
371        }
372        if let Some(tol) = kwargs.get_item("tol")? {
373            let tol_val: f64 = tol.extract()?;
374            if tol_val <= 0.0 {
375                return Err(PyValueError::new_err("tol must be positive"));
376            }
377            self.py_config.tol = tol_val;
378        }
379        if let Some(alpha_init) = kwargs.get_item("alpha_init")? {
380            let alpha_init_val: f64 = alpha_init.extract()?;
381            if alpha_init_val <= 0.0 {
382                return Err(PyValueError::new_err("alpha_init must be positive"));
383            }
384            self.py_config.alpha_init = Some(alpha_init_val);
385        }
386        if let Some(lambda_init) = kwargs.get_item("lambda_init")? {
387            let lambda_init_val: f64 = lambda_init.extract()?;
388            if lambda_init_val <= 0.0 {
389                return Err(PyValueError::new_err("lambda_init must be positive"));
390            }
391            self.py_config.lambda_init = Some(lambda_init_val);
392        }
393        if let Some(threshold_alpha) = kwargs.get_item("threshold_alpha")? {
394            let threshold_alpha_val: f64 = threshold_alpha.extract()?;
395            if threshold_alpha_val <= 0.0 {
396                return Err(PyValueError::new_err("threshold_alpha must be positive"));
397            }
398            self.py_config.threshold_alpha = threshold_alpha_val;
399        }
400        if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
401            self.py_config.fit_intercept = fit_intercept.extract()?;
402        }
403        if let Some(compute_score) = kwargs.get_item("compute_score")? {
404            self.py_config.compute_score = compute_score.extract()?;
405        }
406        if let Some(copy_x) = kwargs.get_item("copy_X")? {
407            self.py_config.copy_x = copy_x.extract()?;
408        }
409
410        // Clear fitted model since config changed
411        self.fitted_model = None;
412
413        Ok(())
414    }
415
416    /// String representation
417    fn __repr__(&self) -> String {
418        format!(
419            "ARDRegression(max_iter={}, tol={}, alpha_init={:?}, lambda_init={:?}, threshold_alpha={}, fit_intercept={}, compute_score={}, copy_X={})",
420            self.py_config.max_iter,
421            self.py_config.tol,
422            self.py_config.alpha_init,
423            self.py_config.lambda_init,
424            self.py_config.threshold_alpha,
425            self.py_config.fit_intercept,
426            self.py_config.compute_score,
427            self.py_config.copy_x
428        )
429    }
430}