Skip to main content

scirs2_stats/mixture_models/
variational.rs

1//! Variational Bayesian Gaussian Mixture Model
2
3use crate::error::{StatsError, StatsResult};
4use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2};
5use scirs2_core::numeric::{Float, FromPrimitive, One, Zero};
6use scirs2_core::random::Rng;
7use scirs2_core::simd_ops::SimdUnifiedOps;
8use std::marker::PhantomData;
9
10/// Variational Bayesian Gaussian Mixture Model
11pub struct VariationalGMM<F> {
12    /// Maximum number of components
13    pub max_components: usize,
14    /// Configuration
15    pub config: VariationalGMMConfig,
16    /// Fitted parameters
17    pub parameters: Option<VariationalGMMParameters<F>>,
18    /// Lower bound history
19    pub lower_bound_history: Vec<F>,
20    _phantom: PhantomData<F>,
21}
22
23/// Configuration for Variational GMM
24#[derive(Debug, Clone)]
25pub struct VariationalGMMConfig {
26    /// Maximum iterations
27    pub max_iter: usize,
28    /// Convergence tolerance
29    pub tolerance: f64,
30    /// Concentration parameter for Dirichlet prior
31    pub alpha: f64,
32    /// Degrees of freedom for Wishart prior
33    pub nu: f64,
34    /// Prior mean
35    pub mean_prior: Option<Vec<f64>>,
36    /// Prior precision matrix
37    pub precision_prior: Option<Vec<Vec<f64>>>,
38    /// Enable automatic relevance determination
39    pub ard: bool,
40    /// Random seed
41    pub seed: Option<u64>,
42}
43
44impl Default for VariationalGMMConfig {
45    fn default() -> Self {
46        Self {
47            max_iter: 100,
48            tolerance: 1e-6,
49            alpha: 1.0,
50            nu: 1.0,
51            mean_prior: None,
52            precision_prior: None,
53            ard: true,
54            seed: None,
55        }
56    }
57}
58
59/// Variational GMM parameters
60#[derive(Debug, Clone)]
61pub struct VariationalGMMParameters<F> {
62    /// Component weights (posterior Dirichlet parameters)
63    pub weight_concentration: Array1<F>,
64    /// Component means (posterior normal parameters)
65    pub mean_precision: Array1<F>,
66    /// Means
67    pub means: Array2<F>,
68    /// Component precisions (posterior Wishart parameters)
69    pub degrees_of_freedom: Array1<F>,
70    /// Scale matrices
71    pub scale_matrices: Array3<F>,
72    /// Lower bound
73    pub lower_bound: F,
74    /// Effective number of components
75    pub effective_components: usize,
76    /// Number of iterations
77    pub n_iter: usize,
78    /// Converged flag
79    pub converged: bool,
80}
81
82/// Variational GMM result
83#[derive(Debug, Clone)]
84pub struct VariationalGMMResult<F> {
85    /// Lower bound value
86    pub lower_bound: F,
87    /// Effective number of components
88    pub effective_components: usize,
89    /// Predictive probabilities
90    pub responsibilities: Array2<F>,
91    /// Component weights
92    pub weights: Array1<F>,
93}
94
95impl<F> VariationalGMM<F>
96where
97    F: Float
98        + FromPrimitive
99        + SimdUnifiedOps
100        + Send
101        + Sync
102        + std::fmt::Debug
103        + std::fmt::Display
104        + std::iter::Sum<F>,
105{
106    /// Create new Variational GMM
107    pub fn new(max_components: usize, config: VariationalGMMConfig) -> Self {
108        Self {
109            max_components,
110            config,
111            parameters: None,
112            lower_bound_history: Vec::new(),
113            _phantom: PhantomData,
114        }
115    }
116
117    /// Fit Variational GMM to data
118    pub fn fit(&mut self, data: &ArrayView2<F>) -> StatsResult<VariationalGMMResult<F>> {
119        let (_n_samples, n_features) = data.dim();
120
121        let alpha_f: F = F::from(self.config.alpha)
122            .ok_or_else(|| StatsError::ComputationError("alpha conversion failed".into()))?;
123        let nu_f: F = F::from(self.config.nu)
124            .ok_or_else(|| StatsError::ComputationError("nu conversion failed".into()))?;
125        let n_feat_f: F = F::from(n_features)
126            .ok_or_else(|| StatsError::ComputationError("n_features conversion failed".into()))?;
127
128        let mut weight_concentration = Array1::from_elem(self.max_components, alpha_f);
129        let mean_precision_val = F::one();
130        let mut mean_precision = Array1::from_elem(self.max_components, mean_precision_val);
131        let mut means = self.initialize_means(data)?;
132        let mut degrees_of_freedom = Array1::from_elem(self.max_components, nu_f + n_feat_f);
133        let mut scale_matrices = Array3::zeros((self.max_components, n_features, n_features));
134        for k in 0..self.max_components {
135            for i in 0..n_features {
136                scale_matrices[[k, i, i]] = F::one();
137            }
138        }
139
140        let mut lower_bound = F::neg_infinity();
141        let mut converged = false;
142        let tol: F = F::from(self.config.tolerance)
143            .ok_or_else(|| StatsError::ComputationError("tolerance conversion failed".into()))?;
144
145        for iteration in 0..self.config.max_iter {
146            let responsibilities = self.compute_responsibilities(
147                data,
148                &means,
149                &scale_matrices,
150                &degrees_of_freedom,
151                &weight_concentration,
152            )?;
153
154            let (new_wc, new_mp, new_means, new_dof, new_sm) =
155                self.update_parameters(data, &responsibilities)?;
156
157            let new_lb =
158                self.compute_lower_bound(data, &responsibilities, &new_wc, &new_means, &new_sm)?;
159
160            if iteration > 0 && (new_lb - lower_bound).abs() < tol {
161                converged = true;
162            }
163
164            weight_concentration = new_wc;
165            mean_precision = new_mp;
166            means = new_means;
167            degrees_of_freedom = new_dof;
168            scale_matrices = new_sm;
169            lower_bound = new_lb;
170            self.lower_bound_history.push(lower_bound);
171
172            if converged {
173                break;
174            }
175        }
176
177        let effective_components = self.compute_effective_components(&weight_concentration);
178        let responsibilities = self.compute_responsibilities(
179            data,
180            &means,
181            &scale_matrices,
182            &degrees_of_freedom,
183            &weight_concentration,
184        )?;
185        let weights = self.compute_weights(&weight_concentration);
186
187        let parameters = VariationalGMMParameters {
188            weight_concentration,
189            mean_precision,
190            means,
191            degrees_of_freedom,
192            scale_matrices,
193            lower_bound,
194            effective_components,
195            n_iter: self.lower_bound_history.len(),
196            converged,
197        };
198        self.parameters = Some(parameters);
199
200        Ok(VariationalGMMResult {
201            lower_bound,
202            effective_components,
203            responsibilities,
204            weights,
205        })
206    }
207
208    fn initialize_means(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
209        let (n_samples, n_features) = data.dim();
210        let mut means = Array2::zeros((self.max_components, n_features));
211        use scirs2_core::random::Random;
212        let mut init_rng = scirs2_core::random::thread_rng();
213        let mut rng = match self.config.seed {
214            Some(seed) => Random::seed(seed),
215            None => Random::seed(init_rng.random()),
216        };
217        for i in 0..self.max_components {
218            let idx = rng.random_range(0..n_samples);
219            means.row_mut(i).assign(&data.row(idx));
220        }
221        Ok(means)
222    }
223
224    fn compute_responsibilities(
225        &self,
226        data: &ArrayView2<F>,
227        means: &Array2<F>,
228        scale_matrices: &Array3<F>,
229        degrees_of_freedom: &Array1<F>,
230        weight_concentration: &Array1<F>,
231    ) -> StatsResult<Array2<F>> {
232        let n_samples = data.shape()[0];
233        let mut responsibilities = Array2::zeros((n_samples, self.max_components));
234
235        for i in 0..n_samples {
236            let mut log_probs = Array1::zeros(self.max_components);
237            for k in 0..self.max_components {
238                let log_weight = weight_concentration[k].ln();
239                let log_ll = self.compute_log_likelihood_component(
240                    &data.row(i),
241                    &means.row(k),
242                    &scale_matrices.slice(s![k, .., ..]),
243                    degrees_of_freedom[k],
244                )?;
245                log_probs[k] = log_weight + log_ll;
246            }
247            let log_sum = self.log_sum_exp(&log_probs);
248            for k in 0..self.max_components {
249                responsibilities[[i, k]] = (log_probs[k] - log_sum).exp();
250            }
251        }
252        Ok(responsibilities)
253    }
254
255    fn update_parameters(
256        &self,
257        data: &ArrayView2<F>,
258        responsibilities: &Array2<F>,
259    ) -> StatsResult<(Array1<F>, Array1<F>, Array2<F>, Array1<F>, Array3<F>)> {
260        let (n_samples, n_features) = data.dim();
261
262        let alpha_f: F = F::from(self.config.alpha)
263            .ok_or_else(|| StatsError::ComputationError("alpha conversion".into()))?;
264        let nu_f: F = F::from(self.config.nu)
265            .ok_or_else(|| StatsError::ComputationError("nu conversion".into()))?;
266        let n_feat_f: F = F::from(n_features)
267            .ok_or_else(|| StatsError::ComputationError("n_features conversion".into()))?;
268        let small: F = F::from(0.1)
269            .ok_or_else(|| StatsError::ComputationError("constant conversion".into()))?;
270
271        let mut weight_concentration = Array1::from_elem(self.max_components, alpha_f);
272        let mean_precision = Array1::ones(self.max_components);
273        let mut means = Array2::zeros((self.max_components, n_features));
274        let mut degrees_of_freedom = Array1::from_elem(self.max_components, nu_f + n_feat_f);
275        let mut scale_matrices = Array3::zeros((self.max_components, n_features, n_features));
276
277        for k in 0..self.max_components {
278            let nk: F = responsibilities.column(k).sum();
279            weight_concentration[k] = weight_concentration[k] + nk;
280
281            if nk > F::zero() {
282                for j in 0..n_features {
283                    let mut weighted_sum = F::zero();
284                    for i in 0..n_samples {
285                        weighted_sum = weighted_sum + responsibilities[[i, k]] * data[[i, j]];
286                    }
287                    means[[k, j]] = weighted_sum / nk;
288                }
289                degrees_of_freedom[k] = nu_f + nk;
290                for i in 0..n_features {
291                    scale_matrices[[k, i, i]] = F::one() + small * nk;
292                }
293            }
294        }
295
296        Ok((
297            weight_concentration,
298            mean_precision,
299            means,
300            degrees_of_freedom,
301            scale_matrices,
302        ))
303    }
304
305    fn compute_lower_bound(
306        &self,
307        data: &ArrayView2<F>,
308        responsibilities: &Array2<F>,
309        weight_concentration: &Array1<F>,
310        means: &Array2<F>,
311        scale_matrices: &Array3<F>,
312    ) -> StatsResult<F> {
313        let n_samples = data.shape()[0];
314        let mut lower_bound = F::zero();
315        let ten: F = F::from(10.0)
316            .ok_or_else(|| StatsError::ComputationError("constant conversion".into()))?;
317        let small_kl: F = F::from(0.01)
318            .ok_or_else(|| StatsError::ComputationError("constant conversion".into()))?;
319
320        for i in 0..n_samples {
321            for k in 0..self.max_components {
322                if responsibilities[[i, k]] > F::zero() {
323                    let log_ll = self.compute_log_likelihood_component(
324                        &data.row(i),
325                        &means.row(k),
326                        &scale_matrices.slice(s![k, .., ..]),
327                        ten,
328                    )?;
329                    lower_bound = lower_bound + responsibilities[[i, k]] * log_ll;
330                }
331            }
332        }
333
334        for k in 0..self.max_components {
335            let w = weight_concentration[k];
336            if w > F::zero() {
337                lower_bound = lower_bound - w * w.ln() * small_kl;
338            }
339        }
340
341        Ok(lower_bound)
342    }
343
344    fn compute_effective_components(&self, wc: &Array1<F>) -> usize {
345        let total: F = wc.sum();
346        let threshold = F::from(0.01).unwrap_or(F::zero());
347        wc.iter().filter(|&&w| w / total > threshold).count()
348    }
349
350    fn compute_weights(&self, wc: &Array1<F>) -> Array1<F> {
351        let total: F = wc.sum();
352        wc.mapv(|w| w / total)
353    }
354
355    fn compute_log_likelihood_component(
356        &self,
357        point: &ArrayView1<F>,
358        mean: &ArrayView1<F>,
359        _scale_matrix: &scirs2_core::ndarray::ArrayBase<
360            scirs2_core::ndarray::ViewRepr<&F>,
361            scirs2_core::ndarray::Dim<[usize; 2]>,
362        >,
363        _degrees_of_freedom: F,
364    ) -> StatsResult<F> {
365        let half: F = F::from(0.5)
366            .ok_or_else(|| StatsError::ComputationError("constant conversion".into()))?;
367        let mut sum_sq = F::zero();
368        for (x, m) in point.iter().zip(mean.iter()) {
369            let diff = *x - *m;
370            sum_sq = sum_sq + diff * diff;
371        }
372        Ok(-half * sum_sq)
373    }
374
375    fn log_sum_exp(&self, logvalues: &Array1<F>) -> F {
376        let max_val = logvalues.iter().fold(F::neg_infinity(), |a, &b| a.max(b));
377        if max_val == F::neg_infinity() {
378            return F::neg_infinity();
379        }
380        let sum: F = logvalues.iter().map(|&x| (x - max_val).exp()).sum();
381        max_val + sum.ln()
382    }
383}