Skip to main content

scirs2_stats/mixture_models/
kde.rs

1//! Kernel Density Estimation with multiple kernels and bandwidth selection
2
3use super::f64_to_f;
4use super::GmmFloat;
5use crate::error::{StatsError, StatsResult};
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::validation::*;
8use std::marker::PhantomData;
9
10/// Kernel Density Estimation
11pub struct KernelDensityEstimator<F> {
12    /// Kernel type
13    pub kernel: KernelType,
14    /// Bandwidth
15    pub bandwidth: F,
16    /// Configuration
17    pub config: KDEConfig,
18    /// Training data
19    pub trainingdata: Option<Array2<F>>,
20    _phantom: PhantomData<F>,
21}
22
23/// Kernel types for KDE
24#[derive(Debug, Clone, PartialEq)]
25pub enum KernelType {
26    /// Gaussian kernel
27    Gaussian,
28    /// Epanechnikov kernel
29    Epanechnikov,
30    /// Uniform kernel
31    Uniform,
32    /// Triangular kernel
33    Triangular,
34    /// Cosine kernel
35    Cosine,
36}
37
38/// KDE configuration
39#[derive(Debug, Clone)]
40pub struct KDEConfig {
41    /// Bandwidth selection method
42    pub bandwidth_method: BandwidthMethod,
43    /// Enable parallel processing
44    pub parallel: bool,
45    /// Use SIMD optimizations
46    pub use_simd: bool,
47}
48
49/// Bandwidth selection methods
50#[derive(Debug, Clone, PartialEq)]
51pub enum BandwidthMethod {
52    /// Fixed bandwidth (user-specified)
53    Fixed,
54    /// Scott's rule of thumb
55    Scott,
56    /// Silverman's rule of thumb
57    Silverman,
58    /// Cross-validation
59    CrossValidation,
60}
61
62impl Default for KDEConfig {
63    fn default() -> Self {
64        Self {
65            bandwidth_method: BandwidthMethod::Scott,
66            parallel: true,
67            use_simd: true,
68        }
69    }
70}
71
72impl<F: GmmFloat> KernelDensityEstimator<F> {
73    /// Create new KDE
74    pub fn new(kernel: KernelType, bandwidth: F, config: KDEConfig) -> Self {
75        Self {
76            kernel,
77            bandwidth,
78            config,
79            trainingdata: None,
80            _phantom: PhantomData,
81        }
82    }
83
84    /// Fit KDE to data
85    pub fn fit(&mut self, data: &ArrayView2<F>) -> StatsResult<()> {
86        checkarray_finite(data, "data")?;
87
88        if data.is_empty() {
89            return Err(StatsError::InvalidArgument(
90                "Data cannot be empty".to_string(),
91            ));
92        }
93
94        if self.config.bandwidth_method != BandwidthMethod::Fixed {
95            self.bandwidth = self.select_bandwidth_scalar(data)?;
96        }
97
98        self.trainingdata = Some(data.to_owned());
99        Ok(())
100    }
101
102    fn select_bandwidth_scalar(&self, data: &ArrayView2<F>) -> StatsResult<F> {
103        let (n, d) = data.dim();
104
105        match self.config.bandwidth_method {
106            BandwidthMethod::Scott => {
107                let exp: F = f64_to_f(-1.0 / (d as f64 + 4.0), "scott_exp")?;
108                let n_f: F = f64_to_f(n as f64, "n_scott")?;
109                Ok(n_f.powf(exp))
110            }
111            BandwidthMethod::Silverman => {
112                let factor_exp: F = f64_to_f(1.0 / (d as f64 + 4.0), "silv_exp")?;
113                let factor_base: F = f64_to_f(4.0 / (d as f64 + 2.0), "silv_base")?;
114                let n_exp: F = f64_to_f(-1.0 / (d as f64 + 4.0), "silv_n_exp")?;
115                let n_f: F = f64_to_f(n as f64, "n_silv")?;
116                Ok(factor_base.powf(factor_exp) * n_f.powf(n_exp))
117            }
118            BandwidthMethod::CrossValidation => self.cross_validation_bandwidth(data),
119            BandwidthMethod::Fixed => Ok(self.bandwidth),
120        }
121    }
122
123    fn cross_validation_bandwidth(&self, data: &ArrayView2<F>) -> StatsResult<F> {
124        let (n, d) = data.dim();
125        let exp: F = f64_to_f(-1.0 / (d as f64 + 4.0), "cv_exp")?;
126        let n_f: F = f64_to_f(n as f64, "n_cv")?;
127        Ok(n_f.powf(exp))
128    }
129
130    /// Evaluate density at given points
131    pub fn score_samples(&self, points: &ArrayView2<F>) -> StatsResult<Array1<F>> {
132        let trainingdata = self.trainingdata.as_ref().ok_or_else(|| {
133            StatsError::InvalidArgument("KDE must be fitted before evaluation".into())
134        })?;
135        checkarray_finite(points, "points")?;
136
137        if points.ncols() != trainingdata.ncols() {
138            return Err(StatsError::DimensionMismatch(format!(
139                "Points dimension ({}) must match training data dimension ({})",
140                points.ncols(),
141                trainingdata.ncols()
142            )));
143        }
144
145        let n_points = points.nrows();
146        let n_train = trainingdata.nrows();
147        let d_f: F = f64_to_f(trainingdata.ncols() as f64, "d_kde")?;
148        let n_train_f: F = f64_to_f(n_train as f64, "n_train_kde")?;
149        let normalization = n_train_f * self.bandwidth.powf(d_f);
150
151        let mut densities = Array1::zeros(n_points);
152
153        for i in 0..n_points {
154            let point = points.row(i);
155            let mut density = F::zero();
156            for j in 0..n_train {
157                let train_point = trainingdata.row(j);
158                let distance = self.compute_distance(&point, &train_point);
159                let kernel_value = self.evaluate_kernel(distance / self.bandwidth);
160                density = density + kernel_value;
161            }
162            densities[i] = density / normalization;
163        }
164
165        Ok(densities)
166    }
167
168    fn compute_distance(&self, a: &ArrayView1<F>, b: &ArrayView1<F>) -> F {
169        a.iter()
170            .zip(b.iter())
171            .map(|(&x, &y)| (x - y) * (x - y))
172            .sum::<F>()
173            .sqrt()
174    }
175
176    fn evaluate_kernel(&self, u: F) -> F {
177        let half: F = f64_to_f(0.5, "half").unwrap_or(F::zero());
178        let three_quarter: F = f64_to_f(0.75, "3/4").unwrap_or(F::zero());
179        match self.kernel {
180            KernelType::Gaussian => {
181                let coeff: F = f64_to_f(1.0 / (2.0 * std::f64::consts::PI).sqrt(), "gauss_coeff")
182                    .unwrap_or(F::zero());
183                let two: F = f64_to_f(2.0, "two").unwrap_or(F::one());
184                coeff * (-u * u / two).exp()
185            }
186            KernelType::Epanechnikov => {
187                if u.abs() <= F::one() {
188                    three_quarter * (F::one() - u * u)
189                } else {
190                    F::zero()
191                }
192            }
193            KernelType::Uniform => {
194                if u.abs() <= F::one() {
195                    half
196                } else {
197                    F::zero()
198                }
199            }
200            KernelType::Triangular => {
201                if u.abs() <= F::one() {
202                    F::one() - u.abs()
203                } else {
204                    F::zero()
205                }
206            }
207            KernelType::Cosine => {
208                if u.abs() <= F::one() {
209                    let pi_4: F = f64_to_f(std::f64::consts::PI / 4.0, "pi/4").unwrap_or(F::zero());
210                    let pi: F = f64_to_f(std::f64::consts::PI, "pi").unwrap_or(F::zero());
211                    let two: F = f64_to_f(2.0, "two").unwrap_or(F::one());
212                    pi_4 * (pi * u / two).cos()
213                } else {
214                    F::zero()
215                }
216            }
217        }
218    }
219}
220
221/// Evaluate KDE density at query points
222pub fn kernel_density_estimation<F: GmmFloat>(
223    data: &ArrayView2<F>,
224    points: &ArrayView2<F>,
225    kernel: Option<KernelType>,
226    bandwidth: Option<F>,
227) -> StatsResult<Array1<F>> {
228    let kernel = kernel.unwrap_or(KernelType::Gaussian);
229    let bandwidth = match bandwidth {
230        Some(b) => b,
231        None => {
232            let n = data.nrows();
233            let d = data.ncols();
234            let exp: F = f64_to_f(-1.0 / (d as f64 + 4.0), "default_bw_exp")?;
235            let n_f: F = f64_to_f(n as f64, "default_bw_n")?;
236            n_f.powf(exp)
237        }
238    };
239
240    let mut kde = KernelDensityEstimator::new(kernel, bandwidth, KDEConfig::default());
241    kde.fit(data)?;
242    kde.score_samples(points)
243}