1#![allow(dead_code)]
2#![allow(non_snake_case)]
3#![allow(missing_docs)]
4#![allow(deprecated)]
5#![allow(clippy::all)]
6#![allow(clippy::pedantic)]
7#![allow(clippy::nursery)]
8#![allow(unused_imports)]
9#![allow(unused_variables)]
10#![allow(unused_mut)]
11#![allow(unused_assignments)]
12#![allow(unused_doc_comments)]
13#![allow(unused_parens)]
14#![allow(unused_comparisons)]
15use scirs2_core::ndarray::{Array1, Array2};
22use scirs2_core::numeric::Float;
23
24mod adaptive_smoothing;
25mod api_builder;
26mod bayesian_network;
29mod benchmarks;
30mod bernoulli;
31mod beta;
32mod bioinformatics;
33mod categorical;
34mod causal_inference;
35mod complement;
36mod computer_vision;
37mod deep_generative;
40mod dirichlet_process;
41mod ensemble;
42mod exponential_family;
43mod feature_engineering;
44mod flexible;
49mod gamma;
50mod gaussian;
51mod hierarchical;
52mod kernel_methods;
53mod mixed;
54mod multilabel;
57mod multinomial;
58mod neural_naive_bayes;
59mod nonparametric;
60mod online_learning;
61mod optimizations;
62mod parameter_estimation;
63mod performance;
64mod plugin_architecture;
65mod poisson;
66mod semi_naive;
69mod smoothing;
70mod temporal;
71mod text_classification;
72mod tree_augmented;
73mod type_safe_prob;
74mod uncertainty;
75mod validation;
76mod variational_bayes;
77
78#[allow(non_snake_case)]
79#[cfg(test)]
80mod property_tests;
81
82pub use adaptive_smoothing::{
83 AdaptiveSmoothing, AdaptiveSmoothingMethod, HyperparameterOptimizer,
84 InformationCriterion as AdaptiveInformationCriterion, ScoringMethod,
85};
86pub use api_builder::{
87 naive_bayes, naive_bayes_preset, FluentBernoulliNB, FluentCategoricalNB, FluentComplementNB,
88 FluentGaussianNB, FluentMultinomialNB, FluentNaiveBayesModel, FluentPoissonNB,
89 NaiveBayesBuilder, NaiveBayesPreset, SerializableNBParams,
90};
91pub use bayesian_network::{
97 BANConfig, BANError, BayesianNetwork, BayesianNetworkAugmentedNB, NetworkEdge,
98 StructureLearningMethod,
99};
100pub use benchmarks::{BenchmarkConfig, BenchmarkResults, NaiveBayesBenchmark};
101pub use bernoulli::BernoulliNB;
102pub use beta::BetaNB;
103pub use bioinformatics::{
104 AminoAcid, BioinformaticsError, BiomarkerDiscoveryNB, GeneExpressionConfig, GeneExpressionNB,
105 GenomicNBConfig, GenomicNaiveBayes, GenomicSequence, Nucleotide, PhylogeneticConfig,
106 PhylogeneticNB, ProteinStructureConfig, ProteinStructureNB, SecondaryStructure,
107 SequenceMetadata, SequenceType,
108};
109pub use categorical::CategoricalNB;
110pub use causal_inference::{
111 CausalDiscovery, CausalDiscoveryAlgorithm, CausalDiscoveryConfig, CausalGraph,
112 CausalInferenceError, CausalNaiveBayes, CounterfactualReasoning, DoCalculus,
113 InstrumentalVariables,
114};
115pub use complement::ComplementNB;
116pub use computer_vision::{
117 utils as cv_utils, ColorSpace, ComputerVisionError, ImageData, ImageMetadata, ImageNBConfig,
118 ImageNaiveBayes, NeighborhoodStats, SpatialModel, SpatialNBConfig, SpatialNaiveBayes,
119};
120pub use deep_generative::{
127 DeepGenerativeConfig, DeepGenerativeError, DeepGenerativeNaiveBayes, FlowConfig, NPEConfig,
128 NeuralPosteriorEstimator, NormalizingFlow, VAEConfig, VariationalAutoencoder,
129};
130pub use dirichlet_process::{
131 BaseDistributionParams, DirichletComponent, DirichletProcessConfig, DirichletProcessError,
132 DirichletProcessNB, InferenceMethod, StickBreaking,
133};
134pub use ensemble::{
135 AdaBoostNaiveBayes, AveragingNaiveBayes, BaggedNaiveBayes, BoostingStrategy, BootstrapConfig,
136 NaiveBayesEstimator, StackingNaiveBayes, VotingNaiveBayes, VotingStrategy,
137};
138pub use exponential_family::{
139 ExponentialFamily, ExponentialFamilyNB, ExponentialFamilyNBConfig, NaturalParameters,
140 ParameterEstimationMethod, SufficientStatistics,
141};
142pub use feature_engineering::{
143 AutoFeatureTransformer, AutoPipelineConfig, AutoTransformConfig,
144 AutomatedPreprocessingPipeline, DataValidationConfig, FeatureInteractionDetector,
145 FeatureSelectionMethod, FeatureSelectionResults, FeatureType as AutoFeatureType,
146 ImbalanceHandlingMethod, InteractionMethod, InteractionResults, MissingValueStrategy,
147 OutlierDetectionMethod, StatisticalTest, TransformMethod,
148};
149pub use flexible::{
162 Distribution, DistributionParams, FlexibleNB, FlexibleNBConfig, SelectionMethod,
163};
164pub use gamma::GammaNB;
165pub use gaussian::GaussianNB;
166pub use hierarchical::{
167 ClassHierarchy, HierarchicalConfig, HierarchicalError, HierarchicalNB, HierarchyNode,
168 PredictionStrategy,
169};
170pub use kernel_methods::{
171 BandwidthMethod as KernelBandwidthMethod, GaussianProcess, KDEConfig,
172 KernelDensityEstimator as AdvancedKDE, KernelError, KernelNaiveBayes, KernelParameterLearner,
173 KernelType as AdvancedKernelType, RKHSFeatureSelector, ReproducingKernelHilbertSpace,
174 ScoringMetric as KernelScoringMetric,
175};
176pub use mixed::{FeatureDistribution, MixedNB};
177pub use multilabel::{
184 AdvancedChainClassifier, ChainOrderingStrategy, LabelCorrelationAnalysis, LabelDependencyGraph,
185 LabelHierarchy, MultiLabelNB, MultiLabelStrategy,
186};
187pub use multinomial::MultinomialNB;
188pub use neural_naive_bayes::{
189 ActivationFunction, NeuralLayer, NeuralNBBuilder, NeuralNBConfig, NeuralNBError,
190 NeuralNaiveBayes,
191};
192pub use nonparametric::{
193 BandwidthMethod, KernelDensityEstimator, KernelType, NonparametricNB, NonparametricNBConfig,
194};
195pub use online_learning::{
196 CSVChunkReader,
197 ConceptDriftDetector,
198 DataChunkIterator,
200 MemoryChunkIterator,
201 OnlineGaussianStats,
202 OnlineLearningConfig,
203 OnlineLearningError,
204 OnlineMultinomialStats,
205 OnlineNaiveBayes,
206 OutOfCoreNaiveBayes,
207 StreamingBuffer,
208};
209pub use optimizations::{
210 cache_optimization, memory_efficient, numerical_stability as advanced_numerical_stability,
211 parallel, profile_guided_optimization, simd, unsafe_optimizations,
212};
213pub use parameter_estimation::{
214 CrossValidationSelector, EmpiricalBayesEstimator, GaussianEstimator, MultinomialEstimator,
215 ParameterEstimator,
216};
217pub use performance::{
218 CompressedNBModel, CompressionMetadata, LazyLoadedModel, MemoryMappedNBModel,
219 MemoryOptimizedOps, MemoryStats, SparseMatrix, SparseRow, VectorizedOps,
220};
221pub use plugin_architecture::{
222 get_global_registry, register_distribution, register_middleware, register_model_selector,
223 register_parameter_estimator, register_smoothing_method, ComposableSmoothingMethod,
224 DataCharacteristics, DataType, EstimationResult, ExtensibleParameterEstimator,
225 FlexibleModelSelector, ModelCandidate, PluggableDistribution, PluginRegistry,
226 PredictionContext, PredictionMiddleware, SelectionCriterion,
227};
228pub use poisson::PoissonNB;
229pub use semi_naive::{
236 DependencySelectionMethod, FeatureDependency, SemiNaiveBayes, SemiNaiveBayesConfig,
237};
238pub use smoothing::{
239 enhanced_log, log_sum_exp, normalize_log_probs, numerical_stability, SmoothingMethod,
240};
241pub use temporal::{
242 CovarianceType, HMMConfig, HMMNaiveBayes, StreamingTemporalNB, TemporalConfig,
243 TemporalFeatureExtractor, TemporalNaiveBayes,
244};
245pub use text_classification::{
246 LDATopicModel, NGramExtractor, TextMultinomialNB, TextMultinomialNBConfig, TextPreprocessor,
247 TfIdfTransformer, TopicAugmentedConfig, TopicAugmentedTextClassifier,
248};
249pub use tree_augmented::{DependencyTree, TANConfig, TANError, TreeAugmentedNB, TreeEdge};
250pub use type_safe_prob::{
251 const_utils, distribution_types, feature_types, utils as prob_utils, FixedSizeModel,
252 ProbabilisticModel as TypeSafeProbabilisticModel, TypedProbability, TypedProbabilityOps,
253 ValidateFeatureType,
254};
255pub use uncertainty::{
256 BayesianUncertainty, CalibrationMethod, EnsemblePredictions, ModelUncertaintyPropagation,
257 ReliabilityTests, StandardUncertainty, UncertaintyDecomposition, UncertaintyMeasures,
258 UncertaintyQuantification,
259};
260pub use validation::{
261 CVResults as ValidationCVResults, CVStrategy as ValidationCVStrategy, CalibrationMetrics,
262 GoodnessOfFitResults, ModelCriticismResults, ModelDegradationMetrics, PPCResults,
263 PredictionStabilityMetrics, PredictiveAccuracyAssessment, PredictiveAccuracyAssessor,
264 ProbabilisticModel, ProbabilisticValidator, ResidualAnalysis,
265 ScoringMetric as ValidationScoringMetric, TemporalValidationResults, ValidationError,
266};
267pub use variational_bayes::{
268 ELBOComponents, FeatureType, VariationalBayesConfig, VariationalBayesNB, VariationalError,
269 VariationalInferenceMethod, VariationalParameters,
270};
271
272pub trait NaiveBayesMixin {
274 fn class_log_prior(&self) -> &Array1<f64>;
276
277 fn feature_log_prob(&self) -> &Array2<f64>;
279
280 fn classes(&self) -> &Array1<i32>;
282}
283
284fn compute_class_prior(y: &Array1<i32>, classes: &Array1<i32>) -> (Array1<f64>, Array1<f64>) {
286 let n_samples = y.len() as f64;
287 let mut class_count = Array1::zeros(classes.len());
288
289 for &label in y.iter() {
290 for (i, &class) in classes.iter().enumerate() {
291 if label == class {
292 class_count[i] += 1.0;
293 }
294 }
295 }
296
297 let class_prior = &class_count / n_samples;
298 (class_count, class_prior)
299}
300
301fn safe_log<F: Float>(x: F) -> F {
303 let epsilon = F::from(1e-10).unwrap();
304 (x + epsilon).ln()
305}