sklears_python/linear/
ridge.rs

1//! Python bindings for Ridge Regression
2//!
3//! This module provides Python bindings for Ridge Regression,
4//! offering scikit-learn compatible interfaces with L2 regularization
5//! using the sklears-linear crate.
6
7use super::common::*;
8use numpy::IntoPyArray;
9use pyo3::types::PyDict;
10use pyo3::Bound;
11use sklears_core::traits::{Fit, Predict, Score, Trained};
12use sklears_linear::{LinearRegression, LinearRegressionConfig, Penalty};
13
14/// Python-specific configuration wrapper for Ridge
15#[derive(Debug, Clone)]
16pub struct PyRidgeConfig {
17    pub alpha: f64,
18    pub fit_intercept: bool,
19    pub copy_x: bool,
20    pub max_iter: Option<usize>,
21    pub tol: Option<f64>,
22    pub solver: Option<String>,
23    pub positive: bool,
24    pub random_state: Option<i32>,
25}
26
27impl Default for PyRidgeConfig {
28    fn default() -> Self {
29        Self {
30            alpha: 1.0,
31            fit_intercept: true,
32            copy_x: true,
33            max_iter: None,
34            tol: Some(1e-4),
35            solver: Some("auto".to_string()),
36            positive: false,
37            random_state: None,
38        }
39    }
40}
41
42impl From<PyRidgeConfig> for LinearRegressionConfig {
43    fn from(py_config: PyRidgeConfig) -> Self {
44        let mut config = LinearRegressionConfig::default();
45        config.fit_intercept = py_config.fit_intercept;
46        config.penalty = Penalty::L2(py_config.alpha);
47        if let Some(max_iter) = py_config.max_iter {
48            config.max_iter = max_iter;
49        }
50        if let Some(tol) = py_config.tol {
51            config.tol = tol;
52        }
53        config
54    }
55}
56
57/// Python wrapper for Ridge regression
58#[pyclass(name = "Ridge")]
59pub struct PyRidge {
60    /// Python-specific configuration
61    py_config: PyRidgeConfig,
62    /// Trained model instance using the actual sklears-linear implementation
63    fitted_model: Option<LinearRegression<Trained>>,
64}
65
66#[pymethods]
67impl PyRidge {
68    #[new]
69    #[pyo3(signature = (alpha=1.0, fit_intercept=true, copy_x=true, max_iter=None, tol=1e-4, solver="auto", positive=false, random_state=None))]
70    fn new(
71        alpha: f64,
72        fit_intercept: bool,
73        copy_x: bool,
74        max_iter: Option<usize>,
75        tol: f64,
76        solver: &str,
77        positive: bool,
78        random_state: Option<i32>,
79    ) -> Self {
80        let py_config = PyRidgeConfig {
81            alpha,
82            fit_intercept,
83            copy_x,
84            max_iter,
85            tol: Some(tol),
86            solver: Some(solver.to_string()),
87            positive,
88            random_state,
89        };
90
91        Self {
92            py_config,
93            fitted_model: None,
94        }
95    }
96
97    /// Fit the Ridge regression model
98    fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
99        let x_array = x.as_array().to_owned();
100        let y_array = y.as_array().to_owned();
101
102        // Validate input arrays
103        validate_fit_arrays(&x_array, &y_array)?;
104
105        // Create sklears-linear model with Ridge configuration
106        let config = LinearRegressionConfig::from(self.py_config.clone());
107        let model = LinearRegression::new()
108            .fit_intercept(config.fit_intercept)
109            .regularization(self.py_config.alpha);
110
111        // Fit the model using sklears-linear's implementation
112        match model.fit(&x_array, &y_array) {
113            Ok(fitted_model) => {
114                self.fitted_model = Some(fitted_model);
115                Ok(())
116            }
117            Err(e) => Err(PyValueError::new_err(format!(
118                "Failed to fit Ridge model: {:?}",
119                e
120            ))),
121        }
122    }
123
124    /// Predict using the fitted model
125    fn predict(&self, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
126        let fitted = self
127            .fitted_model
128            .as_ref()
129            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
130
131        let x_array = x.as_array().to_owned();
132        validate_predict_array(&x_array)?;
133
134        match fitted.predict(&x_array) {
135            Ok(predictions) => {
136                let py = unsafe { Python::assume_attached() };
137                Ok(predictions.into_pyarray(py).into())
138            }
139            Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
140        }
141    }
142
143    /// Get model coefficients
144    #[getter]
145    fn coef_(&self) -> PyResult<Py<PyArray1<f64>>> {
146        let fitted = self
147            .fitted_model
148            .as_ref()
149            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
150
151        let py = unsafe { Python::assume_attached() };
152        Ok(fitted.coef().clone().into_pyarray(py).into())
153    }
154
155    /// Get model intercept
156    #[getter]
157    fn intercept_(&self) -> PyResult<f64> {
158        let fitted = self
159            .fitted_model
160            .as_ref()
161            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
162
163        Ok(fitted.intercept().unwrap_or(0.0))
164    }
165
166    /// Calculate R² score
167    fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
168        let fitted = self
169            .fitted_model
170            .as_ref()
171            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
172
173        let x_array = x.as_array().to_owned();
174        let y_array = y.as_array().to_owned();
175
176        match fitted.score(&x_array, &y_array) {
177            Ok(score) => Ok(score),
178            Err(e) => Err(PyValueError::new_err(format!(
179                "Score calculation failed: {:?}",
180                e
181            ))),
182        }
183    }
184
185    /// Get number of features
186    #[getter]
187    fn n_features_in_(&self) -> PyResult<usize> {
188        let fitted = self
189            .fitted_model
190            .as_ref()
191            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
192
193        // Infer number of features from coefficient array length
194        Ok(fitted.coef().len())
195    }
196
197    /// Return parameters for this estimator (sklearn compatibility)
198    fn get_params(&self, deep: Option<bool>) -> PyResult<Py<PyDict>> {
199        let _deep = deep.unwrap_or(true);
200
201        let py = unsafe { Python::assume_attached() };
202        let dict = PyDict::new(py);
203
204        dict.set_item("alpha", self.py_config.alpha)?;
205        dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
206        dict.set_item("copy_X", self.py_config.copy_x)?;
207        dict.set_item("max_iter", self.py_config.max_iter)?;
208        dict.set_item("tol", self.py_config.tol)?;
209        dict.set_item("solver", &self.py_config.solver)?;
210        dict.set_item("positive", self.py_config.positive)?;
211        dict.set_item("random_state", self.py_config.random_state)?;
212
213        Ok(dict.into())
214    }
215
216    /// Set parameters for this estimator (sklearn compatibility)
217    fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
218        // Update configuration parameters
219        if let Some(alpha) = kwargs.get_item("alpha")? {
220            self.py_config.alpha = alpha.extract()?;
221        }
222        if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
223            self.py_config.fit_intercept = fit_intercept.extract()?;
224        }
225        if let Some(copy_x) = kwargs.get_item("copy_X")? {
226            self.py_config.copy_x = copy_x.extract()?;
227        }
228        if let Some(max_iter) = kwargs.get_item("max_iter")? {
229            self.py_config.max_iter = max_iter.extract()?;
230        }
231        if let Some(tol) = kwargs.get_item("tol")? {
232            self.py_config.tol = tol.extract()?;
233        }
234        if let Some(solver) = kwargs.get_item("solver")? {
235            let solver_str: String = solver.extract()?;
236            self.py_config.solver = Some(solver_str);
237        }
238        if let Some(positive) = kwargs.get_item("positive")? {
239            self.py_config.positive = positive.extract()?;
240        }
241        if let Some(random_state) = kwargs.get_item("random_state")? {
242            self.py_config.random_state = random_state.extract()?;
243        }
244
245        // Clear fitted model since config changed
246        self.fitted_model = None;
247
248        Ok(())
249    }
250
251    /// String representation
252    fn __repr__(&self) -> String {
253        format!(
254            "Ridge(alpha={}, fit_intercept={}, copy_X={}, max_iter={:?}, tol={:?}, solver={:?}, positive={}, random_state={:?})",
255            self.py_config.alpha,
256            self.py_config.fit_intercept,
257            self.py_config.copy_x,
258            self.py_config.max_iter,
259            self.py_config.tol,
260            self.py_config.solver,
261            self.py_config.positive,
262            self.py_config.random_state
263        )
264    }
265}