sklears_mixture/
variational.rs

1//! Variational Bayesian Gaussian Mixture Models
2//!
3//! This module implements variational Bayesian inference for Gaussian mixture models,
4//! providing automatic model selection and uncertainty quantification.
5
6use crate::common::CovarianceType;
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use scirs2_core::random::{Rng, SeedableRng};
9use sklears_core::{
10    error::{Result as SklResult, SklearsError},
11    traits::{Estimator, Fit, Predict, Untrained},
12    types::Float,
13};
14
15/// Variational Bayesian Gaussian Mixture Model
16///
17/// This implementation uses variational inference to perform Bayesian parameter estimation
18/// for Gaussian mixture models. Unlike standard EM, this approach provides uncertainty
19/// estimates and automatic model selection by effectively "turning off" unnecessary components.
20///
21/// # Parameters
22///
23/// * `n_components` - Maximum number of mixture components (actual number determined automatically)
24/// * `covariance_type` - Type of covariance parameters
25/// * `tol` - Convergence threshold
26/// * `reg_covar` - Regularization added to the diagonal of covariance
27/// * `max_iter` - Maximum number of variational iterations
28/// * `random_state` - Random state for reproducibility
29/// * `weight_concentration_prior` - Prior on the weight concentration parameter
30/// * `mean_precision_prior` - Prior precision for component means
31/// * `degrees_of_freedom_prior` - Prior degrees of freedom for covariance matrices
32///
33/// # Examples
34///
35/// ```
36/// use sklears_mixture::{VariationalBayesianGMM, CovarianceType};
37/// use sklears_core::traits::{Predict, Fit};
38/// use scirs2_core::ndarray::array;
39///
40/// let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [10.0, 10.0], [11.0, 11.0], [12.0, 12.0]];
41///
42/// let vbgmm = VariationalBayesianGMM::new()
43///     .n_components(5)  // Will automatically determine optimal number
44///     .covariance_type(CovarianceType::Diagonal)
45///     .max_iter(100);
46/// let fitted = vbgmm.fit(&X.view(), &()).unwrap();
47/// let labels = fitted.predict(&X.view()).unwrap();
48/// ```
49#[derive(Debug, Clone)]
50pub struct VariationalBayesianGMM<S = Untrained> {
51    pub(crate) state: S,
52    pub(crate) n_components: usize,
53    pub(crate) covariance_type: CovarianceType,
54    pub(crate) tol: f64,
55    pub(crate) reg_covar: f64,
56    pub(crate) max_iter: usize,
57    pub(crate) random_state: Option<u64>,
58    pub(crate) weight_concentration_prior: f64,
59    pub(crate) mean_precision_prior: f64,
60    pub(crate) degrees_of_freedom_prior: f64,
61}
62
63/// Trained state for VariationalBayesianGMM
64#[derive(Debug, Clone)]
65pub struct VariationalBayesianGMMTrained {
66    pub(crate) weights: Array1<f64>,
67    pub(crate) means: Array2<f64>,
68    pub(crate) covariances: Vec<Array2<f64>>,
69    pub(crate) weight_concentration: Array1<f64>,
70    pub(crate) mean_precision: Array1<f64>,
71    pub(crate) degrees_of_freedom: Array1<f64>,
72    pub(crate) lower_bound: f64,
73    pub(crate) n_iter: usize,
74    pub(crate) converged: bool,
75    pub(crate) effective_components: usize,
76}
77
78impl VariationalBayesianGMM<Untrained> {
79    /// Create a new VariationalBayesianGMM instance
80    pub fn new() -> Self {
81        Self {
82            state: Untrained,
83            n_components: 1,
84            covariance_type: CovarianceType::Full,
85            tol: 1e-3,
86            reg_covar: 1e-6,
87            max_iter: 100,
88            random_state: None,
89            weight_concentration_prior: 1.0,
90            mean_precision_prior: 1.0,
91            degrees_of_freedom_prior: 1.0,
92        }
93    }
94
95    /// Create a new VariationalBayesianGMM instance using builder pattern (alias for new)
96    pub fn builder() -> Self {
97        Self::new()
98    }
99
100    /// Set the maximum number of components
101    pub fn n_components(mut self, n_components: usize) -> Self {
102        self.n_components = n_components;
103        self
104    }
105
106    /// Set the covariance type
107    pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
108        self.covariance_type = covariance_type;
109        self
110    }
111
112    /// Set the convergence tolerance
113    pub fn tol(mut self, tol: f64) -> Self {
114        self.tol = tol;
115        self
116    }
117
118    /// Set the regularization parameter
119    pub fn reg_covar(mut self, reg_covar: f64) -> Self {
120        self.reg_covar = reg_covar;
121        self
122    }
123
124    /// Set the maximum number of iterations
125    pub fn max_iter(mut self, max_iter: usize) -> Self {
126        self.max_iter = max_iter;
127        self
128    }
129
130    /// Set the random state
131    pub fn random_state(mut self, random_state: u64) -> Self {
132        self.random_state = Some(random_state);
133        self
134    }
135
136    /// Set the weight concentration prior
137    pub fn weight_concentration_prior(mut self, prior: f64) -> Self {
138        self.weight_concentration_prior = prior;
139        self
140    }
141
142    /// Set the mean precision prior
143    pub fn mean_precision_prior(mut self, prior: f64) -> Self {
144        self.mean_precision_prior = prior;
145        self
146    }
147
148    /// Set the degrees of freedom prior
149    pub fn degrees_of_freedom_prior(mut self, prior: f64) -> Self {
150        self.degrees_of_freedom_prior = prior;
151        self
152    }
153
154    /// Build the VariationalBayesianGMM (builder pattern completion)
155    pub fn build(self) -> Self {
156        self
157    }
158}
159
160impl Default for VariationalBayesianGMM<Untrained> {
161    fn default() -> Self {
162        Self::new()
163    }
164}
165
166impl Estimator for VariationalBayesianGMM<Untrained> {
167    type Config = ();
168    type Error = SklearsError;
169    type Float = Float;
170
171    fn config(&self) -> &Self::Config {
172        &()
173    }
174}
175
176impl Fit<ArrayView2<'_, Float>, ()> for VariationalBayesianGMM<Untrained> {
177    type Fitted = VariationalBayesianGMM<VariationalBayesianGMMTrained>;
178
179    #[allow(non_snake_case)]
180    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
181        let X = X.to_owned();
182        let (n_samples, _n_features) = X.dim();
183
184        if n_samples < 2 {
185            return Err(SklearsError::InvalidInput(
186                "Number of samples must be at least 2".to_string(),
187            ));
188        }
189
190        if self.n_components == 0 {
191            return Err(SklearsError::InvalidInput(
192                "Number of components must be positive".to_string(),
193            ));
194        }
195
196        // Initialize parameters
197        let (
198            mut weight_concentration,
199            mut mean_precision,
200            mut means,
201            mut degrees_of_freedom,
202            mut covariances,
203        ) = self.initialize_parameters(&X)?;
204
205        let mut lower_bound = f64::NEG_INFINITY;
206        let mut converged = false;
207        let mut n_iter = 0;
208
209        // Variational EM iterations
210        for iteration in 0..self.max_iter {
211            n_iter = iteration + 1;
212
213            // E-step: Update responsibilities
214            let responsibilities = self.compute_responsibilities(
215                &X,
216                &weight_concentration,
217                &means,
218                &covariances,
219                &degrees_of_freedom,
220            )?;
221
222            // M-step: Update parameters
223            let (
224                new_weight_concentration,
225                new_mean_precision,
226                new_means,
227                new_degrees_of_freedom,
228                new_covariances,
229            ) = self.update_parameters(&X, &responsibilities)?;
230
231            // Compute lower bound
232            let new_lower_bound = self.compute_lower_bound(
233                &X,
234                &responsibilities,
235                &new_weight_concentration,
236                &new_mean_precision,
237                &new_means,
238                &new_degrees_of_freedom,
239                &new_covariances,
240            )?;
241
242            // Check convergence
243            if iteration > 0 && (new_lower_bound - lower_bound).abs() < self.tol {
244                converged = true;
245            }
246
247            weight_concentration = new_weight_concentration;
248            mean_precision = new_mean_precision;
249            means = new_means;
250            degrees_of_freedom = new_degrees_of_freedom;
251            covariances = new_covariances;
252            lower_bound = new_lower_bound;
253
254            if converged {
255                break;
256            }
257        }
258
259        // Compute final weights from concentration parameters
260        let weights = self.compute_weights(&weight_concentration);
261
262        // Count effective components (those with significant weight)
263        let effective_components = weights.iter().filter(|&&w| w > 1e-3).count();
264
265        Ok(VariationalBayesianGMM {
266            state: VariationalBayesianGMMTrained {
267                weights,
268                means,
269                covariances,
270                weight_concentration,
271                mean_precision,
272                degrees_of_freedom,
273                lower_bound,
274                n_iter,
275                converged,
276                effective_components,
277            },
278            n_components: self.n_components,
279            covariance_type: self.covariance_type,
280            tol: self.tol,
281            reg_covar: self.reg_covar,
282            max_iter: self.max_iter,
283            random_state: self.random_state,
284            weight_concentration_prior: self.weight_concentration_prior,
285            mean_precision_prior: self.mean_precision_prior,
286            degrees_of_freedom_prior: self.degrees_of_freedom_prior,
287        })
288    }
289}
290
291impl VariationalBayesianGMM<Untrained> {
292    /// Initialize variational parameters
293    fn initialize_parameters(
294        &self,
295        X: &Array2<f64>,
296    ) -> SklResult<(
297        Array1<f64>,
298        Array1<f64>,
299        Array2<f64>,
300        Array1<f64>,
301        Vec<Array2<f64>>,
302    )> {
303        let (_n_samples, n_features) = X.dim();
304
305        // Initialize weight concentration parameters
306        let weight_concentration =
307            Array1::from_elem(self.n_components, self.weight_concentration_prior);
308
309        // Initialize mean precision parameters
310        let mean_precision = Array1::from_elem(self.n_components, self.mean_precision_prior);
311
312        // Initialize means using k-means++ style initialization
313        let means = self.initialize_means(X)?;
314
315        // Initialize degrees of freedom
316        let degrees_of_freedom = Array1::from_elem(
317            self.n_components,
318            self.degrees_of_freedom_prior + n_features as f64,
319        );
320
321        // Initialize covariances
322        let covariances = self.initialize_covariances(X)?;
323
324        Ok((
325            weight_concentration,
326            mean_precision,
327            means,
328            degrees_of_freedom,
329            covariances,
330        ))
331    }
332
333    /// Initialize means using k-means++ style initialization
334    fn initialize_means(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
335        let (n_samples, n_features) = X.dim();
336        let mut means = Array2::zeros((self.n_components, n_features));
337
338        // Use random initialization if random state is provided
339        if let Some(seed) = self.random_state {
340            let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(seed);
341
342            // First mean: pick random sample
343            let idx = rng.gen_range(0..n_samples);
344            means.row_mut(0).assign(&X.row(idx));
345
346            // Subsequent means: pick samples far from existing means
347            for i in 1..self.n_components {
348                let mut best_distance = 0.0;
349                let mut best_idx = 0;
350
351                for j in 0..n_samples {
352                    let sample = X.row(j);
353                    let mut min_distance = f64::INFINITY;
354
355                    for k in 0..i {
356                        let existing_mean = means.row(k);
357                        let distance = (&sample - &existing_mean).mapv(|x| x * x).sum();
358                        min_distance = min_distance.min(distance);
359                    }
360
361                    if min_distance > best_distance {
362                        best_distance = min_distance;
363                        best_idx = j;
364                    }
365                }
366
367                means.row_mut(i).assign(&X.row(best_idx));
368            }
369        } else {
370            // Deterministic initialization: evenly spaced samples
371            let step = n_samples / self.n_components;
372
373            for (i, mut mean) in means.axis_iter_mut(Axis(0)).enumerate() {
374                let sample_idx = if step == 0 {
375                    i.min(n_samples - 1)
376                } else {
377                    (i * step).min(n_samples - 1)
378                };
379                mean.assign(&X.row(sample_idx));
380            }
381        }
382
383        Ok(means)
384    }
385
386    /// Initialize covariances
387    fn initialize_covariances(&self, X: &Array2<f64>) -> SklResult<Vec<Array2<f64>>> {
388        let (_, n_features) = X.dim();
389        let mut covariances = Vec::new();
390
391        // Estimate global covariance for initialization
392        let global_cov = self.estimate_global_covariance(X)?;
393
394        for _ in 0..self.n_components {
395            let mut cov = global_cov.clone();
396
397            // Add regularization
398            for i in 0..n_features {
399                cov[[i, i]] += self.reg_covar;
400            }
401
402            covariances.push(cov);
403        }
404
405        Ok(covariances)
406    }
407
408    /// Estimate global covariance matrix
409    fn estimate_global_covariance(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
410        let (n_samples, n_features) = X.dim();
411
412        // Compute sample mean
413        let mut mean = Array1::zeros(n_features);
414        for i in 0..n_features {
415            mean[i] = X.column(i).sum() / n_samples as f64;
416        }
417
418        // Compute covariance matrix
419        let mut cov = Array2::zeros((n_features, n_features));
420        for i in 0..n_features {
421            for j in 0..n_features {
422                let mut sum = 0.0;
423                for k in 0..n_samples {
424                    sum += (X[[k, i]] - mean[i]) * (X[[k, j]] - mean[j]);
425                }
426                cov[[i, j]] = sum / (n_samples as f64 - 1.0);
427            }
428        }
429
430        // Apply covariance type constraints
431        match self.covariance_type {
432            CovarianceType::Diagonal => {
433                for i in 0..n_features {
434                    for j in 0..n_features {
435                        if i != j {
436                            cov[[i, j]] = 0.0;
437                        }
438                    }
439                }
440            }
441            CovarianceType::Spherical => {
442                let trace = cov.diag().sum() / n_features as f64;
443                cov.fill(0.0);
444                for i in 0..n_features {
445                    cov[[i, i]] = trace;
446                }
447            }
448            _ => {} // Full and Tied keep the estimated covariance
449        }
450
451        Ok(cov)
452    }
453
454    /// Compute responsibilities using current parameters
455    fn compute_responsibilities(
456        &self,
457        X: &Array2<f64>,
458        weight_concentration: &Array1<f64>,
459        means: &Array2<f64>,
460        covariances: &[Array2<f64>],
461        degrees_of_freedom: &Array1<f64>,
462    ) -> SklResult<Array2<f64>> {
463        let (n_samples, _) = X.dim();
464        let mut responsibilities = Array2::zeros((n_samples, self.n_components));
465
466        // Compute expected log weights
467        let expected_log_weights = self.compute_expected_log_weights(weight_concentration);
468
469        // For each sample
470        for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
471            let mut log_prob_norm = f64::NEG_INFINITY;
472            let mut log_probs = Vec::new();
473
474            // Compute log probabilities for each component
475            for k in 0..self.n_components {
476                let mean = means.row(k);
477                let cov = &covariances[k];
478
479                // Use Student-t distribution due to uncertainty in parameters
480                let log_prob =
481                    self.compute_student_t_log_pdf(&sample, &mean, cov, degrees_of_freedom[k])?;
482                let weighted_log_prob = expected_log_weights[k] + log_prob;
483
484                log_probs.push(weighted_log_prob);
485                log_prob_norm = log_prob_norm.max(weighted_log_prob);
486            }
487
488            // Compute responsibilities using log-sum-exp trick
489            let mut sum_exp = 0.0;
490            for &log_prob in &log_probs {
491                sum_exp += (log_prob - log_prob_norm).exp();
492            }
493            let log_sum_exp = log_prob_norm + sum_exp.ln();
494
495            for k in 0..self.n_components {
496                responsibilities[[i, k]] = (log_probs[k] - log_sum_exp).exp();
497            }
498        }
499
500        Ok(responsibilities)
501    }
502
503    /// Compute expected log weights from concentration parameters
504    fn compute_expected_log_weights(&self, weight_concentration: &Array1<f64>) -> Array1<f64> {
505        let sum_concentration: f64 = weight_concentration.sum();
506        let mut expected_log_weights = Array1::zeros(self.n_components);
507
508        for k in 0..self.n_components {
509            // Expected log weight under Dirichlet distribution
510            expected_log_weights[k] = digamma(weight_concentration[k]) - digamma(sum_concentration);
511        }
512
513        expected_log_weights
514    }
515
516    /// Compute Student-t log PDF (approximation)
517    fn compute_student_t_log_pdf(
518        &self,
519        x: &ArrayView1<f64>,
520        mean: &ArrayView1<f64>,
521        cov: &Array2<f64>,
522        _degrees_of_freedom: f64,
523    ) -> SklResult<f64> {
524        // For simplicity, use Gaussian approximation
525        // In full implementation, this would use proper Student-t distribution
526        crate::common::gaussian_log_pdf(x, mean, &cov.view())
527    }
528
529    /// Update variational parameters
530    fn update_parameters(
531        &self,
532        X: &Array2<f64>,
533        responsibilities: &Array2<f64>,
534    ) -> SklResult<(
535        Array1<f64>,
536        Array1<f64>,
537        Array2<f64>,
538        Array1<f64>,
539        Vec<Array2<f64>>,
540    )> {
541        let (n_samples, n_features) = X.dim();
542
543        // Update weight concentration parameters
544        let mut weight_concentration = Array1::zeros(self.n_components);
545        for k in 0..self.n_components {
546            weight_concentration[k] =
547                self.weight_concentration_prior + responsibilities.column(k).sum();
548        }
549
550        // Update mean precision parameters
551        let mut mean_precision = Array1::zeros(self.n_components);
552        for k in 0..self.n_components {
553            mean_precision[k] = self.mean_precision_prior + responsibilities.column(k).sum();
554        }
555
556        // Update means
557        let mut means = Array2::zeros((self.n_components, n_features));
558        for k in 0..self.n_components {
559            let resp_sum = responsibilities.column(k).sum();
560            if resp_sum > 0.0 {
561                for j in 0..n_features {
562                    let mut weighted_sum = 0.0;
563                    for i in 0..n_samples {
564                        weighted_sum += responsibilities[[i, k]] * X[[i, j]];
565                    }
566                    means[[k, j]] = weighted_sum / resp_sum;
567                }
568            }
569        }
570
571        // Update degrees of freedom
572        let mut degrees_of_freedom = Array1::zeros(self.n_components);
573        for k in 0..self.n_components {
574            degrees_of_freedom[k] =
575                self.degrees_of_freedom_prior + responsibilities.column(k).sum();
576        }
577
578        // Update covariances (simplified)
579        let mut covariances = Vec::new();
580        for _k in 0..self.n_components {
581            let mut cov = Array2::eye(n_features);
582            for i in 0..n_features {
583                cov[[i, i]] = 1.0 + self.reg_covar;
584            }
585            covariances.push(cov);
586        }
587
588        Ok((
589            weight_concentration,
590            mean_precision,
591            means,
592            degrees_of_freedom,
593            covariances,
594        ))
595    }
596
597    /// Compute variational lower bound
598    fn compute_lower_bound(
599        &self,
600        X: &Array2<f64>,
601        responsibilities: &Array2<f64>,
602        weight_concentration: &Array1<f64>,
603        _mean_precision: &Array1<f64>,
604        means: &Array2<f64>,
605        _degrees_of_freedom: &Array1<f64>,
606        covariances: &[Array2<f64>],
607    ) -> SklResult<f64> {
608        // Simplified lower bound computation
609        // In full implementation, this would include all entropy and expectation terms
610        let mut lower_bound = 0.0;
611
612        // Expected log likelihood term
613        for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
614            for k in 0..self.n_components {
615                let resp = responsibilities[[i, k]];
616                if resp > 0.0 {
617                    let mean = means.row(k);
618                    let cov = &covariances[k];
619                    let log_prob = crate::common::gaussian_log_pdf(&sample, &mean, &cov.view())?;
620                    lower_bound += resp * log_prob;
621                }
622            }
623        }
624
625        // KL divergence terms (simplified)
626        let expected_log_weights = self.compute_expected_log_weights(weight_concentration);
627        for k in 0..self.n_components {
628            let resp_sum = responsibilities.column(k).sum();
629            if resp_sum > 0.0 {
630                lower_bound += resp_sum * expected_log_weights[k];
631            }
632        }
633
634        Ok(lower_bound)
635    }
636
637    /// Compute final weights from concentration parameters
638    fn compute_weights(&self, weight_concentration: &Array1<f64>) -> Array1<f64> {
639        let sum_concentration: f64 = weight_concentration.sum();
640        weight_concentration.mapv(|x| x / sum_concentration)
641    }
642}
643
644impl Predict<ArrayView2<'_, Float>, Array1<i32>>
645    for VariationalBayesianGMM<VariationalBayesianGMMTrained>
646{
647    #[allow(non_snake_case)]
648    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
649        let X = X.to_owned();
650        let (n_samples, _) = X.dim();
651        let mut predictions = Array1::zeros(n_samples);
652
653        // For each sample, find the component with highest responsibility
654        for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
655            let mut best_component = 0;
656            let mut best_log_prob = f64::NEG_INFINITY;
657
658            for k in 0..self.n_components {
659                if self.state.weights[k] > 1e-3 {
660                    // Only consider effective components
661                    let mean = self.state.means.row(k);
662                    let cov = &self.state.covariances[k];
663
664                    let log_prob = crate::common::gaussian_log_pdf(&sample, &mean, &cov.view())?;
665                    let weighted_log_prob = self.state.weights[k].ln() + log_prob;
666
667                    if weighted_log_prob > best_log_prob {
668                        best_log_prob = weighted_log_prob;
669                        best_component = k;
670                    }
671                }
672            }
673
674            predictions[i] = best_component as i32;
675        }
676
677        Ok(predictions)
678    }
679}
680
681impl VariationalBayesianGMM<VariationalBayesianGMMTrained> {
682    /// Get the fitted weights
683    pub fn weights(&self) -> &Array1<f64> {
684        &self.state.weights
685    }
686
687    /// Get the fitted means
688    pub fn means(&self) -> &Array2<f64> {
689        &self.state.means
690    }
691
692    /// Get the fitted covariances
693    pub fn covariances(&self) -> &[Array2<f64>] {
694        &self.state.covariances
695    }
696
697    /// Get the variational lower bound
698    pub fn lower_bound(&self) -> f64 {
699        self.state.lower_bound
700    }
701
702    /// Get the number of effective components
703    pub fn effective_components(&self) -> usize {
704        self.state.effective_components
705    }
706
707    /// Check if the model converged
708    pub fn converged(&self) -> bool {
709        self.state.converged
710    }
711
712    /// Get the number of iterations performed
713    pub fn n_iter(&self) -> usize {
714        self.state.n_iter
715    }
716
717    /// Get the weight concentration parameters
718    pub fn weight_concentration(&self) -> &Array1<f64> {
719        &self.state.weight_concentration
720    }
721
722    /// Get the mean precision parameters
723    pub fn mean_precision(&self) -> &Array1<f64> {
724        &self.state.mean_precision
725    }
726
727    /// Get the degrees of freedom parameters
728    pub fn degrees_of_freedom(&self) -> &Array1<f64> {
729        &self.state.degrees_of_freedom
730    }
731}
732
733/// Digamma function approximation (for computing expected log weights)
734fn digamma(x: f64) -> f64 {
735    // Simple approximation using asymptotic expansion
736    if x > 6.0 {
737        x.ln() - 1.0 / (2.0 * x) - 1.0 / (12.0 * x * x)
738    } else {
739        // For small x, use recurrence relation and asymptotic expansion
740        let mut result = x;
741        let mut n = 0;
742        while result < 6.0 {
743            result += 1.0;
744            n += 1;
745        }
746        let asymptotic = result.ln() - 1.0 / (2.0 * result) - 1.0 / (12.0 * result * result);
747        asymptotic - (0..n).map(|i| 1.0 / (x + i as f64)).sum::<f64>()
748    }
749}
750
751#[allow(non_snake_case)]
752#[cfg(test)]
753mod tests {
754    use super::*;
755    use approx::assert_relative_eq;
756    use scirs2_core::ndarray::array;
757
758    #[test]
759    #[allow(non_snake_case)]
760    fn test_variational_bayesian_gmm_basic() {
761        let X = array![
762            [0.0, 0.0],
763            [1.0, 1.0],
764            [2.0, 2.0],
765            [10.0, 10.0],
766            [11.0, 11.0],
767            [12.0, 12.0]
768        ];
769
770        let vbgmm = VariationalBayesianGMM::new()
771            .n_components(3)
772            .max_iter(10)
773            .random_state(42);
774
775        let fitted = vbgmm.fit(&X.view(), &()).unwrap();
776
777        assert!(fitted.converged() || fitted.n_iter() == 10);
778        assert!(fitted.effective_components() <= 3);
779        assert!(fitted.lower_bound().is_finite());
780    }
781
782    #[test]
783    #[allow(non_snake_case)]
784    fn test_variational_bayesian_gmm_prediction() {
785        let X = array![[0.0, 0.0], [1.0, 1.0], [10.0, 10.0], [11.0, 11.0]];
786
787        let vbgmm = VariationalBayesianGMM::new()
788            .n_components(2)
789            .max_iter(20)
790            .random_state(42);
791
792        let fitted = vbgmm.fit(&X.view(), &()).unwrap();
793        let predictions = fitted.predict(&X.view()).unwrap();
794
795        assert_eq!(predictions.len(), 4);
796        // Should cluster into two groups
797        assert!(predictions[0] == predictions[1] || predictions[0] != predictions[2]);
798    }
799
800    #[test]
801    fn test_variational_bayesian_gmm_builder() {
802        let vbgmm = VariationalBayesianGMM::builder()
803            .n_components(5)
804            .covariance_type(CovarianceType::Diagonal)
805            .tol(1e-4)
806            .weight_concentration_prior(0.1)
807            .mean_precision_prior(0.1)
808            .degrees_of_freedom_prior(1.0)
809            .build();
810
811        assert_eq!(vbgmm.n_components, 5);
812        assert_eq!(vbgmm.covariance_type, CovarianceType::Diagonal);
813        assert_relative_eq!(vbgmm.tol, 1e-4);
814        assert_relative_eq!(vbgmm.weight_concentration_prior, 0.1);
815    }
816
817    #[test]
818    fn test_digamma_function() {
819        // Test digamma function approximation
820        assert_relative_eq!(digamma(1.0), -0.5772, epsilon = 0.1);
821        assert_relative_eq!(digamma(2.0), 0.4228, epsilon = 0.1);
822        assert_relative_eq!(digamma(10.0), 2.2517, epsilon = 0.01);
823    }
824}