Skip to main content

sklears_python/
naive_bayes.rs

1//! Python bindings for Naive Bayes classifiers
2//!
3//! This module provides PyO3-based Python bindings for sklears Naive Bayes algorithms,
4//! including Gaussian, Multinomial, Bernoulli, Complement, and Categorical Naive Bayes.
5
6use crate::linear::common::core_array2_to_py;
7use crate::utils::{numpy_to_ndarray1, numpy_to_ndarray2};
8use numpy::{PyArray1, PyArray2};
9use pyo3::exceptions::PyRuntimeError;
10use pyo3::prelude::*;
11use scirs2_core::Array1;
12use sklears_core::traits::{Fit, Predict, PredictProba, Trained, Untrained};
13use sklears_naive_bayes::{BernoulliNB, ComplementNB, GaussianNB, MultinomialNB};
14
15/// Python wrapper for Gaussian Naive Bayes
16#[pyclass(name = "GaussianNB")]
17pub struct PyGaussianNB {
18    inner: Option<GaussianNB<Untrained>>,
19    trained: Option<GaussianNB<Trained>>,
20}
21
22#[pymethods]
23impl PyGaussianNB {
24    #[new]
25    #[pyo3(signature = (priors=None, var_smoothing=1e-9))]
26    fn new(priors: Option<Vec<f64>>, var_smoothing: f64) -> PyResult<Self> {
27        let mut nb = GaussianNB::new().var_smoothing(var_smoothing);
28
29        if let Some(prior_values) = priors {
30            let priors_array = Array1::from_vec(prior_values);
31            nb = nb.priors(priors_array);
32        }
33
34        Ok(Self {
35            inner: Some(nb),
36            trained: None,
37        })
38    }
39
40    /// Fit the Gaussian Naive Bayes classifier
41    fn fit(&mut self, x: &Bound<'_, PyArray2<f64>>, y: &Bound<'_, PyArray1<f64>>) -> PyResult<()> {
42        let x_array = numpy_to_ndarray2(x)?;
43        let y_array = numpy_to_ndarray1(y)?;
44
45        // Convert y to integer vector for classification
46        let y_int: Vec<i32> = y_array.iter().map(|&val| val as i32).collect();
47        let y_int_array = Array1::from_vec(y_int);
48
49        let model = self.inner.take().ok_or_else(|| {
50            PyRuntimeError::new_err("Model has already been fitted or was not initialized")
51        })?;
52
53        match model.fit(&x_array, &y_int_array) {
54            Ok(trained_model) => {
55                self.trained = Some(trained_model);
56                Ok(())
57            }
58            Err(e) => Err(PyRuntimeError::new_err(format!(
59                "Failed to fit model: {}",
60                e
61            ))),
62        }
63    }
64
65    /// Make predictions using the fitted model
66    fn predict<'py>(
67        &self,
68        py: Python<'py>,
69        x: &Bound<'py, PyArray2<f64>>,
70    ) -> PyResult<Py<PyArray1<f64>>> {
71        let trained_model = self.trained.as_ref().ok_or_else(|| {
72            PyRuntimeError::new_err("Model must be fitted before making predictions")
73        })?;
74
75        let x_array = numpy_to_ndarray2(x)?;
76
77        match trained_model.predict(&x_array) {
78            Ok(predictions) => {
79                let predictions_f64: Vec<f64> = predictions.iter().map(|&v| v as f64).collect();
80                Ok(PyArray1::from_vec(py, predictions_f64).unbind())
81            }
82            Err(e) => Err(PyRuntimeError::new_err(format!("Prediction failed: {}", e))),
83        }
84    }
85
86    /// Predict class probabilities
87    fn predict_proba<'py>(
88        &self,
89        py: Python<'py>,
90        x: &Bound<'py, PyArray2<f64>>,
91    ) -> PyResult<Py<PyArray2<f64>>> {
92        let trained_model = self.trained.as_ref().ok_or_else(|| {
93            PyRuntimeError::new_err("Model must be fitted before making predictions")
94        })?;
95
96        let x_array = numpy_to_ndarray2(x)?;
97
98        match trained_model.predict_proba(&x_array) {
99            Ok(probabilities) => Ok(core_array2_to_py(py, &probabilities)?),
100            Err(e) => Err(PyRuntimeError::new_err(format!(
101                "Probability prediction failed: {}",
102                e
103            ))),
104        }
105    }
106
107    fn __repr__(&self) -> String {
108        if self.trained.is_some() {
109            "GaussianNB(fitted=True)".to_string()
110        } else {
111            "GaussianNB(fitted=False)".to_string()
112        }
113    }
114}
115
116/// Python wrapper for Multinomial Naive Bayes
117#[pyclass(name = "MultinomialNB")]
118pub struct PyMultinomialNB {
119    inner: Option<MultinomialNB<Untrained>>,
120    trained: Option<MultinomialNB<Trained>>,
121}
122
123#[pymethods]
124impl PyMultinomialNB {
125    #[new]
126    #[pyo3(signature = (alpha=1.0, fit_prior=true, class_prior=None))]
127    fn new(alpha: f64, fit_prior: bool, class_prior: Option<Vec<f64>>) -> PyResult<Self> {
128        let mut nb = MultinomialNB::new().alpha(alpha).fit_prior(fit_prior);
129
130        if let Some(prior_values) = class_prior {
131            let priors_array = Array1::from_vec(prior_values);
132            nb = nb.class_prior(priors_array);
133        }
134
135        Ok(Self {
136            inner: Some(nb),
137            trained: None,
138        })
139    }
140
141    /// Fit the Multinomial Naive Bayes classifier
142    fn fit(&mut self, x: &Bound<'_, PyArray2<f64>>, y: &Bound<'_, PyArray1<f64>>) -> PyResult<()> {
143        let x_array = numpy_to_ndarray2(x)?;
144        let y_array = numpy_to_ndarray1(y)?;
145
146        // Convert y to integer vector for classification
147        let y_int: Vec<i32> = y_array.iter().map(|&val| val as i32).collect();
148        let y_int_array = Array1::from_vec(y_int);
149
150        let model = self.inner.take().ok_or_else(|| {
151            PyRuntimeError::new_err("Model has already been fitted or was not initialized")
152        })?;
153
154        match model.fit(&x_array, &y_int_array) {
155            Ok(trained_model) => {
156                self.trained = Some(trained_model);
157                Ok(())
158            }
159            Err(e) => Err(PyRuntimeError::new_err(format!(
160                "Failed to fit model: {}",
161                e
162            ))),
163        }
164    }
165
166    /// Make predictions using the fitted model
167    fn predict<'py>(
168        &self,
169        py: Python<'py>,
170        x: &Bound<'py, PyArray2<f64>>,
171    ) -> PyResult<Py<PyArray1<f64>>> {
172        let trained_model = self.trained.as_ref().ok_or_else(|| {
173            PyRuntimeError::new_err("Model must be fitted before making predictions")
174        })?;
175
176        let x_array = numpy_to_ndarray2(x)?;
177
178        match trained_model.predict(&x_array) {
179            Ok(predictions) => {
180                let predictions_f64: Vec<f64> = predictions.iter().map(|&v| v as f64).collect();
181                Ok(PyArray1::from_vec(py, predictions_f64).unbind())
182            }
183            Err(e) => Err(PyRuntimeError::new_err(format!("Prediction failed: {}", e))),
184        }
185    }
186
187    /// Predict class probabilities
188    fn predict_proba<'py>(
189        &self,
190        py: Python<'py>,
191        x: &Bound<'py, PyArray2<f64>>,
192    ) -> PyResult<Py<PyArray2<f64>>> {
193        let trained_model = self.trained.as_ref().ok_or_else(|| {
194            PyRuntimeError::new_err("Model must be fitted before making predictions")
195        })?;
196
197        let x_array = numpy_to_ndarray2(x)?;
198
199        match trained_model.predict_proba(&x_array) {
200            Ok(probabilities) => Ok(core_array2_to_py(py, &probabilities)?),
201            Err(e) => Err(PyRuntimeError::new_err(format!(
202                "Probability prediction failed: {}",
203                e
204            ))),
205        }
206    }
207
208    fn __repr__(&self) -> String {
209        if self.trained.is_some() {
210            "MultinomialNB(fitted=True)".to_string()
211        } else {
212            "MultinomialNB(fitted=False)".to_string()
213        }
214    }
215}
216
217/// Python wrapper for Bernoulli Naive Bayes
218#[pyclass(name = "BernoulliNB")]
219pub struct PyBernoulliNB {
220    inner: Option<BernoulliNB<Untrained>>,
221    trained: Option<BernoulliNB<Trained>>,
222}
223
224#[pymethods]
225impl PyBernoulliNB {
226    #[new]
227    #[pyo3(signature = (alpha=1.0, binarize=0.0, fit_prior=true, class_prior=None))]
228    fn new(
229        alpha: f64,
230        binarize: f64,
231        fit_prior: bool,
232        class_prior: Option<Vec<f64>>,
233    ) -> PyResult<Self> {
234        let mut nb = BernoulliNB::new()
235            .alpha(alpha)
236            .binarize(Some(binarize))
237            .fit_prior(fit_prior);
238
239        if let Some(prior_values) = class_prior {
240            let priors_array = Array1::from_vec(prior_values);
241            nb = nb.class_prior(priors_array);
242        }
243
244        Ok(Self {
245            inner: Some(nb),
246            trained: None,
247        })
248    }
249
250    /// Fit the Bernoulli Naive Bayes classifier
251    fn fit(&mut self, x: &Bound<'_, PyArray2<f64>>, y: &Bound<'_, PyArray1<f64>>) -> PyResult<()> {
252        let x_array = numpy_to_ndarray2(x)?;
253        let y_array = numpy_to_ndarray1(y)?;
254
255        // Convert y to integer vector for classification
256        let y_int: Vec<i32> = y_array.iter().map(|&val| val as i32).collect();
257        let y_int_array = Array1::from_vec(y_int);
258
259        let model = self.inner.take().ok_or_else(|| {
260            PyRuntimeError::new_err("Model has already been fitted or was not initialized")
261        })?;
262
263        match model.fit(&x_array, &y_int_array) {
264            Ok(trained_model) => {
265                self.trained = Some(trained_model);
266                Ok(())
267            }
268            Err(e) => Err(PyRuntimeError::new_err(format!(
269                "Failed to fit model: {}",
270                e
271            ))),
272        }
273    }
274
275    /// Make predictions using the fitted model
276    fn predict<'py>(
277        &self,
278        py: Python<'py>,
279        x: &Bound<'py, PyArray2<f64>>,
280    ) -> PyResult<Py<PyArray1<f64>>> {
281        let trained_model = self.trained.as_ref().ok_or_else(|| {
282            PyRuntimeError::new_err("Model must be fitted before making predictions")
283        })?;
284
285        let x_array = numpy_to_ndarray2(x)?;
286
287        match trained_model.predict(&x_array) {
288            Ok(predictions) => {
289                let predictions_f64: Vec<f64> = predictions.iter().map(|&v| v as f64).collect();
290                Ok(PyArray1::from_vec(py, predictions_f64).unbind())
291            }
292            Err(e) => Err(PyRuntimeError::new_err(format!("Prediction failed: {}", e))),
293        }
294    }
295
296    /// Predict class probabilities
297    fn predict_proba<'py>(
298        &self,
299        py: Python<'py>,
300        x: &Bound<'py, PyArray2<f64>>,
301    ) -> PyResult<Py<PyArray2<f64>>> {
302        let trained_model = self.trained.as_ref().ok_or_else(|| {
303            PyRuntimeError::new_err("Model must be fitted before making predictions")
304        })?;
305
306        let x_array = numpy_to_ndarray2(x)?;
307
308        match trained_model.predict_proba(&x_array) {
309            Ok(probabilities) => Ok(core_array2_to_py(py, &probabilities)?),
310            Err(e) => Err(PyRuntimeError::new_err(format!(
311                "Probability prediction failed: {}",
312                e
313            ))),
314        }
315    }
316
317    fn __repr__(&self) -> String {
318        if self.trained.is_some() {
319            "BernoulliNB(fitted=True)".to_string()
320        } else {
321            "BernoulliNB(fitted=False)".to_string()
322        }
323    }
324}
325
326/// Python wrapper for Complement Naive Bayes
327#[pyclass(name = "ComplementNB")]
328pub struct PyComplementNB {
329    inner: Option<ComplementNB<Untrained>>,
330    trained: Option<ComplementNB<Trained>>,
331}
332
333#[pymethods]
334impl PyComplementNB {
335    #[new]
336    #[pyo3(signature = (alpha=1.0, fit_prior=true, class_prior=None, norm=false))]
337    fn new(
338        alpha: f64,
339        fit_prior: bool,
340        class_prior: Option<Vec<f64>>,
341        norm: bool,
342    ) -> PyResult<Self> {
343        let mut nb = ComplementNB::new()
344            .alpha(alpha)
345            .fit_prior(fit_prior)
346            .norm(norm);
347
348        if let Some(prior_values) = class_prior {
349            let priors_array = Array1::from_vec(prior_values);
350            nb = nb.class_prior(priors_array);
351        }
352
353        Ok(Self {
354            inner: Some(nb),
355            trained: None,
356        })
357    }
358
359    /// Fit the Complement Naive Bayes classifier
360    fn fit(&mut self, x: &Bound<'_, PyArray2<f64>>, y: &Bound<'_, PyArray1<f64>>) -> PyResult<()> {
361        let x_array = numpy_to_ndarray2(x)?;
362        let y_array = numpy_to_ndarray1(y)?;
363
364        // Convert y to integer vector for classification
365        let y_int: Vec<i32> = y_array.iter().map(|&val| val as i32).collect();
366        let y_int_array = Array1::from_vec(y_int);
367
368        let model = self.inner.take().ok_or_else(|| {
369            PyRuntimeError::new_err("Model has already been fitted or was not initialized")
370        })?;
371
372        match model.fit(&x_array, &y_int_array) {
373            Ok(trained_model) => {
374                self.trained = Some(trained_model);
375                Ok(())
376            }
377            Err(e) => Err(PyRuntimeError::new_err(format!(
378                "Failed to fit model: {}",
379                e
380            ))),
381        }
382    }
383
384    /// Make predictions using the fitted model
385    fn predict<'py>(
386        &self,
387        py: Python<'py>,
388        x: &Bound<'py, PyArray2<f64>>,
389    ) -> PyResult<Py<PyArray1<f64>>> {
390        let trained_model = self.trained.as_ref().ok_or_else(|| {
391            PyRuntimeError::new_err("Model must be fitted before making predictions")
392        })?;
393
394        let x_array = numpy_to_ndarray2(x)?;
395
396        match trained_model.predict(&x_array) {
397            Ok(predictions) => {
398                let predictions_f64: Vec<f64> = predictions.iter().map(|&v| v as f64).collect();
399                Ok(PyArray1::from_vec(py, predictions_f64).unbind())
400            }
401            Err(e) => Err(PyRuntimeError::new_err(format!("Prediction failed: {}", e))),
402        }
403    }
404
405    fn __repr__(&self) -> String {
406        if self.trained.is_some() {
407            "ComplementNB(fitted=True)".to_string()
408        } else {
409            "ComplementNB(fitted=False)".to_string()
410        }
411    }
412}