1use super::common::*;
8use pyo3::types::PyDict;
9use pyo3::Bound;
10use sklears_core::traits::{Fit, Predict, Score, Trained};
11use sklears_linear::{ARDRegression, ARDRegressionConfig};
12
13#[derive(Debug, Clone)]
15pub struct PyARDRegressionConfig {
16 pub max_iter: usize,
17 pub tol: f64,
18 pub alpha_init: Option<f64>,
19 pub lambda_init: Option<f64>,
20 pub threshold_alpha: f64,
21 pub fit_intercept: bool,
22 pub compute_score: bool,
23 pub copy_x: bool,
24}
25
26impl Default for PyARDRegressionConfig {
27 fn default() -> Self {
28 Self {
29 max_iter: 300,
30 tol: 1e-3,
31 alpha_init: Some(1.0),
32 lambda_init: Some(1.0),
33 threshold_alpha: 1e10,
34 fit_intercept: true,
35 compute_score: false,
36 copy_x: true,
37 }
38 }
39}
40
41impl From<PyARDRegressionConfig> for ARDRegressionConfig {
42 fn from(py_config: PyARDRegressionConfig) -> Self {
43 ARDRegressionConfig {
44 max_iter: py_config.max_iter,
45 tol: py_config.tol,
46 alpha_init: py_config
47 .alpha_init
48 .unwrap_or_else(|| ARDRegressionConfig::default().alpha_init),
49 lambda_init: py_config
50 .lambda_init
51 .unwrap_or_else(|| ARDRegressionConfig::default().lambda_init),
52 threshold_alpha: py_config.threshold_alpha,
53 fit_intercept: py_config.fit_intercept,
54 compute_score: py_config.compute_score,
55 }
56 }
57}
58
59#[pyclass(name = "ARDRegression")]
158pub struct PyARDRegression {
159 py_config: PyARDRegressionConfig,
161 fitted_model: Option<ARDRegression<Trained>>,
163}
164
165#[pymethods]
166impl PyARDRegression {
167 #[new]
168 #[allow(clippy::too_many_arguments)]
169 #[pyo3(signature = (max_iter=300, tol=1e-3, alpha_init=1.0, lambda_init=1.0, threshold_alpha=1e10, fit_intercept=true, compute_score=false, copy_x=true))]
170 fn new(
171 max_iter: usize,
172 tol: f64,
173 alpha_init: f64,
174 lambda_init: f64,
175 threshold_alpha: f64,
176 fit_intercept: bool,
177 compute_score: bool,
178 copy_x: bool,
179 ) -> PyResult<Self> {
180 if max_iter == 0 {
182 return Err(PyValueError::new_err("max_iter must be greater than 0"));
183 }
184 if tol <= 0.0 {
185 return Err(PyValueError::new_err("tol must be positive"));
186 }
187 if alpha_init <= 0.0 {
188 return Err(PyValueError::new_err("alpha_init must be positive"));
189 }
190 if lambda_init <= 0.0 {
191 return Err(PyValueError::new_err("lambda_init must be positive"));
192 }
193 if threshold_alpha <= 0.0 {
194 return Err(PyValueError::new_err("threshold_alpha must be positive"));
195 }
196
197 let py_config = PyARDRegressionConfig {
198 max_iter,
199 tol,
200 alpha_init: Some(alpha_init),
201 lambda_init: Some(lambda_init),
202 threshold_alpha,
203 fit_intercept,
204 compute_score,
205 copy_x,
206 };
207
208 Ok(Self {
209 py_config,
210 fitted_model: None,
211 })
212 }
213
214 fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
216 let x_array = pyarray_to_core_array2(x)?;
217 let y_array = pyarray_to_core_array1(y)?;
218
219 validate_fit_arrays(&x_array, &y_array)?;
221
222 let model = ARDRegression::new()
224 .max_iter(self.py_config.max_iter)
225 .tol(self.py_config.tol)
226 .threshold_alpha(self.py_config.threshold_alpha)
227 .fit_intercept(self.py_config.fit_intercept);
228
229 match model.fit(&x_array, &y_array) {
231 Ok(fitted_model) => {
232 self.fitted_model = Some(fitted_model);
233 Ok(())
234 }
235 Err(e) => Err(PyValueError::new_err(format!(
236 "Failed to fit ARD regression model: {:?}",
237 e
238 ))),
239 }
240 }
241
242 fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
244 let fitted = self
245 .fitted_model
246 .as_ref()
247 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
248
249 let x_array = pyarray_to_core_array2(x)?;
250 validate_predict_array(&x_array)?;
251
252 match fitted.predict(&x_array) {
253 Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
254 Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
255 }
256 }
257
258 #[getter]
260 fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
261 let fitted = self
262 .fitted_model
263 .as_ref()
264 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
265
266 let coef = fitted
267 .coef()
268 .map_err(|e| PyValueError::new_err(format!("Failed to get coefficients: {:?}", e)))?;
269 Ok(core_array1_to_py(py, coef))
270 }
271
272 #[getter]
274 fn intercept_(&self) -> PyResult<f64> {
275 let fitted = self
276 .fitted_model
277 .as_ref()
278 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
279
280 Ok(fitted.intercept().unwrap_or(0.0))
281 }
282
283 #[getter]
285 fn alpha_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
286 let fitted = self
287 .fitted_model
288 .as_ref()
289 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
290
291 let alpha = fitted
292 .alpha()
293 .map_err(|e| PyValueError::new_err(format!("Failed to get alpha: {:?}", e)))?;
294 Ok(core_array1_to_py(py, alpha))
295 }
296
297 #[getter]
299 fn lambda_(&self) -> PyResult<f64> {
300 let fitted = self
301 .fitted_model
302 .as_ref()
303 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
304
305 fitted
306 .lambda()
307 .map_err(|e| PyValueError::new_err(format!("Failed to get lambda: {:?}", e)))
308 }
309
310 fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
312 let fitted = self
313 .fitted_model
314 .as_ref()
315 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
316
317 let x_array = pyarray_to_core_array2(x)?;
318 let y_array = pyarray_to_core_array1(y)?;
319
320 match fitted.score(&x_array, &y_array) {
321 Ok(score) => Ok(score),
322 Err(e) => Err(PyValueError::new_err(format!(
323 "Score calculation failed: {:?}",
324 e
325 ))),
326 }
327 }
328
329 #[getter]
331 fn n_features_in_(&self) -> PyResult<usize> {
332 let fitted = self
333 .fitted_model
334 .as_ref()
335 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
336
337 let coef = fitted
339 .coef()
340 .map_err(|e| PyValueError::new_err(format!("Failed to get coefficients: {:?}", e)))?;
341 Ok(coef.len())
342 }
343
344 fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
346 let _deep = deep.unwrap_or(true);
347
348 let dict = PyDict::new(py);
349
350 dict.set_item("max_iter", self.py_config.max_iter)?;
351 dict.set_item("tol", self.py_config.tol)?;
352 dict.set_item("alpha_init", self.py_config.alpha_init)?;
353 dict.set_item("lambda_init", self.py_config.lambda_init)?;
354 dict.set_item("threshold_alpha", self.py_config.threshold_alpha)?;
355 dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
356 dict.set_item("compute_score", self.py_config.compute_score)?;
357 dict.set_item("copy_X", self.py_config.copy_x)?;
358
359 Ok(dict.into())
360 }
361
362 fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
364 if let Some(max_iter) = kwargs.get_item("max_iter")? {
366 let max_iter_val: usize = max_iter.extract()?;
367 if max_iter_val == 0 {
368 return Err(PyValueError::new_err("max_iter must be greater than 0"));
369 }
370 self.py_config.max_iter = max_iter_val;
371 }
372 if let Some(tol) = kwargs.get_item("tol")? {
373 let tol_val: f64 = tol.extract()?;
374 if tol_val <= 0.0 {
375 return Err(PyValueError::new_err("tol must be positive"));
376 }
377 self.py_config.tol = tol_val;
378 }
379 if let Some(alpha_init) = kwargs.get_item("alpha_init")? {
380 let alpha_init_val: f64 = alpha_init.extract()?;
381 if alpha_init_val <= 0.0 {
382 return Err(PyValueError::new_err("alpha_init must be positive"));
383 }
384 self.py_config.alpha_init = Some(alpha_init_val);
385 }
386 if let Some(lambda_init) = kwargs.get_item("lambda_init")? {
387 let lambda_init_val: f64 = lambda_init.extract()?;
388 if lambda_init_val <= 0.0 {
389 return Err(PyValueError::new_err("lambda_init must be positive"));
390 }
391 self.py_config.lambda_init = Some(lambda_init_val);
392 }
393 if let Some(threshold_alpha) = kwargs.get_item("threshold_alpha")? {
394 let threshold_alpha_val: f64 = threshold_alpha.extract()?;
395 if threshold_alpha_val <= 0.0 {
396 return Err(PyValueError::new_err("threshold_alpha must be positive"));
397 }
398 self.py_config.threshold_alpha = threshold_alpha_val;
399 }
400 if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
401 self.py_config.fit_intercept = fit_intercept.extract()?;
402 }
403 if let Some(compute_score) = kwargs.get_item("compute_score")? {
404 self.py_config.compute_score = compute_score.extract()?;
405 }
406 if let Some(copy_x) = kwargs.get_item("copy_X")? {
407 self.py_config.copy_x = copy_x.extract()?;
408 }
409
410 self.fitted_model = None;
412
413 Ok(())
414 }
415
416 fn __repr__(&self) -> String {
418 format!(
419 "ARDRegression(max_iter={}, tol={}, alpha_init={:?}, lambda_init={:?}, threshold_alpha={}, fit_intercept={}, compute_score={}, copy_X={})",
420 self.py_config.max_iter,
421 self.py_config.tol,
422 self.py_config.alpha_init,
423 self.py_config.lambda_init,
424 self.py_config.threshold_alpha,
425 self.py_config.fit_intercept,
426 self.py_config.compute_score,
427 self.py_config.copy_x
428 )
429 }
430}