sklears_naive_bayes/
lib.rs

1#![allow(dead_code)]
2#![allow(non_snake_case)]
3#![allow(missing_docs)]
4#![allow(deprecated)]
5#![allow(clippy::all)]
6#![allow(clippy::pedantic)]
7#![allow(clippy::nursery)]
8#![allow(unused_imports)]
9#![allow(unused_variables)]
10#![allow(unused_mut)]
11#![allow(unused_assignments)]
12#![allow(unused_doc_comments)]
13#![allow(unused_parens)]
14#![allow(unused_comparisons)]
15//! Naive Bayes classifiers
16//!
17//! This module provides various Naive Bayes classifiers for different
18//! types of features, compatible with scikit-learn's naive_bayes module.
19
20// SciRS2 Policy Compliance - Use scirs2-autograd for ndarray types
21use scirs2_core::ndarray::{Array1, Array2};
22use scirs2_core::numeric::Float;
23
24mod adaptive_smoothing;
25mod api_builder;
26// TODO: Migrate to scirs2-linalg (uses nalgebra types)
27//mod attention_naive_bayes;
28mod bayesian_network;
29mod benchmarks;
30mod bernoulli;
31mod beta;
32mod bioinformatics;
33mod categorical;
34mod causal_inference;
35mod complement;
36mod computer_vision;
37// TODO: Migrate to scirs2-linalg (uses nalgebra types)
38//mod continual_learning;
39mod deep_generative;
40mod dirichlet_process;
41mod ensemble;
42mod exponential_family;
43mod feature_engineering;
44// TODO: Migrate to scirs2-linalg (uses nalgebra types)
45//mod federated;
46// TODO: Migrate to scirs2-linalg (uses nalgebra types)
47//mod finance;
48mod flexible;
49mod gamma;
50mod gaussian;
51mod hierarchical;
52mod kernel_methods;
53mod mixed;
54// TODO: ndarray 0.17 HRTB trait bound issues
55//mod model_selection;
56mod multilabel;
57mod multinomial;
58mod neural_naive_bayes;
59mod nonparametric;
60mod online_learning;
61mod optimizations;
62mod parameter_estimation;
63mod performance;
64mod plugin_architecture;
65mod poisson;
66// TODO: Migrate to scirs2-linalg (uses nalgebra types)
67//mod quantum;
68mod semi_naive;
69mod smoothing;
70mod temporal;
71mod text_classification;
72mod tree_augmented;
73mod type_safe_prob;
74mod uncertainty;
75mod validation;
76mod variational_bayes;
77
78#[allow(non_snake_case)]
79#[cfg(test)]
80mod property_tests;
81
82pub use adaptive_smoothing::{
83    AdaptiveSmoothing, AdaptiveSmoothingMethod, HyperparameterOptimizer,
84    InformationCriterion as AdaptiveInformationCriterion, ScoringMethod,
85};
86pub use api_builder::{
87    naive_bayes, naive_bayes_preset, FluentBernoulliNB, FluentCategoricalNB, FluentComplementNB,
88    FluentGaussianNB, FluentMultinomialNB, FluentNaiveBayesModel, FluentPoissonNB,
89    NaiveBayesBuilder, NaiveBayesPreset, SerializableNBParams,
90};
91// TODO: Migrate to scirs2-linalg (uses nalgebra types)
92// pub use attention_naive_bayes::{
93//     AttentionNBBuilder, AttentionNBConfig, AttentionNBError, AttentionNaiveBayes, AttentionType,
94//     ImportanceScoring,
95// };
96pub use bayesian_network::{
97    BANConfig, BANError, BayesianNetwork, BayesianNetworkAugmentedNB, NetworkEdge,
98    StructureLearningMethod,
99};
100pub use benchmarks::{BenchmarkConfig, BenchmarkResults, NaiveBayesBenchmark};
101pub use bernoulli::BernoulliNB;
102pub use beta::BetaNB;
103pub use bioinformatics::{
104    AminoAcid, BioinformaticsError, BiomarkerDiscoveryNB, GeneExpressionConfig, GeneExpressionNB,
105    GenomicNBConfig, GenomicNaiveBayes, GenomicSequence, Nucleotide, PhylogeneticConfig,
106    PhylogeneticNB, ProteinStructureConfig, ProteinStructureNB, SecondaryStructure,
107    SequenceMetadata, SequenceType,
108};
109pub use categorical::CategoricalNB;
110pub use causal_inference::{
111    CausalDiscovery, CausalDiscoveryAlgorithm, CausalDiscoveryConfig, CausalGraph,
112    CausalInferenceError, CausalNaiveBayes, CounterfactualReasoning, DoCalculus,
113    InstrumentalVariables,
114};
115pub use complement::ComplementNB;
116pub use computer_vision::{
117    utils as cv_utils, ColorSpace, ComputerVisionError, ImageData, ImageMetadata, ImageNBConfig,
118    ImageNaiveBayes, NeighborhoodStats, SpatialModel, SpatialNBConfig, SpatialNaiveBayes,
119};
120// TODO: Migrate to scirs2-linalg (uses nalgebra types)
121// pub use continual_learning::{
122//     ContinualLearningConfig, ContinualLearningError, ContinualLearningNB,
123//     ContinualLearningNBBuilder, ContinualLearningStrategy, DriftDetectionMethod, MemoryStrategy,
124//     TaskMetadata,
125// };
126pub use deep_generative::{
127    DeepGenerativeConfig, DeepGenerativeError, DeepGenerativeNaiveBayes, FlowConfig, NPEConfig,
128    NeuralPosteriorEstimator, NormalizingFlow, VAEConfig, VariationalAutoencoder,
129};
130pub use dirichlet_process::{
131    BaseDistributionParams, DirichletComponent, DirichletProcessConfig, DirichletProcessError,
132    DirichletProcessNB, InferenceMethod, StickBreaking,
133};
134pub use ensemble::{
135    AdaBoostNaiveBayes, AveragingNaiveBayes, BaggedNaiveBayes, BoostingStrategy, BootstrapConfig,
136    NaiveBayesEstimator, StackingNaiveBayes, VotingNaiveBayes, VotingStrategy,
137};
138pub use exponential_family::{
139    ExponentialFamily, ExponentialFamilyNB, ExponentialFamilyNBConfig, NaturalParameters,
140    ParameterEstimationMethod, SufficientStatistics,
141};
142pub use feature_engineering::{
143    AutoFeatureTransformer, AutoPipelineConfig, AutoTransformConfig,
144    AutomatedPreprocessingPipeline, DataValidationConfig, FeatureInteractionDetector,
145    FeatureSelectionMethod, FeatureSelectionResults, FeatureType as AutoFeatureType,
146    ImbalanceHandlingMethod, InteractionMethod, InteractionResults, MissingValueStrategy,
147    OutlierDetectionMethod, StatisticalTest, TransformMethod,
148};
149// TODO: Migrate to scirs2-linalg (uses nalgebra types)
150// pub use federated::{
151//     AggregationMethod, AggregationStrategy, BudgetAllocationStrategy, ClientId, ClientModel,
152//     CommunicationParams, CompressionStrategy, FederatedError, FederatedNaiveBayes,
153//     FederationParams, FederationStatistics, GlobalModel, LocalUpdate, PrivacyParams, PrivateUpdate,
154// };
155// TODO: Migrate to scirs2-linalg (uses nalgebra types)
156// pub use finance::{
157//     CreditData, CreditRisk, CreditScoringNB, FinanceError, FinancialFeatures,
158//     FinancialTimeSeriesNB, FraudDetectionNB, FraudLabel, PortfolioCategory,
159//     PortfolioClassificationNB, PortfolioData, RiskAssessmentNB, RiskLevel, TransactionData,
160// };
161pub use flexible::{
162    Distribution, DistributionParams, FlexibleNB, FlexibleNBConfig, SelectionMethod,
163};
164pub use gamma::GammaNB;
165pub use gaussian::GaussianNB;
166pub use hierarchical::{
167    ClassHierarchy, HierarchicalConfig, HierarchicalError, HierarchicalNB, HierarchyNode,
168    PredictionStrategy,
169};
170pub use kernel_methods::{
171    BandwidthMethod as KernelBandwidthMethod, GaussianProcess, KDEConfig,
172    KernelDensityEstimator as AdvancedKDE, KernelError, KernelNaiveBayes, KernelParameterLearner,
173    KernelType as AdvancedKernelType, RKHSFeatureSelector, ReproducingKernelHilbertSpace,
174    ScoringMetric as KernelScoringMetric,
175};
176pub use mixed::{FeatureDistribution, MixedNB};
177// TODO: ndarray 0.17 HRTB trait bound issues
178// pub use model_selection::{
179//     BayesianModelComparison, BayesianModelSelector, CVResults, CVStrategy, InformationCriterion,
180//     ModelComparison, ModelSelectionResults, NaiveBayesModelSelector, NestedModelComparison,
181//     NestedModelValidation, ParameterGrid, ParameterValue, ScoringMetric,
182// };
183pub use multilabel::{
184    AdvancedChainClassifier, ChainOrderingStrategy, LabelCorrelationAnalysis, LabelDependencyGraph,
185    LabelHierarchy, MultiLabelNB, MultiLabelStrategy,
186};
187pub use multinomial::MultinomialNB;
188pub use neural_naive_bayes::{
189    ActivationFunction, NeuralLayer, NeuralNBBuilder, NeuralNBConfig, NeuralNBError,
190    NeuralNaiveBayes,
191};
192pub use nonparametric::{
193    BandwidthMethod, KernelDensityEstimator, KernelType, NonparametricNB, NonparametricNBConfig,
194};
195pub use online_learning::{
196    CSVChunkReader,
197    ConceptDriftDetector,
198    // Out-of-core learning types
199    DataChunkIterator,
200    MemoryChunkIterator,
201    OnlineGaussianStats,
202    OnlineLearningConfig,
203    OnlineLearningError,
204    OnlineMultinomialStats,
205    OnlineNaiveBayes,
206    OutOfCoreNaiveBayes,
207    StreamingBuffer,
208};
209pub use optimizations::{
210    cache_optimization, memory_efficient, numerical_stability as advanced_numerical_stability,
211    parallel, profile_guided_optimization, simd, unsafe_optimizations,
212};
213pub use parameter_estimation::{
214    CrossValidationSelector, EmpiricalBayesEstimator, GaussianEstimator, MultinomialEstimator,
215    ParameterEstimator,
216};
217pub use performance::{
218    CompressedNBModel, CompressionMetadata, LazyLoadedModel, MemoryMappedNBModel,
219    MemoryOptimizedOps, MemoryStats, SparseMatrix, SparseRow, VectorizedOps,
220};
221pub use plugin_architecture::{
222    get_global_registry, register_distribution, register_middleware, register_model_selector,
223    register_parameter_estimator, register_smoothing_method, ComposableSmoothingMethod,
224    DataCharacteristics, DataType, EstimationResult, ExtensibleParameterEstimator,
225    FlexibleModelSelector, ModelCandidate, PluggableDistribution, PluginRegistry,
226    PredictionContext, PredictionMiddleware, SelectionCriterion,
227};
228pub use poisson::PoissonNB;
229// TODO: Migrate to scirs2-linalg (uses nalgebra types)
230// pub use quantum::{
231//     EntanglementPattern, GateType, HybridQuantumClassicalNB, QuantumAdvantageMetrics,
232//     QuantumCircuitParams, QuantumError, QuantumFeatureMap, QuantumNaiveBayes, QuantumState,
233//     RotationAxis,
234// };
235pub use semi_naive::{
236    DependencySelectionMethod, FeatureDependency, SemiNaiveBayes, SemiNaiveBayesConfig,
237};
238pub use smoothing::{
239    enhanced_log, log_sum_exp, normalize_log_probs, numerical_stability, SmoothingMethod,
240};
241pub use temporal::{
242    CovarianceType, HMMConfig, HMMNaiveBayes, StreamingTemporalNB, TemporalConfig,
243    TemporalFeatureExtractor, TemporalNaiveBayes,
244};
245pub use text_classification::{
246    LDATopicModel, NGramExtractor, TextMultinomialNB, TextMultinomialNBConfig, TextPreprocessor,
247    TfIdfTransformer, TopicAugmentedConfig, TopicAugmentedTextClassifier,
248};
249pub use tree_augmented::{DependencyTree, TANConfig, TANError, TreeAugmentedNB, TreeEdge};
250pub use type_safe_prob::{
251    const_utils, distribution_types, feature_types, utils as prob_utils, FixedSizeModel,
252    ProbabilisticModel as TypeSafeProbabilisticModel, TypedProbability, TypedProbabilityOps,
253    ValidateFeatureType,
254};
255pub use uncertainty::{
256    BayesianUncertainty, CalibrationMethod, EnsemblePredictions, ModelUncertaintyPropagation,
257    ReliabilityTests, StandardUncertainty, UncertaintyDecomposition, UncertaintyMeasures,
258    UncertaintyQuantification,
259};
260pub use validation::{
261    CVResults as ValidationCVResults, CVStrategy as ValidationCVStrategy, CalibrationMetrics,
262    GoodnessOfFitResults, ModelCriticismResults, ModelDegradationMetrics, PPCResults,
263    PredictionStabilityMetrics, PredictiveAccuracyAssessment, PredictiveAccuracyAssessor,
264    ProbabilisticModel, ProbabilisticValidator, ResidualAnalysis,
265    ScoringMetric as ValidationScoringMetric, TemporalValidationResults, ValidationError,
266};
267pub use variational_bayes::{
268    ELBOComponents, FeatureType, VariationalBayesConfig, VariationalBayesNB, VariationalError,
269    VariationalInferenceMethod, VariationalParameters,
270};
271
272/// Base trait for Naive Bayes classifiers
273pub trait NaiveBayesMixin {
274    /// Get log prior probabilities
275    fn class_log_prior(&self) -> &Array1<f64>;
276
277    /// Get feature log probabilities
278    fn feature_log_prob(&self) -> &Array2<f64>;
279
280    /// Get the classes
281    fn classes(&self) -> &Array1<i32>;
282}
283
284/// Helper function to compute class counts and priors
285fn compute_class_prior(y: &Array1<i32>, classes: &Array1<i32>) -> (Array1<f64>, Array1<f64>) {
286    let n_samples = y.len() as f64;
287    let mut class_count = Array1::zeros(classes.len());
288
289    for &label in y.iter() {
290        for (i, &class) in classes.iter().enumerate() {
291            if label == class {
292                class_count[i] += 1.0;
293            }
294        }
295    }
296
297    let class_prior = &class_count / n_samples;
298    (class_count, class_prior)
299}
300
301/// Helper function to compute log probability
302fn safe_log<F: Float>(x: F) -> F {
303    let epsilon = F::from(1e-10).unwrap();
304    (x + epsilon).ln()
305}