sklears_python/linear/
ridge.rs

1//! Python bindings for Ridge Regression
2//!
3//! This module provides Python bindings for Ridge Regression,
4//! offering scikit-learn compatible interfaces with L2 regularization
5//! using the sklears-linear crate.
6
7use super::common::*;
8use pyo3::types::PyDict;
9use pyo3::Bound;
10use sklears_core::traits::{Fit, Predict, Score, Trained};
11use sklears_linear::{LinearRegression, LinearRegressionConfig, Penalty};
12
13/// Python-specific configuration wrapper for Ridge
14#[derive(Debug, Clone)]
15pub struct PyRidgeConfig {
16    pub alpha: f64,
17    pub fit_intercept: bool,
18    pub copy_x: bool,
19    pub max_iter: Option<usize>,
20    pub tol: Option<f64>,
21    pub solver: Option<String>,
22    pub positive: bool,
23    pub random_state: Option<i32>,
24}
25
26impl Default for PyRidgeConfig {
27    fn default() -> Self {
28        Self {
29            alpha: 1.0,
30            fit_intercept: true,
31            copy_x: true,
32            max_iter: None,
33            tol: Some(1e-4),
34            solver: Some("auto".to_string()),
35            positive: false,
36            random_state: None,
37        }
38    }
39}
40
41impl From<PyRidgeConfig> for LinearRegressionConfig {
42    fn from(py_config: PyRidgeConfig) -> Self {
43        LinearRegressionConfig {
44            fit_intercept: py_config.fit_intercept,
45            penalty: Penalty::L2(py_config.alpha),
46            max_iter: py_config
47                .max_iter
48                .unwrap_or_else(|| LinearRegressionConfig::default().max_iter),
49            tol: py_config
50                .tol
51                .unwrap_or_else(|| LinearRegressionConfig::default().tol),
52            ..Default::default()
53        }
54    }
55}
56
57/// Python wrapper for Ridge regression
58#[pyclass(name = "Ridge")]
59pub struct PyRidge {
60    /// Python-specific configuration
61    py_config: PyRidgeConfig,
62    /// Trained model instance using the actual sklears-linear implementation
63    fitted_model: Option<LinearRegression<Trained>>,
64}
65
66#[pymethods]
67impl PyRidge {
68    #[new]
69    #[allow(clippy::too_many_arguments)]
70    #[pyo3(signature = (alpha=1.0, fit_intercept=true, copy_x=true, max_iter=None, tol=1e-4, solver="auto", positive=false, random_state=None))]
71    fn new(
72        alpha: f64,
73        fit_intercept: bool,
74        copy_x: bool,
75        max_iter: Option<usize>,
76        tol: f64,
77        solver: &str,
78        positive: bool,
79        random_state: Option<i32>,
80    ) -> Self {
81        let py_config = PyRidgeConfig {
82            alpha,
83            fit_intercept,
84            copy_x,
85            max_iter,
86            tol: Some(tol),
87            solver: Some(solver.to_string()),
88            positive,
89            random_state,
90        };
91
92        Self {
93            py_config,
94            fitted_model: None,
95        }
96    }
97
98    /// Fit the Ridge regression model
99    fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
100        let x_array = pyarray_to_core_array2(x)?;
101        let y_array = pyarray_to_core_array1(y)?;
102
103        // Validate input arrays
104        validate_fit_arrays(&x_array, &y_array)?;
105
106        // Create sklears-linear model with Ridge configuration
107        let config = LinearRegressionConfig::from(self.py_config.clone());
108        let model = LinearRegression::new()
109            .fit_intercept(config.fit_intercept)
110            .regularization(self.py_config.alpha);
111
112        // Fit the model using sklears-linear's implementation
113        match model.fit(&x_array, &y_array) {
114            Ok(fitted_model) => {
115                self.fitted_model = Some(fitted_model);
116                Ok(())
117            }
118            Err(e) => Err(PyValueError::new_err(format!(
119                "Failed to fit Ridge model: {:?}",
120                e
121            ))),
122        }
123    }
124
125    /// Predict using the fitted model
126    fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
127        let fitted = self
128            .fitted_model
129            .as_ref()
130            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
131
132        let x_array = pyarray_to_core_array2(x)?;
133        validate_predict_array(&x_array)?;
134
135        match fitted.predict(&x_array) {
136            Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
137            Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
138        }
139    }
140
141    /// Get model coefficients
142    #[getter]
143    fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
144        let fitted = self
145            .fitted_model
146            .as_ref()
147            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
148
149        Ok(core_array1_to_py(py, fitted.coef()))
150    }
151
152    /// Get model intercept
153    #[getter]
154    fn intercept_(&self) -> PyResult<f64> {
155        let fitted = self
156            .fitted_model
157            .as_ref()
158            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
159
160        Ok(fitted.intercept().unwrap_or(0.0))
161    }
162
163    /// Calculate R² score
164    fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
165        let fitted = self
166            .fitted_model
167            .as_ref()
168            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
169
170        let x_array = pyarray_to_core_array2(x)?;
171        let y_array = pyarray_to_core_array1(y)?;
172
173        match fitted.score(&x_array, &y_array) {
174            Ok(score) => Ok(score),
175            Err(e) => Err(PyValueError::new_err(format!(
176                "Score calculation failed: {:?}",
177                e
178            ))),
179        }
180    }
181
182    /// Get number of features
183    #[getter]
184    fn n_features_in_(&self) -> PyResult<usize> {
185        let fitted = self
186            .fitted_model
187            .as_ref()
188            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
189
190        // Infer number of features from coefficient array length
191        Ok(fitted.coef().len())
192    }
193
194    /// Return parameters for this estimator (sklearn compatibility)
195    fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
196        let _deep = deep.unwrap_or(true);
197
198        let dict = PyDict::new(py);
199
200        dict.set_item("alpha", self.py_config.alpha)?;
201        dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
202        dict.set_item("copy_X", self.py_config.copy_x)?;
203        dict.set_item("max_iter", self.py_config.max_iter)?;
204        dict.set_item("tol", self.py_config.tol)?;
205        dict.set_item("solver", &self.py_config.solver)?;
206        dict.set_item("positive", self.py_config.positive)?;
207        dict.set_item("random_state", self.py_config.random_state)?;
208
209        Ok(dict.into())
210    }
211
212    /// Set parameters for this estimator (sklearn compatibility)
213    fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
214        // Update configuration parameters
215        if let Some(alpha) = kwargs.get_item("alpha")? {
216            self.py_config.alpha = alpha.extract()?;
217        }
218        if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
219            self.py_config.fit_intercept = fit_intercept.extract()?;
220        }
221        if let Some(copy_x) = kwargs.get_item("copy_X")? {
222            self.py_config.copy_x = copy_x.extract()?;
223        }
224        if let Some(max_iter) = kwargs.get_item("max_iter")? {
225            self.py_config.max_iter = max_iter.extract()?;
226        }
227        if let Some(tol) = kwargs.get_item("tol")? {
228            self.py_config.tol = tol.extract()?;
229        }
230        if let Some(solver) = kwargs.get_item("solver")? {
231            let solver_str: String = solver.extract()?;
232            self.py_config.solver = Some(solver_str);
233        }
234        if let Some(positive) = kwargs.get_item("positive")? {
235            self.py_config.positive = positive.extract()?;
236        }
237        if let Some(random_state) = kwargs.get_item("random_state")? {
238            self.py_config.random_state = random_state.extract()?;
239        }
240
241        // Clear fitted model since config changed
242        self.fitted_model = None;
243
244        Ok(())
245    }
246
247    /// String representation
248    fn __repr__(&self) -> String {
249        format!(
250            "Ridge(alpha={}, fit_intercept={}, copy_X={}, max_iter={:?}, tol={:?}, solver={:?}, positive={}, random_state={:?})",
251            self.py_config.alpha,
252            self.py_config.fit_intercept,
253            self.py_config.copy_x,
254            self.py_config.max_iter,
255            self.py_config.tol,
256            self.py_config.solver,
257            self.py_config.positive,
258            self.py_config.random_state
259        )
260    }
261}