1use super::common::*;
8use pyo3::types::PyDict;
9use pyo3::Bound;
10use sklears_core::traits::{Fit, Predict, Score, Trained};
11use sklears_linear::{LinearRegression, LinearRegressionConfig, Penalty};
12
13#[derive(Debug, Clone)]
15pub struct PyRidgeConfig {
16 pub alpha: f64,
17 pub fit_intercept: bool,
18 pub copy_x: bool,
19 pub max_iter: Option<usize>,
20 pub tol: Option<f64>,
21 pub solver: Option<String>,
22 pub positive: bool,
23 pub random_state: Option<i32>,
24}
25
26impl Default for PyRidgeConfig {
27 fn default() -> Self {
28 Self {
29 alpha: 1.0,
30 fit_intercept: true,
31 copy_x: true,
32 max_iter: None,
33 tol: Some(1e-4),
34 solver: Some("auto".to_string()),
35 positive: false,
36 random_state: None,
37 }
38 }
39}
40
41impl From<PyRidgeConfig> for LinearRegressionConfig {
42 fn from(py_config: PyRidgeConfig) -> Self {
43 LinearRegressionConfig {
44 fit_intercept: py_config.fit_intercept,
45 penalty: Penalty::L2(py_config.alpha),
46 max_iter: py_config
47 .max_iter
48 .unwrap_or_else(|| LinearRegressionConfig::default().max_iter),
49 tol: py_config
50 .tol
51 .unwrap_or_else(|| LinearRegressionConfig::default().tol),
52 ..Default::default()
53 }
54 }
55}
56
57#[pyclass(name = "Ridge")]
59pub struct PyRidge {
60 py_config: PyRidgeConfig,
62 fitted_model: Option<LinearRegression<Trained>>,
64}
65
66#[pymethods]
67impl PyRidge {
68 #[new]
69 #[allow(clippy::too_many_arguments)]
70 #[pyo3(signature = (alpha=1.0, fit_intercept=true, copy_x=true, max_iter=None, tol=1e-4, solver="auto", positive=false, random_state=None))]
71 fn new(
72 alpha: f64,
73 fit_intercept: bool,
74 copy_x: bool,
75 max_iter: Option<usize>,
76 tol: f64,
77 solver: &str,
78 positive: bool,
79 random_state: Option<i32>,
80 ) -> Self {
81 let py_config = PyRidgeConfig {
82 alpha,
83 fit_intercept,
84 copy_x,
85 max_iter,
86 tol: Some(tol),
87 solver: Some(solver.to_string()),
88 positive,
89 random_state,
90 };
91
92 Self {
93 py_config,
94 fitted_model: None,
95 }
96 }
97
98 fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
100 let x_array = pyarray_to_core_array2(x)?;
101 let y_array = pyarray_to_core_array1(y)?;
102
103 validate_fit_arrays(&x_array, &y_array)?;
105
106 let config = LinearRegressionConfig::from(self.py_config.clone());
108 let model = LinearRegression::new()
109 .fit_intercept(config.fit_intercept)
110 .regularization(self.py_config.alpha);
111
112 match model.fit(&x_array, &y_array) {
114 Ok(fitted_model) => {
115 self.fitted_model = Some(fitted_model);
116 Ok(())
117 }
118 Err(e) => Err(PyValueError::new_err(format!(
119 "Failed to fit Ridge model: {:?}",
120 e
121 ))),
122 }
123 }
124
125 fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
127 let fitted = self
128 .fitted_model
129 .as_ref()
130 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
131
132 let x_array = pyarray_to_core_array2(x)?;
133 validate_predict_array(&x_array)?;
134
135 match fitted.predict(&x_array) {
136 Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
137 Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
138 }
139 }
140
141 #[getter]
143 fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
144 let fitted = self
145 .fitted_model
146 .as_ref()
147 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
148
149 Ok(core_array1_to_py(py, fitted.coef()))
150 }
151
152 #[getter]
154 fn intercept_(&self) -> PyResult<f64> {
155 let fitted = self
156 .fitted_model
157 .as_ref()
158 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
159
160 Ok(fitted.intercept().unwrap_or(0.0))
161 }
162
163 fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
165 let fitted = self
166 .fitted_model
167 .as_ref()
168 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
169
170 let x_array = pyarray_to_core_array2(x)?;
171 let y_array = pyarray_to_core_array1(y)?;
172
173 match fitted.score(&x_array, &y_array) {
174 Ok(score) => Ok(score),
175 Err(e) => Err(PyValueError::new_err(format!(
176 "Score calculation failed: {:?}",
177 e
178 ))),
179 }
180 }
181
182 #[getter]
184 fn n_features_in_(&self) -> PyResult<usize> {
185 let fitted = self
186 .fitted_model
187 .as_ref()
188 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
189
190 Ok(fitted.coef().len())
192 }
193
194 fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
196 let _deep = deep.unwrap_or(true);
197
198 let dict = PyDict::new(py);
199
200 dict.set_item("alpha", self.py_config.alpha)?;
201 dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
202 dict.set_item("copy_X", self.py_config.copy_x)?;
203 dict.set_item("max_iter", self.py_config.max_iter)?;
204 dict.set_item("tol", self.py_config.tol)?;
205 dict.set_item("solver", &self.py_config.solver)?;
206 dict.set_item("positive", self.py_config.positive)?;
207 dict.set_item("random_state", self.py_config.random_state)?;
208
209 Ok(dict.into())
210 }
211
212 fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
214 if let Some(alpha) = kwargs.get_item("alpha")? {
216 self.py_config.alpha = alpha.extract()?;
217 }
218 if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
219 self.py_config.fit_intercept = fit_intercept.extract()?;
220 }
221 if let Some(copy_x) = kwargs.get_item("copy_X")? {
222 self.py_config.copy_x = copy_x.extract()?;
223 }
224 if let Some(max_iter) = kwargs.get_item("max_iter")? {
225 self.py_config.max_iter = max_iter.extract()?;
226 }
227 if let Some(tol) = kwargs.get_item("tol")? {
228 self.py_config.tol = tol.extract()?;
229 }
230 if let Some(solver) = kwargs.get_item("solver")? {
231 let solver_str: String = solver.extract()?;
232 self.py_config.solver = Some(solver_str);
233 }
234 if let Some(positive) = kwargs.get_item("positive")? {
235 self.py_config.positive = positive.extract()?;
236 }
237 if let Some(random_state) = kwargs.get_item("random_state")? {
238 self.py_config.random_state = random_state.extract()?;
239 }
240
241 self.fitted_model = None;
243
244 Ok(())
245 }
246
247 fn __repr__(&self) -> String {
249 format!(
250 "Ridge(alpha={}, fit_intercept={}, copy_X={}, max_iter={:?}, tol={:?}, solver={:?}, positive={}, random_state={:?})",
251 self.py_config.alpha,
252 self.py_config.fit_intercept,
253 self.py_config.copy_x,
254 self.py_config.max_iter,
255 self.py_config.tol,
256 self.py_config.solver,
257 self.py_config.positive,
258 self.py_config.random_state
259 )
260 }
261}