Skip to main content

scirs2_stats/multivariate/
discriminant_analysis.rs

1//! Discriminant Analysis
2//!
3//! This module provides implementations of Linear Discriminant Analysis (LDA) and
4//! Quadratic Discriminant Analysis (QDA) for classification and dimensionality reduction.
5
6use crate::error::{StatsError, StatsResult as Result};
7use crate::error_handling_v2::ErrorCode;
8use crate::{unified_error_handling::global_error_handler, validate_or_error};
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
10
11/// Linear Discriminant Analysis (LDA)
12///
13/// LDA is a dimensionality reduction technique that finds linear combinations of features
14/// that best separate different classes. It assumes that all classes have the same
15/// covariance structure.
16#[derive(Debug, Clone)]
17pub struct LinearDiscriminantAnalysis {
18    /// Solver type for eigenvalue decomposition
19    pub solver: LDASolver,
20    /// Whether to shrink the covariance estimate
21    pub shrinkage: Option<f64>,
22    /// Number of components to keep (None = automatic)
23    pub n_components: Option<usize>,
24    /// Prior probabilities for each class (None = empirical)
25    pub priors: Option<Array1<f64>>,
26    /// Store training fit results
27    pub store_covariance: bool,
28}
29
30/// Solver methods for LDA
31#[derive(Debug, Clone, Copy, PartialEq)]
32pub enum LDASolver {
33    /// SVD-based solver (most stable)
34    Svd,
35    /// Eigenvalue decomposition (faster for small problems)
36    Eigen,
37}
38
39/// Result of Linear Discriminant Analysis
40#[derive(Debug, Clone)]
41pub struct LDAResult {
42    /// Linear discriminant coefficients (scalings)
43    pub scalings: Array2<f64>,
44    /// Intercepts for each class
45    pub intercept: Array1<f64>,
46    /// Pooled covariance matrix
47    pub covariance: Option<Array2<f64>>,
48    /// Class means
49    pub means: Array2<f64>,
50    /// Prior probabilities for each class
51    pub priors: Array1<f64>,
52    /// Class labels
53    pub classes: Array1<i32>,
54    /// Explained variance ratio for each component
55    pub explained_variance_ratio: Array1<f64>,
56    /// Number of features used for training
57    pub n_features: usize,
58}
59
60impl Default for LinearDiscriminantAnalysis {
61    fn default() -> Self {
62        Self {
63            solver: LDASolver::Svd,
64            shrinkage: None,
65            n_components: None,
66            priors: None,
67            store_covariance: true,
68        }
69    }
70}
71
72impl LinearDiscriminantAnalysis {
73    /// Create a new LDA instance
74    pub fn new() -> Self {
75        Self::default()
76    }
77
78    /// Set the solver type
79    pub fn with_solver(mut self, solver: LDASolver) -> Self {
80        self.solver = solver;
81        self
82    }
83
84    /// Set shrinkage parameter for covariance regularization
85    pub fn with_shrinkage(mut self, shrinkage: f64) -> Self {
86        self.shrinkage = Some(shrinkage);
87        self
88    }
89
90    /// Set number of components to keep
91    pub fn with_n_components(mut self, n_components: usize) -> Self {
92        self.n_components = Some(n_components);
93        self
94    }
95
96    /// Set prior probabilities
97    pub fn with_priors(mut self, priors: Array1<f64>) -> Self {
98        self.priors = Some(priors);
99        self
100    }
101
102    /// Set whether to store covariance matrix
103    pub fn with_store_covariance(mut self, store: bool) -> Self {
104        self.store_covariance = store;
105        self
106    }
107
108    /// Fit the LDA model
109    pub fn fit(&self, x: ArrayView2<f64>, y: ArrayView1<i32>) -> Result<LDAResult> {
110        let handler = global_error_handler();
111        validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "LDA fit");
112
113        let (n_samples, n_features) = x.dim();
114        let n_targets = y.len();
115
116        if n_samples != n_targets {
117            return Err(handler
118                .create_validation_error(
119                    ErrorCode::E2001,
120                    "LDA fit",
121                    "samplesize_mismatch",
122                    format!("x: {}, y: {}", n_samples, n_targets),
123                    "Number of samples in X and y must be equal",
124                )
125                .error);
126        }
127
128        if n_samples < 2 {
129            return Err(handler
130                .create_validation_error(
131                    ErrorCode::E2003,
132                    "LDA fit",
133                    "n_samples",
134                    n_samples,
135                    "LDA requires at least 2 samples",
136                )
137                .error);
138        }
139
140        // Get unique classes and validate
141        let unique_classes = self.get_unique_classes(y)?;
142        let n_classes = unique_classes.len();
143
144        if n_classes < 2 {
145            return Err(handler
146                .create_validation_error(
147                    ErrorCode::E1001,
148                    "LDA fit",
149                    "n_classes",
150                    n_classes,
151                    "LDA requires at least 2 classes",
152                )
153                .error);
154        }
155
156        if n_features >= n_samples && self.solver == LDASolver::Eigen {
157            return Err(handler
158                .create_error(
159                    ErrorCode::E1001,
160                    "LDA fit",
161                    "Use SVD solver when n_features >= n_samples for numerical stability",
162                )
163                .error);
164        }
165
166        // Compute class statistics
167        let (class_means, class_priors, class_counts) =
168            self.compute_class_statistics(x, y, &unique_classes)?;
169
170        // Compute within-class and between-class scatter matrices
171        let (sw, sb) = self.compute_scatter_matrices(x, y, &unique_classes, &class_means)?;
172
173        // Apply shrinkage if specified
174        let sw_regularized = if let Some(shrinkage) = self.shrinkage {
175            self.apply_shrinkage(&sw, shrinkage)?
176        } else {
177            sw
178        };
179
180        // Solve generalized eigenvalue problem
181        let (scalings, explained_variance_ratio) =
182            self.solve_eigenvalue_problem(&sw_regularized, &sb)?;
183
184        // Limit number of components
185        let n_components = self
186            .n_components
187            .unwrap_or(n_classes - 1)
188            .min(n_classes - 1)
189            .min(n_features);
190
191        let final_scalings = scalings
192            .slice(scirs2_core::ndarray::s![.., ..n_components])
193            .to_owned();
194        let final_explained_variance = explained_variance_ratio
195            .slice(scirs2_core::ndarray::s![..n_components])
196            .to_owned();
197
198        // Compute intercept
199        let intercept = self.compute_intercept(&class_means, &final_scalings, &class_priors)?;
200
201        Ok(LDAResult {
202            scalings: final_scalings,
203            intercept,
204            covariance: if self.store_covariance {
205                Some(sw_regularized)
206            } else {
207                None
208            },
209            means: class_means,
210            priors: class_priors,
211            classes: unique_classes,
212            explained_variance_ratio: final_explained_variance,
213            n_features,
214        })
215    }
216
217    /// Get unique classes from target array
218    fn get_unique_classes(&self, y: ArrayView1<i32>) -> Result<Array1<i32>> {
219        let mut classes = y.to_vec();
220        classes.sort_unstable();
221        classes.dedup();
222        Ok(Array1::from_vec(classes))
223    }
224
225    /// Compute class means, priors, and counts
226    fn compute_class_statistics(
227        &self,
228        x: ArrayView2<f64>,
229        y: ArrayView1<i32>,
230        classes: &Array1<i32>,
231    ) -> Result<(Array2<f64>, Array1<f64>, Array1<usize>)> {
232        let (n_samples, n_features) = x.dim();
233        let n_classes = classes.len();
234
235        let mut class_means = Array2::zeros((n_classes, n_features));
236        let mut class_counts = Array1::zeros(n_classes);
237
238        // Compute class means and counts
239        for (i, &class_label) in classes.iter().enumerate() {
240            let class_indices: Vec<_> = y
241                .iter()
242                .enumerate()
243                .filter(|(_, &label)| label == class_label)
244                .map(|(idx, _)| idx)
245                .collect();
246
247            if class_indices.is_empty() {
248                return Err(StatsError::InvalidArgument(format!(
249                    "Class {} has no samples",
250                    class_label
251                )));
252            }
253
254            class_counts[i] = class_indices.len();
255
256            // Compute mean for this class
257            let mut sum = Array1::zeros(n_features);
258            for &idx in &class_indices {
259                sum += &x.row(idx);
260            }
261            class_means
262                .row_mut(i)
263                .assign(&(sum / class_indices.len() as f64));
264        }
265
266        // Compute priors
267        let class_priors = if let Some(ref priors) = self.priors {
268            if priors.len() != n_classes {
269                return Err(StatsError::InvalidArgument(format!(
270                    "Priors length ({}) must equal number of classes ({})",
271                    priors.len(),
272                    n_classes
273                )));
274            }
275            priors.clone()
276        } else {
277            // Empirical priors
278            class_counts.mapv(|count| count as f64 / n_samples as f64)
279        };
280
281        Ok((class_means, class_priors, class_counts.mapv(|x| x)))
282    }
283
284    /// Compute within-class (Sw) and between-class (Sb) scatter matrices
285    fn compute_scatter_matrices(
286        &self,
287        x: ArrayView2<f64>,
288        y: ArrayView1<i32>,
289        classes: &Array1<i32>,
290        class_means: &Array2<f64>,
291    ) -> Result<(Array2<f64>, Array2<f64>)> {
292        let (_n_samples, n_features) = x.dim();
293        let _n_classes = classes.len();
294
295        // Overall mean
296        let overall_mean = x.mean_axis(Axis(0)).expect("Operation failed");
297
298        // Initialize scatter matrices
299        let mut sw = Array2::zeros((n_features, n_features));
300        let mut sb = Array2::zeros((n_features, n_features));
301
302        // Compute within-class scatter
303        for (class_idx, &class_label) in classes.iter().enumerate() {
304            let class_mean = class_means.row(class_idx);
305
306            for (sample_idx, &sample_label) in y.iter().enumerate() {
307                if sample_label == class_label {
308                    let sample = x.row(sample_idx);
309                    let diff = &sample - &class_mean;
310
311                    // Outer product: diff^T * diff
312                    for i in 0..n_features {
313                        for j in 0..n_features {
314                            sw[[i, j]] += diff[i] * diff[j];
315                        }
316                    }
317                }
318            }
319        }
320
321        // Compute between-class scatter
322        for (class_idx, _) in classes.iter().enumerate() {
323            let class_mean = class_means.row(class_idx);
324            let class_count = y
325                .iter()
326                .filter(|&&label| label == classes[class_idx])
327                .count() as f64;
328            let diff = &class_mean - &overall_mean;
329
330            // Weighted outer product
331            for i in 0..n_features {
332                for j in 0..n_features {
333                    sb[[i, j]] += class_count * diff[i] * diff[j];
334                }
335            }
336        }
337
338        Ok((sw, sb))
339    }
340
341    /// Apply shrinkage regularization to covariance matrix
342    fn apply_shrinkage(&self, sw: &Array2<f64>, shrinkage: f64) -> Result<Array2<f64>> {
343        let n_features = sw.nrows();
344        let trace = (0..n_features).map(|i| sw[[i, i]]).sum::<f64>();
345        let scaled_identity = Array2::eye(n_features) * (trace / n_features as f64);
346
347        Ok((1.0 - shrinkage) * sw + shrinkage * scaled_identity)
348    }
349
350    /// Solve the generalized eigenvalue problem Sb * v = λ * Sw * v
351    fn solve_eigenvalue_problem(
352        &self,
353        sw: &Array2<f64>,
354        sb: &Array2<f64>,
355    ) -> Result<(Array2<f64>, Array1<f64>)> {
356        match self.solver {
357            LDASolver::Svd => self.solve_svd(sw, sb),
358            LDASolver::Eigen => self.solve_eigen(sw, sb),
359        }
360    }
361
362    /// SVD-based solver (more numerically stable)
363    fn solve_svd(&self, sw: &Array2<f64>, sb: &Array2<f64>) -> Result<(Array2<f64>, Array1<f64>)> {
364        // Cholesky decomposition of Sw = L * L^T
365        let l = scirs2_linalg::cholesky(&sw.view(), None).map_err(|e| {
366            StatsError::ComputationError(format!(
367                "Cholesky decomposition failed: {}. Try using shrinkage.",
368                e
369            ))
370        })?;
371
372        // Solve L * M = Sb for M
373        let l_inv = scirs2_linalg::inv(&l.view(), None).map_err(|e| {
374            StatsError::ComputationError(format!("Failed to invert Cholesky factor: {}", e))
375        })?;
376
377        let m = l_inv.dot(sb).dot(&l_inv.t());
378
379        // SVD of M using scirs2_linalg
380        let (u, s, _vt) = scirs2_linalg::svd(&m.view(), true, None)
381            .map_err(|e| StatsError::ComputationError(format!("SVD failed: {}", e)))?;
382
383        // Transform back: scalings = L^{-T} * U
384        let scalings = l_inv.t().dot(&u);
385
386        // Sort by eigenvalues (singular values in descending order)
387        let mut eigen_pairs: Vec<_> = s.iter().cloned().zip(scalings.columns()).collect();
388        eigen_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("Operation failed"));
389
390        let eigenvalues: Vec<f64> = eigen_pairs.iter().map(|(val_, _)| *val_).collect();
391        let eigenvectors: Array2<f64> = Array2::from_shape_vec(
392            (scalings.nrows(), eigenvalues.len()),
393            eigen_pairs
394                .iter()
395                .flat_map(|(_, vec)| vec.iter().cloned())
396                .collect(),
397        )
398        .map_err(|e| {
399            StatsError::ComputationError(format!("Failed to construct eigenvector matrix: {}", e))
400        })?;
401
402        // Compute explained variance ratio
403        let total_variance: f64 = eigenvalues.iter().sum();
404        let explained_variance_ratio = if total_variance > 1e-10 {
405            Array1::from_vec(
406                eigenvalues
407                    .iter()
408                    .map(|&val| val / total_variance)
409                    .collect(),
410            )
411        } else {
412            Array1::zeros(eigenvalues.len())
413        };
414
415        Ok((eigenvectors, explained_variance_ratio))
416    }
417
418    /// Eigenvalue-based solver
419    fn solve_eigen(
420        &self,
421        sw: &Array2<f64>,
422        sb: &Array2<f64>,
423    ) -> Result<(Array2<f64>, Array1<f64>)> {
424        // Compute Sw^{-1} * Sb
425        let sw_inv = scirs2_linalg::inv(&sw.view(), None).map_err(|e| {
426            StatsError::ComputationError(format!(
427                "Failed to invert within-class scatter matrix: {}. Try using shrinkage.",
428                e
429            ))
430        })?;
431
432        let a = sw_inv.dot(sb);
433
434        // Eigenvalue decomposition using scirs2_linalg
435        // Note: Using eigh_f64_lapack for symmetric eigenvalue decomposition
436        let (eigenvalues, eigenvectors) =
437            scirs2_linalg::eigh_f64_lapack(&a.view()).map_err(|e| {
438                StatsError::ComputationError(format!("Eigenvalue decomposition failed: {}", e))
439            })?;
440
441        // Sort in descending order
442        let mut eigen_pairs: Vec<_> = eigenvalues
443            .iter()
444            .cloned()
445            .zip(eigenvectors.columns())
446            .collect();
447        eigen_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("Operation failed"));
448
449        let sorted_eigenvalues: Vec<f64> = eigen_pairs.iter().map(|(val_, _)| *val_).collect();
450        let sorted_eigenvectors: Array2<f64> = Array2::from_shape_vec(
451            (eigenvectors.nrows(), sorted_eigenvalues.len()),
452            eigen_pairs
453                .iter()
454                .flat_map(|(_, vec)| vec.iter().cloned())
455                .collect(),
456        )
457        .map_err(|e| {
458            StatsError::ComputationError(format!("Failed to construct eigenvector matrix: {}", e))
459        })?;
460
461        // Compute explained variance ratio
462        let total_variance: f64 = sorted_eigenvalues.iter().filter(|&&val| val > 0.0).sum();
463        let explained_variance_ratio = if total_variance > 1e-10 {
464            Array1::from_vec(
465                sorted_eigenvalues
466                    .iter()
467                    .map(|&val| if val > 0.0 { val / total_variance } else { 0.0 })
468                    .collect(),
469            )
470        } else {
471            Array1::zeros(sorted_eigenvalues.len())
472        };
473
474        Ok((sorted_eigenvectors, explained_variance_ratio))
475    }
476
477    /// Compute intercept for decision function
478    fn compute_intercept(
479        &self,
480        class_means: &Array2<f64>,
481        scalings: &Array2<f64>,
482        priors: &Array1<f64>,
483    ) -> Result<Array1<f64>> {
484        let n_classes = class_means.nrows();
485        let mut intercept = Array1::zeros(n_classes);
486
487        for i in 0..n_classes {
488            let class_mean = class_means.row(i);
489            let projected_mean = scalings.t().dot(&class_mean.to_owned());
490            let prior_term = priors[i].ln();
491
492            // Intercept = log(prior) - 0.5 * mean^T * Sigma^{-1} * mean
493            intercept[i] = prior_term - 0.5 * projected_mean.dot(&projected_mean);
494        }
495
496        Ok(intercept)
497    }
498
499    /// Transform data to discriminant space
500    pub fn transform(&self, x: ArrayView2<f64>, result: &LDAResult) -> Result<Array2<f64>> {
501        let handler = global_error_handler();
502        validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "LDA transform");
503
504        if x.ncols() != result.n_features {
505            return Err(handler
506                .create_validation_error(
507                    ErrorCode::E2001,
508                    "LDA transform",
509                    "n_features",
510                    format!("input: {}, expected: {}", x.ncols(), result.n_features),
511                    "Number of features must match training data",
512                )
513                .error);
514        }
515
516        Ok(x.dot(&result.scalings))
517    }
518
519    /// Predict class labels
520    pub fn predict(&self, x: ArrayView2<f64>, result: &LDAResult) -> Result<Array1<i32>> {
521        let scores = self.decision_function(x, result)?;
522        let mut predictions = Array1::zeros(x.nrows());
523
524        for (i, row) in scores.rows().into_iter().enumerate() {
525            let max_idx = row
526                .iter()
527                .enumerate()
528                .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Operation failed"))
529                .map(|(idx, _)| idx)
530                .expect("Operation failed");
531            predictions[i] = result.classes[max_idx];
532        }
533
534        Ok(predictions)
535    }
536
537    /// Compute decision function scores
538    pub fn decision_function(&self, x: ArrayView2<f64>, result: &LDAResult) -> Result<Array2<f64>> {
539        let projected = self.transform(x, result)?;
540        let n_samples = projected.nrows();
541        let n_classes = result.classes.len();
542
543        let mut scores = Array2::zeros((n_samples, n_classes));
544
545        for i in 0..n_samples {
546            let sample = projected.row(i);
547            for j in 0..n_classes {
548                let class_mean = result.means.row(j);
549                let projected_class_mean = result.scalings.t().dot(&class_mean.to_owned());
550
551                // Linear discriminant function
552                scores[[i, j]] = sample.dot(&projected_class_mean) + result.intercept[j];
553            }
554        }
555
556        Ok(scores)
557    }
558
559    /// Compute prediction probabilities using softmax
560    pub fn predict_proba(&self, x: ArrayView2<f64>, result: &LDAResult) -> Result<Array2<f64>> {
561        let scores = self.decision_function(x, result)?;
562        let mut probabilities = Array2::zeros(scores.dim());
563
564        for (i, mut row) in probabilities.rows_mut().into_iter().enumerate() {
565            let score_row = scores.row(i);
566            let max_score = score_row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
567
568            // Compute softmax (numerically stable)
569            let mut sum_exp = 0.0;
570            for (j, &score) in score_row.iter().enumerate() {
571                let exp_score = (score - max_score).exp();
572                row[j] = exp_score;
573                sum_exp += exp_score;
574            }
575
576            // Normalize
577            if sum_exp > 1e-10 {
578                row /= sum_exp;
579            } else {
580                // Uniform distribution if all scores are very negative
581                let len = row.len();
582                row.fill(1.0 / len as f64);
583            }
584        }
585
586        Ok(probabilities)
587    }
588}
589
590/// Quadratic Discriminant Analysis (QDA)
591///
592/// QDA is similar to LDA but allows different covariance matrices for each class.
593/// This makes it more flexible but requires more parameters.
594#[derive(Debug, Clone)]
595pub struct QuadraticDiscriminantAnalysis {
596    /// Prior probabilities for each class (None = empirical)
597    pub priors: Option<Array1<f64>>,
598    /// Regularization parameter for covariance matrices
599    pub reg_param: f64,
600    /// Store covariances during training
601    pub store_covariance: bool,
602}
603
604/// Result of Quadratic Discriminant Analysis
605#[derive(Debug, Clone)]
606pub struct QDAResult {
607    /// Covariance matrices for each class
608    pub covariances: Option<Vec<Array2<f64>>>,
609    /// Class means
610    pub means: Array2<f64>,
611    /// Prior probabilities for each class
612    pub priors: Array1<f64>,
613    /// Class labels
614    pub classes: Array1<i32>,
615    /// Number of features used for training
616    pub n_features: usize,
617}
618
619impl Default for QuadraticDiscriminantAnalysis {
620    fn default() -> Self {
621        Self {
622            priors: None,
623            reg_param: 0.0,
624            store_covariance: true,
625        }
626    }
627}
628
629impl QuadraticDiscriminantAnalysis {
630    /// Create a new QDA instance
631    pub fn new() -> Self {
632        Self::default()
633    }
634
635    /// Set prior probabilities
636    pub fn with_priors(mut self, priors: Array1<f64>) -> Self {
637        self.priors = Some(priors);
638        self
639    }
640
641    /// Set regularization parameter
642    pub fn with_reg_param(mut self, reg_param: f64) -> Self {
643        self.reg_param = reg_param;
644        self
645    }
646
647    /// Set whether to store covariance matrices
648    pub fn with_store_covariance(mut self, store: bool) -> Self {
649        self.store_covariance = store;
650        self
651    }
652
653    /// Fit the QDA model
654    pub fn fit(&self, x: ArrayView2<f64>, y: ArrayView1<i32>) -> Result<QDAResult> {
655        let handler = global_error_handler();
656        validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "QDA fit");
657
658        let (n_samples, n_features) = x.dim();
659
660        if n_samples != y.len() {
661            return Err(handler
662                .create_validation_error(
663                    ErrorCode::E2001,
664                    "QDA fit",
665                    "samplesize_mismatch",
666                    format!("x: {}, y: {}", n_samples, y.len()),
667                    "Number of samples in X and y must be equal",
668                )
669                .error);
670        }
671
672        // Get unique classes
673        let mut classes = y.to_vec();
674        classes.sort_unstable();
675        classes.dedup();
676        let unique_classes = Array1::from_vec(classes);
677        let n_classes = unique_classes.len();
678
679        if n_classes < 2 {
680            return Err(handler
681                .create_validation_error(
682                    ErrorCode::E1001,
683                    "QDA fit",
684                    "n_classes",
685                    n_classes,
686                    "QDA requires at least 2 classes",
687                )
688                .error);
689        }
690
691        // Compute class statistics
692        let mut class_means = Array2::zeros((n_classes, n_features));
693        let mut class_covariances = Vec::with_capacity(n_classes);
694        let mut class_counts = Array1::zeros(n_classes);
695
696        for (class_idx, &class_label) in unique_classes.iter().enumerate() {
697            let class_indices: Vec<_> = y
698                .iter()
699                .enumerate()
700                .filter(|(_, &label)| label == class_label)
701                .map(|(idx, _)| idx)
702                .collect();
703
704            let classsize = class_indices.len();
705            if classsize < 2 {
706                return Err(handler
707                    .create_validation_error(
708                        ErrorCode::E2003,
709                        "QDA fit",
710                        "classsize",
711                        classsize,
712                        "Each class must have at least 2 samples for covariance estimation",
713                    )
714                    .error);
715            }
716
717            class_counts[class_idx] = classsize;
718
719            // Compute class mean
720            let mut classdata = Array2::zeros((classsize, n_features));
721            for (i, &sample_idx) in class_indices.iter().enumerate() {
722                classdata.row_mut(i).assign(&x.row(sample_idx));
723            }
724
725            let class_mean = classdata.mean_axis(Axis(0)).expect("Operation failed");
726            class_means.row_mut(class_idx).assign(&class_mean);
727
728            // Compute class covariance
729            let mut centered = classdata;
730            for mut row in centered.rows_mut() {
731                row -= &class_mean;
732            }
733
734            let mut cov = centered.t().dot(&centered) / (classsize - 1) as f64;
735
736            // Apply regularization
737            if self.reg_param > 0.0 {
738                let trace = (0..n_features).map(|i| cov[[i, i]]).sum::<f64>();
739                let identity_term: Array2<f64> =
740                    Array2::eye(n_features) * (self.reg_param * trace / n_features as f64);
741                cov = cov + identity_term;
742            }
743
744            class_covariances.push(cov);
745        }
746
747        // Compute priors
748        let class_priors = if let Some(ref priors) = self.priors {
749            if priors.len() != n_classes {
750                return Err(StatsError::InvalidArgument(format!(
751                    "Priors length ({}) must equal number of classes ({})",
752                    priors.len(),
753                    n_classes
754                )));
755            }
756            priors.clone()
757        } else {
758            class_counts.mapv(|count| count as f64 / n_samples as f64)
759        };
760
761        Ok(QDAResult {
762            covariances: if self.store_covariance {
763                Some(class_covariances)
764            } else {
765                None
766            },
767            means: class_means,
768            priors: class_priors,
769            classes: unique_classes,
770            n_features,
771        })
772    }
773
774    /// Predict class labels
775    pub fn predict(&self, x: ArrayView2<f64>, result: &QDAResult) -> Result<Array1<i32>> {
776        let scores = self.decision_function(x, result)?;
777        let mut predictions = Array1::zeros(x.nrows());
778
779        for (i, row) in scores.rows().into_iter().enumerate() {
780            let max_idx = row
781                .iter()
782                .enumerate()
783                .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Operation failed"))
784                .map(|(idx, _)| idx)
785                .expect("Operation failed");
786            predictions[i] = result.classes[max_idx];
787        }
788
789        Ok(predictions)
790    }
791
792    /// Compute decision function scores
793    pub fn decision_function(&self, x: ArrayView2<f64>, result: &QDAResult) -> Result<Array2<f64>> {
794        let handler = global_error_handler();
795        validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "QDA decision_function");
796
797        if x.ncols() != result.n_features {
798            return Err(handler
799                .create_validation_error(
800                    ErrorCode::E2001,
801                    "QDA decision_function",
802                    "n_features",
803                    format!("input: {}, expected: {}", x.ncols(), result.n_features),
804                    "Number of features must match training data",
805                )
806                .error);
807        }
808
809        if result.covariances.is_none() {
810            return Err(StatsError::InvalidArgument(
811                "Covariances not stored during training. Set store_covariance=true.".to_string(),
812            ));
813        }
814
815        let covariances = result.covariances.as_ref().expect("Operation failed");
816        let n_samples = x.nrows();
817        let n_classes = result.classes.len();
818        let mut scores = Array2::zeros((n_samples, n_classes));
819
820        for class_idx in 0..n_classes {
821            let class_mean = result.means.row(class_idx);
822            let class_cov = &covariances[class_idx];
823
824            // Compute inverse and determinant
825            let cov_inv = scirs2_linalg::inv(&class_cov.view(), None).map_err(|e| {
826                StatsError::ComputationError(format!(
827                    "Failed to invert covariance matrix for class {}: {}",
828                    class_idx, e
829                ))
830            })?;
831
832            let det_cov = scirs2_linalg::det(&class_cov.view(), None).map_err(|e| {
833                StatsError::ComputationError(format!(
834                    "Failed to compute determinant for class {}: {}",
835                    class_idx, e
836                ))
837            })?;
838
839            if det_cov <= 0.0 {
840                return Err(StatsError::ComputationError(format!(
841                    "Covariance matrix for class {} is not positive definite",
842                    class_idx
843                )));
844            }
845
846            let log_det_term = -0.5 * det_cov.ln();
847            let prior_term = result.priors[class_idx].ln();
848
849            for sample_idx in 0..n_samples {
850                let sample = x.row(sample_idx);
851                let diff = &sample - &class_mean;
852
853                // Quadratic form: (x - μ)^T Σ^{-1} (x - μ)
854                let quad_form = diff.dot(&cov_inv.dot(&diff.to_owned()));
855
856                scores[[sample_idx, class_idx]] = prior_term + log_det_term - 0.5 * quad_form;
857            }
858        }
859
860        Ok(scores)
861    }
862
863    /// Compute prediction probabilities
864    pub fn predict_proba(&self, x: ArrayView2<f64>, result: &QDAResult) -> Result<Array2<f64>> {
865        let scores = self.decision_function(x, result)?;
866        let mut probabilities = Array2::zeros(scores.dim());
867
868        for (i, mut row) in probabilities.rows_mut().into_iter().enumerate() {
869            let score_row = scores.row(i);
870            let max_score = score_row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
871
872            // Compute softmax (numerically stable)
873            let mut sum_exp = 0.0;
874            for (j, &score) in score_row.iter().enumerate() {
875                let exp_score = (score - max_score).exp();
876                row[j] = exp_score;
877                sum_exp += exp_score;
878            }
879
880            // Normalize
881            if sum_exp > 1e-10 {
882                row /= sum_exp;
883            } else {
884                let len = row.len();
885                row.fill(1.0 / len as f64);
886            }
887        }
888
889        Ok(probabilities)
890    }
891}
892
893#[cfg(test)]
894mod tests {
895    use super::*;
896    use scirs2_core::ndarray::array;
897
898    #[test]
899    fn test_lda_basic() {
900        // Create non-degenerate data with proper variance in multiple dimensions
901        let x = array![
902            [1.0, 2.5],
903            [2.1, 3.2],
904            [2.8, 4.1],
905            [6.2, 7.1],
906            [7.3, 8.5],
907            [8.1, 9.3],
908        ];
909        let y = array![0, 0, 0, 1, 1, 1];
910
911        let lda = LinearDiscriminantAnalysis::new();
912        let result = lda.fit(x.view(), y.view()).expect("Operation failed");
913
914        assert_eq!(result.classes, array![0, 1]);
915        assert_eq!(result.means.nrows(), 2);
916        assert_eq!(result.means.ncols(), 2);
917
918        // Test prediction
919        let predictions = lda.predict(x.view(), &result).expect("Operation failed");
920        assert_eq!(predictions.len(), 6);
921    }
922
923    #[test]
924    fn test_qda_basic() {
925        // Create non-degenerate data with different covariance structures for each class
926        let x = array![
927            [1.0, 2.5],
928            [2.1, 3.2],
929            [2.8, 4.1],
930            [6.2, 7.1],
931            [7.3, 8.5],
932            [8.1, 9.3],
933        ];
934        let y = array![0, 0, 0, 1, 1, 1];
935
936        let qda = QuadraticDiscriminantAnalysis::new();
937        let result = qda.fit(x.view(), y.view()).expect("Operation failed");
938
939        assert_eq!(result.classes, array![0, 1]);
940        assert_eq!(result.means.nrows(), 2);
941        assert_eq!(result.means.ncols(), 2);
942
943        // Test prediction
944        let predictions = qda.predict(x.view(), &result).expect("Operation failed");
945        assert_eq!(predictions.len(), 6);
946    }
947
948    #[test]
949    fn test_lda_transform() {
950        // Create non-degenerate 3D data with independent variance in each dimension
951        let x = array![
952            [1.2, 2.8, 3.1],
953            [2.1, 3.5, 4.2],
954            [2.9, 4.1, 5.3],
955            [6.1, 7.2, 8.5],
956            [7.2, 8.3, 9.1],
957            [8.3, 9.1, 10.2],
958        ];
959        let y = array![0, 0, 0, 1, 1, 1];
960
961        let lda = LinearDiscriminantAnalysis::new();
962        let result = lda.fit(x.view(), y.view()).expect("Operation failed");
963
964        let transformed = lda.transform(x.view(), &result).expect("Operation failed");
965        assert_eq!(transformed.nrows(), 6);
966        assert!(transformed.ncols() <= result.classes.len() - 1);
967    }
968}