sklears_python/linear/
lasso.rs

1//! Python bindings for Lasso Regression
2//!
3//! This module provides Python bindings for Lasso Regression,
4//! offering scikit-learn compatible interfaces with L1 regularization
5//! using the sklears-linear crate.
6
7use super::common::*;
8use numpy::IntoPyArray;
9use pyo3::types::PyDict;
10use pyo3::Bound;
11use sklears_core::traits::{Fit, Predict, Score, Trained};
12use sklears_linear::{LinearRegression, LinearRegressionConfig, Penalty};
13
14/// Python-specific configuration wrapper for Lasso
15#[derive(Debug, Clone)]
16pub struct PyLassoConfig {
17    pub alpha: f64,
18    pub fit_intercept: bool,
19    pub copy_x: bool,
20    pub max_iter: usize,
21    pub tol: f64,
22    pub warm_start: bool,
23    pub positive: bool,
24    pub random_state: Option<i32>,
25    pub selection: String,
26}
27
28impl Default for PyLassoConfig {
29    fn default() -> Self {
30        Self {
31            alpha: 1.0,
32            fit_intercept: true,
33            copy_x: true,
34            max_iter: 1000,
35            tol: 1e-4,
36            warm_start: false,
37            positive: false,
38            random_state: None,
39            selection: "cyclic".to_string(),
40        }
41    }
42}
43
44impl From<PyLassoConfig> for LinearRegressionConfig {
45    fn from(py_config: PyLassoConfig) -> Self {
46        let mut config = LinearRegressionConfig::default();
47        config.fit_intercept = py_config.fit_intercept;
48        config.penalty = Penalty::L1(py_config.alpha);
49        config.max_iter = py_config.max_iter;
50        config.tol = py_config.tol;
51        config.warm_start = py_config.warm_start;
52        config
53    }
54}
55
56/// Python wrapper for Lasso regression
57#[pyclass(name = "Lasso")]
58pub struct PyLasso {
59    /// Python-specific configuration
60    py_config: PyLassoConfig,
61    /// Trained model instance using the actual sklears-linear implementation
62    fitted_model: Option<LinearRegression<Trained>>,
63}
64
65#[pymethods]
66impl PyLasso {
67    #[new]
68    #[pyo3(signature = (alpha=1.0, fit_intercept=true, copy_x=true, max_iter=1000, tol=1e-4, warm_start=false, positive=false, random_state=None, selection="cyclic"))]
69    fn new(
70        alpha: f64,
71        fit_intercept: bool,
72        copy_x: bool,
73        max_iter: usize,
74        tol: f64,
75        warm_start: bool,
76        positive: bool,
77        random_state: Option<i32>,
78        selection: &str,
79    ) -> Self {
80        let py_config = PyLassoConfig {
81            alpha,
82            fit_intercept,
83            copy_x,
84            max_iter,
85            tol,
86            warm_start,
87            positive,
88            random_state,
89            selection: selection.to_string(),
90        };
91
92        Self {
93            py_config,
94            fitted_model: None,
95        }
96    }
97
98    /// Fit the Lasso regression model
99    fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
100        let x_array = x.as_array().to_owned();
101        let y_array = y.as_array().to_owned();
102
103        // Validate input arrays
104        validate_fit_arrays(&x_array, &y_array)?;
105
106        // Create sklears-linear model with Lasso configuration
107        let model = LinearRegression::lasso(self.py_config.alpha)
108            .fit_intercept(self.py_config.fit_intercept);
109
110        // Fit the model using sklears-linear's implementation
111        match model.fit(&x_array, &y_array) {
112            Ok(fitted_model) => {
113                self.fitted_model = Some(fitted_model);
114                Ok(())
115            }
116            Err(e) => Err(PyValueError::new_err(format!(
117                "Failed to fit Lasso model: {:?}",
118                e
119            ))),
120        }
121    }
122
123    /// Predict using the fitted model
124    fn predict(&self, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
125        let fitted = self
126            .fitted_model
127            .as_ref()
128            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
129
130        let x_array = x.as_array().to_owned();
131        validate_predict_array(&x_array)?;
132
133        match fitted.predict(&x_array) {
134            Ok(predictions) => {
135                let py = unsafe { Python::assume_attached() };
136                Ok(predictions.into_pyarray(py).into())
137            }
138            Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
139        }
140    }
141
142    /// Get model coefficients
143    #[getter]
144    fn coef_(&self) -> PyResult<Py<PyArray1<f64>>> {
145        let fitted = self
146            .fitted_model
147            .as_ref()
148            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
149
150        let py = unsafe { Python::assume_attached() };
151        Ok(fitted.coef().clone().into_pyarray(py).into())
152    }
153
154    /// Get model intercept
155    #[getter]
156    fn intercept_(&self) -> PyResult<f64> {
157        let fitted = self
158            .fitted_model
159            .as_ref()
160            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
161
162        Ok(fitted.intercept().unwrap_or(0.0))
163    }
164
165    /// Calculate R² score
166    fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
167        let fitted = self
168            .fitted_model
169            .as_ref()
170            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
171
172        let x_array = x.as_array().to_owned();
173        let y_array = y.as_array().to_owned();
174
175        match fitted.score(&x_array, &y_array) {
176            Ok(score) => Ok(score),
177            Err(e) => Err(PyValueError::new_err(format!(
178                "Score calculation failed: {:?}",
179                e
180            ))),
181        }
182    }
183
184    /// Get number of features
185    #[getter]
186    fn n_features_in_(&self) -> PyResult<usize> {
187        let fitted = self
188            .fitted_model
189            .as_ref()
190            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
191
192        // Infer number of features from coefficient array length
193        Ok(fitted.coef().len())
194    }
195
196    /// Return parameters for this estimator (sklearn compatibility)
197    fn get_params(&self, deep: Option<bool>) -> PyResult<Py<PyDict>> {
198        let _deep = deep.unwrap_or(true);
199
200        let py = unsafe { Python::assume_attached() };
201        let dict = PyDict::new(py);
202
203        dict.set_item("alpha", self.py_config.alpha)?;
204        dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
205        dict.set_item("copy_X", self.py_config.copy_x)?;
206        dict.set_item("max_iter", self.py_config.max_iter)?;
207        dict.set_item("tol", self.py_config.tol)?;
208        dict.set_item("warm_start", self.py_config.warm_start)?;
209        dict.set_item("positive", self.py_config.positive)?;
210        dict.set_item("random_state", self.py_config.random_state)?;
211        dict.set_item("selection", &self.py_config.selection)?;
212
213        Ok(dict.into())
214    }
215
216    /// Set parameters for this estimator (sklearn compatibility)
217    fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
218        // Update configuration parameters
219        if let Some(alpha) = kwargs.get_item("alpha")? {
220            self.py_config.alpha = alpha.extract()?;
221        }
222        if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
223            self.py_config.fit_intercept = fit_intercept.extract()?;
224        }
225        if let Some(copy_x) = kwargs.get_item("copy_X")? {
226            self.py_config.copy_x = copy_x.extract()?;
227        }
228        if let Some(max_iter) = kwargs.get_item("max_iter")? {
229            self.py_config.max_iter = max_iter.extract()?;
230        }
231        if let Some(tol) = kwargs.get_item("tol")? {
232            self.py_config.tol = tol.extract()?;
233        }
234        if let Some(warm_start) = kwargs.get_item("warm_start")? {
235            self.py_config.warm_start = warm_start.extract()?;
236        }
237        if let Some(positive) = kwargs.get_item("positive")? {
238            self.py_config.positive = positive.extract()?;
239        }
240        if let Some(random_state) = kwargs.get_item("random_state")? {
241            self.py_config.random_state = random_state.extract()?;
242        }
243        if let Some(selection) = kwargs.get_item("selection")? {
244            let selection_str: String = selection.extract()?;
245            self.py_config.selection = selection_str;
246        }
247
248        // Clear fitted model since config changed
249        self.fitted_model = None;
250
251        Ok(())
252    }
253
254    /// String representation
255    fn __repr__(&self) -> String {
256        format!(
257            "Lasso(alpha={}, fit_intercept={}, copy_X={}, max_iter={}, tol={}, warm_start={}, positive={}, random_state={:?}, selection='{}')",
258            self.py_config.alpha,
259            self.py_config.fit_intercept,
260            self.py_config.copy_x,
261            self.py_config.max_iter,
262            self.py_config.tol,
263            self.py_config.warm_start,
264            self.py_config.positive,
265            self.py_config.random_state,
266            self.py_config.selection
267        )
268    }
269}