1use super::common::*;
8use pyo3::types::PyDict;
9use pyo3::Bound;
10use sklears_core::traits::{Fit, Predict, Score, Trained};
11use sklears_linear::{BayesianRidge, BayesianRidgeConfig};
12
13#[derive(Debug, Clone)]
15pub struct PyBayesianRidgeConfig {
16 pub max_iter: usize,
17 pub tol: f64,
18 pub alpha_init: Option<f64>,
19 pub lambda_init: Option<f64>,
20 pub fit_intercept: bool,
21 pub compute_score: bool,
22 pub copy_x: bool,
23}
24
25impl Default for PyBayesianRidgeConfig {
26 fn default() -> Self {
27 Self {
28 max_iter: 300,
29 tol: 1e-3,
30 alpha_init: Some(1.0),
31 lambda_init: Some(1.0),
32 fit_intercept: true,
33 compute_score: false,
34 copy_x: true,
35 }
36 }
37}
38
39impl From<PyBayesianRidgeConfig> for BayesianRidgeConfig {
40 fn from(py_config: PyBayesianRidgeConfig) -> Self {
41 BayesianRidgeConfig {
42 max_iter: py_config.max_iter,
43 tol: py_config.tol,
44 alpha_init: py_config
45 .alpha_init
46 .unwrap_or_else(|| BayesianRidgeConfig::default().alpha_init),
47 lambda_init: py_config
48 .lambda_init
49 .unwrap_or_else(|| BayesianRidgeConfig::default().lambda_init),
50 fit_intercept: py_config.fit_intercept,
51 compute_score: py_config.compute_score,
52 }
53 }
54}
55
56#[pyclass(name = "BayesianRidge")]
156pub struct PyBayesianRidge {
157 py_config: PyBayesianRidgeConfig,
159 fitted_model: Option<BayesianRidge<Trained>>,
161}
162
163#[pymethods]
164impl PyBayesianRidge {
165 #[new]
166 #[pyo3(signature = (max_iter=300, tol=1e-3, alpha_init=1.0, lambda_init=1.0, fit_intercept=true, compute_score=false, copy_x=true))]
167 fn new(
168 max_iter: usize,
169 tol: f64,
170 alpha_init: f64,
171 lambda_init: f64,
172 fit_intercept: bool,
173 compute_score: bool,
174 copy_x: bool,
175 ) -> PyResult<Self> {
176 if max_iter == 0 {
178 return Err(PyValueError::new_err("max_iter must be greater than 0"));
179 }
180 if tol <= 0.0 {
181 return Err(PyValueError::new_err("tol must be positive"));
182 }
183 if alpha_init <= 0.0 {
184 return Err(PyValueError::new_err("alpha_init must be positive"));
185 }
186 if lambda_init <= 0.0 {
187 return Err(PyValueError::new_err("lambda_init must be positive"));
188 }
189
190 let py_config = PyBayesianRidgeConfig {
191 max_iter,
192 tol,
193 alpha_init: Some(alpha_init),
194 lambda_init: Some(lambda_init),
195 fit_intercept,
196 compute_score,
197 copy_x,
198 };
199
200 Ok(Self {
201 py_config,
202 fitted_model: None,
203 })
204 }
205
206 fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
208 let x_array = pyarray_to_core_array2(x)?;
209 let y_array = pyarray_to_core_array1(y)?;
210
211 validate_fit_arrays(&x_array, &y_array)?;
213
214 let model = BayesianRidge::new()
216 .max_iter(self.py_config.max_iter)
217 .tol(self.py_config.tol)
218 .fit_intercept(self.py_config.fit_intercept)
219 .compute_score(self.py_config.compute_score);
220
221 match model.fit(&x_array, &y_array) {
223 Ok(fitted_model) => {
224 self.fitted_model = Some(fitted_model);
225 Ok(())
226 }
227 Err(e) => Err(PyValueError::new_err(format!(
228 "Failed to fit Bayesian Ridge model: {:?}",
229 e
230 ))),
231 }
232 }
233
234 fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
236 let fitted = self
237 .fitted_model
238 .as_ref()
239 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
240
241 let x_array = pyarray_to_core_array2(x)?;
242 validate_predict_array(&x_array)?;
243
244 match fitted.predict(&x_array) {
245 Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
246 Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
247 }
248 }
249
250 #[getter]
252 fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
253 let fitted = self
254 .fitted_model
255 .as_ref()
256 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
257
258 Ok(core_array1_to_py(py, fitted.coef()))
259 }
260
261 #[getter]
263 fn intercept_(&self) -> PyResult<f64> {
264 let fitted = self
265 .fitted_model
266 .as_ref()
267 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
268
269 Ok(fitted.intercept().unwrap_or(0.0))
270 }
271
272 #[getter]
274 fn alpha_(&self) -> PyResult<f64> {
275 let fitted = self
276 .fitted_model
277 .as_ref()
278 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
279
280 Ok(fitted.alpha())
281 }
282
283 #[getter]
285 fn lambda_(&self) -> PyResult<f64> {
286 let fitted = self
287 .fitted_model
288 .as_ref()
289 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
290
291 Ok(fitted.lambda())
292 }
293
294 fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
296 let fitted = self
297 .fitted_model
298 .as_ref()
299 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
300
301 let x_array = pyarray_to_core_array2(x)?;
302 let y_array = pyarray_to_core_array1(y)?;
303
304 match fitted.score(&x_array, &y_array) {
305 Ok(score) => Ok(score),
306 Err(e) => Err(PyValueError::new_err(format!(
307 "Score calculation failed: {:?}",
308 e
309 ))),
310 }
311 }
312
313 #[getter]
315 fn n_features_in_(&self) -> PyResult<usize> {
316 let fitted = self
317 .fitted_model
318 .as_ref()
319 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
320
321 Ok(fitted.coef().len())
323 }
324
325 fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
327 let _deep = deep.unwrap_or(true);
328
329 let dict = PyDict::new(py);
330
331 dict.set_item("max_iter", self.py_config.max_iter)?;
332 dict.set_item("tol", self.py_config.tol)?;
333 dict.set_item("alpha_init", self.py_config.alpha_init)?;
334 dict.set_item("lambda_init", self.py_config.lambda_init)?;
335 dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
336 dict.set_item("compute_score", self.py_config.compute_score)?;
337 dict.set_item("copy_X", self.py_config.copy_x)?;
338
339 Ok(dict.into())
340 }
341
342 fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
344 if let Some(max_iter) = kwargs.get_item("max_iter")? {
346 let max_iter_val: usize = max_iter.extract()?;
347 if max_iter_val == 0 {
348 return Err(PyValueError::new_err("max_iter must be greater than 0"));
349 }
350 self.py_config.max_iter = max_iter_val;
351 }
352 if let Some(tol) = kwargs.get_item("tol")? {
353 let tol_val: f64 = tol.extract()?;
354 if tol_val <= 0.0 {
355 return Err(PyValueError::new_err("tol must be positive"));
356 }
357 self.py_config.tol = tol_val;
358 }
359 if let Some(alpha_init) = kwargs.get_item("alpha_init")? {
360 let alpha_init_val: f64 = alpha_init.extract()?;
361 if alpha_init_val <= 0.0 {
362 return Err(PyValueError::new_err("alpha_init must be positive"));
363 }
364 self.py_config.alpha_init = Some(alpha_init_val);
365 }
366 if let Some(lambda_init) = kwargs.get_item("lambda_init")? {
367 let lambda_init_val: f64 = lambda_init.extract()?;
368 if lambda_init_val <= 0.0 {
369 return Err(PyValueError::new_err("lambda_init must be positive"));
370 }
371 self.py_config.lambda_init = Some(lambda_init_val);
372 }
373 if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
374 self.py_config.fit_intercept = fit_intercept.extract()?;
375 }
376 if let Some(compute_score) = kwargs.get_item("compute_score")? {
377 self.py_config.compute_score = compute_score.extract()?;
378 }
379 if let Some(copy_x) = kwargs.get_item("copy_X")? {
380 self.py_config.copy_x = copy_x.extract()?;
381 }
382
383 self.fitted_model = None;
385
386 Ok(())
387 }
388
389 fn __repr__(&self) -> String {
391 format!(
392 "BayesianRidge(max_iter={}, tol={}, alpha_init={:?}, lambda_init={:?}, fit_intercept={}, compute_score={}, copy_X={})",
393 self.py_config.max_iter,
394 self.py_config.tol,
395 self.py_config.alpha_init,
396 self.py_config.lambda_init,
397 self.py_config.fit_intercept,
398 self.py_config.compute_score,
399 self.py_config.copy_x
400 )
401 }
402}