Skip to main content

scirs2_stats/mixture_models/
mod.rs

1//! Advanced mixture models and kernel density estimation
2//!
3//! This module provides comprehensive implementations of mixture models and
4//! non-parametric density estimation methods including:
5//! - Gaussian Mixture Models (GMM) with robust EM algorithm
6//! - Variational Bayesian Gaussian Mixture Models
7//! - Online/Streaming EM algorithms
8//! - Robust mixture models with outlier detection
9//! - Model selection criteria (AIC, BIC, ICL)
10//! - Advanced initialization strategies (K-means++, random)
11//! - Kernel Density Estimation with various kernels
12//! - Adaptive bandwidth selection with cross-validation
13//! - Mixture model diagnostics and validation
14
15mod kde;
16mod variational;
17
18pub use kde::*;
19pub use variational::*;
20
21use crate::error::{StatsError, StatsResult};
22use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
23use scirs2_core::numeric::{Float, FromPrimitive, One, Zero};
24use scirs2_core::random::Rng;
25use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
26use std::marker::PhantomData;
27
28// ---------------------------------------------------------------------------
29// Types and configs
30// ---------------------------------------------------------------------------
31
32/// Gaussian Mixture Model with EM algorithm
33pub struct GaussianMixtureModel<F> {
34    /// Number of components
35    pub n_components: usize,
36    /// Configuration
37    pub config: GMMConfig,
38    /// Fitted parameters
39    pub parameters: Option<GMMParameters<F>>,
40    /// Convergence history
41    pub convergence_history: Vec<F>,
42    _phantom: PhantomData<F>,
43}
44
45/// Advanced GMM configuration
46#[derive(Debug, Clone)]
47pub struct GMMConfig {
48    /// Maximum iterations for EM algorithm
49    pub max_iter: usize,
50    /// Convergence tolerance for log-likelihood
51    pub tolerance: f64,
52    /// Relative tolerance for parameter changes
53    pub param_tolerance: f64,
54    /// Covariance type
55    pub covariance_type: CovarianceType,
56    /// Regularization for covariance matrices
57    pub reg_covar: f64,
58    /// Initialization method
59    pub init_method: InitializationMethod,
60    /// Number of initialization runs (best result selected)
61    pub n_init: usize,
62    /// Random seed
63    pub seed: Option<u64>,
64    /// Enable parallel processing
65    pub parallel: bool,
66    /// Enable SIMD optimizations
67    pub use_simd: bool,
68    /// Warm start (use existing parameters if available)
69    pub warm_start: bool,
70    /// Enable robust EM (outlier detection)
71    pub robust_em: bool,
72    /// Outlier threshold for robust EM
73    pub outlier_threshold: f64,
74    /// Enable early stopping based on validation likelihood
75    pub early_stopping: bool,
76    /// Validation fraction for early stopping
77    pub validation_fraction: f64,
78    /// Patience for early stopping
79    pub patience: usize,
80}
81
82/// Covariance matrix types
83#[derive(Debug, Clone, PartialEq)]
84pub enum CovarianceType {
85    /// Full covariance matrices
86    Full,
87    /// Diagonal covariance matrices
88    Diagonal,
89    /// Tied covariance (same for all components)
90    Tied,
91    /// Spherical covariance (isotropic)
92    Spherical,
93    /// Factor analysis covariance (low-rank + diagonal)
94    Factor {
95        /// Number of factors
96        n_factors: usize,
97    },
98    /// Constrained covariance with specific structure
99    Constrained {
100        /// Constraint type
101        constraint: CovarianceConstraint,
102    },
103}
104
105/// Covariance constraints
106#[derive(Debug, Clone, PartialEq)]
107pub enum CovarianceConstraint {
108    /// Minimum eigenvalue constraint
109    MinEigenvalue(f64),
110    /// Maximum condition number
111    MaxCondition(f64),
112    /// Sparsity pattern
113    Sparse(Vec<(usize, usize)>),
114}
115
116/// Initialization methods
117#[derive(Debug, Clone, PartialEq)]
118pub enum InitializationMethod {
119    /// Random initialization
120    Random,
121    /// K-means++ initialization
122    KMeansPlus,
123    /// K-means with multiple runs
124    KMeans {
125        /// Number of k-means runs
126        n_runs: usize,
127    },
128    /// Furthest-first initialization
129    FurthestFirst,
130    /// User-provided parameters
131    Custom,
132    /// Quantile-based initialization
133    Quantile,
134    /// PCA-based initialization
135    PCA,
136    /// Spectral clustering initialization
137    Spectral,
138}
139
140/// Advanced GMM parameters with diagnostics
141#[derive(Debug, Clone)]
142pub struct GMMParameters<F> {
143    /// Component weights (mixing coefficients)
144    pub weights: Array1<F>,
145    /// Component means
146    pub means: Array2<F>,
147    /// Component covariances
148    pub covariances: Vec<Array2<F>>,
149    /// Log-likelihood
150    pub log_likelihood: F,
151    /// Number of iterations to convergence
152    pub n_iter: usize,
153    /// Converged flag
154    pub converged: bool,
155    /// Convergence reason
156    pub convergence_reason: ConvergenceReason,
157    /// Model selection criteria
158    pub model_selection: ModelSelectionCriteria<F>,
159    /// Component diagnostics
160    pub component_diagnostics: Vec<ComponentDiagnostics<F>>,
161    /// Outlier scores (if robust EM was used)
162    pub outlier_scores: Option<Array1<F>>,
163    /// Responsibility matrix for training data
164    pub responsibilities: Option<Array2<F>>,
165    /// Parameter change history
166    pub parameter_history: Vec<ParameterSnapshot<F>>,
167}
168
169/// Convergence reasons
170#[derive(Debug, Clone, PartialEq)]
171pub enum ConvergenceReason {
172    /// Log-likelihood tolerance reached
173    LogLikelihoodTolerance,
174    /// Parameter change tolerance reached
175    ParameterTolerance,
176    /// Maximum iterations reached
177    MaxIterations,
178    /// Early stopping triggered
179    EarlyStopping,
180    /// Numerical instability detected
181    NumericalInstability,
182}
183
184/// Model selection criteria
185#[derive(Debug, Clone)]
186pub struct ModelSelectionCriteria<F> {
187    /// Akaike Information Criterion
188    pub aic: F,
189    /// Bayesian Information Criterion
190    pub bic: F,
191    /// Integrated Classification Likelihood
192    pub icl: F,
193    /// Hannan-Quinn Information Criterion
194    pub hqic: F,
195    /// Cross-validation log-likelihood
196    pub cv_log_likelihood: Option<F>,
197    /// Number of effective parameters
198    pub n_parameters: usize,
199}
200
201/// Component diagnostics
202#[derive(Debug, Clone)]
203pub struct ComponentDiagnostics<F> {
204    /// Effective sample size
205    pub effective_samplesize: F,
206    /// Condition number of covariance
207    pub condition_number: F,
208    /// Determinant of covariance
209    pub covariance_determinant: F,
210    /// Component separation (minimum Mahalanobis distance to other components)
211    pub component_separation: F,
212    /// Relative weight change over iterations
213    pub weight_stability: F,
214}
215
216/// Parameter snapshot for tracking changes
217#[derive(Debug, Clone)]
218pub struct ParameterSnapshot<F> {
219    /// Iteration number
220    pub iteration: usize,
221    /// Log-likelihood at this iteration
222    pub log_likelihood: F,
223    /// Parameter change norm
224    pub parameter_change: F,
225    /// Weights at this iteration
226    pub weights: Array1<F>,
227}
228
229impl Default for GMMConfig {
230    fn default() -> Self {
231        Self {
232            max_iter: 100,
233            tolerance: 1e-3,
234            param_tolerance: 1e-4,
235            covariance_type: CovarianceType::Full,
236            reg_covar: 1e-6,
237            init_method: InitializationMethod::KMeansPlus,
238            n_init: 1,
239            seed: None,
240            parallel: true,
241            use_simd: true,
242            warm_start: false,
243            robust_em: false,
244            outlier_threshold: 0.01,
245            early_stopping: false,
246            validation_fraction: 0.1,
247            patience: 10,
248        }
249    }
250}
251
252// ---------------------------------------------------------------------------
253// Helper: trait alias
254// ---------------------------------------------------------------------------
255
256/// Trait bound alias used throughout this module
257pub trait GmmFloat:
258    Float
259    + Zero
260    + One
261    + Copy
262    + Send
263    + Sync
264    + SimdUnifiedOps
265    + FromPrimitive
266    + std::fmt::Display
267    + std::iter::Sum
268    + scirs2_core::ndarray::ScalarOperand
269{
270}
271
272impl<F> GmmFloat for F where
273    F: Float
274        + Zero
275        + One
276        + Copy
277        + Send
278        + Sync
279        + SimdUnifiedOps
280        + FromPrimitive
281        + std::fmt::Display
282        + std::iter::Sum
283        + scirs2_core::ndarray::ScalarOperand
284{
285}
286
287// ---------------------------------------------------------------------------
288// Helper: convert f64 -> F with proper error
289// ---------------------------------------------------------------------------
290
291fn f64_to_f<F: Float + FromPrimitive>(v: f64, ctx: &str) -> StatsResult<F> {
292    F::from(v).ok_or_else(|| {
293        StatsError::ComputationError(format!("Failed to convert f64 ({v}) to float ({ctx})"))
294    })
295}
296
297// ---------------------------------------------------------------------------
298// GaussianMixtureModel implementation
299// ---------------------------------------------------------------------------
300
301impl<F: GmmFloat> GaussianMixtureModel<F> {
302    /// Create new Gaussian Mixture Model
303    pub fn new(n_components: usize, config: GMMConfig) -> StatsResult<Self> {
304        check_positive(n_components, "n_components")?;
305
306        Ok(Self {
307            n_components,
308            config,
309            parameters: None,
310            convergence_history: Vec::new(),
311            _phantom: PhantomData,
312        })
313    }
314
315    // ------------------------------------------------------------------
316    // Public API
317    // ------------------------------------------------------------------
318
319    /// Fit GMM to data using EM algorithm
320    pub fn fit(&mut self, data: &ArrayView2<F>) -> StatsResult<&GMMParameters<F>> {
321        checkarray_finite(data, "data")?;
322
323        let (n_samples, n_features) = data.dim();
324
325        if n_samples < self.n_components {
326            return Err(StatsError::InvalidArgument(format!(
327                "Number of samples ({n_samples}) must be >= number of components ({})",
328                self.n_components
329            )));
330        }
331
332        let inv_k: F = f64_to_f(1.0 / self.n_components as f64, "inv_k")?;
333        let mut weights = Array1::from_elem(self.n_components, inv_k);
334        let mut means = self.initialize_means(data)?;
335        let mut covariances = self.initialize_covariances(data, &means)?;
336
337        let mut log_likelihood = F::neg_infinity();
338        let mut converged = false;
339        self.convergence_history.clear();
340
341        let n_iter_used;
342
343        for iter_idx in 0..self.config.max_iter {
344            let responsibilities = self.e_step(data, &weights, &means, &covariances)?;
345            let new_weights = self.m_step_weights(&responsibilities)?;
346            let new_means = self.m_step_means(data, &responsibilities)?;
347            let new_covariances = self.m_step_covariances(data, &responsibilities, &new_means)?;
348
349            let new_ll =
350                self.compute_log_likelihood(data, &new_weights, &new_means, &new_covariances)?;
351
352            self.convergence_history.push(new_ll);
353
354            let improvement = new_ll - log_likelihood;
355            let tol: F = f64_to_f(self.config.tolerance, "tolerance")?;
356            if improvement.abs() < tol && iter_idx > 0 {
357                converged = true;
358            }
359
360            weights = new_weights;
361            means = new_means;
362            covariances = new_covariances;
363            log_likelihood = new_ll;
364
365            if converged {
366                n_iter_used = iter_idx + 1;
367                self.store_parameters(
368                    weights,
369                    means,
370                    covariances,
371                    log_likelihood,
372                    n_iter_used,
373                    converged,
374                    n_samples,
375                    n_features,
376                    data,
377                )?;
378                return self
379                    .parameters
380                    .as_ref()
381                    .ok_or_else(|| StatsError::ComputationError("Parameters not stored".into()));
382            }
383        }
384
385        n_iter_used = self.config.max_iter;
386        self.store_parameters(
387            weights,
388            means,
389            covariances,
390            log_likelihood,
391            n_iter_used,
392            false,
393            n_samples,
394            n_features,
395            data,
396        )?;
397
398        self.parameters
399            .as_ref()
400            .ok_or_else(|| StatsError::ComputationError("Parameters not stored".into()))
401    }
402
403    /// Predict cluster assignments (hard assignment: argmax of responsibilities)
404    pub fn predict(&self, data: &ArrayView2<F>) -> StatsResult<Array1<usize>> {
405        let params = self.require_fitted()?;
406        let responsibilities =
407            self.e_step(data, &params.weights, &params.means, &params.covariances)?;
408
409        let mut predictions = Array1::zeros(data.nrows());
410        for i in 0..data.nrows() {
411            let mut max_resp = F::neg_infinity();
412            let mut best = 0usize;
413            for k in 0..self.n_components {
414                if responsibilities[[i, k]] > max_resp {
415                    max_resp = responsibilities[[i, k]];
416                    best = k;
417                }
418            }
419            predictions[i] = best;
420        }
421        Ok(predictions)
422    }
423
424    /// Predict soft cluster assignment (responsibility matrix)
425    pub fn predict_proba(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
426        let params = self.require_fitted()?;
427        self.e_step(data, &params.weights, &params.means, &params.covariances)
428    }
429
430    /// Average log-likelihood per sample
431    pub fn score(&self, data: &ArrayView2<F>) -> StatsResult<F> {
432        let params = self.require_fitted()?;
433        let total_ll =
434            self.compute_log_likelihood(data, &params.weights, &params.means, &params.covariances)?;
435        let n: F = f64_to_f(data.nrows() as f64, "n_samples")?;
436        Ok(total_ll / n)
437    }
438
439    /// Per-sample log-likelihood
440    pub fn score_samples(&self, data: &ArrayView2<F>) -> StatsResult<Array1<F>> {
441        let params = self.require_fitted()?;
442        self.per_sample_log_likelihood(data, &params.weights, &params.means, &params.covariances)
443    }
444
445    /// Generate random samples from the fitted mixture model
446    pub fn sample(&self, n: usize, seed: Option<u64>) -> StatsResult<Array2<F>> {
447        let params = self.require_fitted()?;
448        let n_features = params.means.ncols();
449
450        use scirs2_core::random::Random;
451        let mut init_rng = scirs2_core::random::thread_rng();
452        let mut rng = match seed {
453            Some(s) => Random::seed(s),
454            None => Random::seed(init_rng.random()),
455        };
456
457        let mut samples = Array2::zeros((n, n_features));
458
459        for i in 0..n {
460            // 1. Choose component according to weights
461            let u: f64 = rng.random_f64();
462            let mut cumsum = 0.0;
463            let mut chosen_k = self.n_components - 1;
464            for k in 0..self.n_components {
465                let wk = params.weights[k].to_f64().ok_or_else(|| {
466                    StatsError::ComputationError("Weight conversion failed".into())
467                })?;
468                cumsum += wk;
469                if u < cumsum {
470                    chosen_k = k;
471                    break;
472                }
473            }
474
475            // 2. Sample from the chosen component using Cholesky decomposition
476            let mean = params.means.row(chosen_k);
477            let cov = &params.covariances[chosen_k];
478
479            // Generate z ~ N(0,I) using Box-Muller
480            let mut z = Array1::<f64>::zeros(n_features);
481            for j in (0..n_features).step_by(2) {
482                let u1: f64 = rng.random_f64().max(1e-300);
483                let u2: f64 = rng.random_f64();
484                let r = (-2.0 * u1.ln()).sqrt();
485                let theta = 2.0 * std::f64::consts::PI * u2;
486                z[j] = r * theta.cos();
487                if j + 1 < n_features {
488                    z[j + 1] = r * theta.sin();
489                }
490            }
491
492            let cov_f64 = cov.mapv(|x| x.to_f64().unwrap_or(0.0));
493            let chol = cholesky_lower(&cov_f64)?;
494            let sampled = chol.dot(&z);
495            for j in 0..n_features {
496                let val: F = f64_to_f(sampled[j], "sample_val")?;
497                samples[[i, j]] = mean[j] + val;
498            }
499        }
500
501        Ok(samples)
502    }
503
504    /// Bayesian Information Criterion for the fitted model
505    pub fn bic(&self, _data: &ArrayView2<F>) -> StatsResult<F> {
506        let params = self.require_fitted()?;
507        Ok(params.model_selection.bic)
508    }
509
510    /// Akaike Information Criterion for the fitted model
511    pub fn aic(&self, _data: &ArrayView2<F>) -> StatsResult<F> {
512        let params = self.require_fitted()?;
513        Ok(params.model_selection.aic)
514    }
515
516    /// Number of free parameters in the model
517    pub fn n_parameters(&self) -> StatsResult<usize> {
518        let params = self.require_fitted()?;
519        Ok(params.model_selection.n_parameters)
520    }
521
522    // ------------------------------------------------------------------
523    // Internal helpers
524    // ------------------------------------------------------------------
525
526    fn require_fitted(&self) -> StatsResult<&GMMParameters<F>> {
527        self.parameters
528            .as_ref()
529            .ok_or_else(|| StatsError::InvalidArgument("Model must be fitted before use".into()))
530    }
531
532    #[allow(clippy::too_many_arguments)]
533    fn store_parameters(
534        &mut self,
535        weights: Array1<F>,
536        means: Array2<F>,
537        covariances: Vec<Array2<F>>,
538        log_likelihood: F,
539        n_iter: usize,
540        converged: bool,
541        n_samples: usize,
542        n_features: usize,
543        data: &ArrayView2<F>,
544    ) -> StatsResult<()> {
545        let n_params = self.compute_n_parameters(n_features);
546        let n_f: F = f64_to_f(n_samples as f64, "n_samples")?;
547        let p_f: F = f64_to_f(n_params as f64, "n_params")?;
548        let two: F = f64_to_f(2.0, "two")?;
549
550        let aic = -two * log_likelihood + two * p_f;
551        let bic = -two * log_likelihood + p_f * n_f.ln();
552        let hqic = -two * log_likelihood + two * p_f * n_f.ln().ln();
553
554        let responsibilities = self.e_step(data, &weights, &means, &covariances)?;
555        let entropy = self.responsibility_entropy(&responsibilities);
556        let icl = bic - two * entropy;
557
558        let mut diagnostics = Vec::with_capacity(self.n_components);
559        for k in 0..self.n_components {
560            let nk = responsibilities.column(k).sum();
561            let cov_f64 = covariances[k].mapv(|x| x.to_f64().unwrap_or(0.0));
562            let det = scirs2_linalg::det(&cov_f64.view(), None).unwrap_or(1.0);
563            let cond = self.estimate_condition_number(&cov_f64);
564            let sep = self.compute_component_separation(k, &means, &covariances);
565
566            diagnostics.push(ComponentDiagnostics {
567                effective_samplesize: nk,
568                condition_number: f64_to_f(cond, "cond").unwrap_or(F::one()),
569                covariance_determinant: f64_to_f(det.abs(), "det").unwrap_or(F::one()),
570                component_separation: sep,
571                weight_stability: F::zero(),
572            });
573        }
574
575        let parameters = GMMParameters {
576            weights,
577            means,
578            covariances,
579            log_likelihood,
580            n_iter,
581            converged,
582            convergence_reason: if converged {
583                ConvergenceReason::LogLikelihoodTolerance
584            } else {
585                ConvergenceReason::MaxIterations
586            },
587            model_selection: ModelSelectionCriteria {
588                aic,
589                bic,
590                icl,
591                hqic,
592                cv_log_likelihood: None,
593                n_parameters: n_params,
594            },
595            component_diagnostics: diagnostics,
596            outlier_scores: None,
597            responsibilities: Some(responsibilities),
598            parameter_history: Vec::new(),
599        };
600
601        self.parameters = Some(parameters);
602        Ok(())
603    }
604
605    fn compute_n_parameters(&self, d: usize) -> usize {
606        let k = self.n_components;
607        let weight_params = k - 1;
608        let mean_params = k * d;
609        let cov_params = match &self.config.covariance_type {
610            CovarianceType::Full => k * d * (d + 1) / 2,
611            CovarianceType::Diagonal => k * d,
612            CovarianceType::Tied => d * (d + 1) / 2,
613            CovarianceType::Spherical => k,
614            CovarianceType::Factor { n_factors } => k * (d * n_factors + d),
615            CovarianceType::Constrained { .. } => k * d * (d + 1) / 2,
616        };
617        weight_params + mean_params + cov_params
618    }
619
620    fn responsibility_entropy(&self, resp: &Array2<F>) -> F {
621        let mut entropy = F::zero();
622        let eps: F = f64_to_f(1e-300, "eps").unwrap_or(F::min_positive_value());
623        for row in resp.rows() {
624            for &r in row.iter() {
625                if r > eps {
626                    entropy = entropy + r * r.ln();
627                }
628            }
629        }
630        entropy
631    }
632
633    fn estimate_condition_number(&self, cov: &Array2<f64>) -> f64 {
634        let diag: Vec<f64> = (0..cov.nrows()).map(|i| cov[[i, i]].abs()).collect();
635        let max_d = diag.iter().copied().fold(f64::NEG_INFINITY, f64::max);
636        let min_d = diag
637            .iter()
638            .copied()
639            .filter(|&v| v > 1e-300)
640            .fold(f64::INFINITY, f64::min);
641        if min_d > 0.0 {
642            max_d / min_d
643        } else {
644            f64::INFINITY
645        }
646    }
647
648    fn compute_component_separation(&self, k: usize, means: &Array2<F>, _covs: &[Array2<F>]) -> F {
649        let mut min_dist = F::infinity();
650        let mean_k = means.row(k);
651        for j in 0..self.n_components {
652            if j == k {
653                continue;
654            }
655            let mean_j = means.row(j);
656            let d: F = mean_k
657                .iter()
658                .zip(mean_j.iter())
659                .map(|(&a, &b)| (a - b) * (a - b))
660                .sum();
661            let d_sqrt = d.sqrt();
662            if d_sqrt < min_dist {
663                min_dist = d_sqrt;
664            }
665        }
666        min_dist
667    }
668
669    fn initialize_means(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
670        let (n_samples, n_features) = data.dim();
671        let mut means = Array2::zeros((self.n_components, n_features));
672
673        match self.config.init_method {
674            InitializationMethod::Random => {
675                use scirs2_core::random::Random;
676                let mut init_rng = scirs2_core::random::thread_rng();
677                let mut rng = match self.config.seed {
678                    Some(seed) => Random::seed(seed),
679                    None => Random::seed(init_rng.random()),
680                };
681                for i in 0..self.n_components {
682                    let idx = rng.random_range(0..n_samples);
683                    means.row_mut(i).assign(&data.row(idx));
684                }
685            }
686            InitializationMethod::KMeansPlus => {
687                means = self.kmeans_plus_plus_init(data)?;
688            }
689            InitializationMethod::FurthestFirst => {
690                means = self.furthest_first_init(data)?;
691            }
692            InitializationMethod::Quantile => {
693                means = self.quantile_init(data)?;
694            }
695            _ => {
696                means = self.kmeans_plus_plus_init(data)?;
697            }
698        }
699
700        Ok(means)
701    }
702
703    fn kmeans_plus_plus_init(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
704        use scirs2_core::random::Random;
705        let mut init_rng = scirs2_core::random::thread_rng();
706        let mut rng = match self.config.seed {
707            Some(seed) => Random::seed(seed),
708            None => Random::seed(init_rng.random()),
709        };
710
711        let (n_samples, n_features) = data.dim();
712        let mut means = Array2::zeros((self.n_components, n_features));
713        let first_idx = rng.random_range(0..n_samples);
714        means.row_mut(0).assign(&data.row(first_idx));
715
716        for i in 1..self.n_components {
717            let mut distances = Array1::zeros(n_samples);
718            for j in 0..n_samples {
719                let mut min_dist = F::infinity();
720                for k_idx in 0..i {
721                    let dist = self.squared_distance(&data.row(j), &means.row(k_idx));
722                    min_dist = min_dist.min(dist);
723                }
724                distances[j] = min_dist;
725            }
726
727            let total_dist: F = distances.sum();
728            if total_dist <= F::zero() {
729                let idx = rng.random_range(0..n_samples);
730                means.row_mut(i).assign(&data.row(idx));
731                continue;
732            }
733
734            let threshold_f64: f64 = rng.random_f64();
735            let threshold_ratio: F = F::from(threshold_f64)
736                .ok_or_else(|| StatsError::ComputationError("threshold conversion".into()))?;
737            let threshold: F = threshold_ratio * total_dist;
738            let mut cumsum = F::zero();
739            let mut picked = false;
740            for j in 0..n_samples {
741                cumsum = cumsum + distances[j];
742                if cumsum >= threshold {
743                    means.row_mut(i).assign(&data.row(j));
744                    picked = true;
745                    break;
746                }
747            }
748            if !picked {
749                means
750                    .row_mut(i)
751                    .assign(&data.row(n_samples.saturating_sub(1)));
752            }
753        }
754
755        Ok(means)
756    }
757
758    fn furthest_first_init(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
759        use scirs2_core::random::Random;
760        let mut init_rng = scirs2_core::random::thread_rng();
761        let mut rng = match self.config.seed {
762            Some(s) => Random::seed(s),
763            None => Random::seed(init_rng.random()),
764        };
765
766        let (n_samples, n_features) = data.dim();
767        let mut means = Array2::zeros((self.n_components, n_features));
768        let first_idx = rng.random_range(0..n_samples);
769        means.row_mut(0).assign(&data.row(first_idx));
770
771        for i in 1..self.n_components {
772            let mut best_idx = 0;
773            let mut best_dist = F::neg_infinity();
774            for j in 0..n_samples {
775                let mut min_dist = F::infinity();
776                for k_idx in 0..i {
777                    let d = self.squared_distance(&data.row(j), &means.row(k_idx));
778                    min_dist = min_dist.min(d);
779                }
780                if min_dist > best_dist {
781                    best_dist = min_dist;
782                    best_idx = j;
783                }
784            }
785            means.row_mut(i).assign(&data.row(best_idx));
786        }
787
788        Ok(means)
789    }
790
791    fn quantile_init(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
792        let (n_samples, n_features) = data.dim();
793        let mut means = Array2::zeros((self.n_components, n_features));
794        for i in 0..self.n_components {
795            let frac = (i as f64 + 0.5) / self.n_components as f64;
796            let idx = ((frac * n_samples as f64) as usize).min(n_samples.saturating_sub(1));
797            means.row_mut(i).assign(&data.row(idx));
798        }
799        Ok(means)
800    }
801
802    fn initialize_covariances(
803        &self,
804        data: &ArrayView2<F>,
805        _means: &Array2<F>,
806    ) -> StatsResult<Vec<Array2<F>>> {
807        let n_features = data.ncols();
808        let n_samples = data.nrows();
809        let mut covariances = Vec::with_capacity(self.n_components);
810        let n_f: F = f64_to_f(n_samples as f64, "n_samples_init")?;
811        let reg: F = f64_to_f(self.config.reg_covar, "reg_covar")?;
812
813        let mut data_var = Array1::zeros(n_features);
814        for j in 0..n_features {
815            let col_mean: F = data.column(j).sum() / n_f;
816            let var: F = data
817                .column(j)
818                .iter()
819                .map(|&x| (x - col_mean) * (x - col_mean))
820                .sum::<F>()
821                / n_f;
822            data_var[j] = if var > F::zero() { var } else { F::one() };
823        }
824
825        for _i in 0..self.n_components {
826            let mut cov = Array2::zeros((n_features, n_features));
827            for j in 0..n_features {
828                cov[[j, j]] = data_var[j] + reg;
829            }
830            covariances.push(cov);
831        }
832        Ok(covariances)
833    }
834
835    fn e_step(
836        &self,
837        data: &ArrayView2<F>,
838        weights: &Array1<F>,
839        means: &Array2<F>,
840        covariances: &[Array2<F>],
841    ) -> StatsResult<Array2<F>> {
842        let n_samples = data.shape()[0];
843        let mut responsibilities = Array2::zeros((n_samples, self.n_components));
844
845        for i in 0..n_samples {
846            let sample = data.row(i);
847            let mut log_probs = Array1::zeros(self.n_components);
848
849            for k in 0..self.n_components {
850                let mean = means.row(k);
851                let log_prob = self.log_multivariate_normal_pdf(&sample, &mean, &covariances[k])?;
852                log_probs[k] = weights[k].ln() + log_prob;
853            }
854
855            let max_lp = log_probs.iter().copied().fold(F::neg_infinity(), F::max);
856            if max_lp == F::neg_infinity() {
857                let uni: F = f64_to_f(1.0 / self.n_components as f64, "uniform")?;
858                for k in 0..self.n_components {
859                    responsibilities[[i, k]] = uni;
860                }
861                continue;
862            }
863            let log_sum_exp = (log_probs.mapv(|x| (x - max_lp).exp()).sum()).ln() + max_lp;
864
865            for k in 0..self.n_components {
866                responsibilities[[i, k]] = (log_probs[k] - log_sum_exp).exp();
867            }
868        }
869        Ok(responsibilities)
870    }
871
872    fn m_step_weights(&self, responsibilities: &Array2<F>) -> StatsResult<Array1<F>> {
873        let n_f: F = f64_to_f(responsibilities.nrows() as f64, "n_samples_m")?;
874        let mut weights = Array1::zeros(self.n_components);
875        for k in 0..self.n_components {
876            weights[k] = responsibilities.column(k).sum() / n_f;
877        }
878        Ok(weights)
879    }
880
881    fn m_step_means(
882        &self,
883        data: &ArrayView2<F>,
884        responsibilities: &Array2<F>,
885    ) -> StatsResult<Array2<F>> {
886        let n_features = data.ncols();
887        let mut means = Array2::zeros((self.n_components, n_features));
888        let eps: F = f64_to_f(1e-10, "eps_m")?;
889
890        for k in 0..self.n_components {
891            let resp_sum = responsibilities.column(k).sum();
892            if resp_sum > eps {
893                for j in 0..n_features {
894                    let weighted_sum: F = data
895                        .column(j)
896                        .iter()
897                        .zip(responsibilities.column(k).iter())
898                        .map(|(&x, &r)| x * r)
899                        .sum();
900                    means[[k, j]] = weighted_sum / resp_sum;
901                }
902            }
903        }
904        Ok(means)
905    }
906
907    fn m_step_covariances(
908        &self,
909        data: &ArrayView2<F>,
910        responsibilities: &Array2<F>,
911        means: &Array2<F>,
912    ) -> StatsResult<Vec<Array2<F>>> {
913        let n_features = data.ncols();
914        let mut covariances = Vec::with_capacity(self.n_components);
915        let eps: F = f64_to_f(1e-10, "eps_cov")?;
916        let reg: F = f64_to_f(self.config.reg_covar, "reg_covar")?;
917
918        for k in 0..self.n_components {
919            let resp_sum = responsibilities.column(k).sum();
920            let mean_k = means.row(k);
921            let mut cov = Array2::zeros((n_features, n_features));
922
923            if resp_sum > eps {
924                for i in 0..data.nrows() {
925                    let diff = &data.row(i) - &mean_k;
926                    let resp = responsibilities[[i, k]];
927                    for j in 0..n_features {
928                        for l in 0..n_features {
929                            cov[[j, l]] = cov[[j, l]] + resp * diff[j] * diff[l];
930                        }
931                    }
932                }
933                cov = cov / resp_sum;
934            }
935
936            for i in 0..n_features {
937                cov[[i, i]] = cov[[i, i]] + reg;
938            }
939
940            match self.config.covariance_type {
941                CovarianceType::Diagonal => {
942                    for i in 0..n_features {
943                        for j in 0..n_features {
944                            if i != j {
945                                cov[[i, j]] = F::zero();
946                            }
947                        }
948                    }
949                }
950                CovarianceType::Spherical => {
951                    let n_feat_f: F = f64_to_f(n_features as f64, "n_feat")?;
952                    let trace = cov.diag().sum() / n_feat_f;
953                    cov = Array2::eye(n_features) * trace;
954                }
955                _ => {}
956            }
957
958            covariances.push(cov);
959        }
960        Ok(covariances)
961    }
962
963    fn log_multivariate_normal_pdf(
964        &self,
965        x: &ArrayView1<F>,
966        mean: &ArrayView1<F>,
967        cov: &Array2<F>,
968    ) -> StatsResult<F> {
969        let d = x.len();
970        let diff = x - mean;
971
972        let cov_f64 = cov.mapv(|v| v.to_f64().unwrap_or(0.0));
973        let det = scirs2_linalg::det(&cov_f64.view(), None).map_err(|e| {
974            StatsError::ComputationError(format!("Determinant computation failed: {e}"))
975        })?;
976
977        if det <= 0.0 {
978            return Ok(F::neg_infinity());
979        }
980
981        let log_det = det.ln();
982        let cov_inv = scirs2_linalg::inv(&cov_f64.view(), None)
983            .map_err(|e| StatsError::ComputationError(format!("Matrix inversion failed: {e}")))?;
984
985        let diff_f64 = diff.mapv(|v| v.to_f64().unwrap_or(0.0));
986        let quad_form = diff_f64.dot(&cov_inv.dot(&diff_f64));
987
988        let log_pdf = -0.5 * (d as f64 * (2.0 * std::f64::consts::PI).ln() + log_det + quad_form);
989        f64_to_f(log_pdf, "log_pdf")
990    }
991
992    fn compute_log_likelihood(
993        &self,
994        data: &ArrayView2<F>,
995        weights: &Array1<F>,
996        means: &Array2<F>,
997        covariances: &[Array2<F>],
998    ) -> StatsResult<F> {
999        let per_sample = self.per_sample_log_likelihood(data, weights, means, covariances)?;
1000        Ok(per_sample.sum())
1001    }
1002
1003    fn per_sample_log_likelihood(
1004        &self,
1005        data: &ArrayView2<F>,
1006        weights: &Array1<F>,
1007        means: &Array2<F>,
1008        covariances: &[Array2<F>],
1009    ) -> StatsResult<Array1<F>> {
1010        let n_samples = data.nrows();
1011        let mut scores = Array1::zeros(n_samples);
1012
1013        for i in 0..n_samples {
1014            let sample = data.row(i);
1015            let mut log_probs = Array1::zeros(self.n_components);
1016
1017            for k in 0..self.n_components {
1018                let mean = means.row(k);
1019                let log_prob = self.log_multivariate_normal_pdf(&sample, &mean, &covariances[k])?;
1020                log_probs[k] = weights[k].ln() + log_prob;
1021            }
1022
1023            let max_lp = log_probs.iter().copied().fold(F::neg_infinity(), F::max);
1024            let log_sum_exp = (log_probs.mapv(|x| (x - max_lp).exp()).sum()).ln() + max_lp;
1025            scores[i] = log_sum_exp;
1026        }
1027        Ok(scores)
1028    }
1029
1030    fn squared_distance(&self, a: &ArrayView1<F>, b: &ArrayView1<F>) -> F {
1031        a.iter()
1032            .zip(b.iter())
1033            .map(|(&x, &y)| (x - y) * (x - y))
1034            .sum()
1035    }
1036}
1037
1038// ---------------------------------------------------------------------------
1039// Cholesky decomposition helper (pure Rust, lower-triangular)
1040// ---------------------------------------------------------------------------
1041
1042fn cholesky_lower(a: &Array2<f64>) -> StatsResult<Array2<f64>> {
1043    let n = a.nrows();
1044    if n != a.ncols() {
1045        return Err(StatsError::DimensionMismatch(
1046            "Cholesky requires a square matrix".into(),
1047        ));
1048    }
1049    let mut l = Array2::<f64>::zeros((n, n));
1050    for i in 0..n {
1051        for j in 0..=i {
1052            let mut sum = 0.0;
1053            for k in 0..j {
1054                sum += l[[i, k]] * l[[j, k]];
1055            }
1056            if i == j {
1057                let diag = a[[i, i]] - sum;
1058                if diag <= 0.0 {
1059                    l[[i, j]] = (diag.abs() + 1e-10).sqrt();
1060                } else {
1061                    l[[i, j]] = diag.sqrt();
1062                }
1063            } else if l[[j, j]].abs() > 1e-300 {
1064                l[[i, j]] = (a[[i, j]] - sum) / l[[j, j]];
1065            }
1066        }
1067    }
1068    Ok(l)
1069}
1070
1071// ---------------------------------------------------------------------------
1072// Convenience functions
1073// ---------------------------------------------------------------------------
1074
1075/// Fit a GMM and return its parameters
1076pub fn gaussian_mixture_model<F: GmmFloat>(
1077    data: &ArrayView2<F>,
1078    n_components: usize,
1079    config: Option<GMMConfig>,
1080) -> StatsResult<GMMParameters<F>> {
1081    let config = config.unwrap_or_default();
1082    let mut gmm = GaussianMixtureModel::new(n_components, config)?;
1083    Ok(gmm.fit(data)?.clone())
1084}
1085
1086/// Advanced model selection for GMM: try min..=max components, return best by BIC
1087pub fn gmm_model_selection<F: GmmFloat>(
1088    data: &ArrayView2<F>,
1089    min_components: usize,
1090    max_components: usize,
1091    config: Option<GMMConfig>,
1092) -> StatsResult<(usize, GMMParameters<F>)> {
1093    let config = config.unwrap_or_default();
1094    let mut best_n = min_components;
1095    let mut best_bic = F::infinity();
1096    let mut best_params: Option<GMMParameters<F>> = None;
1097
1098    for n_comp in min_components..=max_components {
1099        let mut gmm = GaussianMixtureModel::new(n_comp, config.clone())?;
1100        let params = gmm.fit(data)?;
1101
1102        if params.model_selection.bic < best_bic {
1103            best_bic = params.model_selection.bic;
1104            best_n = n_comp;
1105            best_params = Some(params.clone());
1106        }
1107    }
1108
1109    let params = best_params.ok_or_else(|| {
1110        StatsError::ComputationError("No valid model found during selection".into())
1111    })?;
1112    Ok((best_n, params))
1113}
1114
1115/// Select the optimal number of components by BIC or AIC.
1116///
1117/// Returns `(best_k, scores)` where `scores[i]` is the criterion value for k = 1..=max_k.
1118pub fn select_n_components<F: GmmFloat>(
1119    data: &ArrayView2<F>,
1120    max_k: usize,
1121    criterion: &str,
1122) -> StatsResult<(usize, Vec<f64>)> {
1123    if max_k == 0 {
1124        return Err(StatsError::InvalidArgument("max_k must be >= 1".into()));
1125    }
1126
1127    let mut scores = Vec::with_capacity(max_k);
1128    let mut best_k = 1usize;
1129    let mut best_score = f64::INFINITY;
1130
1131    for k in 1..=max_k {
1132        let config = GMMConfig {
1133            max_iter: 100,
1134            ..Default::default()
1135        };
1136        let mut gmm = GaussianMixtureModel::<F>::new(k, config)?;
1137        let params = gmm.fit(data)?;
1138
1139        let score_f64 = match criterion {
1140            "aic" | "AIC" => params.model_selection.aic.to_f64().unwrap_or(f64::INFINITY),
1141            _ => params.model_selection.bic.to_f64().unwrap_or(f64::INFINITY),
1142        };
1143
1144        scores.push(score_f64);
1145
1146        if score_f64 < best_score {
1147            best_score = score_f64;
1148            best_k = k;
1149        }
1150    }
1151
1152    Ok((best_k, scores))
1153}
1154
1155// ---------------------------------------------------------------------------
1156// RobustGMM
1157// ---------------------------------------------------------------------------
1158
1159/// Robust Gaussian Mixture Model with outlier detection
1160pub struct RobustGMM<F> {
1161    /// Base GMM
1162    pub gmm: GaussianMixtureModel<F>,
1163    /// Outlier detection threshold
1164    pub outlier_threshold: F,
1165    /// Contamination rate (expected fraction of outliers)
1166    pub contamination: F,
1167    _phantom: PhantomData<F>,
1168}
1169
1170impl<F: GmmFloat> RobustGMM<F> {
1171    /// Create new Robust GMM
1172    pub fn new(
1173        n_components: usize,
1174        outlier_threshold: F,
1175        contamination: F,
1176        mut config: GMMConfig,
1177    ) -> StatsResult<Self> {
1178        config.robust_em = true;
1179        config.outlier_threshold = outlier_threshold.to_f64().unwrap_or(0.01);
1180
1181        let gmm = GaussianMixtureModel::new(n_components, config)?;
1182        Ok(Self {
1183            gmm,
1184            outlier_threshold,
1185            contamination,
1186            _phantom: PhantomData,
1187        })
1188    }
1189
1190    /// Fit robust GMM with outlier detection
1191    pub fn fit(&mut self, data: &ArrayView2<F>) -> StatsResult<&GMMParameters<F>> {
1192        self.gmm.fit(data)?;
1193        let outlier_scores = self.compute_outlier_scores(data)?;
1194
1195        if let Some(ref mut params) = self.gmm.parameters {
1196            params.outlier_scores = Some(outlier_scores);
1197        }
1198
1199        self.gmm.parameters.as_ref().ok_or_else(|| {
1200            StatsError::ComputationError("Parameters not stored after robust fit".into())
1201        })
1202    }
1203
1204    fn compute_outlier_scores(&self, data: &ArrayView2<F>) -> StatsResult<Array1<F>> {
1205        let params = self.gmm.require_fitted()?;
1206        let per_sample_ll = self.gmm.per_sample_log_likelihood(
1207            data,
1208            &params.weights,
1209            &params.means,
1210            &params.covariances,
1211        )?;
1212        Ok(per_sample_ll.mapv(|x| -x))
1213    }
1214
1215    /// Detect outliers in data based on contamination rate
1216    pub fn detect_outliers(&self, _data: &ArrayView2<F>) -> StatsResult<Array1<bool>> {
1217        let params = self.gmm.require_fitted()?;
1218
1219        let outlier_scores = params.outlier_scores.as_ref().ok_or_else(|| {
1220            StatsError::InvalidArgument("Robust EM must be enabled for outlier detection".into())
1221        })?;
1222
1223        let mut sorted: Vec<F> = outlier_scores.iter().copied().collect();
1224        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1225
1226        let threshold_idx_f =
1227            (F::one() - self.contamination) * f64_to_f(sorted.len() as f64, "sorted_len")?;
1228        let threshold_idx = threshold_idx_f
1229            .to_usize()
1230            .unwrap_or(sorted.len().saturating_sub(1))
1231            .min(sorted.len().saturating_sub(1));
1232        let adaptive_threshold = sorted[threshold_idx];
1233
1234        let outliers = outlier_scores.mapv(|score| score > adaptive_threshold);
1235        Ok(outliers)
1236    }
1237}
1238
1239// ---------------------------------------------------------------------------
1240// StreamingGMM
1241// ---------------------------------------------------------------------------
1242
1243/// Streaming/Online Gaussian Mixture Model
1244pub struct StreamingGMM<F> {
1245    /// Base GMM
1246    pub gmm: GaussianMixtureModel<F>,
1247    /// Learning rate for online updates
1248    pub learning_rate: F,
1249    /// Decay factor for old data
1250    pub decay_factor: F,
1251    /// Number of samples processed
1252    pub n_samples_seen: usize,
1253    /// Running statistics
1254    pub running_means: Option<Array2<F>>,
1255    pub running_covariances: Option<Vec<Array2<F>>>,
1256    pub running_weights: Option<Array1<F>>,
1257    _phantom: PhantomData<F>,
1258}
1259
1260impl<F: GmmFloat> StreamingGMM<F> {
1261    /// Create new Streaming GMM
1262    pub fn new(
1263        n_components: usize,
1264        learning_rate: F,
1265        decay_factor: F,
1266        config: GMMConfig,
1267    ) -> StatsResult<Self> {
1268        let gmm = GaussianMixtureModel::new(n_components, config)?;
1269        Ok(Self {
1270            gmm,
1271            learning_rate,
1272            decay_factor,
1273            n_samples_seen: 0,
1274            running_means: None,
1275            running_covariances: None,
1276            running_weights: None,
1277            _phantom: PhantomData,
1278        })
1279    }
1280
1281    /// Update model with new batch of data
1282    pub fn partial_fit(&mut self, batch: &ArrayView2<F>) -> StatsResult<()> {
1283        let batchsize = batch.nrows();
1284
1285        if self.n_samples_seen == 0 {
1286            self.gmm.fit(batch)?;
1287            let params = self.gmm.require_fitted()?;
1288            self.running_means = Some(params.means.clone());
1289            self.running_covariances = Some(params.covariances.clone());
1290            self.running_weights = Some(params.weights.clone());
1291        } else {
1292            self.online_update(batch)?;
1293        }
1294
1295        self.n_samples_seen += batchsize;
1296        Ok(())
1297    }
1298
1299    fn online_update(&mut self, batch: &ArrayView2<F>) -> StatsResult<()> {
1300        let params = self.gmm.require_fitted()?;
1301
1302        let responsibilities =
1303            self.gmm
1304                .e_step(batch, &params.weights, &params.means, &params.covariances)?;
1305
1306        let batch_weights = self.gmm.m_step_weights(&responsibilities)?;
1307        let batch_means = self.gmm.m_step_means(batch, &responsibilities)?;
1308
1309        let lr = self.learning_rate;
1310        let decay = self.decay_factor;
1311
1312        if let (Some(ref mut r_weights), Some(ref mut r_means)) =
1313            (&mut self.running_weights, &mut self.running_means)
1314        {
1315            *r_weights = r_weights.mapv(|x| x * decay) + batch_weights.mapv(|x| x * lr);
1316            let weight_sum = r_weights.sum();
1317            if weight_sum > F::zero() {
1318                *r_weights = r_weights.mapv(|x| x / weight_sum);
1319            }
1320            *r_means = r_means.mapv(|x| x * decay) + batch_means.mapv(|x| x * lr);
1321        }
1322
1323        if let Some(ref mut p) = self.gmm.parameters {
1324            if let Some(ref rw) = self.running_weights {
1325                p.weights = rw.clone();
1326            }
1327            if let Some(ref rm) = self.running_means {
1328                p.means = rm.clone();
1329            }
1330        }
1331
1332        Ok(())
1333    }
1334
1335    /// Get current model parameters
1336    pub fn get_parameters(&self) -> Option<&GMMParameters<F>> {
1337        self.gmm.parameters.as_ref()
1338    }
1339}
1340
1341// ---------------------------------------------------------------------------
1342// Hierarchical GMM init
1343// ---------------------------------------------------------------------------
1344
1345/// Hierarchical clustering-based mixture model initialization
1346pub fn hierarchical_gmm_init<F: GmmFloat>(
1347    data: &ArrayView2<F>,
1348    n_components: usize,
1349    config: GMMConfig,
1350) -> StatsResult<GMMParameters<F>> {
1351    let mut init_config = config;
1352    init_config.init_method = InitializationMethod::FurthestFirst;
1353    gaussian_mixture_model(data, n_components, Some(init_config))
1354}
1355
1356// ---------------------------------------------------------------------------
1357// GMM cross-validation
1358// ---------------------------------------------------------------------------
1359
1360/// Cross-validation for GMM hyperparameter tuning
1361pub fn gmm_cross_validation<F: GmmFloat>(
1362    data: &ArrayView2<F>,
1363    n_components: usize,
1364    n_folds: usize,
1365    config: GMMConfig,
1366) -> StatsResult<F> {
1367    let n_samples = data.nrows();
1368    if n_folds < 2 || n_folds > n_samples {
1369        return Err(StatsError::InvalidArgument(format!(
1370            "n_folds ({n_folds}) must be in [2, n_samples ({n_samples})]"
1371        )));
1372    }
1373    let foldsize = n_samples / n_folds;
1374    let mut cv_scores = Vec::with_capacity(n_folds);
1375
1376    for fold in 0..n_folds {
1377        let val_start = fold * foldsize;
1378        let val_end = if fold == n_folds - 1 {
1379            n_samples
1380        } else {
1381            (fold + 1) * foldsize
1382        };
1383
1384        let mut train_indices = Vec::new();
1385        for i in 0..n_samples {
1386            if i < val_start || i >= val_end {
1387                train_indices.push(i);
1388            }
1389        }
1390
1391        let traindata = Array2::from_shape_fn((train_indices.len(), data.ncols()), |(i, j)| {
1392            data[[train_indices[i], j]]
1393        });
1394        let valdata = data.slice(s![val_start..val_end, ..]);
1395
1396        let mut gmm = GaussianMixtureModel::new(n_components, config.clone())?;
1397        let params = gmm.fit(&traindata.view())?.clone();
1398
1399        let val_ll = gmm.compute_log_likelihood(
1400            &valdata,
1401            &params.weights,
1402            &params.means,
1403            &params.covariances,
1404        )?;
1405        cv_scores.push(val_ll);
1406    }
1407
1408    let n_folds_f: F = f64_to_f(cv_scores.len() as f64, "cv_n")?;
1409    let avg_score: F = cv_scores.iter().copied().sum::<F>() / n_folds_f;
1410    Ok(avg_score)
1411}
1412
1413// ---------------------------------------------------------------------------
1414// Benchmark helper
1415// ---------------------------------------------------------------------------
1416
1417/// Performance benchmarking for mixture models
1418pub fn benchmark_mixture_models<F: GmmFloat>(
1419    data: &ArrayView2<F>,
1420    methods: &[(
1421        &str,
1422        Box<dyn Fn(&ArrayView2<F>) -> StatsResult<GMMParameters<F>>>,
1423    )],
1424) -> StatsResult<Vec<(String, std::time::Duration, F)>> {
1425    let mut results = Vec::new();
1426    for (name, method) in methods {
1427        let start_time = std::time::Instant::now();
1428        let params = method(data)?;
1429        let duration = start_time.elapsed();
1430        results.push((name.to_string(), duration, params.log_likelihood));
1431    }
1432    Ok(results)
1433}
1434
1435// ---------------------------------------------------------------------------
1436// Tests
1437// ---------------------------------------------------------------------------
1438
1439#[cfg(test)]
1440mod tests;