sklears_python/linear/
elastic_net.rs

1//! Python bindings for ElasticNet Regression
2//!
3//! This module provides Python bindings for ElasticNet Regression,
4//! offering scikit-learn compatible interfaces with combined L1+L2 regularization
5//! using the sklears-linear crate.
6
7use super::common::*;
8use pyo3::types::PyDict;
9use pyo3::Bound;
10use sklears_core::traits::{Fit, Predict, Score, Trained};
11use sklears_linear::{LinearRegression, LinearRegressionConfig, Penalty};
12
13/// Python-specific configuration wrapper for ElasticNet
14#[derive(Debug, Clone)]
15pub struct PyElasticNetConfig {
16    pub alpha: f64,
17    pub l1_ratio: f64,
18    pub fit_intercept: bool,
19    pub copy_x: bool,
20    pub max_iter: usize,
21    pub tol: f64,
22    pub warm_start: bool,
23    pub positive: bool,
24    pub random_state: Option<i32>,
25    pub selection: String,
26}
27
28impl Default for PyElasticNetConfig {
29    fn default() -> Self {
30        Self {
31            alpha: 1.0,
32            l1_ratio: 0.5,
33            fit_intercept: true,
34            copy_x: true,
35            max_iter: 1000,
36            tol: 1e-4,
37            warm_start: false,
38            positive: false,
39            random_state: None,
40            selection: "cyclic".to_string(),
41        }
42    }
43}
44
45impl From<PyElasticNetConfig> for LinearRegressionConfig {
46    fn from(py_config: PyElasticNetConfig) -> Self {
47        // ElasticNet combines L1 and L2 penalties
48        LinearRegressionConfig {
49            fit_intercept: py_config.fit_intercept,
50            penalty: Penalty::ElasticNet {
51                alpha: py_config.alpha,
52                l1_ratio: py_config.l1_ratio,
53            },
54            max_iter: py_config.max_iter,
55            tol: py_config.tol,
56            warm_start: py_config.warm_start,
57            ..Default::default()
58        }
59    }
60}
61
62/// Linear regression with combined L1 and L2 priors as regularizer.
63///
64/// Minimizes the objective function:
65///
66///     1 / (2 * n_samples) * ||y - Xw||^2_2
67///     + alpha * l1_ratio * ||w||_1
68///     + 0.5 * alpha * (1 - l1_ratio) * ||w||^2_2
69///
70/// If you are interested in controlling the L1 and L2 penalty
71/// separately, keep in mind that this is equivalent to:
72///
73///     a * L1 + b * L2
74///
75/// where:
76///
77///     alpha = a + b and l1_ratio = a / (a + b)
78///
79/// The parameter l1_ratio corresponds to alpha in the glmnet R package
80/// while alpha corresponds to the lambda parameter in glmnet.
81/// Specifically, l1_ratio = 1 is the lasso penalty. Currently, l1_ratio
82/// <= 0.01 is not reliable, unless you supply your own sequence of alpha.
83///
84/// Parameters
85/// ----------
86/// alpha : float, default=1.0
87///     Constant that multiplies the penalty terms. Defaults to 1.0.
88///     See the notes for the exact mathematical meaning of this
89///     parameter. ``alpha = 0`` is equivalent to an ordinary least square,
90///     solved by the :class:`LinearRegression` object. For numerical
91///     reasons, using ``alpha = 0`` with the ``Lasso`` object is not advised.
92///     Given this, you should use the :class:`LinearRegression` object.
93///
94/// l1_ratio : float, default=0.5
95///     The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For
96///     ``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it
97///     is an L1 penalty.  For ``0 < l1_ratio < 1``, the penalty is a
98///     combination of L1 and L2.
99///
100/// fit_intercept : bool, default=True
101///     Whether to calculate the intercept for this model. If set
102///     to False, no intercept will be used in calculations
103///     (i.e. data is expected to be centered).
104///
105/// copy_X : bool, default=True
106///     If ``True``, X will be copied; else, it may be overwritten.
107///
108/// max_iter : int, default=1000
109///     The maximum number of iterations for the optimization algorithm.
110///
111/// tol : float, default=1e-4
112///     The tolerance for the optimization: if the updates are
113///     smaller than ``tol``, the optimization code checks the
114///     dual gap for optimality and continues until it is smaller
115///     than ``tol``, see Notes below.
116///
117/// warm_start : bool, default=False
118///     When set to ``True``, reuse the solution of the previous call to fit as
119///     initialization, otherwise, just erase the previous solution.
120///     See :term:`the Glossary <warm_start>`.
121///
122/// positive : bool, default=False
123///     When set to ``True``, forces the coefficients to be positive.
124///
125/// random_state : int, RandomState instance, default=None
126///     The seed of the pseudo random number generator that selects a random
127///     feature to update. Used when ``selection`` == 'random'.
128///     Pass an int for reproducible output across multiple function calls.
129///     See :term:`Glossary <random_state>`.
130///
131/// selection : {'cyclic', 'random'}, default='cyclic'
132///     If set to 'random', a random coefficient is updated every iteration
133///     rather than looping over features sequentially by default. This
134///     (setting to 'random') often leads to significantly faster convergence
135///     especially when tol is higher than 1e-4.
136///
137/// Attributes
138/// ----------
139/// coef_ : ndarray of shape (n_features,) or (n_targets, n_features)
140///     Parameter vector (w in the cost function formula).
141///
142/// sparse_coef_ : sparse matrix of shape (n_features,) or \
143///         (n_targets, n_features)
144///     Sparse representation of the fitted ``coef_``.
145///
146/// intercept_ : float or ndarray of shape (n_targets,)
147///     Independent term in decision function.
148///
149/// n_features_in_ : int
150///     Number of features seen during :term:`fit`.
151///
152/// n_iter_ : list of int
153///     Number of iterations run by the coordinate descent solver to reach
154///     the specified tolerance.
155///
156/// Examples
157/// --------
158/// >>> from sklears_python import ElasticNet
159/// >>> from sklearn.datasets import make_regression
160/// >>> X, y = make_regression(n_features=2, random_state=0)
161/// >>> regr = ElasticNet(random_state=0)
162/// >>> regr.fit(X, y)
163/// ElasticNet(random_state=0)
164/// >>> print(regr.coef_)
165/// [18.83816119 64.55968437]
166/// >>> print(regr.intercept_)
167/// 1.451...
168/// >>> print(regr.predict([[0, 0]]))
169/// [1.451...]
170///
171/// Notes
172/// -----
173/// To avoid unnecessary memory duplication the X argument of the fit method
174/// should be directly passed as a Fortran-contiguous NumPy array.
175///
176/// The precise stopping criteria based on `tol` are the following: First,
177/// check that that maximum coordinate update, i.e. :math:`\\max_j |w_j^{new} -
178/// w_j^{old}|` is smaller than `tol` times the maximum absolute coefficient,
179/// :math:`\\max_j |w_j|`. If so, then additionally check whether the dual gap
180/// is smaller than `tol` times :math:`||y||_2^2 / n_\\text{samples}`.
181#[pyclass(name = "ElasticNet")]
182pub struct PyElasticNet {
183    /// Python-specific configuration
184    py_config: PyElasticNetConfig,
185    /// Trained model instance using the actual sklears-linear implementation
186    fitted_model: Option<LinearRegression<Trained>>,
187}
188
189#[pymethods]
190impl PyElasticNet {
191    #[new]
192    #[allow(clippy::too_many_arguments)]
193    #[pyo3(signature = (alpha=1.0, l1_ratio=0.5, fit_intercept=true, copy_x=true, max_iter=1000, tol=1e-4, warm_start=false, positive=false, random_state=None, selection="cyclic"))]
194    fn new(
195        alpha: f64,
196        l1_ratio: f64,
197        fit_intercept: bool,
198        copy_x: bool,
199        max_iter: usize,
200        tol: f64,
201        warm_start: bool,
202        positive: bool,
203        random_state: Option<i32>,
204        selection: &str,
205    ) -> PyResult<Self> {
206        // Validate l1_ratio
207        if !(0.0..=1.0).contains(&l1_ratio) {
208            return Err(PyValueError::new_err(
209                "l1_ratio must be between 0 and 1 (inclusive)",
210            ));
211        }
212
213        // Validate alpha
214        if alpha < 0.0 {
215            return Err(PyValueError::new_err("alpha must be non-negative"));
216        }
217
218        let py_config = PyElasticNetConfig {
219            alpha,
220            l1_ratio,
221            fit_intercept,
222            copy_x,
223            max_iter,
224            tol,
225            warm_start,
226            positive,
227            random_state,
228            selection: selection.to_string(),
229        };
230
231        Ok(Self {
232            py_config,
233            fitted_model: None,
234        })
235    }
236
237    /// Fit the ElasticNet regression model
238    fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
239        let x_array = pyarray_to_core_array2(x)?;
240        let y_array = pyarray_to_core_array1(y)?;
241
242        // Validate input arrays
243        validate_fit_arrays(&x_array, &y_array)?;
244
245        // Create sklears-linear model with ElasticNet configuration
246        let model = LinearRegression::elastic_net(self.py_config.alpha, self.py_config.l1_ratio)
247            .fit_intercept(self.py_config.fit_intercept);
248
249        // Fit the model using sklears-linear's implementation
250        match model.fit(&x_array, &y_array) {
251            Ok(fitted_model) => {
252                self.fitted_model = Some(fitted_model);
253                Ok(())
254            }
255            Err(e) => Err(PyValueError::new_err(format!(
256                "Failed to fit ElasticNet model: {:?}",
257                e
258            ))),
259        }
260    }
261
262    /// Predict using the fitted model
263    fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
264        let fitted = self
265            .fitted_model
266            .as_ref()
267            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
268
269        let x_array = pyarray_to_core_array2(x)?;
270        validate_predict_array(&x_array)?;
271
272        match fitted.predict(&x_array) {
273            Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
274            Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
275        }
276    }
277
278    /// Get model coefficients
279    #[getter]
280    fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
281        let fitted = self
282            .fitted_model
283            .as_ref()
284            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
285
286        Ok(core_array1_to_py(py, fitted.coef()))
287    }
288
289    /// Get model intercept
290    #[getter]
291    fn intercept_(&self) -> PyResult<f64> {
292        let fitted = self
293            .fitted_model
294            .as_ref()
295            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
296
297        Ok(fitted.intercept().unwrap_or(0.0))
298    }
299
300    /// Calculate R² score
301    fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
302        let fitted = self
303            .fitted_model
304            .as_ref()
305            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
306
307        let x_array = pyarray_to_core_array2(x)?;
308        let y_array = pyarray_to_core_array1(y)?;
309
310        match fitted.score(&x_array, &y_array) {
311            Ok(score) => Ok(score),
312            Err(e) => Err(PyValueError::new_err(format!(
313                "Score calculation failed: {:?}",
314                e
315            ))),
316        }
317    }
318
319    /// Get number of features
320    #[getter]
321    fn n_features_in_(&self) -> PyResult<usize> {
322        let fitted = self
323            .fitted_model
324            .as_ref()
325            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
326
327        // Infer number of features from coefficient array length
328        Ok(fitted.coef().len())
329    }
330
331    /// Return parameters for this estimator (sklearn compatibility)
332    fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
333        let _deep = deep.unwrap_or(true);
334
335        let dict = PyDict::new(py);
336
337        dict.set_item("alpha", self.py_config.alpha)?;
338        dict.set_item("l1_ratio", self.py_config.l1_ratio)?;
339        dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
340        dict.set_item("copy_X", self.py_config.copy_x)?;
341        dict.set_item("max_iter", self.py_config.max_iter)?;
342        dict.set_item("tol", self.py_config.tol)?;
343        dict.set_item("warm_start", self.py_config.warm_start)?;
344        dict.set_item("positive", self.py_config.positive)?;
345        dict.set_item("random_state", self.py_config.random_state)?;
346        dict.set_item("selection", &self.py_config.selection)?;
347
348        Ok(dict.into())
349    }
350
351    /// Set parameters for this estimator (sklearn compatibility)
352    fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
353        // Update configuration parameters
354        if let Some(alpha) = kwargs.get_item("alpha")? {
355            let alpha_val: f64 = alpha.extract()?;
356            if alpha_val < 0.0 {
357                return Err(PyValueError::new_err("alpha must be non-negative"));
358            }
359            self.py_config.alpha = alpha_val;
360        }
361        if let Some(l1_ratio) = kwargs.get_item("l1_ratio")? {
362            let l1_ratio_val: f64 = l1_ratio.extract()?;
363            if !(0.0..=1.0).contains(&l1_ratio_val) {
364                return Err(PyValueError::new_err(
365                    "l1_ratio must be between 0 and 1 (inclusive)",
366                ));
367            }
368            self.py_config.l1_ratio = l1_ratio_val;
369        }
370        if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
371            self.py_config.fit_intercept = fit_intercept.extract()?;
372        }
373        if let Some(copy_x) = kwargs.get_item("copy_X")? {
374            self.py_config.copy_x = copy_x.extract()?;
375        }
376        if let Some(max_iter) = kwargs.get_item("max_iter")? {
377            self.py_config.max_iter = max_iter.extract()?;
378        }
379        if let Some(tol) = kwargs.get_item("tol")? {
380            self.py_config.tol = tol.extract()?;
381        }
382        if let Some(warm_start) = kwargs.get_item("warm_start")? {
383            self.py_config.warm_start = warm_start.extract()?;
384        }
385        if let Some(positive) = kwargs.get_item("positive")? {
386            self.py_config.positive = positive.extract()?;
387        }
388        if let Some(random_state) = kwargs.get_item("random_state")? {
389            self.py_config.random_state = random_state.extract()?;
390        }
391        if let Some(selection) = kwargs.get_item("selection")? {
392            let selection_str: String = selection.extract()?;
393            self.py_config.selection = selection_str;
394        }
395
396        // Clear fitted model since config changed
397        self.fitted_model = None;
398
399        Ok(())
400    }
401
402    /// String representation
403    fn __repr__(&self) -> String {
404        format!(
405            "ElasticNet(alpha={}, l1_ratio={}, fit_intercept={}, copy_X={}, max_iter={}, tol={}, warm_start={}, positive={}, random_state={:?}, selection='{}')",
406            self.py_config.alpha,
407            self.py_config.l1_ratio,
408            self.py_config.fit_intercept,
409            self.py_config.copy_x,
410            self.py_config.max_iter,
411            self.py_config.tol,
412            self.py_config.warm_start,
413            self.py_config.positive,
414            self.py_config.random_state,
415            self.py_config.selection
416        )
417    }
418}