1use super::common::*;
8use pyo3::types::PyDict;
9use pyo3::Bound;
10use sklears_core::traits::{Fit, Predict, Score, Trained};
11use sklears_linear::{ARDRegression, ARDRegressionConfig};
12
13#[derive(Debug, Clone)]
15pub struct PyARDRegressionConfig {
16 pub max_iter: usize,
17 pub tol: f64,
18 pub alpha_init: Option<f64>,
19 pub lambda_init: Option<f64>,
20 pub threshold_alpha: f64,
21 pub fit_intercept: bool,
22 pub compute_score: bool,
23 pub copy_x: bool,
24}
25
26impl Default for PyARDRegressionConfig {
27 fn default() -> Self {
28 Self {
29 max_iter: 300,
30 tol: 1e-3,
31 alpha_init: Some(1.0),
32 lambda_init: Some(1.0),
33 threshold_alpha: 1e10,
34 fit_intercept: true,
35 compute_score: false,
36 copy_x: true,
37 }
38 }
39}
40
41impl From<PyARDRegressionConfig> for ARDRegressionConfig {
42 fn from(py_config: PyARDRegressionConfig) -> Self {
43 ARDRegressionConfig {
44 max_iter: py_config.max_iter,
45 tol: py_config.tol,
46 alpha_init: py_config
47 .alpha_init
48 .unwrap_or_else(|| ARDRegressionConfig::default().alpha_init),
49 lambda_init: py_config
50 .lambda_init
51 .unwrap_or_else(|| ARDRegressionConfig::default().lambda_init),
52 threshold_alpha: py_config.threshold_alpha,
53 fit_intercept: py_config.fit_intercept,
54 compute_score: py_config.compute_score,
55 }
56 }
57}
58
59#[pyclass(name = "ARDRegression")]
156pub struct PyARDRegression {
157 py_config: PyARDRegressionConfig,
159 fitted_model: Option<ARDRegression<Trained>>,
161}
162
163#[pymethods]
164impl PyARDRegression {
165 #[new]
166 #[allow(clippy::too_many_arguments)]
167 #[pyo3(signature = (max_iter=300, tol=1e-3, alpha_init=1.0, lambda_init=1.0, threshold_alpha=1e10, fit_intercept=true, compute_score=false, copy_x=true))]
168 fn new(
169 max_iter: usize,
170 tol: f64,
171 alpha_init: f64,
172 lambda_init: f64,
173 threshold_alpha: f64,
174 fit_intercept: bool,
175 compute_score: bool,
176 copy_x: bool,
177 ) -> PyResult<Self> {
178 if max_iter == 0 {
180 return Err(PyValueError::new_err("max_iter must be greater than 0"));
181 }
182 if tol <= 0.0 {
183 return Err(PyValueError::new_err("tol must be positive"));
184 }
185 if alpha_init <= 0.0 {
186 return Err(PyValueError::new_err("alpha_init must be positive"));
187 }
188 if lambda_init <= 0.0 {
189 return Err(PyValueError::new_err("lambda_init must be positive"));
190 }
191 if threshold_alpha <= 0.0 {
192 return Err(PyValueError::new_err("threshold_alpha must be positive"));
193 }
194
195 let py_config = PyARDRegressionConfig {
196 max_iter,
197 tol,
198 alpha_init: Some(alpha_init),
199 lambda_init: Some(lambda_init),
200 threshold_alpha,
201 fit_intercept,
202 compute_score,
203 copy_x,
204 };
205
206 Ok(Self {
207 py_config,
208 fitted_model: None,
209 })
210 }
211
212 fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
214 let x_array = pyarray_to_core_array2(x)?;
215 let y_array = pyarray_to_core_array1(y)?;
216
217 validate_fit_arrays(&x_array, &y_array)?;
219
220 let model = ARDRegression::new()
222 .max_iter(self.py_config.max_iter)
223 .tol(self.py_config.tol)
224 .threshold_alpha(self.py_config.threshold_alpha)
225 .fit_intercept(self.py_config.fit_intercept);
226
227 match model.fit(&x_array, &y_array) {
229 Ok(fitted_model) => {
230 self.fitted_model = Some(fitted_model);
231 Ok(())
232 }
233 Err(e) => Err(PyValueError::new_err(format!(
234 "Failed to fit ARD regression model: {:?}",
235 e
236 ))),
237 }
238 }
239
240 fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
242 let fitted = self
243 .fitted_model
244 .as_ref()
245 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
246
247 let x_array = pyarray_to_core_array2(x)?;
248 validate_predict_array(&x_array)?;
249
250 match fitted.predict(&x_array) {
251 Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
252 Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
253 }
254 }
255
256 #[getter]
258 fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
259 let fitted = self
260 .fitted_model
261 .as_ref()
262 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
263
264 Ok(core_array1_to_py(py, fitted.coef()))
265 }
266
267 #[getter]
269 fn intercept_(&self) -> PyResult<f64> {
270 let fitted = self
271 .fitted_model
272 .as_ref()
273 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
274
275 Ok(fitted.intercept().unwrap_or(0.0))
276 }
277
278 #[getter]
280 fn alpha_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
281 let fitted = self
282 .fitted_model
283 .as_ref()
284 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
285
286 Ok(core_array1_to_py(py, fitted.alpha()))
287 }
288
289 #[getter]
291 fn lambda_(&self) -> PyResult<f64> {
292 let fitted = self
293 .fitted_model
294 .as_ref()
295 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
296
297 Ok(fitted.lambda())
298 }
299
300 fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
302 let fitted = self
303 .fitted_model
304 .as_ref()
305 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
306
307 let x_array = pyarray_to_core_array2(x)?;
308 let y_array = pyarray_to_core_array1(y)?;
309
310 match fitted.score(&x_array, &y_array) {
311 Ok(score) => Ok(score),
312 Err(e) => Err(PyValueError::new_err(format!(
313 "Score calculation failed: {:?}",
314 e
315 ))),
316 }
317 }
318
319 #[getter]
321 fn n_features_in_(&self) -> PyResult<usize> {
322 let fitted = self
323 .fitted_model
324 .as_ref()
325 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
326
327 Ok(fitted.coef().len())
329 }
330
331 fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
333 let _deep = deep.unwrap_or(true);
334
335 let dict = PyDict::new(py);
336
337 dict.set_item("max_iter", self.py_config.max_iter)?;
338 dict.set_item("tol", self.py_config.tol)?;
339 dict.set_item("alpha_init", self.py_config.alpha_init)?;
340 dict.set_item("lambda_init", self.py_config.lambda_init)?;
341 dict.set_item("threshold_alpha", self.py_config.threshold_alpha)?;
342 dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
343 dict.set_item("compute_score", self.py_config.compute_score)?;
344 dict.set_item("copy_X", self.py_config.copy_x)?;
345
346 Ok(dict.into())
347 }
348
349 fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
351 if let Some(max_iter) = kwargs.get_item("max_iter")? {
353 let max_iter_val: usize = max_iter.extract()?;
354 if max_iter_val == 0 {
355 return Err(PyValueError::new_err("max_iter must be greater than 0"));
356 }
357 self.py_config.max_iter = max_iter_val;
358 }
359 if let Some(tol) = kwargs.get_item("tol")? {
360 let tol_val: f64 = tol.extract()?;
361 if tol_val <= 0.0 {
362 return Err(PyValueError::new_err("tol must be positive"));
363 }
364 self.py_config.tol = tol_val;
365 }
366 if let Some(alpha_init) = kwargs.get_item("alpha_init")? {
367 let alpha_init_val: f64 = alpha_init.extract()?;
368 if alpha_init_val <= 0.0 {
369 return Err(PyValueError::new_err("alpha_init must be positive"));
370 }
371 self.py_config.alpha_init = Some(alpha_init_val);
372 }
373 if let Some(lambda_init) = kwargs.get_item("lambda_init")? {
374 let lambda_init_val: f64 = lambda_init.extract()?;
375 if lambda_init_val <= 0.0 {
376 return Err(PyValueError::new_err("lambda_init must be positive"));
377 }
378 self.py_config.lambda_init = Some(lambda_init_val);
379 }
380 if let Some(threshold_alpha) = kwargs.get_item("threshold_alpha")? {
381 let threshold_alpha_val: f64 = threshold_alpha.extract()?;
382 if threshold_alpha_val <= 0.0 {
383 return Err(PyValueError::new_err("threshold_alpha must be positive"));
384 }
385 self.py_config.threshold_alpha = threshold_alpha_val;
386 }
387 if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
388 self.py_config.fit_intercept = fit_intercept.extract()?;
389 }
390 if let Some(compute_score) = kwargs.get_item("compute_score")? {
391 self.py_config.compute_score = compute_score.extract()?;
392 }
393 if let Some(copy_x) = kwargs.get_item("copy_X")? {
394 self.py_config.copy_x = copy_x.extract()?;
395 }
396
397 self.fitted_model = None;
399
400 Ok(())
401 }
402
403 fn __repr__(&self) -> String {
405 format!(
406 "ARDRegression(max_iter={}, tol={}, alpha_init={:?}, lambda_init={:?}, threshold_alpha={}, fit_intercept={}, compute_score={}, copy_X={})",
407 self.py_config.max_iter,
408 self.py_config.tol,
409 self.py_config.alpha_init,
410 self.py_config.lambda_init,
411 self.py_config.threshold_alpha,
412 self.py_config.fit_intercept,
413 self.py_config.compute_score,
414 self.py_config.copy_x
415 )
416 }
417}