scirs2_stats/mixture_models/
kde.rs1use super::f64_to_f;
4use super::GmmFloat;
5use crate::error::{StatsError, StatsResult};
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::validation::*;
8use std::marker::PhantomData;
9
10pub struct KernelDensityEstimator<F> {
12 pub kernel: KernelType,
14 pub bandwidth: F,
16 pub config: KDEConfig,
18 pub trainingdata: Option<Array2<F>>,
20 _phantom: PhantomData<F>,
21}
22
23#[derive(Debug, Clone, PartialEq)]
25pub enum KernelType {
26 Gaussian,
28 Epanechnikov,
30 Uniform,
32 Triangular,
34 Cosine,
36}
37
38#[derive(Debug, Clone)]
40pub struct KDEConfig {
41 pub bandwidth_method: BandwidthMethod,
43 pub parallel: bool,
45 pub use_simd: bool,
47}
48
49#[derive(Debug, Clone, PartialEq)]
51pub enum BandwidthMethod {
52 Fixed,
54 Scott,
56 Silverman,
58 CrossValidation,
60}
61
62impl Default for KDEConfig {
63 fn default() -> Self {
64 Self {
65 bandwidth_method: BandwidthMethod::Scott,
66 parallel: true,
67 use_simd: true,
68 }
69 }
70}
71
72impl<F: GmmFloat> KernelDensityEstimator<F> {
73 pub fn new(kernel: KernelType, bandwidth: F, config: KDEConfig) -> Self {
75 Self {
76 kernel,
77 bandwidth,
78 config,
79 trainingdata: None,
80 _phantom: PhantomData,
81 }
82 }
83
84 pub fn fit(&mut self, data: &ArrayView2<F>) -> StatsResult<()> {
86 checkarray_finite(data, "data")?;
87
88 if data.is_empty() {
89 return Err(StatsError::InvalidArgument(
90 "Data cannot be empty".to_string(),
91 ));
92 }
93
94 if self.config.bandwidth_method != BandwidthMethod::Fixed {
95 self.bandwidth = self.select_bandwidth_scalar(data)?;
96 }
97
98 self.trainingdata = Some(data.to_owned());
99 Ok(())
100 }
101
102 fn select_bandwidth_scalar(&self, data: &ArrayView2<F>) -> StatsResult<F> {
103 let (n, d) = data.dim();
104
105 match self.config.bandwidth_method {
106 BandwidthMethod::Scott => {
107 let exp: F = f64_to_f(-1.0 / (d as f64 + 4.0), "scott_exp")?;
108 let n_f: F = f64_to_f(n as f64, "n_scott")?;
109 Ok(n_f.powf(exp))
110 }
111 BandwidthMethod::Silverman => {
112 let factor_exp: F = f64_to_f(1.0 / (d as f64 + 4.0), "silv_exp")?;
113 let factor_base: F = f64_to_f(4.0 / (d as f64 + 2.0), "silv_base")?;
114 let n_exp: F = f64_to_f(-1.0 / (d as f64 + 4.0), "silv_n_exp")?;
115 let n_f: F = f64_to_f(n as f64, "n_silv")?;
116 Ok(factor_base.powf(factor_exp) * n_f.powf(n_exp))
117 }
118 BandwidthMethod::CrossValidation => self.cross_validation_bandwidth(data),
119 BandwidthMethod::Fixed => Ok(self.bandwidth),
120 }
121 }
122
123 fn cross_validation_bandwidth(&self, data: &ArrayView2<F>) -> StatsResult<F> {
124 let (n, d) = data.dim();
125 let exp: F = f64_to_f(-1.0 / (d as f64 + 4.0), "cv_exp")?;
126 let n_f: F = f64_to_f(n as f64, "n_cv")?;
127 Ok(n_f.powf(exp))
128 }
129
130 pub fn score_samples(&self, points: &ArrayView2<F>) -> StatsResult<Array1<F>> {
132 let trainingdata = self.trainingdata.as_ref().ok_or_else(|| {
133 StatsError::InvalidArgument("KDE must be fitted before evaluation".into())
134 })?;
135 checkarray_finite(points, "points")?;
136
137 if points.ncols() != trainingdata.ncols() {
138 return Err(StatsError::DimensionMismatch(format!(
139 "Points dimension ({}) must match training data dimension ({})",
140 points.ncols(),
141 trainingdata.ncols()
142 )));
143 }
144
145 let n_points = points.nrows();
146 let n_train = trainingdata.nrows();
147 let d_f: F = f64_to_f(trainingdata.ncols() as f64, "d_kde")?;
148 let n_train_f: F = f64_to_f(n_train as f64, "n_train_kde")?;
149 let normalization = n_train_f * self.bandwidth.powf(d_f);
150
151 let mut densities = Array1::zeros(n_points);
152
153 for i in 0..n_points {
154 let point = points.row(i);
155 let mut density = F::zero();
156 for j in 0..n_train {
157 let train_point = trainingdata.row(j);
158 let distance = self.compute_distance(&point, &train_point);
159 let kernel_value = self.evaluate_kernel(distance / self.bandwidth);
160 density = density + kernel_value;
161 }
162 densities[i] = density / normalization;
163 }
164
165 Ok(densities)
166 }
167
168 fn compute_distance(&self, a: &ArrayView1<F>, b: &ArrayView1<F>) -> F {
169 a.iter()
170 .zip(b.iter())
171 .map(|(&x, &y)| (x - y) * (x - y))
172 .sum::<F>()
173 .sqrt()
174 }
175
176 fn evaluate_kernel(&self, u: F) -> F {
177 let half: F = f64_to_f(0.5, "half").unwrap_or(F::zero());
178 let three_quarter: F = f64_to_f(0.75, "3/4").unwrap_or(F::zero());
179 match self.kernel {
180 KernelType::Gaussian => {
181 let coeff: F = f64_to_f(1.0 / (2.0 * std::f64::consts::PI).sqrt(), "gauss_coeff")
182 .unwrap_or(F::zero());
183 let two: F = f64_to_f(2.0, "two").unwrap_or(F::one());
184 coeff * (-u * u / two).exp()
185 }
186 KernelType::Epanechnikov => {
187 if u.abs() <= F::one() {
188 three_quarter * (F::one() - u * u)
189 } else {
190 F::zero()
191 }
192 }
193 KernelType::Uniform => {
194 if u.abs() <= F::one() {
195 half
196 } else {
197 F::zero()
198 }
199 }
200 KernelType::Triangular => {
201 if u.abs() <= F::one() {
202 F::one() - u.abs()
203 } else {
204 F::zero()
205 }
206 }
207 KernelType::Cosine => {
208 if u.abs() <= F::one() {
209 let pi_4: F = f64_to_f(std::f64::consts::PI / 4.0, "pi/4").unwrap_or(F::zero());
210 let pi: F = f64_to_f(std::f64::consts::PI, "pi").unwrap_or(F::zero());
211 let two: F = f64_to_f(2.0, "two").unwrap_or(F::one());
212 pi_4 * (pi * u / two).cos()
213 } else {
214 F::zero()
215 }
216 }
217 }
218 }
219}
220
221pub fn kernel_density_estimation<F: GmmFloat>(
223 data: &ArrayView2<F>,
224 points: &ArrayView2<F>,
225 kernel: Option<KernelType>,
226 bandwidth: Option<F>,
227) -> StatsResult<Array1<F>> {
228 let kernel = kernel.unwrap_or(KernelType::Gaussian);
229 let bandwidth = match bandwidth {
230 Some(b) => b,
231 None => {
232 let n = data.nrows();
233 let d = data.ncols();
234 let exp: F = f64_to_f(-1.0 / (d as f64 + 4.0), "default_bw_exp")?;
235 let n_f: F = f64_to_f(n as f64, "default_bw_n")?;
236 n_f.powf(exp)
237 }
238 };
239
240 let mut kde = KernelDensityEstimator::new(kernel, bandwidth, KDEConfig::default());
241 kde.fit(data)?;
242 kde.score_samples(points)
243}