sklears_python/linear/
lasso.rs

1//! Python bindings for Lasso Regression
2//!
3//! This module provides Python bindings for Lasso Regression,
4//! offering scikit-learn compatible interfaces with L1 regularization
5//! using the sklears-linear crate.
6
7use super::common::*;
8use pyo3::types::PyDict;
9use pyo3::Bound;
10use sklears_core::traits::{Fit, Predict, Score, Trained};
11use sklears_linear::{LinearRegression, LinearRegressionConfig, Penalty};
12
13/// Python-specific configuration wrapper for Lasso
14#[derive(Debug, Clone)]
15pub struct PyLassoConfig {
16    pub alpha: f64,
17    pub fit_intercept: bool,
18    pub copy_x: bool,
19    pub max_iter: usize,
20    pub tol: f64,
21    pub warm_start: bool,
22    pub positive: bool,
23    pub random_state: Option<i32>,
24    pub selection: String,
25}
26
27impl Default for PyLassoConfig {
28    fn default() -> Self {
29        Self {
30            alpha: 1.0,
31            fit_intercept: true,
32            copy_x: true,
33            max_iter: 1000,
34            tol: 1e-4,
35            warm_start: false,
36            positive: false,
37            random_state: None,
38            selection: "cyclic".to_string(),
39        }
40    }
41}
42
43impl From<PyLassoConfig> for LinearRegressionConfig {
44    fn from(py_config: PyLassoConfig) -> Self {
45        LinearRegressionConfig {
46            fit_intercept: py_config.fit_intercept,
47            penalty: Penalty::L1(py_config.alpha),
48            max_iter: py_config.max_iter,
49            tol: py_config.tol,
50            warm_start: py_config.warm_start,
51            ..Default::default()
52        }
53    }
54}
55
56/// Python wrapper for Lasso regression
57#[pyclass(name = "Lasso")]
58pub struct PyLasso {
59    /// Python-specific configuration
60    py_config: PyLassoConfig,
61    /// Trained model instance using the actual sklears-linear implementation
62    fitted_model: Option<LinearRegression<Trained>>,
63}
64
65#[pymethods]
66impl PyLasso {
67    #[new]
68    #[allow(clippy::too_many_arguments)]
69    #[pyo3(signature = (alpha=1.0, fit_intercept=true, copy_x=true, max_iter=1000, tol=1e-4, warm_start=false, positive=false, random_state=None, selection="cyclic"))]
70    fn new(
71        alpha: f64,
72        fit_intercept: bool,
73        copy_x: bool,
74        max_iter: usize,
75        tol: f64,
76        warm_start: bool,
77        positive: bool,
78        random_state: Option<i32>,
79        selection: &str,
80    ) -> Self {
81        let py_config = PyLassoConfig {
82            alpha,
83            fit_intercept,
84            copy_x,
85            max_iter,
86            tol,
87            warm_start,
88            positive,
89            random_state,
90            selection: selection.to_string(),
91        };
92
93        Self {
94            py_config,
95            fitted_model: None,
96        }
97    }
98
99    /// Fit the Lasso regression model
100    fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
101        let x_array = pyarray_to_core_array2(x)?;
102        let y_array = pyarray_to_core_array1(y)?;
103
104        // Validate input arrays
105        validate_fit_arrays(&x_array, &y_array)?;
106
107        // Create sklears-linear model with Lasso configuration
108        let model = LinearRegression::lasso(self.py_config.alpha)
109            .fit_intercept(self.py_config.fit_intercept);
110
111        // Fit the model using sklears-linear's implementation
112        match model.fit(&x_array, &y_array) {
113            Ok(fitted_model) => {
114                self.fitted_model = Some(fitted_model);
115                Ok(())
116            }
117            Err(e) => Err(PyValueError::new_err(format!(
118                "Failed to fit Lasso model: {:?}",
119                e
120            ))),
121        }
122    }
123
124    /// Predict using the fitted model
125    fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
126        let fitted = self
127            .fitted_model
128            .as_ref()
129            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
130
131        let x_array = pyarray_to_core_array2(x)?;
132        validate_predict_array(&x_array)?;
133
134        match fitted.predict(&x_array) {
135            Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
136            Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
137        }
138    }
139
140    /// Get model coefficients
141    #[getter]
142    fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
143        let fitted = self
144            .fitted_model
145            .as_ref()
146            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
147
148        Ok(core_array1_to_py(py, fitted.coef()))
149    }
150
151    /// Get model intercept
152    #[getter]
153    fn intercept_(&self) -> PyResult<f64> {
154        let fitted = self
155            .fitted_model
156            .as_ref()
157            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
158
159        Ok(fitted.intercept().unwrap_or(0.0))
160    }
161
162    /// Calculate R² score
163    fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
164        let fitted = self
165            .fitted_model
166            .as_ref()
167            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
168
169        let x_array = pyarray_to_core_array2(x)?;
170        let y_array = pyarray_to_core_array1(y)?;
171
172        match fitted.score(&x_array, &y_array) {
173            Ok(score) => Ok(score),
174            Err(e) => Err(PyValueError::new_err(format!(
175                "Score calculation failed: {:?}",
176                e
177            ))),
178        }
179    }
180
181    /// Get number of features
182    #[getter]
183    fn n_features_in_(&self) -> PyResult<usize> {
184        let fitted = self
185            .fitted_model
186            .as_ref()
187            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
188
189        // Infer number of features from coefficient array length
190        Ok(fitted.coef().len())
191    }
192
193    /// Return parameters for this estimator (sklearn compatibility)
194    fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
195        let _deep = deep.unwrap_or(true);
196
197        let dict = PyDict::new(py);
198
199        dict.set_item("alpha", self.py_config.alpha)?;
200        dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
201        dict.set_item("copy_X", self.py_config.copy_x)?;
202        dict.set_item("max_iter", self.py_config.max_iter)?;
203        dict.set_item("tol", self.py_config.tol)?;
204        dict.set_item("warm_start", self.py_config.warm_start)?;
205        dict.set_item("positive", self.py_config.positive)?;
206        dict.set_item("random_state", self.py_config.random_state)?;
207        dict.set_item("selection", &self.py_config.selection)?;
208
209        Ok(dict.into())
210    }
211
212    /// Set parameters for this estimator (sklearn compatibility)
213    fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
214        // Update configuration parameters
215        if let Some(alpha) = kwargs.get_item("alpha")? {
216            self.py_config.alpha = alpha.extract()?;
217        }
218        if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
219            self.py_config.fit_intercept = fit_intercept.extract()?;
220        }
221        if let Some(copy_x) = kwargs.get_item("copy_X")? {
222            self.py_config.copy_x = copy_x.extract()?;
223        }
224        if let Some(max_iter) = kwargs.get_item("max_iter")? {
225            self.py_config.max_iter = max_iter.extract()?;
226        }
227        if let Some(tol) = kwargs.get_item("tol")? {
228            self.py_config.tol = tol.extract()?;
229        }
230        if let Some(warm_start) = kwargs.get_item("warm_start")? {
231            self.py_config.warm_start = warm_start.extract()?;
232        }
233        if let Some(positive) = kwargs.get_item("positive")? {
234            self.py_config.positive = positive.extract()?;
235        }
236        if let Some(random_state) = kwargs.get_item("random_state")? {
237            self.py_config.random_state = random_state.extract()?;
238        }
239        if let Some(selection) = kwargs.get_item("selection")? {
240            let selection_str: String = selection.extract()?;
241            self.py_config.selection = selection_str;
242        }
243
244        // Clear fitted model since config changed
245        self.fitted_model = None;
246
247        Ok(())
248    }
249
250    /// String representation
251    fn __repr__(&self) -> String {
252        format!(
253            "Lasso(alpha={}, fit_intercept={}, copy_X={}, max_iter={}, tol={}, warm_start={}, positive={}, random_state={:?}, selection='{}')",
254            self.py_config.alpha,
255            self.py_config.fit_intercept,
256            self.py_config.copy_x,
257            self.py_config.max_iter,
258            self.py_config.tol,
259            self.py_config.warm_start,
260            self.py_config.positive,
261            self.py_config.random_state,
262            self.py_config.selection
263        )
264    }
265}