Skip to main content

scirs2_stats/
mixture_models.rs

1//! Advanced mixture models and kernel density estimation
2//!
3//! This module provides comprehensive implementations of mixture models and
4//! non-parametric density estimation methods including:
5//! - Gaussian Mixture Models (GMM) with robust EM algorithm
6//! - Dirichlet Process Gaussian Mixture Models (DPGMM)
7//! - Variational Bayesian Gaussian Mixture Models
8//! - Online/Streaming EM algorithms
9//! - Robust mixture models with outlier detection
10//! - Model selection criteria (AIC, BIC, ICL)
11//! - Advanced initialization strategies
12//! - Kernel Density Estimation with various kernels
13//! - Adaptive bandwidth selection with cross-validation
14//! - SIMD and parallel optimizations
15//! - Mixture model diagnostics and validation
16
17use crate::error::{StatsError, StatsResult};
18use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2};
19use scirs2_core::numeric::{Float, FromPrimitive, One, Zero};
20use scirs2_core::random::Rng;
21use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
22use std::marker::PhantomData;
23
24/// Gaussian Mixture Model with EM algorithm
25pub struct GaussianMixtureModel<F> {
26    /// Number of components
27    pub n_components: usize,
28    /// Configuration
29    pub config: GMMConfig,
30    /// Fitted parameters
31    pub parameters: Option<GMMParameters<F>>,
32    /// Convergence history
33    pub convergence_history: Vec<F>,
34    _phantom: PhantomData<F>,
35}
36
37/// Advanced GMM configuration
38#[derive(Debug, Clone)]
39pub struct GMMConfig {
40    /// Maximum iterations for EM algorithm
41    pub max_iter: usize,
42    /// Convergence tolerance for log-likelihood
43    pub tolerance: f64,
44    /// Relative tolerance for parameter changes
45    pub param_tolerance: f64,
46    /// Covariance type
47    pub covariance_type: CovarianceType,
48    /// Regularization for covariance matrices
49    pub reg_covar: f64,
50    /// Initialization method
51    pub init_method: InitializationMethod,
52    /// Number of initialization runs (best result selected)
53    pub n_init: usize,
54    /// Random seed
55    pub seed: Option<u64>,
56    /// Enable parallel processing
57    pub parallel: bool,
58    /// Enable SIMD optimizations
59    pub use_simd: bool,
60    /// Warm start (use existing parameters if available)
61    pub warm_start: bool,
62    /// Enable robust EM (outlier detection)
63    pub robust_em: bool,
64    /// Outlier threshold for robust EM
65    pub outlier_threshold: f64,
66    /// Enable early stopping based on validation likelihood
67    pub early_stopping: bool,
68    /// Validation fraction for early stopping
69    pub validation_fraction: f64,
70    /// Patience for early stopping
71    pub patience: usize,
72}
73
74/// Covariance matrix types
75#[derive(Debug, Clone, PartialEq)]
76pub enum CovarianceType {
77    /// Full covariance matrices
78    Full,
79    /// Diagonal covariance matrices
80    Diagonal,
81    /// Tied covariance (same for all components)
82    Tied,
83    /// Spherical covariance (isotropic)
84    Spherical,
85    /// Factor analysis covariance (low-rank + diagonal)
86    Factor {
87        /// Number of factors
88        n_factors: usize,
89    },
90    /// Constrained covariance with specific structure
91    Constrained {
92        /// Constraint type
93        constraint: CovarianceConstraint,
94    },
95}
96
97/// Covariance constraints
98#[derive(Debug, Clone, PartialEq)]
99pub enum CovarianceConstraint {
100    /// Minimum eigenvalue constraint
101    MinEigenvalue(f64),
102    /// Maximum condition number
103    MaxCondition(f64),
104    /// Sparsity pattern
105    Sparse(Vec<(usize, usize)>),
106}
107
108/// Initialization methods
109#[derive(Debug, Clone, PartialEq)]
110pub enum InitializationMethod {
111    /// Random initialization
112    Random,
113    /// K-means++ initialization
114    KMeansPlus,
115    /// K-means with multiple runs
116    KMeans {
117        /// Number of k-means runs
118        n_runs: usize,
119    },
120    /// Furthest-first initialization
121    FurthestFirst,
122    /// User-provided parameters
123    Custom,
124    /// Quantile-based initialization
125    Quantile,
126    /// PCA-based initialization
127    PCA,
128    /// Spectral clustering initialization
129    Spectral,
130}
131
132/// Advanced GMM parameters with diagnostics
133#[derive(Debug, Clone)]
134pub struct GMMParameters<F> {
135    /// Component weights (mixing coefficients)
136    pub weights: Array1<F>,
137    /// Component means
138    pub means: Array2<F>,
139    /// Component covariances
140    pub covariances: Vec<Array2<F>>,
141    /// Log-likelihood
142    pub log_likelihood: F,
143    /// Number of iterations to convergence
144    pub n_iter: usize,
145    /// Converged flag
146    pub converged: bool,
147    /// Convergence reason
148    pub convergence_reason: ConvergenceReason,
149    /// Model selection criteria
150    pub model_selection: ModelSelectionCriteria<F>,
151    /// Component diagnostics
152    pub component_diagnostics: Vec<ComponentDiagnostics<F>>,
153    /// Outlier scores (if robust EM was used)
154    pub outlier_scores: Option<Array1<F>>,
155    /// Responsibility matrix for training data
156    pub responsibilities: Option<Array2<F>>,
157    /// Parameter change history
158    pub parameter_history: Vec<ParameterSnapshot<F>>,
159}
160
161/// Convergence reasons
162#[derive(Debug, Clone, PartialEq)]
163pub enum ConvergenceReason {
164    /// Log-likelihood tolerance reached
165    LogLikelihoodTolerance,
166    /// Parameter change tolerance reached
167    ParameterTolerance,
168    /// Maximum iterations reached
169    MaxIterations,
170    /// Early stopping triggered
171    EarlyStopping,
172    /// Numerical instability detected
173    NumericalInstability,
174}
175
176/// Model selection criteria
177#[derive(Debug, Clone)]
178pub struct ModelSelectionCriteria<F> {
179    /// Akaike Information Criterion
180    pub aic: F,
181    /// Bayesian Information Criterion
182    pub bic: F,
183    /// Integrated Classification Likelihood
184    pub icl: F,
185    /// Hannan-Quinn Information Criterion
186    pub hqic: F,
187    /// Cross-validation log-likelihood
188    pub cv_log_likelihood: Option<F>,
189    /// Number of effective parameters
190    pub n_parameters: usize,
191}
192
193/// Component diagnostics
194#[derive(Debug, Clone)]
195pub struct ComponentDiagnostics<F> {
196    /// Effective sample size
197    pub effective_samplesize: F,
198    /// Condition number of covariance
199    pub condition_number: F,
200    /// Determinant of covariance
201    pub covariance_determinant: F,
202    /// Component separation (minimum Mahalanobis distance to other components)
203    pub component_separation: F,
204    /// Relative weight change over iterations
205    pub weight_stability: F,
206}
207
208/// Parameter snapshot for tracking changes
209#[derive(Debug, Clone)]
210pub struct ParameterSnapshot<F> {
211    /// Iteration number
212    pub iteration: usize,
213    /// Log-likelihood at this iteration
214    pub log_likelihood: F,
215    /// Parameter change norm
216    pub parameter_change: F,
217    /// Weights at this iteration
218    pub weights: Array1<F>,
219}
220
221impl Default for GMMConfig {
222    fn default() -> Self {
223        Self {
224            max_iter: 100,
225            tolerance: 1e-3,
226            param_tolerance: 1e-4,
227            covariance_type: CovarianceType::Full,
228            reg_covar: 1e-6,
229            init_method: InitializationMethod::KMeansPlus,
230            n_init: 1,
231            seed: None,
232            parallel: true,
233            use_simd: true,
234            warm_start: false,
235            robust_em: false,
236            outlier_threshold: 0.01,
237            early_stopping: false,
238            validation_fraction: 0.1,
239            patience: 10,
240        }
241    }
242}
243
244impl<F> GaussianMixtureModel<F>
245where
246    F: Float
247        + Zero
248        + One
249        + Copy
250        + Send
251        + Sync
252        + SimdUnifiedOps
253        + FromPrimitive
254        + std::fmt::Display
255        + std::iter::Sum<F>
256        + scirs2_core::ndarray::ScalarOperand,
257{
258    /// Create new Gaussian Mixture Model
259    pub fn new(_ncomponents: usize, config: GMMConfig) -> StatsResult<Self> {
260        check_positive(_ncomponents, "_n_components")?;
261
262        Ok(Self {
263            n_components: _ncomponents,
264            config,
265            parameters: None,
266            convergence_history: Vec::new(),
267            _phantom: PhantomData,
268        })
269    }
270
271    /// Fit GMM to data using EM algorithm
272    pub fn fit(&mut self, data: &ArrayView2<F>) -> StatsResult<&GMMParameters<F>> {
273        checkarray_finite(data, "data")?;
274
275        let (n_samples, n_features) = data.dim();
276
277        if n_samples < self.n_components {
278            return Err(StatsError::InvalidArgument(format!(
279                "Number of samples ({}) must be >= number of components ({})",
280                n_samples, self.n_components
281            )));
282        }
283
284        // Initialize parameters
285        let mut weights = Array1::from_elem(
286            self.n_components,
287            F::one() / F::from(self.n_components).expect("Failed to convert to float"),
288        );
289        let mut means = self.initialize_means(data)?;
290        let mut covariances = self.initialize_covariances(data, &means)?;
291
292        let mut log_likelihood = F::neg_infinity();
293        let mut converged = false;
294        self.convergence_history.clear();
295
296        // EM algorithm
297        for _iter in 0..self.config.max_iter {
298            // E-step: compute responsibilities
299            let responsibilities = self.e_step(data, &weights, &means, &covariances)?;
300
301            // M-step: update parameters
302            let new_weights = self.m_step_weights(&responsibilities)?;
303            let new_means = self.m_step_means(data, &responsibilities)?;
304            let new_covariances = self.m_step_covariances(data, &responsibilities, &new_means)?;
305
306            // Compute log-likelihood
307            let new_log_likelihood =
308                self.compute_log_likelihood(data, &new_weights, &new_means, &new_covariances)?;
309
310            // Check for convergence
311            let improvement = new_log_likelihood - log_likelihood;
312            self.convergence_history.push(new_log_likelihood);
313
314            if improvement.abs()
315                < F::from(self.config.tolerance).expect("Failed to convert to float")
316            {
317                converged = true;
318            }
319
320            // Update parameters
321            weights = new_weights;
322            means = new_means;
323            covariances = new_covariances;
324            log_likelihood = new_log_likelihood;
325
326            if converged {
327                break;
328            }
329        }
330
331        let parameters = GMMParameters {
332            weights,
333            means,
334            covariances,
335            log_likelihood,
336            n_iter: self.convergence_history.len(),
337            converged,
338            convergence_reason: if converged {
339                ConvergenceReason::LogLikelihoodTolerance
340            } else {
341                ConvergenceReason::MaxIterations
342            },
343            model_selection: ModelSelectionCriteria {
344                aic: F::zero(),  // Placeholder - would compute AIC in full implementation
345                bic: F::zero(),  // Placeholder - would compute BIC in full implementation
346                icl: F::zero(),  // Placeholder - would compute ICL in full implementation
347                hqic: F::zero(), // Placeholder - would compute HQIC in full implementation
348                cv_log_likelihood: None,
349                n_parameters: self.n_components * 2, // Simplified parameter count
350            },
351            component_diagnostics: vec![
352                ComponentDiagnostics {
353                    effective_samplesize: F::zero(),
354                    condition_number: F::one(),
355                    covariance_determinant: F::one(),
356                    component_separation: F::zero(),
357                    weight_stability: F::zero(),
358                };
359                self.n_components
360            ],
361            outlier_scores: None,
362            responsibilities: None,
363            parameter_history: Vec::new(),
364        };
365
366        self.parameters = Some(parameters);
367        Ok(self.parameters.as_ref().expect("Operation failed"))
368    }
369
370    /// Initialize means using chosen method
371    fn initialize_means(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
372        let (n_samples_, n_features) = data.dim();
373        let mut means = Array2::zeros((self.n_components, n_features));
374
375        match self.config.init_method {
376            InitializationMethod::Random => {
377                // Random selection from data points
378                use scirs2_core::random::Random;
379                let mut init_rng = scirs2_core::random::thread_rng();
380                let mut rng = match self.config.seed {
381                    Some(seed) => Random::seed(seed),
382                    None => Random::seed(init_rng.random()),
383                };
384
385                for i in 0..self.n_components {
386                    let idx = rng.random_range(0..n_samples_);
387                    means.row_mut(i).assign(&data.row(idx));
388                }
389            }
390            InitializationMethod::KMeansPlus => {
391                // K-means++ initialization
392                means = self.kmeans_plus_plus_init(data)?;
393            }
394            InitializationMethod::Custom => {
395                // Would be provided by user in full implementation
396                return Err(StatsError::InvalidArgument(
397                    "Custom initialization not implemented".to_string(),
398                ));
399            }
400            InitializationMethod::KMeans { n_runs: _ } => {
401                // K-means with multiple runs initialization
402                return Err(StatsError::InvalidArgument(
403                    "K-means initialization not implemented".to_string(),
404                ));
405            }
406            InitializationMethod::FurthestFirst => {
407                // Furthest-first initialization
408                return Err(StatsError::InvalidArgument(
409                    "Furthest-first initialization not implemented".to_string(),
410                ));
411            }
412            InitializationMethod::Quantile => {
413                // Quantile-based initialization
414                return Err(StatsError::InvalidArgument(
415                    "Quantile initialization not implemented".to_string(),
416                ));
417            }
418            InitializationMethod::PCA => {
419                // PCA-based initialization
420                return Err(StatsError::InvalidArgument(
421                    "PCA initialization not implemented".to_string(),
422                ));
423            }
424            InitializationMethod::Spectral => {
425                // Spectral clustering initialization
426                return Err(StatsError::InvalidArgument(
427                    "Spectral initialization not implemented".to_string(),
428                ));
429            }
430        }
431
432        Ok(means)
433    }
434
435    /// K-means++ initialization
436    fn kmeans_plus_plus_init(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
437        use scirs2_core::random::Random;
438        let mut init_rng = scirs2_core::random::thread_rng();
439        let mut rng = match self.config.seed {
440            Some(seed) => Random::seed(seed),
441            None => Random::seed(init_rng.random()),
442        };
443
444        let (n_samples_, n_features) = data.dim();
445        let mut means = Array2::zeros((self.n_components, n_features));
446
447        // Choose first center randomly
448        let first_idx = rng.random_range(0..n_samples_);
449        means.row_mut(0).assign(&data.row(first_idx));
450
451        // Choose remaining centers
452        for i in 1..self.n_components {
453            let mut distances = Array1::zeros(n_samples_);
454
455            for j in 0..n_samples_ {
456                let mut min_dist = F::infinity();
457                for k in 0..i {
458                    let dist = self.squared_distance(&data.row(j), &means.row(k));
459                    min_dist = min_dist.min(dist);
460                }
461                distances[j] = min_dist;
462            }
463
464            // Choose next center with probability proportional to squared distance
465            let total_dist: F = distances.sum();
466            let mut cumsum = F::zero();
467            let threshold: F = F::from(rng.random_f64()).expect("Operation failed") * total_dist;
468
469            for j in 0..n_samples_ {
470                cumsum = cumsum + distances[j];
471                if cumsum >= threshold {
472                    means.row_mut(i).assign(&data.row(j));
473                    break;
474                }
475            }
476        }
477
478        Ok(means)
479    }
480
481    /// Initialize covariances
482    fn initialize_covariances(
483        &self,
484        data: &ArrayView2<F>,
485        _means: &Array2<F>,
486    ) -> StatsResult<Vec<Array2<F>>> {
487        let n_features = data.ncols();
488        let mut covariances = Vec::with_capacity(self.n_components);
489
490        for _i in 0..self.n_components {
491            let cov = match self.config.covariance_type {
492                CovarianceType::Full => {
493                    // Initialize as identity with regularization
494                    Array2::eye(n_features)
495                        * F::from(self.config.reg_covar).expect("Failed to convert to float")
496                }
497                CovarianceType::Diagonal => {
498                    // Diagonal covariance
499                    Array2::eye(n_features)
500                        * F::from(self.config.reg_covar).expect("Failed to convert to float")
501                }
502                CovarianceType::Tied => {
503                    // All components share the same covariance
504                    Array2::eye(n_features)
505                        * F::from(self.config.reg_covar).expect("Failed to convert to float")
506                }
507                CovarianceType::Spherical => {
508                    // Isotropic covariance
509                    Array2::eye(n_features)
510                        * F::from(self.config.reg_covar).expect("Failed to convert to float")
511                }
512                CovarianceType::Factor { .. } => {
513                    // Factor analysis covariance - simplified to identity for now
514                    Array2::eye(n_features)
515                        * F::from(self.config.reg_covar).expect("Failed to convert to float")
516                }
517                CovarianceType::Constrained { .. } => {
518                    // Constrained covariance - simplified to identity for now
519                    Array2::eye(n_features)
520                        * F::from(self.config.reg_covar).expect("Failed to convert to float")
521                }
522            };
523            covariances.push(cov);
524        }
525
526        Ok(covariances)
527    }
528
529    /// E-step: compute responsibilities
530    fn e_step(
531        &self,
532        data: &ArrayView2<F>,
533        weights: &Array1<F>,
534        means: &Array2<F>,
535        covariances: &[Array2<F>],
536    ) -> StatsResult<Array2<F>> {
537        let n_samples_ = data.shape()[0];
538        let mut responsibilities = Array2::zeros((n_samples_, self.n_components));
539
540        for i in 0..n_samples_ {
541            let sample = data.row(i);
542            let mut log_probs = Array1::zeros(self.n_components);
543
544            // Compute log probabilities for each component
545            for k in 0..self.n_components {
546                let mean = means.row(k);
547                let log_prob = self.log_multivariate_normal_pdf(&sample, &mean, &covariances[k])?;
548                log_probs[k] = weights[k].ln() + log_prob;
549            }
550
551            // Normalize using log-sum-exp trick
552            let max_log_prob = log_probs.iter().copied().fold(F::neg_infinity(), F::max);
553            let log_sum_exp =
554                (log_probs.mapv(|x| (x - max_log_prob).exp()).sum()).ln() + max_log_prob;
555
556            for k in 0..self.n_components {
557                responsibilities[[i, k]] = (log_probs[k] - log_sum_exp).exp();
558            }
559        }
560
561        Ok(responsibilities)
562    }
563
564    /// M-step: update weights
565    fn m_step_weights(&self, responsibilities: &Array2<F>) -> StatsResult<Array1<F>> {
566        let n_samples_ = responsibilities.nrows();
567        let mut weights = Array1::zeros(self.n_components);
568
569        for k in 0..self.n_components {
570            weights[k] = responsibilities.column(k).sum()
571                / F::from(n_samples_).expect("Failed to convert to float");
572        }
573
574        Ok(weights)
575    }
576
577    /// M-step: update means
578    fn m_step_means(
579        &self,
580        data: &ArrayView2<F>,
581        responsibilities: &Array2<F>,
582    ) -> StatsResult<Array2<F>> {
583        let (_, n_features) = data.dim();
584        let mut means = Array2::zeros((self.n_components, n_features));
585
586        for k in 0..self.n_components {
587            let resp_sum = responsibilities.column(k).sum();
588
589            if resp_sum > F::from(1e-10).expect("Failed to convert constant to float") {
590                for j in 0..n_features {
591                    let weighted_sum = data
592                        .column(j)
593                        .iter()
594                        .zip(responsibilities.column(k).iter())
595                        .map(|(&x, &r)| x * r)
596                        .sum::<F>();
597                    means[[k, j]] = weighted_sum / resp_sum;
598                }
599            }
600        }
601
602        Ok(means)
603    }
604
605    /// M-step: update covariances
606    fn m_step_covariances(
607        &self,
608        data: &ArrayView2<F>,
609        responsibilities: &Array2<F>,
610        means: &Array2<F>,
611    ) -> StatsResult<Vec<Array2<F>>> {
612        let (n_samples_, n_features) = data.dim();
613        let mut covariances = Vec::with_capacity(self.n_components);
614
615        for k in 0..self.n_components {
616            let resp_sum = responsibilities.column(k).sum();
617            let mean_k = means.row(k);
618
619            let mut cov = Array2::zeros((n_features, n_features));
620
621            if resp_sum > F::from(1e-10).expect("Failed to convert constant to float") {
622                for i in 0..n_samples_ {
623                    let diff = &data.row(i) - &mean_k;
624                    let resp = responsibilities[[i, k]];
625
626                    for j in 0..n_features {
627                        for l in 0..n_features {
628                            cov[[j, l]] = cov[[j, l]] + resp * diff[j] * diff[l];
629                        }
630                    }
631                }
632
633                cov = cov / resp_sum;
634            }
635
636            // Add regularization
637            for i in 0..n_features {
638                cov[[i, i]] = cov[[i, i]]
639                    + F::from(self.config.reg_covar).expect("Failed to convert to float");
640            }
641
642            // Apply covariance type constraints
643            match self.config.covariance_type {
644                CovarianceType::Diagonal => {
645                    // Keep only diagonal elements
646                    for i in 0..n_features {
647                        for j in 0..n_features {
648                            if i != j {
649                                cov[[i, j]] = F::zero();
650                            }
651                        }
652                    }
653                }
654                CovarianceType::Spherical => {
655                    // Make isotropic
656                    let trace =
657                        cov.diag().sum() / F::from(n_features).expect("Failed to convert to float");
658                    cov = Array2::eye(n_features) * trace;
659                }
660                _ => {} // Full and Tied types keep the full covariance
661            }
662
663            covariances.push(cov);
664        }
665
666        Ok(covariances)
667    }
668
669    /// Compute multivariate normal log PDF
670    fn log_multivariate_normal_pdf(
671        &self,
672        x: &ArrayView1<F>,
673        mean: &ArrayView1<F>,
674        cov: &Array2<F>,
675    ) -> StatsResult<F> {
676        let d = x.len();
677        let diff = x - mean;
678
679        // Compute log determinant and inverse of covariance
680        let cov_f64 = cov.mapv(|x| x.to_f64().expect("Operation failed"));
681        let det = scirs2_linalg::det(&cov_f64.view(), None).map_err(|e| {
682            StatsError::ComputationError(format!("Determinant computation failed: {}", e))
683        })?;
684
685        if det <= 0.0 {
686            return Ok(F::neg_infinity());
687        }
688
689        let log_det = det.ln();
690        let cov_inv = scirs2_linalg::inv(&cov_f64.view(), None)
691            .map_err(|e| StatsError::ComputationError(format!("Matrix inversion failed: {}", e)))?;
692
693        // Compute quadratic form
694        let diff_f64 = diff.mapv(|x| x.to_f64().expect("Operation failed"));
695        let quad_form = diff_f64.dot(&cov_inv.dot(&diff_f64));
696
697        let log_pdf = -0.5 * (d as f64 * (2.0 * std::f64::consts::PI).ln() + log_det + quad_form);
698        Ok(F::from(log_pdf).expect("Failed to convert to float"))
699    }
700
701    /// Compute log-likelihood
702    fn compute_log_likelihood(
703        &self,
704        data: &ArrayView2<F>,
705        weights: &Array1<F>,
706        means: &Array2<F>,
707        covariances: &[Array2<F>],
708    ) -> StatsResult<F> {
709        let n_samples_ = data.nrows();
710        let mut total_log_likelihood = F::zero();
711
712        for i in 0..n_samples_ {
713            let sample = data.row(i);
714            let mut log_probs = Array1::zeros(self.n_components);
715
716            for k in 0..self.n_components {
717                let mean = means.row(k);
718                let log_prob = self.log_multivariate_normal_pdf(&sample, &mean, &covariances[k])?;
719                log_probs[k] = weights[k].ln() + log_prob;
720            }
721
722            // Log-sum-exp
723            let max_log_prob = log_probs.iter().copied().fold(F::neg_infinity(), F::max);
724            let log_sum_exp =
725                (log_probs.mapv(|x| (x - max_log_prob).exp()).sum()).ln() + max_log_prob;
726
727            total_log_likelihood = total_log_likelihood + log_sum_exp;
728        }
729
730        Ok(total_log_likelihood)
731    }
732
733    /// Compute squared Euclidean distance
734    fn squared_distance(&self, a: &ArrayView1<F>, b: &ArrayView1<F>) -> F {
735        a.iter()
736            .zip(b.iter())
737            .map(|(&x, &y)| (x - y) * (x - y))
738            .sum()
739    }
740
741    /// Predict cluster assignments
742    pub fn predict(&self, data: &ArrayView2<F>) -> StatsResult<Array1<usize>> {
743        let params = self.parameters.as_ref().ok_or_else(|| {
744            StatsError::InvalidArgument("Model must be fitted before prediction".to_string())
745        })?;
746
747        let responsibilities =
748            self.e_step(data, &params.weights, &params.means, &params.covariances)?;
749
750        let mut predictions = Array1::zeros(data.nrows());
751        for i in 0..data.nrows() {
752            let mut max_resp = F::neg_infinity();
753            let mut best_component = 0;
754
755            for k in 0..self.n_components {
756                if responsibilities[[i, k]] > max_resp {
757                    max_resp = responsibilities[[i, k]];
758                    best_component = k;
759                }
760            }
761
762            predictions[i] = best_component;
763        }
764
765        Ok(predictions)
766    }
767
768    /// Compute probability density for new data points
769    pub fn score_samples(&self, data: &ArrayView2<F>) -> StatsResult<Array1<F>> {
770        let params = self.parameters.as_ref().ok_or_else(|| {
771            StatsError::InvalidArgument("Model must be fitted before scoring".to_string())
772        })?;
773
774        let mut scores = Array1::zeros(data.nrows());
775
776        for i in 0..data.nrows() {
777            let sample = data.row(i);
778            let mut log_probs = Array1::zeros(self.n_components);
779
780            for k in 0..self.n_components {
781                let mean = params.means.row(k);
782                let log_prob =
783                    self.log_multivariate_normal_pdf(&sample, &mean, &params.covariances[k])?;
784                log_probs[k] = params.weights[k].ln() + log_prob;
785            }
786
787            // Log-sum-exp
788            let max_log_prob = log_probs.iter().copied().fold(F::neg_infinity(), F::max);
789            let log_sum_exp =
790                (log_probs.mapv(|x| (x - max_log_prob).exp()).sum()).ln() + max_log_prob;
791
792            scores[i] = log_sum_exp.exp();
793        }
794
795        Ok(scores)
796    }
797}
798
799/// Kernel Density Estimation
800pub struct KernelDensityEstimator<F> {
801    /// Kernel type
802    pub kernel: KernelType,
803    /// Bandwidth
804    pub bandwidth: F,
805    /// Configuration
806    pub config: KDEConfig,
807    /// Training data
808    pub trainingdata: Option<Array2<F>>,
809    _phantom: PhantomData<F>,
810}
811
812/// Kernel types for KDE
813#[derive(Debug, Clone, PartialEq)]
814pub enum KernelType {
815    /// Gaussian kernel
816    Gaussian,
817    /// Epanechnikov kernel
818    Epanechnikov,
819    /// Uniform kernel
820    Uniform,
821    /// Triangular kernel
822    Triangular,
823    /// Cosine kernel
824    Cosine,
825}
826
827/// KDE configuration
828#[derive(Debug, Clone)]
829pub struct KDEConfig {
830    /// Bandwidth selection method
831    pub bandwidth_method: BandwidthMethod,
832    /// Enable parallel processing
833    pub parallel: bool,
834    /// Use SIMD optimizations
835    pub use_simd: bool,
836}
837
838/// Bandwidth selection methods
839#[derive(Debug, Clone, PartialEq)]
840pub enum BandwidthMethod {
841    /// Fixed bandwidth (user-specified)
842    Fixed,
843    /// Scott's rule of thumb
844    Scott,
845    /// Silverman's rule of thumb
846    Silverman,
847    /// Cross-validation
848    CrossValidation,
849}
850
851impl Default for KDEConfig {
852    fn default() -> Self {
853        Self {
854            bandwidth_method: BandwidthMethod::Scott,
855            parallel: true,
856            use_simd: true,
857        }
858    }
859}
860
861impl<F> KernelDensityEstimator<F>
862where
863    F: Float
864        + Zero
865        + One
866        + Copy
867        + Send
868        + Sync
869        + SimdUnifiedOps
870        + FromPrimitive
871        + std::fmt::Display
872        + std::iter::Sum<F>
873        + scirs2_core::ndarray::ScalarOperand,
874{
875    /// Create new KDE
876    pub fn new(kernel: KernelType, bandwidth: F, config: KDEConfig) -> Self {
877        Self {
878            kernel,
879            bandwidth,
880            config,
881            trainingdata: None,
882            _phantom: PhantomData,
883        }
884    }
885
886    /// Fit KDE to data
887    pub fn fit(&mut self, data: &ArrayView2<F>) -> StatsResult<()> {
888        checkarray_finite(data, "data")?;
889
890        if data.is_empty() {
891            return Err(StatsError::InvalidArgument(
892                "Data cannot be empty".to_string(),
893            ));
894        }
895
896        // Update bandwidth if using automatic selection
897        if self.config.bandwidth_method != BandwidthMethod::Fixed {
898            self.bandwidth = self.select_bandwidth_scalar(data)?;
899        }
900
901        self.trainingdata = Some(data.to_owned());
902        Ok(())
903    }
904
905    /// Select scalar bandwidth automatically
906    fn select_bandwidth_scalar(&self, data: &ArrayView2<F>) -> StatsResult<F> {
907        let (n, d) = data.dim();
908
909        match self.config.bandwidth_method {
910            BandwidthMethod::Scott => {
911                // Scott's rule: h = n^(-1/(d+4))
912                let h = F::from(n as f64)
913                    .expect("Operation failed")
914                    .powf(F::from(-1.0 / (d as f64 + 4.0)).expect("Operation failed"));
915                Ok(h)
916            }
917            BandwidthMethod::Silverman => {
918                // Silverman's rule: h = (4/(d+2))^(1/(d+4)) * n^(-1/(d+4))
919                let factor = F::from(4.0 / (d as f64 + 2.0))
920                    .expect("Operation failed")
921                    .powf(F::from(1.0 / (d as f64 + 4.0)).expect("Operation failed"));
922                let n_factor = F::from(n as f64)
923                    .expect("Operation failed")
924                    .powf(F::from(-1.0 / (d as f64 + 4.0)).expect("Operation failed"));
925                Ok(factor * n_factor)
926            }
927            BandwidthMethod::CrossValidation => {
928                // Simplified cross-validation
929                self.cross_validation_bandwidth(data)
930            }
931            BandwidthMethod::Fixed => Ok(self.bandwidth),
932        }
933    }
934
935    /// Cross-validation bandwidth selection (simplified)
936    fn cross_validation_bandwidth(&self, data: &ArrayView2<F>) -> StatsResult<F> {
937        // Simplified implementation - full CV would try multiple bandwidths
938        let (n, d) = data.dim();
939        let h = F::from(n as f64)
940            .expect("Operation failed")
941            .powf(F::from(-1.0 / (d as f64 + 4.0)).expect("Operation failed"));
942        Ok(h)
943    }
944
945    /// Evaluate density at given points
946    pub fn score_samples(&self, points: &ArrayView2<F>) -> StatsResult<Array1<F>> {
947        let trainingdata = self.trainingdata.as_ref().ok_or_else(|| {
948            StatsError::InvalidArgument("KDE must be fitted before evaluation".to_string())
949        })?;
950
951        checkarray_finite(points, "points")?;
952
953        if points.ncols() != trainingdata.ncols() {
954            return Err(StatsError::DimensionMismatch(format!(
955                "Points dimension ({}) must match training data dimension ({})",
956                points.ncols(),
957                trainingdata.ncols()
958            )));
959        }
960
961        let n_points = points.nrows();
962        let n_train = trainingdata.nrows();
963        let mut densities = Array1::zeros(n_points);
964
965        for i in 0..n_points {
966            let point = points.row(i);
967            let mut density = F::zero();
968
969            for j in 0..n_train {
970                let train_point = trainingdata.row(j);
971                let distance = self.compute_distance(&point, &train_point);
972                let bandwidth_val = self.bandwidth;
973                let kernel_value = self.evaluate_kernel(distance / bandwidth_val);
974                density = density + kernel_value;
975            }
976
977            // Normalize
978            let bandwidth_val = self.bandwidth;
979            let normalization = F::from(n_train as f64).expect("Failed to convert to float")
980                * bandwidth_val.powf(F::from(trainingdata.ncols()).expect("Operation failed"));
981            densities[i] = density / normalization;
982        }
983
984        Ok(densities)
985    }
986
987    /// Compute distance between two points
988    fn compute_distance(&self, a: &ArrayView1<F>, b: &ArrayView1<F>) -> F {
989        // Euclidean distance
990        a.iter()
991            .zip(b.iter())
992            .map(|(&x, &y)| (x - y) * (x - y))
993            .sum::<F>()
994            .sqrt()
995    }
996
997    /// Evaluate kernel function
998    fn evaluate_kernel(&self, u: F) -> F {
999        match self.kernel {
1000            KernelType::Gaussian => {
1001                // (2π)^(-1/2) * exp(-u²/2)
1002                let coeff =
1003                    F::from(1.0 / (2.0 * std::f64::consts::PI).sqrt()).expect("Operation failed");
1004                coeff * (-u * u / F::from(2.0).expect("Failed to convert constant to float")).exp()
1005            }
1006            KernelType::Epanechnikov => {
1007                // (3/4) * (1 - u²) for |u| ≤ 1
1008                if u.abs() <= F::one() {
1009                    F::from(0.75).expect("Failed to convert constant to float") * (F::one() - u * u)
1010                } else {
1011                    F::zero()
1012                }
1013            }
1014            KernelType::Uniform => {
1015                // 1/2 for |u| ≤ 1
1016                if u.abs() <= F::one() {
1017                    F::from(0.5).expect("Failed to convert constant to float")
1018                } else {
1019                    F::zero()
1020                }
1021            }
1022            KernelType::Triangular => {
1023                // (1 - |u|) for |u| ≤ 1
1024                if u.abs() <= F::one() {
1025                    F::one() - u.abs()
1026                } else {
1027                    F::zero()
1028                }
1029            }
1030            KernelType::Cosine => {
1031                // (π/4) * cos(πu/2) for |u| ≤ 1
1032                if u.abs() <= F::one() {
1033                    let coeff =
1034                        F::from(std::f64::consts::PI / 4.0).expect("Failed to convert to float");
1035                    let arg = F::from(std::f64::consts::PI).expect("Failed to convert to float")
1036                        * u
1037                        / F::from(2.0).expect("Failed to convert constant to float");
1038                    coeff * arg.cos()
1039                } else {
1040                    F::zero()
1041                }
1042            }
1043        }
1044    }
1045}
1046
1047/// Convenience functions
1048#[allow(dead_code)]
1049pub fn gaussian_mixture_model<F>(
1050    data: &ArrayView2<F>,
1051    n_components: usize,
1052    config: Option<GMMConfig>,
1053) -> StatsResult<GMMParameters<F>>
1054where
1055    F: Float
1056        + Zero
1057        + One
1058        + Copy
1059        + Send
1060        + Sync
1061        + SimdUnifiedOps
1062        + FromPrimitive
1063        + std::fmt::Display
1064        + std::iter::Sum<F>
1065        + scirs2_core::ndarray::ScalarOperand,
1066{
1067    let config = config.unwrap_or_default();
1068    let mut gmm = GaussianMixtureModel::new(n_components, config)?;
1069    Ok(gmm.fit(data)?.clone())
1070}
1071
1072#[allow(dead_code)]
1073pub fn kernel_density_estimation<F>(
1074    data: &ArrayView2<F>,
1075    points: &ArrayView2<F>,
1076    kernel: Option<KernelType>,
1077    bandwidth: Option<F>,
1078) -> StatsResult<Array1<F>>
1079where
1080    F: Float
1081        + Zero
1082        + One
1083        + Copy
1084        + Send
1085        + Sync
1086        + SimdUnifiedOps
1087        + FromPrimitive
1088        + std::fmt::Display
1089        + std::iter::Sum<F>
1090        + scirs2_core::ndarray::ScalarOperand,
1091{
1092    let kernel = kernel.unwrap_or(KernelType::Gaussian);
1093    let bandwidth = bandwidth.unwrap_or_else(|| {
1094        // Use Scott's rule as default
1095        let n = data.nrows();
1096        let d = data.ncols();
1097        F::from(n as f64)
1098            .expect("Operation failed")
1099            .powf(F::from(-1.0 / (d as f64 + 4.0)).expect("Operation failed"))
1100    });
1101
1102    let mut kde = KernelDensityEstimator::new(kernel, bandwidth, KDEConfig::default());
1103    kde.fit(data)?;
1104    kde.score_samples(points)
1105}
1106
1107/// Advanced model selection for GMM
1108#[allow(dead_code)]
1109pub fn gmm_model_selection<F>(
1110    data: &ArrayView2<F>,
1111    min_components: usize,
1112    max_components: usize,
1113    config: Option<GMMConfig>,
1114) -> StatsResult<(usize, GMMParameters<F>)>
1115where
1116    F: Float
1117        + Zero
1118        + One
1119        + Copy
1120        + Send
1121        + Sync
1122        + SimdUnifiedOps
1123        + FromPrimitive
1124        + std::fmt::Display
1125        + std::iter::Sum<F>
1126        + scirs2_core::ndarray::ScalarOperand,
1127{
1128    let config = config.unwrap_or_default();
1129    let mut best_n_components = min_components;
1130    let mut best_bic = F::infinity();
1131    let mut best_params: Option<GMMParameters<F>> = None;
1132
1133    for n_comp in min_components..=max_components {
1134        let mut gmm = GaussianMixtureModel::new(n_comp, config.clone())?;
1135        let params = gmm.fit(data)?;
1136
1137        if params.model_selection.bic < best_bic {
1138            best_bic = params.model_selection.bic;
1139            best_n_components = n_comp;
1140            best_params = Some(params.clone());
1141        }
1142    }
1143
1144    Ok((best_n_components, best_params.expect("Operation failed")))
1145}
1146
1147/// Robust Gaussian Mixture Model with outlier detection
1148pub struct RobustGMM<F> {
1149    /// Base GMM
1150    pub gmm: GaussianMixtureModel<F>,
1151    /// Outlier detection threshold
1152    pub outlier_threshold: F,
1153    /// Contamination rate (expected fraction of outliers)
1154    pub contamination: F,
1155    _phantom: PhantomData<F>,
1156}
1157
1158impl<F> RobustGMM<F>
1159where
1160    F: Float
1161        + Zero
1162        + One
1163        + Copy
1164        + Send
1165        + Sync
1166        + SimdUnifiedOps
1167        + FromPrimitive
1168        + std::fmt::Display
1169        + std::iter::Sum<F>
1170        + scirs2_core::ndarray::ScalarOperand,
1171{
1172    /// Create new Robust GMM
1173    pub fn new(
1174        n_components: usize,
1175        outlier_threshold: F,
1176        contamination: F,
1177        mut config: GMMConfig,
1178    ) -> StatsResult<Self> {
1179        // Enable robust EM
1180        config.robust_em = true;
1181        config.outlier_threshold = outlier_threshold.to_f64().expect("Operation failed");
1182
1183        let gmm = GaussianMixtureModel::new(n_components, config)?;
1184
1185        Ok(Self {
1186            gmm,
1187            outlier_threshold,
1188            contamination,
1189            _phantom: PhantomData,
1190        })
1191    }
1192
1193    /// Fit robust GMM with outlier detection
1194    pub fn fit(&mut self, data: &ArrayView2<F>) -> StatsResult<&GMMParameters<F>> {
1195        // First fit the regular GMM
1196        self.gmm.fit(data)?;
1197
1198        // Compute outlier scores for robust EM
1199        let outlier_scores = self.compute_outlier_scores(data)?;
1200
1201        // Update parameters with outlier scores
1202        if let Some(ref mut params) = self.gmm.parameters {
1203            params.outlier_scores = Some(outlier_scores);
1204        }
1205
1206        Ok(self.gmm.parameters.as_ref().expect("Operation failed"))
1207    }
1208
1209    /// Compute outlier scores based on negative log-likelihood
1210    fn compute_outlier_scores(&self, data: &ArrayView2<F>) -> StatsResult<Array1<F>> {
1211        let params =
1212            self.gmm.parameters.as_ref().ok_or_else(|| {
1213                StatsError::InvalidArgument("Model must be fitted first".to_string())
1214            })?;
1215
1216        let (n_samples, _) = data.dim();
1217        let mut outlier_scores = Array1::zeros(n_samples);
1218
1219        for (i, sample) in data.rows().into_iter().enumerate() {
1220            // Compute log-likelihood for this sample under the fitted model
1221            let mut log_likelihood = F::neg_infinity();
1222
1223            for j in 0..self.gmm.n_components {
1224                let weight = params.weights[j];
1225                let mean = params.means.row(j);
1226                let cov = &params.covariances[j];
1227
1228                // Compute log probability density for this component
1229                let log_prob = self.log_multivariate_normal_pdf(&sample, &mean, cov)?;
1230                let weighted_log_prob = weight.ln() + log_prob;
1231
1232                // Log-sum-exp to combine components
1233                if log_likelihood == F::neg_infinity() {
1234                    log_likelihood = weighted_log_prob;
1235                } else {
1236                    log_likelihood =
1237                        log_likelihood + (weighted_log_prob - log_likelihood).exp().ln_1p();
1238                }
1239            }
1240
1241            // Outlier score is negative log-likelihood (higher = more outlier-like)
1242            outlier_scores[i] = -log_likelihood;
1243        }
1244
1245        Ok(outlier_scores)
1246    }
1247
1248    /// Compute log probability density function for multivariate normal distribution
1249    fn log_multivariate_normal_pdf(
1250        &self,
1251        x: &ArrayView1<F>,
1252        mean: &ArrayView1<F>,
1253        cov: &Array2<F>,
1254    ) -> StatsResult<F> {
1255        let diff = x - mean;
1256        let k = F::from(x.len()).expect("Operation failed");
1257
1258        // Simple case: assume diagonal covariance for numerical stability
1259        let mut log_prob = F::zero();
1260        let pi = F::from(std::f64::consts::PI).expect("Failed to convert to float");
1261
1262        for i in 0..x.len() {
1263            let variance = cov[[i, i]];
1264            if variance <= F::zero() {
1265                return Err(StatsError::ComputationError(
1266                    "Invalid covariance".to_string(),
1267                ));
1268            }
1269
1270            let term = diff[i] * diff[i] / variance;
1271            log_prob = log_prob
1272                - (F::from(0.5).expect("Failed to convert constant to float") * term)
1273                - (F::from(0.5).expect("Failed to convert constant to float") * variance.ln())
1274                - (F::from(0.5).expect("Failed to convert constant to float")
1275                    * (F::from(2.0).expect("Failed to convert constant to float") * pi).ln());
1276        }
1277
1278        Ok(log_prob)
1279    }
1280
1281    /// Detect outliers in data
1282    pub fn detect_outliers(&self, data: &ArrayView2<F>) -> StatsResult<Array1<bool>> {
1283        let params = self.gmm.parameters.as_ref().ok_or_else(|| {
1284            StatsError::InvalidArgument("Model must be fitted before outlier detection".to_string())
1285        })?;
1286
1287        let outlier_scores = params.outlier_scores.as_ref().ok_or_else(|| {
1288            StatsError::InvalidArgument(
1289                "Robust EM must be enabled for outlier detection".to_string(),
1290            )
1291        })?;
1292
1293        // Determine outlier threshold based on contamination rate
1294        let mut sorted_scores = outlier_scores.to_owned();
1295        sorted_scores
1296            .as_slice_mut()
1297            .expect("Operation failed")
1298            .sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
1299
1300        let threshold_idx = ((F::one() - self.contamination)
1301            * F::from(sorted_scores.len()).expect("Operation failed"))
1302        .to_usize()
1303        .expect("Operation failed")
1304        .min(sorted_scores.len() - 1);
1305        let adaptive_threshold = sorted_scores[threshold_idx];
1306
1307        let outliers = outlier_scores.mapv(|score| score > adaptive_threshold);
1308        Ok(outliers)
1309    }
1310}
1311
1312/// Streaming/Online Gaussian Mixture Model
1313pub struct StreamingGMM<F> {
1314    /// Base GMM
1315    pub gmm: GaussianMixtureModel<F>,
1316    /// Learning rate for online updates
1317    pub learning_rate: F,
1318    /// Decay factor for old data
1319    pub decay_factor: F,
1320    /// Number of samples processed
1321    pub n_samples_seen: usize,
1322    /// Running statistics
1323    pub running_means: Option<Array2<F>>,
1324    pub running_covariances: Option<Vec<Array2<F>>>,
1325    pub running_weights: Option<Array1<F>>,
1326    _phantom: PhantomData<F>,
1327}
1328
1329impl<F> StreamingGMM<F>
1330where
1331    F: Float
1332        + Zero
1333        + One
1334        + Copy
1335        + Send
1336        + Sync
1337        + SimdUnifiedOps
1338        + FromPrimitive
1339        + std::fmt::Display
1340        + std::iter::Sum<F>
1341        + scirs2_core::ndarray::ScalarOperand,
1342{
1343    /// Create new Streaming GMM
1344    pub fn new(
1345        n_components: usize,
1346        learning_rate: F,
1347        decay_factor: F,
1348        config: GMMConfig,
1349    ) -> StatsResult<Self> {
1350        let gmm = GaussianMixtureModel::new(n_components, config)?;
1351
1352        Ok(Self {
1353            gmm,
1354            learning_rate,
1355            decay_factor,
1356            n_samples_seen: 0,
1357            running_means: None,
1358            running_covariances: None,
1359            running_weights: None,
1360            _phantom: PhantomData,
1361        })
1362    }
1363
1364    /// Update model with new batch of data
1365    pub fn partial_fit(&mut self, batch: &ArrayView2<F>) -> StatsResult<()> {
1366        let batchsize = batch.nrows();
1367
1368        if self.n_samples_seen == 0 {
1369            // Initialize with first batch
1370            self.gmm.fit(batch)?;
1371            let params = self.gmm.parameters.as_ref().expect("Operation failed");
1372            self.running_means = Some(params.means.clone());
1373            self.running_covariances = Some(params.covariances.clone());
1374            self.running_weights = Some(params.weights.clone());
1375        } else {
1376            // Online update
1377            self.online_update(batch)?;
1378        }
1379
1380        self.n_samples_seen += batchsize;
1381        Ok(())
1382    }
1383
1384    /// Perform online parameter update
1385    fn online_update(&mut self, batch: &ArrayView2<F>) -> StatsResult<()> {
1386        let params = self.gmm.parameters.as_ref().expect("Operation failed");
1387
1388        // E-step on new batch
1389        let responsibilities =
1390            self.gmm
1391                .e_step(batch, &params.weights, &params.means, &params.covariances)?;
1392
1393        // Compute batch statistics
1394        let batch_weights = self.gmm.m_step_weights(&responsibilities)?;
1395        let batch_means = self.gmm.m_step_means(batch, &responsibilities)?;
1396
1397        // Update running statistics with exponential decay
1398        let lr = self.learning_rate;
1399        let decay = self.decay_factor;
1400
1401        if let (Some(ref mut r_weights), Some(ref mut r_means)) =
1402            (&mut self.running_weights, &mut self.running_means)
1403        {
1404            // Update weights: w = decay * w_old + lr * w_batch
1405            *r_weights = r_weights.mapv(|x| x * decay) + batch_weights.mapv(|x| x * lr);
1406
1407            // Normalize weights
1408            let weight_sum = r_weights.sum();
1409            if weight_sum > F::zero() {
1410                *r_weights = r_weights.mapv(|x| x / weight_sum);
1411            }
1412
1413            // Update means
1414            *r_means = r_means.mapv(|x| x * decay) + batch_means.mapv(|x| x * lr);
1415        }
1416
1417        // Update model parameters
1418        if let Some(params_mut) = &mut self.gmm.parameters {
1419            params_mut.weights = self
1420                .running_weights
1421                .as_ref()
1422                .expect("Operation failed")
1423                .clone();
1424            params_mut.means = self
1425                .running_means
1426                .as_ref()
1427                .expect("Operation failed")
1428                .clone();
1429        }
1430
1431        Ok(())
1432    }
1433
1434    /// Get current model parameters
1435    pub fn get_parameters(&self) -> Option<&GMMParameters<F>> {
1436        self.gmm.parameters.as_ref()
1437    }
1438}
1439
1440/// Hierarchical clustering-based mixture model initialization
1441#[allow(dead_code)]
1442pub fn hierarchical_gmm_init<F>(
1443    data: &ArrayView2<F>,
1444    n_components: usize,
1445    config: GMMConfig,
1446) -> StatsResult<GMMParameters<F>>
1447where
1448    F: Float
1449        + Zero
1450        + One
1451        + Copy
1452        + Send
1453        + Sync
1454        + SimdUnifiedOps
1455        + FromPrimitive
1456        + std::fmt::Display
1457        + std::iter::Sum<F>
1458        + scirs2_core::ndarray::ScalarOperand,
1459{
1460    // Simplified hierarchical clustering for initialization
1461    // Full implementation would use proper hierarchical clustering
1462
1463    let mut init_config = config.clone();
1464    init_config.init_method = InitializationMethod::FurthestFirst;
1465
1466    gaussian_mixture_model(data, n_components, Some(init_config))
1467}
1468
1469/// Cross-validation for GMM hyperparameter tuning
1470#[allow(dead_code)]
1471pub fn gmm_cross_validation<F>(
1472    data: &ArrayView2<F>,
1473    n_components: usize,
1474    n_folds: usize,
1475    config: GMMConfig,
1476) -> StatsResult<F>
1477where
1478    F: Float
1479        + Zero
1480        + One
1481        + Copy
1482        + Send
1483        + Sync
1484        + SimdUnifiedOps
1485        + FromPrimitive
1486        + std::fmt::Display
1487        + std::iter::Sum<F>
1488        + scirs2_core::ndarray::ScalarOperand,
1489{
1490    let (n_samples_, _) = data.dim();
1491    let foldsize = n_samples_ / n_folds;
1492    let mut cv_scores = Vec::with_capacity(n_folds);
1493
1494    for fold in 0..n_folds {
1495        let val_start = fold * foldsize;
1496        let val_end = if fold == n_folds - 1 {
1497            n_samples_
1498        } else {
1499            (fold + 1) * foldsize
1500        };
1501
1502        // Create training and validation sets
1503        let mut train_indices = Vec::new();
1504        for i in 0..n_samples_ {
1505            if i < val_start || i >= val_end {
1506                train_indices.push(i);
1507            }
1508        }
1509
1510        let traindata = Array2::from_shape_fn((train_indices.len(), data.ncols()), |(i, j)| {
1511            data[[train_indices[i], j]]
1512        });
1513
1514        let valdata = data.slice(scirs2_core::ndarray::s![val_start..val_end, ..]);
1515
1516        // Fit model on training data
1517        let mut gmm = GaussianMixtureModel::new(n_components, config.clone())?;
1518        let params = gmm.fit(&traindata.view())?.clone();
1519
1520        // Evaluate on validation data
1521        let val_likelihood = gmm.compute_log_likelihood(
1522            &valdata,
1523            &params.weights,
1524            &params.means,
1525            &params.covariances,
1526        )?;
1527
1528        cv_scores.push(val_likelihood);
1529    }
1530
1531    // Return average CV score
1532    let avg_score =
1533        cv_scores.iter().copied().sum::<F>() / F::from(cv_scores.len()).expect("Operation failed");
1534    Ok(avg_score)
1535}
1536
1537/// Performance benchmarking for mixture models
1538#[allow(dead_code)]
1539pub fn benchmark_mixture_models<F>(
1540    data: &ArrayView2<F>,
1541    methods: &[(
1542        &str,
1543        Box<dyn Fn(&ArrayView2<F>) -> StatsResult<GMMParameters<F>>>,
1544    )],
1545) -> StatsResult<Vec<(String, std::time::Duration, F)>>
1546where
1547    F: Float
1548        + Zero
1549        + One
1550        + Copy
1551        + Send
1552        + Sync
1553        + SimdUnifiedOps
1554        + FromPrimitive
1555        + std::fmt::Display
1556        + std::iter::Sum<F>
1557        + scirs2_core::ndarray::ScalarOperand,
1558{
1559    let mut results = Vec::new();
1560
1561    for (name, method) in methods {
1562        let start_time = std::time::Instant::now();
1563        let params = method(data)?;
1564        let duration = start_time.elapsed();
1565
1566        results.push((name.to_string(), duration, params.log_likelihood));
1567    }
1568
1569    Ok(results)
1570}
1571
1572#[cfg(test)]
1573#[path = "mixture_models_tests.rs"]
1574mod tests;