1use super::common::*;
8use pyo3::types::PyDict;
9use pyo3::Bound;
10use sklears_core::traits::{Fit, Predict, Score, Trained};
11use sklears_linear::{BayesianRidge, BayesianRidgeConfig};
12
13#[derive(Debug, Clone)]
15pub struct PyBayesianRidgeConfig {
16 pub max_iter: usize,
17 pub tol: f64,
18 pub alpha_init: Option<f64>,
19 pub lambda_init: Option<f64>,
20 pub fit_intercept: bool,
21 pub compute_score: bool,
22 pub copy_x: bool,
23}
24
25impl Default for PyBayesianRidgeConfig {
26 fn default() -> Self {
27 Self {
28 max_iter: 300,
29 tol: 1e-3,
30 alpha_init: Some(1.0),
31 lambda_init: Some(1.0),
32 fit_intercept: true,
33 compute_score: false,
34 copy_x: true,
35 }
36 }
37}
38
39impl From<PyBayesianRidgeConfig> for BayesianRidgeConfig {
40 fn from(py_config: PyBayesianRidgeConfig) -> Self {
41 BayesianRidgeConfig {
42 max_iter: py_config.max_iter,
43 tol: py_config.tol,
44 alpha_init: py_config
45 .alpha_init
46 .unwrap_or_else(|| BayesianRidgeConfig::default().alpha_init),
47 lambda_init: py_config
48 .lambda_init
49 .unwrap_or_else(|| BayesianRidgeConfig::default().lambda_init),
50 fit_intercept: py_config.fit_intercept,
51 compute_score: py_config.compute_score,
52 }
53 }
54}
55
56#[pyclass(name = "BayesianRidge")]
156pub struct PyBayesianRidge {
157 py_config: PyBayesianRidgeConfig,
159 fitted_model: Option<BayesianRidge<Trained>>,
161}
162
163#[pymethods]
164impl PyBayesianRidge {
165 #[new]
166 #[pyo3(signature = (max_iter=300, tol=1e-3, alpha_init=1.0, lambda_init=1.0, fit_intercept=true, compute_score=false, copy_x=true))]
167 fn new(
168 max_iter: usize,
169 tol: f64,
170 alpha_init: f64,
171 lambda_init: f64,
172 fit_intercept: bool,
173 compute_score: bool,
174 copy_x: bool,
175 ) -> PyResult<Self> {
176 if max_iter == 0 {
178 return Err(PyValueError::new_err("max_iter must be greater than 0"));
179 }
180 if tol <= 0.0 {
181 return Err(PyValueError::new_err("tol must be positive"));
182 }
183 if alpha_init <= 0.0 {
184 return Err(PyValueError::new_err("alpha_init must be positive"));
185 }
186 if lambda_init <= 0.0 {
187 return Err(PyValueError::new_err("lambda_init must be positive"));
188 }
189
190 let py_config = PyBayesianRidgeConfig {
191 max_iter,
192 tol,
193 alpha_init: Some(alpha_init),
194 lambda_init: Some(lambda_init),
195 fit_intercept,
196 compute_score,
197 copy_x,
198 };
199
200 Ok(Self {
201 py_config,
202 fitted_model: None,
203 })
204 }
205
206 fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
208 let x_array = pyarray_to_core_array2(x)?;
209 let y_array = pyarray_to_core_array1(y)?;
210
211 validate_fit_arrays(&x_array, &y_array)?;
213
214 let model = BayesianRidge::new()
216 .max_iter(self.py_config.max_iter)
217 .tol(self.py_config.tol)
218 .fit_intercept(self.py_config.fit_intercept)
219 .compute_score(self.py_config.compute_score);
220
221 match model.fit(&x_array, &y_array) {
223 Ok(fitted_model) => {
224 self.fitted_model = Some(fitted_model);
225 Ok(())
226 }
227 Err(e) => Err(PyValueError::new_err(format!(
228 "Failed to fit Bayesian Ridge model: {:?}",
229 e
230 ))),
231 }
232 }
233
234 fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
236 let fitted = self
237 .fitted_model
238 .as_ref()
239 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
240
241 let x_array = pyarray_to_core_array2(x)?;
242 validate_predict_array(&x_array)?;
243
244 match fitted.predict(&x_array) {
245 Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
246 Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
247 }
248 }
249
250 #[getter]
252 fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
253 let fitted = self
254 .fitted_model
255 .as_ref()
256 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
257
258 let coef = fitted
259 .coef()
260 .map_err(|e| PyValueError::new_err(format!("Failed to get coefficients: {:?}", e)))?;
261 Ok(core_array1_to_py(py, coef))
262 }
263
264 #[getter]
266 fn intercept_(&self) -> PyResult<f64> {
267 let fitted = self
268 .fitted_model
269 .as_ref()
270 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
271
272 Ok(fitted.intercept().unwrap_or(0.0))
273 }
274
275 #[getter]
277 fn alpha_(&self) -> PyResult<f64> {
278 let fitted = self
279 .fitted_model
280 .as_ref()
281 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
282
283 fitted
284 .alpha()
285 .map_err(|e| PyValueError::new_err(format!("Failed to get alpha: {:?}", e)))
286 }
287
288 #[getter]
290 fn lambda_(&self) -> PyResult<f64> {
291 let fitted = self
292 .fitted_model
293 .as_ref()
294 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
295
296 fitted
297 .lambda()
298 .map_err(|e| PyValueError::new_err(format!("Failed to get lambda: {:?}", e)))
299 }
300
301 fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
303 let fitted = self
304 .fitted_model
305 .as_ref()
306 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
307
308 let x_array = pyarray_to_core_array2(x)?;
309 let y_array = pyarray_to_core_array1(y)?;
310
311 match fitted.score(&x_array, &y_array) {
312 Ok(score) => Ok(score),
313 Err(e) => Err(PyValueError::new_err(format!(
314 "Score calculation failed: {:?}",
315 e
316 ))),
317 }
318 }
319
320 #[getter]
322 fn n_features_in_(&self) -> PyResult<usize> {
323 let fitted = self
324 .fitted_model
325 .as_ref()
326 .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
327
328 let coef = fitted
330 .coef()
331 .map_err(|e| PyValueError::new_err(format!("Failed to get coefficients: {:?}", e)))?;
332 Ok(coef.len())
333 }
334
335 fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
337 let _deep = deep.unwrap_or(true);
338
339 let dict = PyDict::new(py);
340
341 dict.set_item("max_iter", self.py_config.max_iter)?;
342 dict.set_item("tol", self.py_config.tol)?;
343 dict.set_item("alpha_init", self.py_config.alpha_init)?;
344 dict.set_item("lambda_init", self.py_config.lambda_init)?;
345 dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
346 dict.set_item("compute_score", self.py_config.compute_score)?;
347 dict.set_item("copy_X", self.py_config.copy_x)?;
348
349 Ok(dict.into())
350 }
351
352 fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
354 if let Some(max_iter) = kwargs.get_item("max_iter")? {
356 let max_iter_val: usize = max_iter.extract()?;
357 if max_iter_val == 0 {
358 return Err(PyValueError::new_err("max_iter must be greater than 0"));
359 }
360 self.py_config.max_iter = max_iter_val;
361 }
362 if let Some(tol) = kwargs.get_item("tol")? {
363 let tol_val: f64 = tol.extract()?;
364 if tol_val <= 0.0 {
365 return Err(PyValueError::new_err("tol must be positive"));
366 }
367 self.py_config.tol = tol_val;
368 }
369 if let Some(alpha_init) = kwargs.get_item("alpha_init")? {
370 let alpha_init_val: f64 = alpha_init.extract()?;
371 if alpha_init_val <= 0.0 {
372 return Err(PyValueError::new_err("alpha_init must be positive"));
373 }
374 self.py_config.alpha_init = Some(alpha_init_val);
375 }
376 if let Some(lambda_init) = kwargs.get_item("lambda_init")? {
377 let lambda_init_val: f64 = lambda_init.extract()?;
378 if lambda_init_val <= 0.0 {
379 return Err(PyValueError::new_err("lambda_init must be positive"));
380 }
381 self.py_config.lambda_init = Some(lambda_init_val);
382 }
383 if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
384 self.py_config.fit_intercept = fit_intercept.extract()?;
385 }
386 if let Some(compute_score) = kwargs.get_item("compute_score")? {
387 self.py_config.compute_score = compute_score.extract()?;
388 }
389 if let Some(copy_x) = kwargs.get_item("copy_X")? {
390 self.py_config.copy_x = copy_x.extract()?;
391 }
392
393 self.fitted_model = None;
395
396 Ok(())
397 }
398
399 fn __repr__(&self) -> String {
401 format!(
402 "BayesianRidge(max_iter={}, tol={}, alpha_init={:?}, lambda_init={:?}, fit_intercept={}, compute_score={}, copy_X={})",
403 self.py_config.max_iter,
404 self.py_config.tol,
405 self.py_config.alpha_init,
406 self.py_config.lambda_init,
407 self.py_config.fit_intercept,
408 self.py_config.compute_score,
409 self.py_config.copy_x
410 )
411 }
412}