1use crate::error::{StatsError, StatsResult as Result};
7use crate::error_handling_v2::ErrorCode;
8use crate::{unified_error_handling::global_error_handler, validate_or_error};
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
10
11#[derive(Debug, Clone)]
17pub struct LinearDiscriminantAnalysis {
18 pub solver: LDASolver,
20 pub shrinkage: Option<f64>,
22 pub n_components: Option<usize>,
24 pub priors: Option<Array1<f64>>,
26 pub store_covariance: bool,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq)]
32pub enum LDASolver {
33 Svd,
35 Eigen,
37}
38
39#[derive(Debug, Clone)]
41pub struct LDAResult {
42 pub scalings: Array2<f64>,
44 pub intercept: Array1<f64>,
46 pub covariance: Option<Array2<f64>>,
48 pub means: Array2<f64>,
50 pub priors: Array1<f64>,
52 pub classes: Array1<i32>,
54 pub explained_variance_ratio: Array1<f64>,
56 pub n_features: usize,
58}
59
60impl Default for LinearDiscriminantAnalysis {
61 fn default() -> Self {
62 Self {
63 solver: LDASolver::Svd,
64 shrinkage: None,
65 n_components: None,
66 priors: None,
67 store_covariance: true,
68 }
69 }
70}
71
72impl LinearDiscriminantAnalysis {
73 pub fn new() -> Self {
75 Self::default()
76 }
77
78 pub fn with_solver(mut self, solver: LDASolver) -> Self {
80 self.solver = solver;
81 self
82 }
83
84 pub fn with_shrinkage(mut self, shrinkage: f64) -> Self {
86 self.shrinkage = Some(shrinkage);
87 self
88 }
89
90 pub fn with_n_components(mut self, n_components: usize) -> Self {
92 self.n_components = Some(n_components);
93 self
94 }
95
96 pub fn with_priors(mut self, priors: Array1<f64>) -> Self {
98 self.priors = Some(priors);
99 self
100 }
101
102 pub fn with_store_covariance(mut self, store: bool) -> Self {
104 self.store_covariance = store;
105 self
106 }
107
108 pub fn fit(&self, x: ArrayView2<f64>, y: ArrayView1<i32>) -> Result<LDAResult> {
110 let handler = global_error_handler();
111 validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "LDA fit");
112
113 let (n_samples, n_features) = x.dim();
114 let n_targets = y.len();
115
116 if n_samples != n_targets {
117 return Err(handler
118 .create_validation_error(
119 ErrorCode::E2001,
120 "LDA fit",
121 "samplesize_mismatch",
122 format!("x: {}, y: {}", n_samples, n_targets),
123 "Number of samples in X and y must be equal",
124 )
125 .error);
126 }
127
128 if n_samples < 2 {
129 return Err(handler
130 .create_validation_error(
131 ErrorCode::E2003,
132 "LDA fit",
133 "n_samples",
134 n_samples,
135 "LDA requires at least 2 samples",
136 )
137 .error);
138 }
139
140 let unique_classes = self.get_unique_classes(y)?;
142 let n_classes = unique_classes.len();
143
144 if n_classes < 2 {
145 return Err(handler
146 .create_validation_error(
147 ErrorCode::E1001,
148 "LDA fit",
149 "n_classes",
150 n_classes,
151 "LDA requires at least 2 classes",
152 )
153 .error);
154 }
155
156 if n_features >= n_samples && self.solver == LDASolver::Eigen {
157 return Err(handler
158 .create_error(
159 ErrorCode::E1001,
160 "LDA fit",
161 "Use SVD solver when n_features >= n_samples for numerical stability",
162 )
163 .error);
164 }
165
166 let (class_means, class_priors, class_counts) =
168 self.compute_class_statistics(x, y, &unique_classes)?;
169
170 let (sw, sb) = self.compute_scatter_matrices(x, y, &unique_classes, &class_means)?;
172
173 let sw_regularized = if let Some(shrinkage) = self.shrinkage {
175 self.apply_shrinkage(&sw, shrinkage)?
176 } else {
177 sw
178 };
179
180 let (scalings, explained_variance_ratio) =
182 self.solve_eigenvalue_problem(&sw_regularized, &sb)?;
183
184 let n_components = self
186 .n_components
187 .unwrap_or(n_classes - 1)
188 .min(n_classes - 1)
189 .min(n_features);
190
191 let final_scalings = scalings
192 .slice(scirs2_core::ndarray::s![.., ..n_components])
193 .to_owned();
194 let final_explained_variance = explained_variance_ratio
195 .slice(scirs2_core::ndarray::s![..n_components])
196 .to_owned();
197
198 let intercept = self.compute_intercept(&class_means, &final_scalings, &class_priors)?;
200
201 Ok(LDAResult {
202 scalings: final_scalings,
203 intercept,
204 covariance: if self.store_covariance {
205 Some(sw_regularized)
206 } else {
207 None
208 },
209 means: class_means,
210 priors: class_priors,
211 classes: unique_classes,
212 explained_variance_ratio: final_explained_variance,
213 n_features,
214 })
215 }
216
217 fn get_unique_classes(&self, y: ArrayView1<i32>) -> Result<Array1<i32>> {
219 let mut classes = y.to_vec();
220 classes.sort_unstable();
221 classes.dedup();
222 Ok(Array1::from_vec(classes))
223 }
224
225 fn compute_class_statistics(
227 &self,
228 x: ArrayView2<f64>,
229 y: ArrayView1<i32>,
230 classes: &Array1<i32>,
231 ) -> Result<(Array2<f64>, Array1<f64>, Array1<usize>)> {
232 let (n_samples, n_features) = x.dim();
233 let n_classes = classes.len();
234
235 let mut class_means = Array2::zeros((n_classes, n_features));
236 let mut class_counts = Array1::zeros(n_classes);
237
238 for (i, &class_label) in classes.iter().enumerate() {
240 let class_indices: Vec<_> = y
241 .iter()
242 .enumerate()
243 .filter(|(_, &label)| label == class_label)
244 .map(|(idx, _)| idx)
245 .collect();
246
247 if class_indices.is_empty() {
248 return Err(StatsError::InvalidArgument(format!(
249 "Class {} has no samples",
250 class_label
251 )));
252 }
253
254 class_counts[i] = class_indices.len();
255
256 let mut sum = Array1::zeros(n_features);
258 for &idx in &class_indices {
259 sum += &x.row(idx);
260 }
261 class_means
262 .row_mut(i)
263 .assign(&(sum / class_indices.len() as f64));
264 }
265
266 let class_priors = if let Some(ref priors) = self.priors {
268 if priors.len() != n_classes {
269 return Err(StatsError::InvalidArgument(format!(
270 "Priors length ({}) must equal number of classes ({})",
271 priors.len(),
272 n_classes
273 )));
274 }
275 priors.clone()
276 } else {
277 class_counts.mapv(|count| count as f64 / n_samples as f64)
279 };
280
281 Ok((class_means, class_priors, class_counts.mapv(|x| x)))
282 }
283
284 fn compute_scatter_matrices(
286 &self,
287 x: ArrayView2<f64>,
288 y: ArrayView1<i32>,
289 classes: &Array1<i32>,
290 class_means: &Array2<f64>,
291 ) -> Result<(Array2<f64>, Array2<f64>)> {
292 let (_n_samples, n_features) = x.dim();
293 let _n_classes = classes.len();
294
295 let overall_mean = x.mean_axis(Axis(0)).expect("Operation failed");
297
298 let mut sw = Array2::zeros((n_features, n_features));
300 let mut sb = Array2::zeros((n_features, n_features));
301
302 for (class_idx, &class_label) in classes.iter().enumerate() {
304 let class_mean = class_means.row(class_idx);
305
306 for (sample_idx, &sample_label) in y.iter().enumerate() {
307 if sample_label == class_label {
308 let sample = x.row(sample_idx);
309 let diff = &sample - &class_mean;
310
311 for i in 0..n_features {
313 for j in 0..n_features {
314 sw[[i, j]] += diff[i] * diff[j];
315 }
316 }
317 }
318 }
319 }
320
321 for (class_idx, _) in classes.iter().enumerate() {
323 let class_mean = class_means.row(class_idx);
324 let class_count = y
325 .iter()
326 .filter(|&&label| label == classes[class_idx])
327 .count() as f64;
328 let diff = &class_mean - &overall_mean;
329
330 for i in 0..n_features {
332 for j in 0..n_features {
333 sb[[i, j]] += class_count * diff[i] * diff[j];
334 }
335 }
336 }
337
338 Ok((sw, sb))
339 }
340
341 fn apply_shrinkage(&self, sw: &Array2<f64>, shrinkage: f64) -> Result<Array2<f64>> {
343 let n_features = sw.nrows();
344 let trace = (0..n_features).map(|i| sw[[i, i]]).sum::<f64>();
345 let scaled_identity = Array2::eye(n_features) * (trace / n_features as f64);
346
347 Ok((1.0 - shrinkage) * sw + shrinkage * scaled_identity)
348 }
349
350 fn solve_eigenvalue_problem(
352 &self,
353 sw: &Array2<f64>,
354 sb: &Array2<f64>,
355 ) -> Result<(Array2<f64>, Array1<f64>)> {
356 match self.solver {
357 LDASolver::Svd => self.solve_svd(sw, sb),
358 LDASolver::Eigen => self.solve_eigen(sw, sb),
359 }
360 }
361
362 fn solve_svd(&self, sw: &Array2<f64>, sb: &Array2<f64>) -> Result<(Array2<f64>, Array1<f64>)> {
364 let l = scirs2_linalg::cholesky(&sw.view(), None).map_err(|e| {
366 StatsError::ComputationError(format!(
367 "Cholesky decomposition failed: {}. Try using shrinkage.",
368 e
369 ))
370 })?;
371
372 let l_inv = scirs2_linalg::inv(&l.view(), None).map_err(|e| {
374 StatsError::ComputationError(format!("Failed to invert Cholesky factor: {}", e))
375 })?;
376
377 let m = l_inv.dot(sb).dot(&l_inv.t());
378
379 let (u, s, _vt) = scirs2_linalg::svd(&m.view(), true, None)
381 .map_err(|e| StatsError::ComputationError(format!("SVD failed: {}", e)))?;
382
383 let scalings = l_inv.t().dot(&u);
385
386 let mut eigen_pairs: Vec<_> = s.iter().cloned().zip(scalings.columns()).collect();
388 eigen_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("Operation failed"));
389
390 let eigenvalues: Vec<f64> = eigen_pairs.iter().map(|(val_, _)| *val_).collect();
391 let eigenvectors: Array2<f64> = Array2::from_shape_vec(
392 (scalings.nrows(), eigenvalues.len()),
393 eigen_pairs
394 .iter()
395 .flat_map(|(_, vec)| vec.iter().cloned())
396 .collect(),
397 )
398 .map_err(|e| {
399 StatsError::ComputationError(format!("Failed to construct eigenvector matrix: {}", e))
400 })?;
401
402 let total_variance: f64 = eigenvalues.iter().sum();
404 let explained_variance_ratio = if total_variance > 1e-10 {
405 Array1::from_vec(
406 eigenvalues
407 .iter()
408 .map(|&val| val / total_variance)
409 .collect(),
410 )
411 } else {
412 Array1::zeros(eigenvalues.len())
413 };
414
415 Ok((eigenvectors, explained_variance_ratio))
416 }
417
418 fn solve_eigen(
420 &self,
421 sw: &Array2<f64>,
422 sb: &Array2<f64>,
423 ) -> Result<(Array2<f64>, Array1<f64>)> {
424 let sw_inv = scirs2_linalg::inv(&sw.view(), None).map_err(|e| {
426 StatsError::ComputationError(format!(
427 "Failed to invert within-class scatter matrix: {}. Try using shrinkage.",
428 e
429 ))
430 })?;
431
432 let a = sw_inv.dot(sb);
433
434 let (eigenvalues, eigenvectors) =
437 scirs2_linalg::eigh_f64_lapack(&a.view()).map_err(|e| {
438 StatsError::ComputationError(format!("Eigenvalue decomposition failed: {}", e))
439 })?;
440
441 let mut eigen_pairs: Vec<_> = eigenvalues
443 .iter()
444 .cloned()
445 .zip(eigenvectors.columns())
446 .collect();
447 eigen_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("Operation failed"));
448
449 let sorted_eigenvalues: Vec<f64> = eigen_pairs.iter().map(|(val_, _)| *val_).collect();
450 let sorted_eigenvectors: Array2<f64> = Array2::from_shape_vec(
451 (eigenvectors.nrows(), sorted_eigenvalues.len()),
452 eigen_pairs
453 .iter()
454 .flat_map(|(_, vec)| vec.iter().cloned())
455 .collect(),
456 )
457 .map_err(|e| {
458 StatsError::ComputationError(format!("Failed to construct eigenvector matrix: {}", e))
459 })?;
460
461 let total_variance: f64 = sorted_eigenvalues.iter().filter(|&&val| val > 0.0).sum();
463 let explained_variance_ratio = if total_variance > 1e-10 {
464 Array1::from_vec(
465 sorted_eigenvalues
466 .iter()
467 .map(|&val| if val > 0.0 { val / total_variance } else { 0.0 })
468 .collect(),
469 )
470 } else {
471 Array1::zeros(sorted_eigenvalues.len())
472 };
473
474 Ok((sorted_eigenvectors, explained_variance_ratio))
475 }
476
477 fn compute_intercept(
479 &self,
480 class_means: &Array2<f64>,
481 scalings: &Array2<f64>,
482 priors: &Array1<f64>,
483 ) -> Result<Array1<f64>> {
484 let n_classes = class_means.nrows();
485 let mut intercept = Array1::zeros(n_classes);
486
487 for i in 0..n_classes {
488 let class_mean = class_means.row(i);
489 let projected_mean = scalings.t().dot(&class_mean.to_owned());
490 let prior_term = priors[i].ln();
491
492 intercept[i] = prior_term - 0.5 * projected_mean.dot(&projected_mean);
494 }
495
496 Ok(intercept)
497 }
498
499 pub fn transform(&self, x: ArrayView2<f64>, result: &LDAResult) -> Result<Array2<f64>> {
501 let handler = global_error_handler();
502 validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "LDA transform");
503
504 if x.ncols() != result.n_features {
505 return Err(handler
506 .create_validation_error(
507 ErrorCode::E2001,
508 "LDA transform",
509 "n_features",
510 format!("input: {}, expected: {}", x.ncols(), result.n_features),
511 "Number of features must match training data",
512 )
513 .error);
514 }
515
516 Ok(x.dot(&result.scalings))
517 }
518
519 pub fn predict(&self, x: ArrayView2<f64>, result: &LDAResult) -> Result<Array1<i32>> {
521 let scores = self.decision_function(x, result)?;
522 let mut predictions = Array1::zeros(x.nrows());
523
524 for (i, row) in scores.rows().into_iter().enumerate() {
525 let max_idx = row
526 .iter()
527 .enumerate()
528 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Operation failed"))
529 .map(|(idx, _)| idx)
530 .expect("Operation failed");
531 predictions[i] = result.classes[max_idx];
532 }
533
534 Ok(predictions)
535 }
536
537 pub fn decision_function(&self, x: ArrayView2<f64>, result: &LDAResult) -> Result<Array2<f64>> {
539 let projected = self.transform(x, result)?;
540 let n_samples = projected.nrows();
541 let n_classes = result.classes.len();
542
543 let mut scores = Array2::zeros((n_samples, n_classes));
544
545 for i in 0..n_samples {
546 let sample = projected.row(i);
547 for j in 0..n_classes {
548 let class_mean = result.means.row(j);
549 let projected_class_mean = result.scalings.t().dot(&class_mean.to_owned());
550
551 scores[[i, j]] = sample.dot(&projected_class_mean) + result.intercept[j];
553 }
554 }
555
556 Ok(scores)
557 }
558
559 pub fn predict_proba(&self, x: ArrayView2<f64>, result: &LDAResult) -> Result<Array2<f64>> {
561 let scores = self.decision_function(x, result)?;
562 let mut probabilities = Array2::zeros(scores.dim());
563
564 for (i, mut row) in probabilities.rows_mut().into_iter().enumerate() {
565 let score_row = scores.row(i);
566 let max_score = score_row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
567
568 let mut sum_exp = 0.0;
570 for (j, &score) in score_row.iter().enumerate() {
571 let exp_score = (score - max_score).exp();
572 row[j] = exp_score;
573 sum_exp += exp_score;
574 }
575
576 if sum_exp > 1e-10 {
578 row /= sum_exp;
579 } else {
580 let len = row.len();
582 row.fill(1.0 / len as f64);
583 }
584 }
585
586 Ok(probabilities)
587 }
588}
589
590#[derive(Debug, Clone)]
595pub struct QuadraticDiscriminantAnalysis {
596 pub priors: Option<Array1<f64>>,
598 pub reg_param: f64,
600 pub store_covariance: bool,
602}
603
604#[derive(Debug, Clone)]
606pub struct QDAResult {
607 pub covariances: Option<Vec<Array2<f64>>>,
609 pub means: Array2<f64>,
611 pub priors: Array1<f64>,
613 pub classes: Array1<i32>,
615 pub n_features: usize,
617}
618
619impl Default for QuadraticDiscriminantAnalysis {
620 fn default() -> Self {
621 Self {
622 priors: None,
623 reg_param: 0.0,
624 store_covariance: true,
625 }
626 }
627}
628
629impl QuadraticDiscriminantAnalysis {
630 pub fn new() -> Self {
632 Self::default()
633 }
634
635 pub fn with_priors(mut self, priors: Array1<f64>) -> Self {
637 self.priors = Some(priors);
638 self
639 }
640
641 pub fn with_reg_param(mut self, reg_param: f64) -> Self {
643 self.reg_param = reg_param;
644 self
645 }
646
647 pub fn with_store_covariance(mut self, store: bool) -> Self {
649 self.store_covariance = store;
650 self
651 }
652
653 pub fn fit(&self, x: ArrayView2<f64>, y: ArrayView1<i32>) -> Result<QDAResult> {
655 let handler = global_error_handler();
656 validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "QDA fit");
657
658 let (n_samples, n_features) = x.dim();
659
660 if n_samples != y.len() {
661 return Err(handler
662 .create_validation_error(
663 ErrorCode::E2001,
664 "QDA fit",
665 "samplesize_mismatch",
666 format!("x: {}, y: {}", n_samples, y.len()),
667 "Number of samples in X and y must be equal",
668 )
669 .error);
670 }
671
672 let mut classes = y.to_vec();
674 classes.sort_unstable();
675 classes.dedup();
676 let unique_classes = Array1::from_vec(classes);
677 let n_classes = unique_classes.len();
678
679 if n_classes < 2 {
680 return Err(handler
681 .create_validation_error(
682 ErrorCode::E1001,
683 "QDA fit",
684 "n_classes",
685 n_classes,
686 "QDA requires at least 2 classes",
687 )
688 .error);
689 }
690
691 let mut class_means = Array2::zeros((n_classes, n_features));
693 let mut class_covariances = Vec::with_capacity(n_classes);
694 let mut class_counts = Array1::zeros(n_classes);
695
696 for (class_idx, &class_label) in unique_classes.iter().enumerate() {
697 let class_indices: Vec<_> = y
698 .iter()
699 .enumerate()
700 .filter(|(_, &label)| label == class_label)
701 .map(|(idx, _)| idx)
702 .collect();
703
704 let classsize = class_indices.len();
705 if classsize < 2 {
706 return Err(handler
707 .create_validation_error(
708 ErrorCode::E2003,
709 "QDA fit",
710 "classsize",
711 classsize,
712 "Each class must have at least 2 samples for covariance estimation",
713 )
714 .error);
715 }
716
717 class_counts[class_idx] = classsize;
718
719 let mut classdata = Array2::zeros((classsize, n_features));
721 for (i, &sample_idx) in class_indices.iter().enumerate() {
722 classdata.row_mut(i).assign(&x.row(sample_idx));
723 }
724
725 let class_mean = classdata.mean_axis(Axis(0)).expect("Operation failed");
726 class_means.row_mut(class_idx).assign(&class_mean);
727
728 let mut centered = classdata;
730 for mut row in centered.rows_mut() {
731 row -= &class_mean;
732 }
733
734 let mut cov = centered.t().dot(¢ered) / (classsize - 1) as f64;
735
736 if self.reg_param > 0.0 {
738 let trace = (0..n_features).map(|i| cov[[i, i]]).sum::<f64>();
739 let identity_term: Array2<f64> =
740 Array2::eye(n_features) * (self.reg_param * trace / n_features as f64);
741 cov = cov + identity_term;
742 }
743
744 class_covariances.push(cov);
745 }
746
747 let class_priors = if let Some(ref priors) = self.priors {
749 if priors.len() != n_classes {
750 return Err(StatsError::InvalidArgument(format!(
751 "Priors length ({}) must equal number of classes ({})",
752 priors.len(),
753 n_classes
754 )));
755 }
756 priors.clone()
757 } else {
758 class_counts.mapv(|count| count as f64 / n_samples as f64)
759 };
760
761 Ok(QDAResult {
762 covariances: if self.store_covariance {
763 Some(class_covariances)
764 } else {
765 None
766 },
767 means: class_means,
768 priors: class_priors,
769 classes: unique_classes,
770 n_features,
771 })
772 }
773
774 pub fn predict(&self, x: ArrayView2<f64>, result: &QDAResult) -> Result<Array1<i32>> {
776 let scores = self.decision_function(x, result)?;
777 let mut predictions = Array1::zeros(x.nrows());
778
779 for (i, row) in scores.rows().into_iter().enumerate() {
780 let max_idx = row
781 .iter()
782 .enumerate()
783 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Operation failed"))
784 .map(|(idx, _)| idx)
785 .expect("Operation failed");
786 predictions[i] = result.classes[max_idx];
787 }
788
789 Ok(predictions)
790 }
791
792 pub fn decision_function(&self, x: ArrayView2<f64>, result: &QDAResult) -> Result<Array2<f64>> {
794 let handler = global_error_handler();
795 validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "QDA decision_function");
796
797 if x.ncols() != result.n_features {
798 return Err(handler
799 .create_validation_error(
800 ErrorCode::E2001,
801 "QDA decision_function",
802 "n_features",
803 format!("input: {}, expected: {}", x.ncols(), result.n_features),
804 "Number of features must match training data",
805 )
806 .error);
807 }
808
809 if result.covariances.is_none() {
810 return Err(StatsError::InvalidArgument(
811 "Covariances not stored during training. Set store_covariance=true.".to_string(),
812 ));
813 }
814
815 let covariances = result.covariances.as_ref().expect("Operation failed");
816 let n_samples = x.nrows();
817 let n_classes = result.classes.len();
818 let mut scores = Array2::zeros((n_samples, n_classes));
819
820 for class_idx in 0..n_classes {
821 let class_mean = result.means.row(class_idx);
822 let class_cov = &covariances[class_idx];
823
824 let cov_inv = scirs2_linalg::inv(&class_cov.view(), None).map_err(|e| {
826 StatsError::ComputationError(format!(
827 "Failed to invert covariance matrix for class {}: {}",
828 class_idx, e
829 ))
830 })?;
831
832 let det_cov = scirs2_linalg::det(&class_cov.view(), None).map_err(|e| {
833 StatsError::ComputationError(format!(
834 "Failed to compute determinant for class {}: {}",
835 class_idx, e
836 ))
837 })?;
838
839 if det_cov <= 0.0 {
840 return Err(StatsError::ComputationError(format!(
841 "Covariance matrix for class {} is not positive definite",
842 class_idx
843 )));
844 }
845
846 let log_det_term = -0.5 * det_cov.ln();
847 let prior_term = result.priors[class_idx].ln();
848
849 for sample_idx in 0..n_samples {
850 let sample = x.row(sample_idx);
851 let diff = &sample - &class_mean;
852
853 let quad_form = diff.dot(&cov_inv.dot(&diff.to_owned()));
855
856 scores[[sample_idx, class_idx]] = prior_term + log_det_term - 0.5 * quad_form;
857 }
858 }
859
860 Ok(scores)
861 }
862
863 pub fn predict_proba(&self, x: ArrayView2<f64>, result: &QDAResult) -> Result<Array2<f64>> {
865 let scores = self.decision_function(x, result)?;
866 let mut probabilities = Array2::zeros(scores.dim());
867
868 for (i, mut row) in probabilities.rows_mut().into_iter().enumerate() {
869 let score_row = scores.row(i);
870 let max_score = score_row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
871
872 let mut sum_exp = 0.0;
874 for (j, &score) in score_row.iter().enumerate() {
875 let exp_score = (score - max_score).exp();
876 row[j] = exp_score;
877 sum_exp += exp_score;
878 }
879
880 if sum_exp > 1e-10 {
882 row /= sum_exp;
883 } else {
884 let len = row.len();
885 row.fill(1.0 / len as f64);
886 }
887 }
888
889 Ok(probabilities)
890 }
891}
892
893#[cfg(test)]
894mod tests {
895 use super::*;
896 use scirs2_core::ndarray::array;
897
898 #[test]
899 fn test_lda_basic() {
900 let x = array![
902 [1.0, 2.5],
903 [2.1, 3.2],
904 [2.8, 4.1],
905 [6.2, 7.1],
906 [7.3, 8.5],
907 [8.1, 9.3],
908 ];
909 let y = array![0, 0, 0, 1, 1, 1];
910
911 let lda = LinearDiscriminantAnalysis::new();
912 let result = lda.fit(x.view(), y.view()).expect("Operation failed");
913
914 assert_eq!(result.classes, array![0, 1]);
915 assert_eq!(result.means.nrows(), 2);
916 assert_eq!(result.means.ncols(), 2);
917
918 let predictions = lda.predict(x.view(), &result).expect("Operation failed");
920 assert_eq!(predictions.len(), 6);
921 }
922
923 #[test]
924 fn test_qda_basic() {
925 let x = array![
927 [1.0, 2.5],
928 [2.1, 3.2],
929 [2.8, 4.1],
930 [6.2, 7.1],
931 [7.3, 8.5],
932 [8.1, 9.3],
933 ];
934 let y = array![0, 0, 0, 1, 1, 1];
935
936 let qda = QuadraticDiscriminantAnalysis::new();
937 let result = qda.fit(x.view(), y.view()).expect("Operation failed");
938
939 assert_eq!(result.classes, array![0, 1]);
940 assert_eq!(result.means.nrows(), 2);
941 assert_eq!(result.means.ncols(), 2);
942
943 let predictions = qda.predict(x.view(), &result).expect("Operation failed");
945 assert_eq!(predictions.len(), 6);
946 }
947
948 #[test]
949 fn test_lda_transform() {
950 let x = array![
952 [1.2, 2.8, 3.1],
953 [2.1, 3.5, 4.2],
954 [2.9, 4.1, 5.3],
955 [6.1, 7.2, 8.5],
956 [7.2, 8.3, 9.1],
957 [8.3, 9.1, 10.2],
958 ];
959 let y = array![0, 0, 0, 1, 1, 1];
960
961 let lda = LinearDiscriminantAnalysis::new();
962 let result = lda.fit(x.view(), y.view()).expect("Operation failed");
963
964 let transformed = lda.transform(x.view(), &result).expect("Operation failed");
965 assert_eq!(transformed.nrows(), 6);
966 assert!(transformed.ncols() <= result.classes.len() - 1);
967 }
968}