Skip to main content

sklears_python/
naive_bayes.rs

1//! Python bindings for Naive Bayes classifiers
2//!
3//! This module provides PyO3-based Python bindings for sklears Naive Bayes algorithms,
4//! including Gaussian, Multinomial, Bernoulli, Complement, and Categorical Naive Bayes.
5
6use crate::utils::{numpy_to_ndarray1, numpy_to_ndarray2};
7use numpy::{IntoPyArray, PyArray1, PyArray2};
8use pyo3::exceptions::PyRuntimeError;
9use pyo3::prelude::*;
10use scirs2_core::Array1;
11use sklears_core::traits::{Fit, Predict, PredictProba, Trained, Untrained};
12use sklears_naive_bayes::{BernoulliNB, ComplementNB, GaussianNB, MultinomialNB};
13
14/// Python wrapper for Gaussian Naive Bayes
15#[pyclass(name = "GaussianNB")]
16pub struct PyGaussianNB {
17    inner: Option<GaussianNB<Untrained>>,
18    trained: Option<GaussianNB<Trained>>,
19}
20
21#[pymethods]
22impl PyGaussianNB {
23    #[new]
24    #[pyo3(signature = (priors=None, var_smoothing=1e-9))]
25    fn new(priors: Option<Vec<f64>>, var_smoothing: f64) -> PyResult<Self> {
26        let mut nb = GaussianNB::new().var_smoothing(var_smoothing);
27
28        if let Some(prior_values) = priors {
29            let priors_array = Array1::from_vec(prior_values);
30            nb = nb.priors(priors_array);
31        }
32
33        Ok(Self {
34            inner: Some(nb),
35            trained: None,
36        })
37    }
38
39    /// Fit the Gaussian Naive Bayes classifier
40    fn fit(&mut self, x: &Bound<'_, PyArray2<f64>>, y: &Bound<'_, PyArray1<f64>>) -> PyResult<()> {
41        let x_array = numpy_to_ndarray2(x)?;
42        let y_array = numpy_to_ndarray1(y)?;
43
44        // Convert y to integer vector for classification
45        let y_int: Vec<i32> = y_array.iter().map(|&val| val as i32).collect();
46        let y_int_array = Array1::from_vec(y_int);
47
48        let model = self.inner.take().ok_or_else(|| {
49            PyRuntimeError::new_err("Model has already been fitted or was not initialized")
50        })?;
51
52        match model.fit(&x_array, &y_int_array) {
53            Ok(trained_model) => {
54                self.trained = Some(trained_model);
55                Ok(())
56            }
57            Err(e) => Err(PyRuntimeError::new_err(format!(
58                "Failed to fit model: {}",
59                e
60            ))),
61        }
62    }
63
64    /// Make predictions using the fitted model
65    fn predict<'py>(
66        &self,
67        py: Python<'py>,
68        x: &Bound<'py, PyArray2<f64>>,
69    ) -> PyResult<Py<PyArray1<f64>>> {
70        let trained_model = self.trained.as_ref().ok_or_else(|| {
71            PyRuntimeError::new_err("Model must be fitted before making predictions")
72        })?;
73
74        let x_array = numpy_to_ndarray2(x)?;
75
76        match trained_model.predict(&x_array) {
77            Ok(predictions) => {
78                let predictions_f64: Vec<f64> = predictions.iter().map(|&v| v as f64).collect();
79                Ok(PyArray1::from_vec(py, predictions_f64).unbind())
80            }
81            Err(e) => Err(PyRuntimeError::new_err(format!("Prediction failed: {}", e))),
82        }
83    }
84
85    /// Predict class probabilities
86    fn predict_proba<'py>(
87        &self,
88        py: Python<'py>,
89        x: &Bound<'py, PyArray2<f64>>,
90    ) -> PyResult<Py<PyArray2<f64>>> {
91        let trained_model = self.trained.as_ref().ok_or_else(|| {
92            PyRuntimeError::new_err("Model must be fitted before making predictions")
93        })?;
94
95        let x_array = numpy_to_ndarray2(x)?;
96
97        match trained_model.predict_proba(&x_array) {
98            Ok(probabilities) => Ok(probabilities.into_pyarray(py).unbind()),
99            Err(e) => Err(PyRuntimeError::new_err(format!(
100                "Probability prediction failed: {}",
101                e
102            ))),
103        }
104    }
105
106    fn __repr__(&self) -> String {
107        if self.trained.is_some() {
108            "GaussianNB(fitted=True)".to_string()
109        } else {
110            "GaussianNB(fitted=False)".to_string()
111        }
112    }
113}
114
115/// Python wrapper for Multinomial Naive Bayes
116#[pyclass(name = "MultinomialNB")]
117pub struct PyMultinomialNB {
118    inner: Option<MultinomialNB<Untrained>>,
119    trained: Option<MultinomialNB<Trained>>,
120}
121
122#[pymethods]
123impl PyMultinomialNB {
124    #[new]
125    #[pyo3(signature = (alpha=1.0, fit_prior=true, class_prior=None))]
126    fn new(alpha: f64, fit_prior: bool, class_prior: Option<Vec<f64>>) -> PyResult<Self> {
127        let mut nb = MultinomialNB::new().alpha(alpha).fit_prior(fit_prior);
128
129        if let Some(prior_values) = class_prior {
130            let priors_array = Array1::from_vec(prior_values);
131            nb = nb.class_prior(priors_array);
132        }
133
134        Ok(Self {
135            inner: Some(nb),
136            trained: None,
137        })
138    }
139
140    /// Fit the Multinomial Naive Bayes classifier
141    fn fit(&mut self, x: &Bound<'_, PyArray2<f64>>, y: &Bound<'_, PyArray1<f64>>) -> PyResult<()> {
142        let x_array = numpy_to_ndarray2(x)?;
143        let y_array = numpy_to_ndarray1(y)?;
144
145        // Convert y to integer vector for classification
146        let y_int: Vec<i32> = y_array.iter().map(|&val| val as i32).collect();
147        let y_int_array = Array1::from_vec(y_int);
148
149        let model = self.inner.take().ok_or_else(|| {
150            PyRuntimeError::new_err("Model has already been fitted or was not initialized")
151        })?;
152
153        match model.fit(&x_array, &y_int_array) {
154            Ok(trained_model) => {
155                self.trained = Some(trained_model);
156                Ok(())
157            }
158            Err(e) => Err(PyRuntimeError::new_err(format!(
159                "Failed to fit model: {}",
160                e
161            ))),
162        }
163    }
164
165    /// Make predictions using the fitted model
166    fn predict<'py>(
167        &self,
168        py: Python<'py>,
169        x: &Bound<'py, PyArray2<f64>>,
170    ) -> PyResult<Py<PyArray1<f64>>> {
171        let trained_model = self.trained.as_ref().ok_or_else(|| {
172            PyRuntimeError::new_err("Model must be fitted before making predictions")
173        })?;
174
175        let x_array = numpy_to_ndarray2(x)?;
176
177        match trained_model.predict(&x_array) {
178            Ok(predictions) => {
179                let predictions_f64: Vec<f64> = predictions.iter().map(|&v| v as f64).collect();
180                Ok(PyArray1::from_vec(py, predictions_f64).unbind())
181            }
182            Err(e) => Err(PyRuntimeError::new_err(format!("Prediction failed: {}", e))),
183        }
184    }
185
186    /// Predict class probabilities
187    fn predict_proba<'py>(
188        &self,
189        py: Python<'py>,
190        x: &Bound<'py, PyArray2<f64>>,
191    ) -> PyResult<Py<PyArray2<f64>>> {
192        let trained_model = self.trained.as_ref().ok_or_else(|| {
193            PyRuntimeError::new_err("Model must be fitted before making predictions")
194        })?;
195
196        let x_array = numpy_to_ndarray2(x)?;
197
198        match trained_model.predict_proba(&x_array) {
199            Ok(probabilities) => Ok(probabilities.into_pyarray(py).unbind()),
200            Err(e) => Err(PyRuntimeError::new_err(format!(
201                "Probability prediction failed: {}",
202                e
203            ))),
204        }
205    }
206
207    fn __repr__(&self) -> String {
208        if self.trained.is_some() {
209            "MultinomialNB(fitted=True)".to_string()
210        } else {
211            "MultinomialNB(fitted=False)".to_string()
212        }
213    }
214}
215
216/// Python wrapper for Bernoulli Naive Bayes
217#[pyclass(name = "BernoulliNB")]
218pub struct PyBernoulliNB {
219    inner: Option<BernoulliNB<Untrained>>,
220    trained: Option<BernoulliNB<Trained>>,
221}
222
223#[pymethods]
224impl PyBernoulliNB {
225    #[new]
226    #[pyo3(signature = (alpha=1.0, binarize=0.0, fit_prior=true, class_prior=None))]
227    fn new(
228        alpha: f64,
229        binarize: f64,
230        fit_prior: bool,
231        class_prior: Option<Vec<f64>>,
232    ) -> PyResult<Self> {
233        let mut nb = BernoulliNB::new()
234            .alpha(alpha)
235            .binarize(Some(binarize))
236            .fit_prior(fit_prior);
237
238        if let Some(prior_values) = class_prior {
239            let priors_array = Array1::from_vec(prior_values);
240            nb = nb.class_prior(priors_array);
241        }
242
243        Ok(Self {
244            inner: Some(nb),
245            trained: None,
246        })
247    }
248
249    /// Fit the Bernoulli Naive Bayes classifier
250    fn fit(&mut self, x: &Bound<'_, PyArray2<f64>>, y: &Bound<'_, PyArray1<f64>>) -> PyResult<()> {
251        let x_array = numpy_to_ndarray2(x)?;
252        let y_array = numpy_to_ndarray1(y)?;
253
254        // Convert y to integer vector for classification
255        let y_int: Vec<i32> = y_array.iter().map(|&val| val as i32).collect();
256        let y_int_array = Array1::from_vec(y_int);
257
258        let model = self.inner.take().ok_or_else(|| {
259            PyRuntimeError::new_err("Model has already been fitted or was not initialized")
260        })?;
261
262        match model.fit(&x_array, &y_int_array) {
263            Ok(trained_model) => {
264                self.trained = Some(trained_model);
265                Ok(())
266            }
267            Err(e) => Err(PyRuntimeError::new_err(format!(
268                "Failed to fit model: {}",
269                e
270            ))),
271        }
272    }
273
274    /// Make predictions using the fitted model
275    fn predict<'py>(
276        &self,
277        py: Python<'py>,
278        x: &Bound<'py, PyArray2<f64>>,
279    ) -> PyResult<Py<PyArray1<f64>>> {
280        let trained_model = self.trained.as_ref().ok_or_else(|| {
281            PyRuntimeError::new_err("Model must be fitted before making predictions")
282        })?;
283
284        let x_array = numpy_to_ndarray2(x)?;
285
286        match trained_model.predict(&x_array) {
287            Ok(predictions) => {
288                let predictions_f64: Vec<f64> = predictions.iter().map(|&v| v as f64).collect();
289                Ok(PyArray1::from_vec(py, predictions_f64).unbind())
290            }
291            Err(e) => Err(PyRuntimeError::new_err(format!("Prediction failed: {}", e))),
292        }
293    }
294
295    /// Predict class probabilities
296    fn predict_proba<'py>(
297        &self,
298        py: Python<'py>,
299        x: &Bound<'py, PyArray2<f64>>,
300    ) -> PyResult<Py<PyArray2<f64>>> {
301        let trained_model = self.trained.as_ref().ok_or_else(|| {
302            PyRuntimeError::new_err("Model must be fitted before making predictions")
303        })?;
304
305        let x_array = numpy_to_ndarray2(x)?;
306
307        match trained_model.predict_proba(&x_array) {
308            Ok(probabilities) => Ok(probabilities.into_pyarray(py).unbind()),
309            Err(e) => Err(PyRuntimeError::new_err(format!(
310                "Probability prediction failed: {}",
311                e
312            ))),
313        }
314    }
315
316    fn __repr__(&self) -> String {
317        if self.trained.is_some() {
318            "BernoulliNB(fitted=True)".to_string()
319        } else {
320            "BernoulliNB(fitted=False)".to_string()
321        }
322    }
323}
324
325/// Python wrapper for Complement Naive Bayes
326#[pyclass(name = "ComplementNB")]
327pub struct PyComplementNB {
328    inner: Option<ComplementNB<Untrained>>,
329    trained: Option<ComplementNB<Trained>>,
330}
331
332#[pymethods]
333impl PyComplementNB {
334    #[new]
335    #[pyo3(signature = (alpha=1.0, fit_prior=true, class_prior=None, norm=false))]
336    fn new(
337        alpha: f64,
338        fit_prior: bool,
339        class_prior: Option<Vec<f64>>,
340        norm: bool,
341    ) -> PyResult<Self> {
342        let mut nb = ComplementNB::new()
343            .alpha(alpha)
344            .fit_prior(fit_prior)
345            .norm(norm);
346
347        if let Some(prior_values) = class_prior {
348            let priors_array = Array1::from_vec(prior_values);
349            nb = nb.class_prior(priors_array);
350        }
351
352        Ok(Self {
353            inner: Some(nb),
354            trained: None,
355        })
356    }
357
358    /// Fit the Complement Naive Bayes classifier
359    fn fit(&mut self, x: &Bound<'_, PyArray2<f64>>, y: &Bound<'_, PyArray1<f64>>) -> PyResult<()> {
360        let x_array = numpy_to_ndarray2(x)?;
361        let y_array = numpy_to_ndarray1(y)?;
362
363        // Convert y to integer vector for classification
364        let y_int: Vec<i32> = y_array.iter().map(|&val| val as i32).collect();
365        let y_int_array = Array1::from_vec(y_int);
366
367        let model = self.inner.take().ok_or_else(|| {
368            PyRuntimeError::new_err("Model has already been fitted or was not initialized")
369        })?;
370
371        match model.fit(&x_array, &y_int_array) {
372            Ok(trained_model) => {
373                self.trained = Some(trained_model);
374                Ok(())
375            }
376            Err(e) => Err(PyRuntimeError::new_err(format!(
377                "Failed to fit model: {}",
378                e
379            ))),
380        }
381    }
382
383    /// Make predictions using the fitted model
384    fn predict<'py>(
385        &self,
386        py: Python<'py>,
387        x: &Bound<'py, PyArray2<f64>>,
388    ) -> PyResult<Py<PyArray1<f64>>> {
389        let trained_model = self.trained.as_ref().ok_or_else(|| {
390            PyRuntimeError::new_err("Model must be fitted before making predictions")
391        })?;
392
393        let x_array = numpy_to_ndarray2(x)?;
394
395        match trained_model.predict(&x_array) {
396            Ok(predictions) => {
397                let predictions_f64: Vec<f64> = predictions.iter().map(|&v| v as f64).collect();
398                Ok(PyArray1::from_vec(py, predictions_f64).unbind())
399            }
400            Err(e) => Err(PyRuntimeError::new_err(format!("Prediction failed: {}", e))),
401        }
402    }
403
404    fn __repr__(&self) -> String {
405        if self.trained.is_some() {
406            "ComplementNB(fitted=True)".to_string()
407        } else {
408            "ComplementNB(fitted=False)".to_string()
409        }
410    }
411}