1use crate::error::StatsResult as Result;
8use crate::multivariate::{
9 CCAResult, CanonicalCorrelationAnalysis, LDAResult, LinearDiscriminantAnalysis,
10};
11use crate::{
12 unified_error_handling::{create_standardized_error, global_error_handler},
13 validate_or_error,
14};
15
16use num_cpus;
17use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
18use scirs2_core::random::prelude::*;
19use scirs2_core::simd_ops::SimdUnifiedOps;
20use statrs::statistics::Statistics;
21use std::time::Instant;
22
23#[derive(Debug, Clone)]
25pub struct PerformanceConfig {
26 pub enable_simd: bool,
28 pub enable_parallel: bool,
30 pub simd_threshold: usize,
32 pub parallel_threshold: usize,
34 pub max_threads: Option<usize>,
36 pub auto_tune: bool,
38 pub benchmark: bool,
40 pub auto_select: bool,
42}
43
44impl Default for PerformanceConfig {
45 fn default() -> Self {
46 let capabilities = scirs2_core::simd_ops::PlatformCapabilities::detect();
48
49 Self {
50 enable_simd: capabilities.avx2_available
51 || capabilities.avx512_available
52 || capabilities.simd_available,
53 enable_parallel: num_cpus::get() > 1,
54 simd_threshold: if capabilities.avx512_available {
55 32
56 } else {
57 64
58 },
59 parallel_threshold: 1000,
60 max_threads: None,
61 auto_tune: true,
62 benchmark: false,
63 auto_select: true,
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct PerformanceMetrics {
71 pub execution_time_ms: f64,
73 pub memory_usage: Option<usize>,
75 pub operations_count: usize,
77 pub ops_per_second: f64,
79 pub used_simd: bool,
81 pub used_parallel: bool,
83 pub threads_used: usize,
85}
86
87#[derive(Debug, Clone)]
89pub struct OptimizedLinearDiscriminantAnalysis {
90 lda: LinearDiscriminantAnalysis,
91 config: PerformanceConfig,
92 metrics: Option<PerformanceMetrics>,
93}
94
95impl OptimizedLinearDiscriminantAnalysis {
96 pub fn new(config: PerformanceConfig) -> Self {
98 Self {
99 lda: LinearDiscriminantAnalysis::new(),
100 config,
101 metrics: None,
102 }
103 }
104
105 fn validatedata_optimized(&self, x: ArrayView2<f64>, y: ArrayView1<i32>) -> Result<()>
107 where
108 f64: std::fmt::Display,
109 i32: std::fmt::Display,
110 {
111 let handler = global_error_handler();
112
113 handler.validate_finite_array_or_error(
115 x.as_slice().expect("Operation failed"),
116 "x",
117 "Optimized LDA fit",
118 )?;
119 handler.validate_array_or_error(
120 y.as_slice().expect("Operation failed"),
121 "y",
122 "Optimized LDA fit",
123 )?;
124
125 let (n_samples_, _) = x.dim();
126 if n_samples_ != y.len() {
127 return Err(create_standardized_error(
128 "dimension_mismatch",
129 "samples",
130 &format!("x: {}, y: {}", n_samples_, y.len()),
131 "LDA fit",
132 ));
133 }
134
135 Ok(())
136 }
137
138 pub fn fit(&mut self, x: ArrayView2<f64>, y: ArrayView1<i32>) -> Result<LDAResult> {
140 let start_time = if self.config.benchmark {
141 Some(Instant::now())
142 } else {
143 None
144 };
145 let _handler = global_error_handler();
146 self.validatedata_optimized(x, y)?;
148
149 let (n_samples_, n_features) = x.dim();
150 let datasize = n_samples_ * n_features;
151
152 if self.config.auto_tune {
154 self.auto_tune_thresholds(datasize);
155 }
156
157 let use_simd = self.config.enable_simd && datasize >= self.config.simd_threshold;
159 let use_parallel =
160 self.config.enable_parallel && n_samples_ >= self.config.parallel_threshold;
161
162 let result = if use_parallel && n_samples_ > 5000 {
163 self.fit_parallel(x, y)?
164 } else if use_simd && datasize > self.config.simd_threshold {
165 self.fit_simd(x, y)?
166 } else {
167 self.lda.fit(x, y)?
168 };
169
170 if let Some(start) = start_time {
172 let execution_time = start.elapsed().as_secs_f64() * 1000.0;
173 self.metrics = Some(PerformanceMetrics {
174 execution_time_ms: execution_time,
175 memory_usage: Some(datasize * 8), operations_count: n_samples_ * n_features,
177 ops_per_second: (n_samples_ * n_features) as f64 / (execution_time / 1000.0),
178 used_simd: use_simd,
179 used_parallel: use_parallel,
180 threads_used: if use_parallel { num_cpus::get() } else { 1 },
181 });
182 }
183
184 Ok(result)
185 }
186
187 fn auto_tune_thresholds(&mut self, datasize: usize) {
189 if datasize > 100_000 {
191 self.config.simd_threshold = 32;
192 self.config.parallel_threshold = 500;
193 } else if datasize > 10_000 {
194 self.config.simd_threshold = 64;
195 self.config.parallel_threshold = 1000;
196 } else {
197 self.config.simd_threshold = 128;
198 self.config.parallel_threshold = 2000;
199 }
200 }
201
202 fn fit_simd(&self, x: ArrayView2<f64>, y: ArrayView1<i32>) -> Result<LDAResult> {
204 let mut classes = y.to_vec();
209 classes.sort_unstable();
210 classes.dedup();
211 let unique_classes = Array1::from_vec(classes);
212 let _n_classes = unique_classes.len();
213 let _n_samples_n_features = x.dim();
214
215 let class_means = self.compute_class_means_simd(x, y, &unique_classes)?;
217
218 let _sw_sb = self.compute_scatter_matrices_simd(x, y, &unique_classes, &class_means)?;
220
221 let _lda_temp = LinearDiscriminantAnalysis::new();
223
224 self.lda.fit(x, y)
227 }
228
229 fn fit_parallel(&self, x: ArrayView2<f64>, y: ArrayView1<i32>) -> Result<LDAResult> {
231 let _n_samples_n_features = x.dim();
232
233 let mut classes = y.to_vec();
235 classes.sort_unstable();
236 classes.dedup();
237 let unique_classes = Array1::from_vec(classes);
238 let _n_classes = unique_classes.len();
239
240 let class_means = self.compute_class_means_parallel(x, y, &unique_classes)?;
242
243 let _sw_sb = self.compute_scatter_matrices_parallel(x, y, &unique_classes, &class_means)?;
245
246 self.lda.fit(x, y)
248 }
249
250 fn compute_class_means_simd(
252 &self,
253 x: ArrayView2<f64>,
254 y: ArrayView1<i32>,
255 classes: &Array1<i32>,
256 ) -> Result<Array2<f64>> {
257 let (_n_samples_, n_features) = x.dim();
258 let n_classes = classes.len();
259 let mut class_means = Array2::zeros((n_classes, n_features));
260
261 for (class_idx, &class_label) in classes.iter().enumerate() {
262 let class_indices: Vec<_> = y
263 .iter()
264 .enumerate()
265 .filter(|(_, &label)| label == class_label)
266 .map(|(idx, _)| idx)
267 .collect();
268
269 if class_indices.is_empty() {
270 continue;
271 }
272
273 let classsize = class_indices.len();
274
275 if n_features >= self.config.simd_threshold {
277 let mut sum = Array1::zeros(n_features);
278
279 for &idx in &class_indices {
280 let row = x.row(idx);
281 if n_features > 16 {
282 sum = f64::simd_add(&sum.view(), &row);
283 } else {
284 sum += &row;
285 }
286 }
287
288 class_means
289 .row_mut(class_idx)
290 .assign(&(sum / classsize as f64));
291 } else {
292 let mut sum = Array1::zeros(n_features);
294 for &idx in &class_indices {
295 sum += &x.row(idx);
296 }
297 class_means
298 .row_mut(class_idx)
299 .assign(&(sum / classsize as f64));
300 }
301 }
302
303 Ok(class_means)
304 }
305
306 fn compute_class_means_parallel(
308 &self,
309 x: ArrayView2<f64>,
310 y: ArrayView1<i32>,
311 classes: &Array1<i32>,
312 ) -> Result<Array2<f64>> {
313 let (_n_samples_, n_features) = x.dim();
314 let n_classes = classes.len();
315
316 let class_means: Vec<Array1<f64>> = classes
318 .iter()
319 .map(|&class_label| {
320 let class_indices: Vec<_> = y
321 .iter()
322 .enumerate()
323 .filter(|(_, &label)| label == class_label)
324 .map(|(idx, _)| idx)
325 .collect();
326
327 if class_indices.is_empty() {
328 return Array1::zeros(n_features);
329 }
330
331 let mut sum = Array1::zeros(n_features);
332 for &idx in &class_indices {
333 sum += &x.row(idx);
334 }
335 sum / class_indices.len() as f64
336 })
337 .collect();
338
339 let mut result = Array2::zeros((n_classes, n_features));
341 for (i, mean) in class_means.into_iter().enumerate() {
342 result.row_mut(i).assign(&mean);
343 }
344
345 Ok(result)
346 }
347
348 fn compute_scatter_matrices_simd(
350 &self,
351 x: ArrayView2<f64>,
352 y: ArrayView1<i32>,
353 classes: &Array1<i32>,
354 class_means: &Array2<f64>,
355 ) -> Result<(Array2<f64>, Array2<f64>)> {
356 let (_n_samples_, n_features) = x.dim();
357 let overall_mean = x.mean_axis(Axis(0)).expect("Operation failed");
358
359 let mut sw = Array2::zeros((n_features, n_features));
360 let mut sb = Array2::zeros((n_features, n_features));
361
362 for (class_idx, &class_label) in classes.iter().enumerate() {
364 let class_mean = class_means.row(class_idx);
365
366 for (sample_idx, &sample_label) in y.iter().enumerate() {
367 if sample_label == class_label {
368 let sample = x.row(sample_idx);
369
370 let diff = if n_features >= self.config.simd_threshold {
372 f64::simd_sub(&sample, &class_mean)
373 } else {
374 &sample - &class_mean
375 };
376
377 for i in 0..n_features {
379 for j in 0..n_features {
380 sw[[i, j]] += diff[i] * diff[j];
381 }
382 }
383 }
384 }
385 }
386
387 for (class_idx, &class_label) in classes.iter().enumerate() {
389 let class_mean = class_means.row(class_idx);
390 let class_count = y.iter().filter(|&&label| label == class_label).count() as f64;
391
392 let diff = if n_features >= self.config.simd_threshold {
393 f64::simd_sub(&class_mean, &overall_mean.view())
394 } else {
395 &class_mean - &overall_mean
396 };
397
398 for i in 0..n_features {
399 for j in 0..n_features {
400 sb[[i, j]] += class_count * diff[i] * diff[j];
401 }
402 }
403 }
404
405 Ok((sw, sb))
406 }
407
408 fn compute_scatter_matrices_parallel(
410 &self,
411 x: ArrayView2<f64>,
412 y: ArrayView1<i32>,
413 classes: &Array1<i32>,
414 class_means: &Array2<f64>,
415 ) -> Result<(Array2<f64>, Array2<f64>)> {
416 let (_n_samples_, n_features) = x.dim();
417 let overall_mean = x.mean_axis(Axis(0)).expect("Operation failed");
418
419 let sw_contributions: Vec<Array2<f64>> = (0..classes.len())
421 .map(|class_idx| {
422 let class_label = classes[class_idx];
423 let mut sw_contrib = Array2::zeros((n_features, n_features));
424 let class_mean = class_means.row(class_idx);
425
426 for (sample_idx, &sample_label) in y.iter().enumerate() {
427 if sample_label == class_label {
428 let sample = x.row(sample_idx);
429 let diff = &sample - &class_mean;
430
431 for i in 0..n_features {
432 for j in 0..n_features {
433 sw_contrib[[i, j]] += diff[i] * diff[j];
434 }
435 }
436 }
437 }
438 sw_contrib
439 })
440 .collect();
441
442 let mut sw = Array2::zeros((n_features, n_features));
444 for contrib in sw_contributions {
445 sw += &contrib;
446 }
447
448 let mut sb = Array2::zeros((n_features, n_features));
450 for (class_idx, &class_label) in classes.iter().enumerate() {
451 let class_mean = class_means.row(class_idx);
452 let class_count = y.iter().filter(|&&label| label == class_label).count() as f64;
453 let diff = &class_mean - &overall_mean;
454
455 for i in 0..n_features {
456 for j in 0..n_features {
457 sb[[i, j]] += class_count * diff[i] * diff[j];
458 }
459 }
460 }
461
462 Ok((sw, sb))
463 }
464
465 pub fn get_metrics(&self) -> Option<&PerformanceMetrics> {
467 self.metrics.as_ref()
468 }
469
470 pub fn transform(&self, x: ArrayView2<f64>, result: &LDAResult) -> Result<Array2<f64>> {
472 let datasize = x.nrows() * x.ncols();
473
474 if self.config.enable_simd && datasize >= self.config.simd_threshold {
475 self.transform_simd(x, result)
476 } else {
477 self.lda.transform(x, result)
478 }
479 }
480
481 fn transform_simd(&self, x: ArrayView2<f64>, result: &LDAResult) -> Result<Array2<f64>> {
483 let (n_samples_, n_features) = x.dim();
484 let n_components = result.scalings.ncols();
485
486 if n_features >= self.config.simd_threshold {
487 let mut transformed = Array2::zeros((n_samples_, n_components));
489
490 for i in 0..n_samples_ {
491 let row = x.row(i);
492 for j in 0..n_components {
493 let column = result.scalings.column(j);
494 transformed[[i, j]] = f64::simd_dot(&row, &column.view());
495 }
496 }
497
498 Ok(transformed)
499 } else {
500 self.lda.transform(x, result)
501 }
502 }
503}
504
505#[derive(Debug, Clone)]
507pub struct OptimizedCanonicalCorrelationAnalysis {
508 cca: CanonicalCorrelationAnalysis,
509 config: PerformanceConfig,
510 metrics: Option<PerformanceMetrics>,
511}
512
513impl OptimizedCanonicalCorrelationAnalysis {
514 pub fn new(config: PerformanceConfig) -> Self {
516 Self {
517 cca: CanonicalCorrelationAnalysis::new(),
518 config,
519 metrics: None,
520 }
521 }
522
523 pub fn fit(&mut self, x: ArrayView2<f64>, y: ArrayView2<f64>) -> Result<CCAResult>
525 where
526 f64: std::fmt::Display,
527 {
528 let start_time = if self.config.benchmark {
529 Some(Instant::now())
530 } else {
531 None
532 };
533 let _handler = global_error_handler();
534 validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "Optimized CCA fit");
535 validate_or_error!(finite: y.as_slice().expect("Operation failed"), "y", "Optimized CCA fit");
536
537 let datasize = x.nrows() * (x.ncols() + y.ncols());
538 let use_parallel =
539 self.config.enable_parallel && x.nrows() >= self.config.parallel_threshold;
540
541 let result = if use_parallel {
542 self.fit_parallel(x, y)?
543 } else {
544 self.cca.fit(x, y)?
545 };
546
547 if let Some(start) = start_time {
549 let execution_time = start.elapsed().as_secs_f64() * 1000.0;
550 self.metrics = Some(PerformanceMetrics {
551 execution_time_ms: execution_time,
552 memory_usage: Some(datasize * 8),
553 operations_count: datasize,
554 ops_per_second: datasize as f64 / (execution_time / 1000.0),
555 used_simd: false, used_parallel: use_parallel,
557 threads_used: if use_parallel { num_cpus::get() } else { 1 },
558 });
559 }
560
561 Ok(result)
562 }
563
564 fn fit_parallel(&self, x: ArrayView2<f64>, y: ArrayView2<f64>) -> Result<CCAResult> {
566 let (x_processed, y_processed) = self.center_and_scale_parallel(x, y)?;
568
569 let _cxx_cyy_cxy = self.compute_covariances_parallel(&x_processed, &y_processed)?;
571
572 self.cca.fit(x, y)
574 }
575
576 fn center_and_scale_parallel(
578 &self,
579 x: ArrayView2<f64>,
580 y: ArrayView2<f64>,
581 ) -> Result<(Array2<f64>, Array2<f64>)> {
582 let x_mean = x
584 .axis_iter(Axis(1))
585 .map(|col| col.mean())
586 .collect::<Vec<_>>();
587
588 let y_mean = y
589 .axis_iter(Axis(1))
590 .map(|col| col.mean())
591 .collect::<Vec<_>>();
592
593 let mut x_centered = x.to_owned();
595 let mut y_centered = y.to_owned();
596
597 x_centered.axis_iter_mut(Axis(0)).for_each(|mut row| {
598 for (i, &mean) in x_mean.iter().enumerate() {
599 row[i] -= mean;
600 }
601 });
602
603 y_centered.axis_iter_mut(Axis(0)).for_each(|mut row| {
604 for (i, &mean) in y_mean.iter().enumerate() {
605 row[i] -= mean;
606 }
607 });
608
609 Ok((x_centered, y_centered))
610 }
611
612 fn compute_covariances_parallel(
614 &self,
615 x: &Array2<f64>,
616 y: &Array2<f64>,
617 ) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>)> {
618 let n_samples_ = x.nrows() as f64;
619
620 let cxx = self.parallel_covariance_matrix(x, x);
622 let cyy = self.parallel_covariance_matrix(y, y);
623 let cxy = self.parallel_covariance_matrix(x, y);
624
625 Ok((
626 cxx / (n_samples_ - 1.0),
627 cyy / (n_samples_ - 1.0),
628 cxy / (n_samples_ - 1.0),
629 ))
630 }
631
632 fn parallel_covariance_matrix(&self, a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
634 let (_n_samples_, n_features_a) = a.dim();
635 let n_features_b = b.ncols();
636
637 let cov = Array2::from_shape_fn((n_features_a, n_features_b), |(i, j)| {
638 a.column(i).dot(&b.column(j))
639 });
640
641 cov
642 }
643
644 pub fn get_metrics(&self) -> Option<&PerformanceMetrics> {
646 self.metrics.as_ref()
647 }
648}
649
650pub struct PerformanceBenchmark;
652
653impl PerformanceBenchmark {
654 pub fn benchmark_lda(
656 datasizes: &[(usize, usize)], n_classes: usize,
658 ) -> Result<Vec<(String, PerformanceMetrics)>> {
659 let mut results = Vec::new();
660
661 for &(n_samples_, n_features) in datasizes {
662 let (x, y) =
664 Self::generate_synthetic_classificationdata(n_samples_, n_features, n_classes)?;
665
666 let configs = vec![
668 (
669 "baseline",
670 PerformanceConfig {
671 enable_simd: false,
672 enable_parallel: false,
673 benchmark: true,
674 ..Default::default()
675 },
676 ),
677 (
678 "simd",
679 PerformanceConfig {
680 enable_simd: true,
681 enable_parallel: false,
682 benchmark: true,
683 ..Default::default()
684 },
685 ),
686 (
687 "parallel",
688 PerformanceConfig {
689 enable_simd: false,
690 enable_parallel: true,
691 benchmark: true,
692 ..Default::default()
693 },
694 ),
695 (
696 "simd+parallel",
697 PerformanceConfig {
698 enable_simd: true,
699 enable_parallel: true,
700 benchmark: true,
701 ..Default::default()
702 },
703 ),
704 ];
705
706 for (name, config) in configs {
707 let mut opt_lda = OptimizedLinearDiscriminantAnalysis::new(config);
708 let _result = opt_lda.fit(x.view(), y.view())?;
709
710 if let Some(metrics) = opt_lda.get_metrics() {
711 results.push((
712 format!("{}_{}x{}", name, n_samples_, n_features),
713 metrics.clone(),
714 ));
715 }
716 }
717 }
718
719 Ok(results)
720 }
721
722 fn generate_synthetic_classificationdata(
724 n_samples_: usize,
725 n_features: usize,
726 n_classes: usize,
727 ) -> Result<(Array2<f64>, Array1<i32>)> {
728 use scirs2_core::random::{Distribution, Normal};
729
730 let mut rng = thread_rng();
731 let normal = Normal::new(0.0, 1.0).expect("Operation failed");
732
733 let mut x = Array2::zeros((n_samples_, n_features));
734 let mut y = Array1::zeros(n_samples_);
735
736 let samples_per_class = n_samples_ / n_classes;
737
738 for class in 0..n_classes {
739 let start_idx = class * samples_per_class;
740 let end_idx = if class == n_classes - 1 {
741 n_samples_
742 } else {
743 (class + 1) * samples_per_class
744 };
745
746 for i in start_idx..end_idx {
747 y[i] = class as i32;
748
749 for j in 0..n_features {
750 let offset = (class as f64) * 2.0;
752 x[[i, j]] = normal.sample(&mut rng) + offset;
753 }
754 }
755 }
756
757 Ok((x, y))
758 }
759
760 pub fn print_benchmark_results(results: &[(String, PerformanceMetrics)]) {
762 println!("\n=== PERFORMANCE BENCHMARK RESULTS ===");
763 println!(
764 "{:<20} {:>12} {:>10} {:>15} {:>8} {:>8}",
765 "Configuration", "Time (ms)", "Ops/sec", "Memory (KB)", "SIMD", "Parallel"
766 );
767 println!("{}", "-".repeat(80));
768
769 for (name, metrics) in results {
770 println!(
771 "{:<20} {:>12.2} {:>10.0} {:>15} {:>8} {:>8}",
772 name,
773 metrics.execution_time_ms,
774 metrics.ops_per_second,
775 metrics
776 .memory_usage
777 .map_or("N/A".to_string(), |m| format!("{}", m / 1024)),
778 if metrics.used_simd { "✓" } else { "✗" },
779 if metrics.used_parallel { "✓" } else { "✗" }
780 );
781 }
782 }
783}
784
785#[cfg(test)]
786mod tests {
787 use super::*;
788 use scirs2_core::ndarray::array;
789
790 #[test]
791 fn test_optimized_lda() {
792 let x = array![
794 [1.0, 2.5],
795 [2.1, 3.2],
796 [2.8, 4.1],
797 [6.2, 7.1],
798 [7.3, 8.5],
799 [8.1, 9.3],
800 ];
801 let y = array![0, 0, 0, 1, 1, 1];
802
803 let config = PerformanceConfig::default();
804 let mut opt_lda = OptimizedLinearDiscriminantAnalysis::new(config);
805 let result = opt_lda.fit(x.view(), y.view()).expect("Operation failed");
806
807 assert_eq!(result.classes.len(), 2);
808 assert_eq!(result.scalings.nrows(), 2);
809 }
810
811 #[test]
812 fn test_optimized_cca() {
813 let x = array![[1.2, 2.8], [2.1, 3.5], [3.2, 4.1], [4.3, 5.2], [5.1, 6.4],];
815
816 let y = array![
817 [2.1, 4.3],
818 [4.2, 6.1],
819 [6.3, 8.2],
820 [8.1, 10.4],
821 [10.2, 12.3],
822 ];
823
824 let config = PerformanceConfig::default();
825 let mut opt_cca = OptimizedCanonicalCorrelationAnalysis::new(config);
826 let result = opt_cca.fit(x.view(), y.view()).expect("Operation failed");
827
828 assert!(result.correlations.len() > 0);
829 assert_eq!(result.x_weights.nrows(), 2);
830 assert_eq!(result.y_weights.nrows(), 2);
831 }
832}