1use super::common::*;
8use numpy::IntoPyArray;
9use pyo3::types::PyDict;
10use pyo3::Bound;
11use sklears_core::traits::{Fit, Predict, Score, Trained};
12use sklears_linear::{LinearRegression, LinearRegressionConfig, Penalty};
13
14#[derive(Debug, Clone)]
16pub struct PyRidgeConfig {
17 pub alpha: f64,
18 pub fit_intercept: bool,
19 pub copy_x: bool,
20 pub max_iter: Option<usize>,
21 pub tol: Option<f64>,
22 pub solver: Option<String>,
23 pub positive: bool,
24 pub random_state: Option<i32>,
25}
26
27impl Default for PyRidgeConfig {
28 fn default() -> Self {
29 Self {
30 alpha: 1.0,
31 fit_intercept: true,
32 copy_x: true,
33 max_iter: None,
34 tol: Some(1e-4),
35 solver: Some("auto".to_string()),
36 positive: false,
37 random_state: None,
38 }
39 }
40}
41
42impl From<PyRidgeConfig> for LinearRegressionConfig {
43 fn from(py_config: PyRidgeConfig) -> Self {
44 let mut config = LinearRegressionConfig::default();
45 config.fit_intercept = py_config.fit_intercept;
46 config.penalty = Penalty::L2(py_config.alpha);
47 if let Some(max_iter) = py_config.max_iter {
48 config.max_iter = max_iter;
49 }
50 if let Some(tol) = py_config.tol {
51 config.tol = tol;
52 }
53 config
54 }
55}
56
57#[pyclass(name = "Ridge")]
59pub struct PyRidge {
60 py_config: PyRidgeConfig,
62 fitted_model: Option<LinearRegression<Trained>>,
64}
65
66#[pymethods]
67impl PyRidge {
68 #[new]
69 #[pyo3(signature = (alpha=1.0, fit_intercept=true, copy_x=true, max_iter=None, tol=1e-4, solver="auto", positive=false, random_state=None))]
70 fn new(
71 alpha: f64,
72 fit_intercept: bool,
73 copy_x: bool,
74 max_iter: Option<usize>,
75 tol: f64,
76 solver: &str,
77 positive: bool,
78 random_state: Option<i32>,
79 ) -> Self {
80 let py_config = PyRidgeConfig {
81 alpha,
82 fit_intercept,
83 copy_x,
84 max_iter,
85 tol: Some(tol),
86 solver: Some(solver.to_string()),
87 positive,
88 random_state,
89 };
90
91 Self {
92 py_config,
93 fitted_model: None,
94 }
95 }
96
97 fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
99 let x_array = x.as_array().to_owned();
100 let y_array = y.as_array().to_owned();
101
102 validate_fit_arrays(&x_array, &y_array)?;
104
105 let config = LinearRegressionConfig::from(self.py_config.clone());
107 let model = LinearRegression::new()
108 .fit_intercept(config.fit_intercept)
109 .regularization(self.py_config.alpha);
110
111 match model.fit(&x_array, &y_array) {
113 Ok(fitted_model) => {
114 self.fitted_model = Some(fitted_model);
115 Ok(())
116 }
117 Err(e) => Err(PyValueError::new_err(format!(
118 "Failed to fit Ridge model: {:?}",
119 e
120 ))),
121 }
122 }
123
124 fn predict(&self, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
126 let fitted = self
127 .fitted_model
128 .as_ref()
129 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
130
131 let x_array = x.as_array().to_owned();
132 validate_predict_array(&x_array)?;
133
134 match fitted.predict(&x_array) {
135 Ok(predictions) => {
136 let py = unsafe { Python::assume_attached() };
137 Ok(predictions.into_pyarray(py).into())
138 }
139 Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
140 }
141 }
142
143 #[getter]
145 fn coef_(&self) -> PyResult<Py<PyArray1<f64>>> {
146 let fitted = self
147 .fitted_model
148 .as_ref()
149 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
150
151 let py = unsafe { Python::assume_attached() };
152 Ok(fitted.coef().clone().into_pyarray(py).into())
153 }
154
155 #[getter]
157 fn intercept_(&self) -> PyResult<f64> {
158 let fitted = self
159 .fitted_model
160 .as_ref()
161 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
162
163 Ok(fitted.intercept().unwrap_or(0.0))
164 }
165
166 fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
168 let fitted = self
169 .fitted_model
170 .as_ref()
171 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
172
173 let x_array = x.as_array().to_owned();
174 let y_array = y.as_array().to_owned();
175
176 match fitted.score(&x_array, &y_array) {
177 Ok(score) => Ok(score),
178 Err(e) => Err(PyValueError::new_err(format!(
179 "Score calculation failed: {:?}",
180 e
181 ))),
182 }
183 }
184
185 #[getter]
187 fn n_features_in_(&self) -> PyResult<usize> {
188 let fitted = self
189 .fitted_model
190 .as_ref()
191 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
192
193 Ok(fitted.coef().len())
195 }
196
197 fn get_params(&self, deep: Option<bool>) -> PyResult<Py<PyDict>> {
199 let _deep = deep.unwrap_or(true);
200
201 let py = unsafe { Python::assume_attached() };
202 let dict = PyDict::new(py);
203
204 dict.set_item("alpha", self.py_config.alpha)?;
205 dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
206 dict.set_item("copy_X", self.py_config.copy_x)?;
207 dict.set_item("max_iter", self.py_config.max_iter)?;
208 dict.set_item("tol", self.py_config.tol)?;
209 dict.set_item("solver", &self.py_config.solver)?;
210 dict.set_item("positive", self.py_config.positive)?;
211 dict.set_item("random_state", self.py_config.random_state)?;
212
213 Ok(dict.into())
214 }
215
216 fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
218 if let Some(alpha) = kwargs.get_item("alpha")? {
220 self.py_config.alpha = alpha.extract()?;
221 }
222 if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
223 self.py_config.fit_intercept = fit_intercept.extract()?;
224 }
225 if let Some(copy_x) = kwargs.get_item("copy_X")? {
226 self.py_config.copy_x = copy_x.extract()?;
227 }
228 if let Some(max_iter) = kwargs.get_item("max_iter")? {
229 self.py_config.max_iter = max_iter.extract()?;
230 }
231 if let Some(tol) = kwargs.get_item("tol")? {
232 self.py_config.tol = tol.extract()?;
233 }
234 if let Some(solver) = kwargs.get_item("solver")? {
235 let solver_str: String = solver.extract()?;
236 self.py_config.solver = Some(solver_str);
237 }
238 if let Some(positive) = kwargs.get_item("positive")? {
239 self.py_config.positive = positive.extract()?;
240 }
241 if let Some(random_state) = kwargs.get_item("random_state")? {
242 self.py_config.random_state = random_state.extract()?;
243 }
244
245 self.fitted_model = None;
247
248 Ok(())
249 }
250
251 fn __repr__(&self) -> String {
253 format!(
254 "Ridge(alpha={}, fit_intercept={}, copy_X={}, max_iter={:?}, tol={:?}, solver={:?}, positive={}, random_state={:?})",
255 self.py_config.alpha,
256 self.py_config.fit_intercept,
257 self.py_config.copy_x,
258 self.py_config.max_iter,
259 self.py_config.tol,
260 self.py_config.solver,
261 self.py_config.positive,
262 self.py_config.random_state
263 )
264 }
265}