1use scirs2_core::ndarray::{Array2, ArrayView1, ArrayView2};
8use sklears_core::error::{Result as SklResult, SklearsError};
9type Result<T> = SklResult<T>;
10
11impl From<RedundancyError> for SklearsError {
12 fn from(err: RedundancyError) -> Self {
13 SklearsError::FitError(format!("Redundancy analysis error: {}", err))
14 }
15}
16use std::collections::HashMap;
17use thiserror::Error;
18
19#[derive(Debug, Error)]
20pub enum RedundancyError {
21 #[error("Feature matrix is empty")]
22 EmptyFeatureMatrix,
23 #[error("Feature indices must be valid")]
24 InvalidFeatureIndices,
25 #[error("Insufficient variance for redundancy computation")]
26 InsufficientVariance,
27 #[error("Correlation matrix computation failed")]
28 CorrelationComputationFailed,
29}
30
31#[derive(Debug, Clone)]
33pub struct CorrelationRedundancy {
34 correlation_threshold: f64,
35 absolute_correlation: bool,
36}
37
38impl CorrelationRedundancy {
39 pub fn new(correlation_threshold: f64, absolute_correlation: bool) -> Self {
41 Self {
42 correlation_threshold,
43 absolute_correlation,
44 }
45 }
46
47 pub fn compute(&self, X: ArrayView2<f64>, feature_indices: &[usize]) -> Result<f64> {
49 if feature_indices.len() < 2 {
50 return Ok(0.0); }
52
53 if X.is_empty() {
54 return Err(RedundancyError::EmptyFeatureMatrix.into());
55 }
56
57 let selected_features = self.extract_features(X, feature_indices)?;
59
60 let correlation_matrix = self.compute_correlation_matrix(&selected_features)?;
62
63 self.calculate_redundancy_score(&correlation_matrix)
65 }
66
67 fn extract_features(
69 &self,
70 X: ArrayView2<f64>,
71 feature_indices: &[usize],
72 ) -> Result<Array2<f64>> {
73 let n_samples = X.nrows();
74 let mut selected = Array2::zeros((n_samples, feature_indices.len()));
75
76 for (col_idx, &feature_idx) in feature_indices.iter().enumerate() {
77 if feature_idx >= X.ncols() {
78 return Err(RedundancyError::InvalidFeatureIndices.into());
79 }
80 selected.column_mut(col_idx).assign(&X.column(feature_idx));
81 }
82
83 Ok(selected)
84 }
85
86 fn compute_correlation_matrix(&self, features: &Array2<f64>) -> Result<Array2<f64>> {
88 let n_features = features.ncols();
89 let mut correlation_matrix = Array2::zeros((n_features, n_features));
90
91 for i in 0..n_features {
92 for j in 0..n_features {
93 if i == j {
94 correlation_matrix[[i, j]] = 1.0;
95 } else {
96 let feature1 = features.column(i);
97 let feature2 = features.column(j);
98 let correlation = self.compute_correlation(feature1, feature2)?;
99 correlation_matrix[[i, j]] = if self.absolute_correlation {
100 correlation.abs()
101 } else {
102 correlation
103 };
104 }
105 }
106 }
107
108 Ok(correlation_matrix)
109 }
110
111 fn compute_correlation(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> Result<f64> {
113 let n = x.len() as f64;
114 if n < 2.0 {
115 return Ok(0.0);
116 }
117
118 let mean_x = x.mean().unwrap_or(0.0);
119 let mean_y = y.mean().unwrap_or(0.0);
120
121 let mut sum_xy = 0.0;
122 let mut sum_x2 = 0.0;
123 let mut sum_y2 = 0.0;
124
125 for i in 0..x.len() {
126 let dx = x[i] - mean_x;
127 let dy = y[i] - mean_y;
128 sum_xy += dx * dy;
129 sum_x2 += dx * dx;
130 sum_y2 += dy * dy;
131 }
132
133 let denom = (sum_x2 * sum_y2).sqrt();
134 if denom < 1e-10 {
135 return Ok(0.0);
136 }
137
138 Ok(sum_xy / denom)
139 }
140
141 fn calculate_redundancy_score(&self, correlation_matrix: &Array2<f64>) -> Result<f64> {
143 let n_features = correlation_matrix.nrows();
144 if n_features < 2 {
145 return Ok(0.0);
146 }
147
148 let mut total_redundancy = 0.0;
149 let mut _pair_count = 0;
150
151 for i in 0..n_features {
153 for j in (i + 1)..n_features {
154 let correlation = correlation_matrix[[i, j]];
155 if correlation.abs() >= self.correlation_threshold {
156 total_redundancy += correlation.abs();
157 _pair_count += 1;
158 }
159 }
160 }
161
162 let total_pairs = (n_features * (n_features - 1)) / 2;
163 Ok(total_redundancy / total_pairs as f64)
164 }
165
166 pub fn identify_redundant_pairs(
168 &self,
169 X: ArrayView2<f64>,
170 feature_indices: &[usize],
171 ) -> Result<Vec<(usize, usize, f64)>> {
172 let selected_features = self.extract_features(X, feature_indices)?;
173 let correlation_matrix = self.compute_correlation_matrix(&selected_features)?;
174
175 let mut redundant_pairs = Vec::new();
176
177 for i in 0..feature_indices.len() {
178 for j in (i + 1)..feature_indices.len() {
179 let correlation = correlation_matrix[[i, j]];
180 if correlation.abs() >= self.correlation_threshold {
181 redundant_pairs.push((feature_indices[i], feature_indices[j], correlation));
182 }
183 }
184 }
185
186 redundant_pairs.sort_by(|a, b| b.2.abs().partial_cmp(&a.2.abs()).unwrap());
188
189 Ok(redundant_pairs)
190 }
191}
192
193#[derive(Debug, Clone)]
195pub struct MutualInformationRedundancy {
196 mi_threshold: f64,
197 n_bins: usize,
198}
199
200impl MutualInformationRedundancy {
201 pub fn new(mi_threshold: f64, n_bins: usize) -> Self {
203 Self {
204 mi_threshold,
205 n_bins,
206 }
207 }
208
209 pub fn compute(&self, X: ArrayView2<f64>, feature_indices: &[usize]) -> Result<f64> {
211 if feature_indices.len() < 2 {
212 return Ok(0.0);
213 }
214
215 let selected_features = self.extract_features(X, feature_indices)?;
216 let mi_matrix = self.compute_mi_matrix(&selected_features)?;
217
218 self.calculate_mi_redundancy_score(&mi_matrix)
219 }
220
221 fn extract_features(
222 &self,
223 X: ArrayView2<f64>,
224 feature_indices: &[usize],
225 ) -> Result<Array2<f64>> {
226 let n_samples = X.nrows();
227 let mut selected = Array2::zeros((n_samples, feature_indices.len()));
228
229 for (col_idx, &feature_idx) in feature_indices.iter().enumerate() {
230 if feature_idx >= X.ncols() {
231 return Err(RedundancyError::InvalidFeatureIndices.into());
232 }
233 selected.column_mut(col_idx).assign(&X.column(feature_idx));
234 }
235
236 Ok(selected)
237 }
238
239 fn compute_mi_matrix(&self, features: &Array2<f64>) -> Result<Array2<f64>> {
241 let n_features = features.ncols();
242 let mut mi_matrix = Array2::zeros((n_features, n_features));
243
244 for i in 0..n_features {
245 for j in 0..n_features {
246 if i == j {
247 mi_matrix[[i, j]] = self.compute_entropy(features.column(i))?;
248 } else {
249 let mi =
250 self.compute_mutual_information(features.column(i), features.column(j))?;
251 mi_matrix[[i, j]] = mi;
252 }
253 }
254 }
255
256 Ok(mi_matrix)
257 }
258
259 fn compute_entropy(&self, feature: ArrayView1<f64>) -> Result<f64> {
261 let histogram = self.create_histogram(feature);
262 let total_count = feature.len() as f64;
263
264 let mut entropy = 0.0;
265 for &count in histogram.values() {
266 if count > 0 {
267 let probability = count as f64 / total_count;
268 entropy -= probability * probability.ln();
269 }
270 }
271
272 Ok(entropy)
273 }
274
275 fn compute_mutual_information(
277 &self,
278 feature1: ArrayView1<f64>,
279 feature2: ArrayView1<f64>,
280 ) -> Result<f64> {
281 let hist1 = self.create_histogram(feature1);
282 let hist2 = self.create_histogram(feature2);
283 let joint_hist = self.create_joint_histogram(feature1, feature2);
284
285 let n = feature1.len() as f64;
286 let mut mi = 0.0;
287
288 for (&(val1, val2), &joint_count) in joint_hist.iter() {
289 if joint_count > 0 {
290 let p_xy = joint_count as f64 / n;
291 let p_x = *hist1.get(&val1).unwrap_or(&0) as f64 / n;
292 let p_y = *hist2.get(&val2).unwrap_or(&0) as f64 / n;
293
294 if p_x > 0.0 && p_y > 0.0 {
295 mi += p_xy * (p_xy / (p_x * p_y)).ln();
296 }
297 }
298 }
299
300 Ok(mi)
301 }
302
303 fn create_histogram(&self, feature: ArrayView1<f64>) -> HashMap<i32, usize> {
305 let min_val = feature.iter().fold(f64::INFINITY, |acc, &x| acc.min(x));
306 let max_val = feature.iter().fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
307 let bin_width = (max_val - min_val) / self.n_bins as f64;
308
309 let mut histogram = HashMap::new();
310
311 for &value in feature.iter() {
312 let bin = if bin_width > 0.0 {
313 ((value - min_val) / bin_width).floor() as i32
314 } else {
315 0
316 };
317 let bin = bin.min((self.n_bins - 1) as i32).max(0);
318 *histogram.entry(bin).or_insert(0) += 1;
319 }
320
321 histogram
322 }
323
324 fn create_joint_histogram(
326 &self,
327 feature1: ArrayView1<f64>,
328 feature2: ArrayView1<f64>,
329 ) -> HashMap<(i32, i32), usize> {
330 let min1 = feature1.iter().fold(f64::INFINITY, |acc, &x| acc.min(x));
331 let max1 = feature1
332 .iter()
333 .fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
334 let min2 = feature2.iter().fold(f64::INFINITY, |acc, &x| acc.min(x));
335 let max2 = feature2
336 .iter()
337 .fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
338
339 let bin_width1 = (max1 - min1) / self.n_bins as f64;
340 let bin_width2 = (max2 - min2) / self.n_bins as f64;
341
342 let mut joint_histogram = HashMap::new();
343
344 for i in 0..feature1.len() {
345 let bin1 = if bin_width1 > 0.0 {
346 ((feature1[i] - min1) / bin_width1).floor() as i32
347 } else {
348 0
349 };
350 let bin2 = if bin_width2 > 0.0 {
351 ((feature2[i] - min2) / bin_width2).floor() as i32
352 } else {
353 0
354 };
355
356 let bin1 = bin1.min((self.n_bins - 1) as i32).max(0);
357 let bin2 = bin2.min((self.n_bins - 1) as i32).max(0);
358
359 *joint_histogram.entry((bin1, bin2)).or_insert(0) += 1;
360 }
361
362 joint_histogram
363 }
364
365 fn calculate_mi_redundancy_score(&self, mi_matrix: &Array2<f64>) -> Result<f64> {
366 let n_features = mi_matrix.nrows();
367 if n_features < 2 {
368 return Ok(0.0);
369 }
370
371 let mut total_redundancy = 0.0;
372 let mut _pair_count = 0;
373
374 for i in 0..n_features {
375 for j in (i + 1)..n_features {
376 let mi = mi_matrix[[i, j]];
377 if mi >= self.mi_threshold {
378 total_redundancy += mi;
379 _pair_count += 1;
380 }
381 }
382 }
383
384 let total_pairs = (n_features * (n_features - 1)) / 2;
385 Ok(total_redundancy / total_pairs as f64)
386 }
387}
388
389#[derive(Debug, Clone)]
391pub struct VarianceInflationFactor {
392 vif_threshold: f64,
393}
394
395impl VarianceInflationFactor {
396 pub fn new(vif_threshold: f64) -> Self {
398 Self { vif_threshold }
399 }
400
401 pub fn compute_all(&self, X: ArrayView2<f64>, feature_indices: &[usize]) -> Result<Vec<f64>> {
403 if feature_indices.len() < 2 {
404 return Ok(vec![1.0; feature_indices.len()]);
405 }
406
407 let selected_features = self.extract_features(X, feature_indices)?;
408 let mut vif_scores = Vec::with_capacity(feature_indices.len());
409
410 for i in 0..feature_indices.len() {
411 let vif = self.compute_single_vif(&selected_features, i)?;
412 vif_scores.push(vif);
413 }
414
415 Ok(vif_scores)
416 }
417
418 fn extract_features(
419 &self,
420 X: ArrayView2<f64>,
421 feature_indices: &[usize],
422 ) -> Result<Array2<f64>> {
423 let n_samples = X.nrows();
424 let mut selected = Array2::zeros((n_samples, feature_indices.len()));
425
426 for (col_idx, &feature_idx) in feature_indices.iter().enumerate() {
427 if feature_idx >= X.ncols() {
428 return Err(RedundancyError::InvalidFeatureIndices.into());
429 }
430 selected.column_mut(col_idx).assign(&X.column(feature_idx));
431 }
432
433 Ok(selected)
434 }
435
436 fn compute_single_vif(&self, features: &Array2<f64>, target_feature_idx: usize) -> Result<f64> {
438 if target_feature_idx >= features.ncols() {
439 return Err(RedundancyError::InvalidFeatureIndices.into());
440 }
441
442 let n_features = features.ncols();
443 if n_features < 2 {
444 return Ok(1.0);
445 }
446
447 let target_feature = features.column(target_feature_idx);
449
450 let mut other_features = Array2::zeros((features.nrows(), n_features - 1));
452 let mut col_idx = 0;
453
454 for i in 0..n_features {
455 if i != target_feature_idx {
456 other_features
457 .column_mut(col_idx)
458 .assign(&features.column(i));
459 col_idx += 1;
460 }
461 }
462
463 let r_squared = self.compute_r_squared(target_feature, other_features.view())?;
465
466 if (1.0 - r_squared).abs() < 1e-10 {
468 return Ok(f64::INFINITY);
469 }
470
471 Ok(1.0 / (1.0 - r_squared))
472 }
473
474 fn compute_r_squared(&self, target: ArrayView1<f64>, features: ArrayView2<f64>) -> Result<f64> {
476 if features.ncols() == 0 {
477 return Ok(0.0);
478 }
479
480 let mut max_r_squared: f64 = 0.0;
482
483 for i in 0..features.ncols() {
484 let feature = features.column(i);
485 let correlation = self.compute_correlation(target, feature)?;
486 max_r_squared = max_r_squared.max(correlation * correlation);
487 }
488
489 Ok(max_r_squared)
490 }
491
492 fn compute_correlation(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> Result<f64> {
493 let n = x.len() as f64;
494 if n < 2.0 {
495 return Ok(0.0);
496 }
497
498 let mean_x = x.mean().unwrap_or(0.0);
499 let mean_y = y.mean().unwrap_or(0.0);
500
501 let mut sum_xy = 0.0;
502 let mut sum_x2 = 0.0;
503 let mut sum_y2 = 0.0;
504
505 for i in 0..x.len() {
506 let dx = x[i] - mean_x;
507 let dy = y[i] - mean_y;
508 sum_xy += dx * dy;
509 sum_x2 += dx * dx;
510 sum_y2 += dy * dy;
511 }
512
513 let denom = (sum_x2 * sum_y2).sqrt();
514 if denom < 1e-10 {
515 return Ok(0.0);
516 }
517
518 Ok(sum_xy / denom)
519 }
520
521 pub fn identify_high_vif_features(
523 &self,
524 X: ArrayView2<f64>,
525 feature_indices: &[usize],
526 ) -> Result<Vec<(usize, f64)>> {
527 let vif_scores = self.compute_all(X, feature_indices)?;
528 let mut high_vif_features = Vec::new();
529
530 for (i, &vif) in vif_scores.iter().enumerate() {
531 if vif >= self.vif_threshold {
532 high_vif_features.push((feature_indices[i], vif));
533 }
534 }
535
536 high_vif_features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
538
539 Ok(high_vif_features)
540 }
541}
542
543#[derive(Debug, Clone)]
545pub struct RedundancyMatrix {
546 correlation_threshold: f64,
547 mi_threshold: f64,
548 vif_threshold: f64,
549 n_bins: usize,
550}
551
552impl RedundancyMatrix {
553 pub fn new(
555 correlation_threshold: f64,
556 mi_threshold: f64,
557 vif_threshold: f64,
558 n_bins: usize,
559 ) -> Self {
560 Self {
561 correlation_threshold,
562 mi_threshold,
563 vif_threshold,
564 n_bins,
565 }
566 }
567
568 pub fn compute(
570 &self,
571 X: ArrayView2<f64>,
572 feature_indices: &[usize],
573 ) -> Result<RedundancyAssessment> {
574 let correlation_redundancy = CorrelationRedundancy::new(self.correlation_threshold, true);
575 let mi_redundancy = MutualInformationRedundancy::new(self.mi_threshold, self.n_bins);
576 let vif_analyzer = VarianceInflationFactor::new(self.vif_threshold);
577
578 let correlation_score = correlation_redundancy.compute(X, feature_indices)?;
579 let redundant_pairs =
580 correlation_redundancy.identify_redundant_pairs(X, feature_indices)?;
581
582 let mi_score = mi_redundancy.compute(X, feature_indices)?;
583
584 let vif_scores = vif_analyzer.compute_all(X, feature_indices)?;
585 let high_vif_features = vif_analyzer.identify_high_vif_features(X, feature_indices)?;
586
587 Ok(RedundancyAssessment {
588 correlation_redundancy_score: correlation_score,
589 mutual_information_redundancy_score: mi_score,
590 average_vif: vif_scores.iter().sum::<f64>() / vif_scores.len() as f64,
591 max_vif: vif_scores.iter().fold(0.0, |acc, &x| acc.max(x)),
592 redundant_correlation_pairs: redundant_pairs,
593 high_vif_features,
594 vif_scores,
595 n_features: feature_indices.len(),
596 })
597 }
598}
599
600#[derive(Debug, Clone)]
602pub struct RedundancyAssessment {
603 pub correlation_redundancy_score: f64,
604 pub mutual_information_redundancy_score: f64,
605 pub average_vif: f64,
606 pub max_vif: f64,
607 pub redundant_correlation_pairs: Vec<(usize, usize, f64)>,
608 pub high_vif_features: Vec<(usize, f64)>,
609 pub vif_scores: Vec<f64>,
610 pub n_features: usize,
611}
612
613impl RedundancyAssessment {
614 pub fn report(&self) -> String {
616 let mut report = String::new();
617
618 report.push_str("=== Feature Set Redundancy Assessment ===\n\n");
619
620 report.push_str(&format!(
621 "Number of features analyzed: {}\n\n",
622 self.n_features
623 ));
624
625 report.push_str(&format!(
626 "Correlation Redundancy Score: {:.4}\n",
627 self.correlation_redundancy_score
628 ));
629 report.push_str(&format!(
630 " Interpretation: {}\n",
631 self.interpret_correlation_redundancy()
632 ));
633
634 report.push_str(&format!(
635 "\nMutual Information Redundancy Score: {:.4}\n",
636 self.mutual_information_redundancy_score
637 ));
638 report.push_str(&format!(
639 " Interpretation: {}\n",
640 self.interpret_mi_redundancy()
641 ));
642
643 report.push_str("\nVariance Inflation Factor Analysis:\n");
644 report.push_str(&format!(" Average VIF: {:.4}\n", self.average_vif));
645 report.push_str(&format!(" Maximum VIF: {:.4}\n", self.max_vif));
646 report.push_str(&format!(" Interpretation: {}\n", self.interpret_vif()));
647
648 if !self.redundant_correlation_pairs.is_empty() {
649 report.push_str(&format!(
650 "\nHighly Correlated Feature Pairs ({}):\n",
651 self.redundant_correlation_pairs.len()
652 ));
653 for (i, &(feat1, feat2, corr)) in
654 self.redundant_correlation_pairs.iter().take(10).enumerate()
655 {
656 report.push_str(&format!(
657 " {}. Features {} and {}: correlation = {:.4}\n",
658 i + 1,
659 feat1,
660 feat2,
661 corr
662 ));
663 }
664 if self.redundant_correlation_pairs.len() > 10 {
665 report.push_str(&format!(
666 " ... and {} more pairs\n",
667 self.redundant_correlation_pairs.len() - 10
668 ));
669 }
670 }
671
672 if !self.high_vif_features.is_empty() {
673 report.push_str(&format!(
674 "\nHigh VIF Features ({}):\n",
675 self.high_vif_features.len()
676 ));
677 for (i, &(feat, vif)) in self.high_vif_features.iter().take(10).enumerate() {
678 report.push_str(&format!(
679 " {}. Feature {}: VIF = {:.4}\n",
680 i + 1,
681 feat,
682 vif
683 ));
684 }
685 if self.high_vif_features.len() > 10 {
686 report.push_str(&format!(
687 " ... and {} more features\n",
688 self.high_vif_features.len() - 10
689 ));
690 }
691 }
692
693 report.push_str(&format!(
694 "\nOverall Redundancy Assessment: {}\n",
695 self.overall_assessment()
696 ));
697
698 report
699 }
700
701 fn interpret_correlation_redundancy(&self) -> &'static str {
702 match self.correlation_redundancy_score {
703 x if x >= 0.8 => "Very high correlation redundancy - many highly correlated features",
704 x if x >= 0.6 => "High correlation redundancy - significant feature overlap",
705 x if x >= 0.4 => "Moderate correlation redundancy - some correlated features",
706 x if x >= 0.2 => "Low correlation redundancy - minimal feature overlap",
707 _ => "Very low correlation redundancy - features are largely independent",
708 }
709 }
710
711 fn interpret_mi_redundancy(&self) -> &'static str {
712 match self.mutual_information_redundancy_score {
713 x if x >= 0.8 => "Very high information redundancy - features share much information",
714 x if x >= 0.6 => "High information redundancy - significant information overlap",
715 x if x >= 0.4 => "Moderate information redundancy - some shared information",
716 x if x >= 0.2 => "Low information redundancy - minimal information overlap",
717 _ => "Very low information redundancy - features provide unique information",
718 }
719 }
720
721 fn interpret_vif(&self) -> &'static str {
722 match self.max_vif {
723 x if x >= 10.0 => "Severe multicollinearity - high VIF values detected",
724 x if x >= 5.0 => "Moderate multicollinearity - concerning VIF values",
725 x if x >= 2.5 => "Mild multicollinearity - some elevated VIF values",
726 _ => "No multicollinearity concerns - acceptable VIF values",
727 }
728 }
729
730 fn overall_assessment(&self) -> &'static str {
731 let redundancy_indicators = [
732 self.correlation_redundancy_score >= 0.6,
733 self.mutual_information_redundancy_score >= 0.6,
734 self.max_vif >= 5.0,
735 self.redundant_correlation_pairs.len() > self.n_features / 2,
736 ];
737
738 let high_redundancy_count = redundancy_indicators.iter().filter(|&&x| x).count();
739
740 match high_redundancy_count {
741 4 => "CRITICAL: Very high redundancy detected across all measures - major feature set cleanup needed",
742 3 => "HIGH: High redundancy detected - consider feature reduction strategies",
743 2 => "MODERATE: Some redundancy detected - review feature selection carefully",
744 1 => "LOW: Minimal redundancy - feature set appears reasonable",
745 _ => "EXCELLENT: Low redundancy across all measures - well-diversified feature set",
746 }
747 }
748}
749
750#[derive(Debug, Clone)]
752pub struct RedundancyMeasures;
753
754impl RedundancyMeasures {
755 pub fn compute(X: ArrayView2<f64>, feature_indices: &[usize]) -> Result<RedundancyAssessment> {
757 let redundancy_matrix = RedundancyMatrix::new(0.7, 0.1, 5.0, 10);
758 redundancy_matrix.compute(X, feature_indices)
759 }
760
761 pub fn compute_with_params(
763 X: ArrayView2<f64>,
764 feature_indices: &[usize],
765 correlation_threshold: f64,
766 mi_threshold: f64,
767 vif_threshold: f64,
768 n_bins: usize,
769 ) -> Result<RedundancyAssessment> {
770 let redundancy_matrix =
771 RedundancyMatrix::new(correlation_threshold, mi_threshold, vif_threshold, n_bins);
772 redundancy_matrix.compute(X, feature_indices)
773 }
774}
775
776#[allow(non_snake_case)]
777#[cfg(test)]
778mod tests {
779 use super::*;
780 use scirs2_core::ndarray::array;
781
782 #[test]
783 #[allow(non_snake_case)]
784 fn test_correlation_redundancy() {
785 let X = array![
786 [1.0, 2.0, 1.1, 5.0],
787 [2.0, 4.0, 2.1, 6.0],
788 [3.0, 6.0, 3.1, 7.0],
789 [4.0, 8.0, 4.1, 8.0],
790 ];
791
792 let feature_indices = vec![0, 1, 2]; let redundancy = CorrelationRedundancy::new(0.5, true);
794 let score = redundancy.compute(X.view(), &feature_indices).unwrap();
795
796 assert!(score > 0.0);
797
798 let pairs = redundancy
799 .identify_redundant_pairs(X.view(), &feature_indices)
800 .unwrap();
801 assert!(!pairs.is_empty());
802 }
803
804 #[test]
805 #[allow(non_snake_case)]
806 fn test_mi_redundancy() {
807 let X = array![
808 [1.0, 1.0, 5.0],
809 [2.0, 2.0, 6.0],
810 [3.0, 3.0, 7.0],
811 [4.0, 4.0, 8.0],
812 ];
813
814 let feature_indices = vec![0, 1, 2];
815 let redundancy = MutualInformationRedundancy::new(0.1, 3);
816 let score = redundancy.compute(X.view(), &feature_indices).unwrap();
817
818 assert!(score >= 0.0);
819 }
820
821 #[test]
822 #[allow(non_snake_case)]
823 fn test_vif() {
824 let X = array![
825 [1.0, 2.0, 1.5],
826 [2.0, 4.0, 3.0],
827 [3.0, 6.0, 4.5],
828 [4.0, 8.0, 6.0],
829 ];
830
831 let feature_indices = vec![0, 1, 2];
832 let vif_analyzer = VarianceInflationFactor::new(5.0);
833 let vif_scores = vif_analyzer
834 .compute_all(X.view(), &feature_indices)
835 .unwrap();
836
837 assert_eq!(vif_scores.len(), 3);
838 for score in &vif_scores {
839 assert!(score >= &1.0);
840 }
841
842 let high_vif = vif_analyzer
843 .identify_high_vif_features(X.view(), &feature_indices)
844 .unwrap();
845 assert!(high_vif.len() <= feature_indices.len());
847 }
848
849 #[test]
850 #[allow(non_snake_case)]
851 fn test_redundancy_measures() {
852 let X = array![
853 [1.0, 2.0, 1.1, 5.0, 0.5],
854 [2.0, 4.0, 2.1, 6.0, 1.0],
855 [3.0, 6.0, 3.1, 7.0, 1.5],
856 [4.0, 8.0, 4.1, 8.0, 2.0],
857 [5.0, 10.0, 5.1, 9.0, 2.5],
858 ];
859
860 let feature_indices = vec![0, 1, 2, 3, 4];
861 let assessment = RedundancyMeasures::compute(X.view(), &feature_indices).unwrap();
862
863 assert!(assessment.correlation_redundancy_score >= 0.0);
864 assert!(assessment.mutual_information_redundancy_score >= 0.0);
865 assert!(assessment.average_vif >= 1.0);
866 assert!(assessment.max_vif >= 1.0);
867 assert_eq!(assessment.vif_scores.len(), 5);
868
869 let report = assessment.report();
870 assert!(report.contains("Redundancy Assessment"));
871 assert!(report.contains("Overall"));
872 }
873
874 #[test]
875 #[allow(non_snake_case)]
876 fn test_empty_feature_set() {
877 let X = array![[1.0], [2.0], [3.0]];
878 let feature_indices = vec![0];
879
880 let redundancy = CorrelationRedundancy::new(0.5, true);
881 let score = redundancy.compute(X.view(), &feature_indices).unwrap();
882 assert_eq!(score, 0.0); }
884}