sklears_feature_selection/evaluation/
relevance_scoring.rs

1//! Relevance scoring methods for feature selection evaluation
2//!
3//! This module implements comprehensive relevance scoring methods to evaluate
4//! how relevant selected features are to the target variable. All implementations
5//! follow the SciRS2 policy using scirs2-core for numerical computations.
6
7use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2};
8use sklears_core::error::{Result as SklResult, SklearsError};
9type Result<T> = SklResult<T>;
10
11impl From<RelevanceError> for SklearsError {
12    fn from(err: RelevanceError) -> Self {
13        SklearsError::FitError(format!("Relevance analysis error: {}", err))
14    }
15}
16use scirs2_core::random::{thread_rng, Rng};
17use std::collections::HashMap;
18use thiserror::Error;
19
20#[derive(Debug, Error)]
21pub enum RelevanceError {
22    #[error("Feature matrix is empty")]
23    EmptyFeatureMatrix,
24    #[error("Target array is empty")]
25    EmptyTarget,
26    #[error("Feature and target lengths do not match")]
27    LengthMismatch,
28    #[error("Invalid feature indices")]
29    InvalidFeatureIndices,
30    #[error("Insufficient variance in data")]
31    InsufficientVariance,
32}
33
34/// Information gain-based relevance scoring
35#[derive(Debug, Clone)]
36pub struct InformationGainScoring {
37    n_bins: usize,
38    use_equal_width: bool,
39}
40
41impl InformationGainScoring {
42    /// Create a new information gain scorer
43    pub fn new(n_bins: usize, use_equal_width: bool) -> Self {
44        Self {
45            n_bins,
46            use_equal_width,
47        }
48    }
49
50    /// Compute information gain for selected features
51    pub fn compute(
52        &self,
53        X: ArrayView2<f64>,
54        y: ArrayView1<f64>,
55        feature_indices: &[usize],
56    ) -> Result<Vec<f64>> {
57        if X.nrows() != y.len() {
58            return Err(RelevanceError::LengthMismatch.into());
59        }
60
61        if X.is_empty() || y.is_empty() {
62            return Err(RelevanceError::EmptyFeatureMatrix.into());
63        }
64
65        let mut scores = Vec::with_capacity(feature_indices.len());
66
67        // Discretize target if it's continuous
68        let discretized_target = self.discretize_target(y)?;
69
70        for &feature_idx in feature_indices {
71            if feature_idx >= X.ncols() {
72                return Err(RelevanceError::InvalidFeatureIndices.into());
73            }
74
75            let feature_column = X.column(feature_idx);
76            let discretized_feature = self.discretize_feature(feature_column)?;
77
78            let ig = self.compute_information_gain(&discretized_feature, &discretized_target)?;
79            scores.push(ig);
80        }
81
82        Ok(scores)
83    }
84
85    /// Discretize target variable
86    fn discretize_target(&self, target: ArrayView1<f64>) -> Result<Array1<i32>> {
87        // Check if target is already discrete (all integers)
88        let is_discrete = target.iter().all(|&x| x.fract() == 0.0);
89
90        if is_discrete {
91            // Convert to integers
92            return Ok(target.mapv(|x| x as i32));
93        }
94
95        // Discretize continuous target
96        self.discretize_continuous(target)
97    }
98
99    /// Discretize feature variable
100    fn discretize_feature(&self, feature: ArrayView1<f64>) -> Result<Array1<i32>> {
101        // Check if feature is already discrete (all integers)
102        let is_discrete = feature.iter().all(|&x| x.fract() == 0.0);
103
104        if is_discrete {
105            // Convert to integers, ensuring non-negative
106            let min_val = feature.iter().fold(f64::INFINITY, |acc, &x| acc.min(x)) as i32;
107            return Ok(feature.mapv(|x| x as i32 - min_val));
108        }
109
110        // Discretize continuous feature
111        self.discretize_continuous(feature)
112    }
113
114    /// Discretize continuous variable using binning
115    fn discretize_continuous(&self, values: ArrayView1<f64>) -> Result<Array1<i32>> {
116        let min_val = values.iter().fold(f64::INFINITY, |acc, &x| acc.min(x));
117        let max_val = values.iter().fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
118
119        if (max_val - min_val).abs() < 1e-10 {
120            // Constant feature
121            return Ok(Array1::zeros(values.len()));
122        }
123
124        if self.use_equal_width {
125            self.equal_width_binning(values, min_val, max_val)
126        } else {
127            self.equal_frequency_binning(values)
128        }
129    }
130
131    /// Equal width binning
132    fn equal_width_binning(
133        &self,
134        values: ArrayView1<f64>,
135        min_val: f64,
136        max_val: f64,
137    ) -> Result<Array1<i32>> {
138        let bin_width = (max_val - min_val) / self.n_bins as f64;
139        let mut discretized = Array1::zeros(values.len());
140
141        for (i, &value) in values.iter().enumerate() {
142            let bin = ((value - min_val) / bin_width).floor() as i32;
143            discretized[i] = bin.min((self.n_bins - 1) as i32).max(0);
144        }
145
146        Ok(discretized)
147    }
148
149    /// Equal frequency binning (quantile-based)
150    fn equal_frequency_binning(&self, values: ArrayView1<f64>) -> Result<Array1<i32>> {
151        let mut sorted_values: Vec<f64> = values.to_vec();
152        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
153
154        let n = values.len();
155        let bin_size = n / self.n_bins;
156        let mut discretized = Array1::zeros(n);
157
158        for (i, &value) in values.iter().enumerate() {
159            // Find which quantile this value belongs to
160            let rank = sorted_values.partition_point(|&x| x < value);
161            let bin = (rank / bin_size.max(1)).min(self.n_bins - 1);
162            discretized[i] = bin as i32;
163        }
164
165        Ok(discretized)
166    }
167
168    /// Compute information gain between feature and target
169    fn compute_information_gain(&self, feature: &Array1<i32>, target: &Array1<i32>) -> Result<f64> {
170        let target_entropy = self.compute_entropy(target)?;
171        let conditional_entropy = self.compute_conditional_entropy(feature, target)?;
172
173        Ok(target_entropy - conditional_entropy)
174    }
175
176    /// Compute entropy of a discrete variable
177    fn compute_entropy(&self, values: &Array1<i32>) -> Result<f64> {
178        let mut counts = HashMap::new();
179        let total = values.len() as f64;
180
181        for &value in values.iter() {
182            *counts.entry(value).or_insert(0) += 1;
183        }
184
185        let mut entropy = 0.0;
186        for count in counts.values() {
187            if *count > 0 {
188                let probability = *count as f64 / total;
189                entropy -= probability * probability.ln();
190            }
191        }
192
193        Ok(entropy)
194    }
195
196    /// Compute conditional entropy H(Y|X)
197    fn compute_conditional_entropy(
198        &self,
199        feature: &Array1<i32>,
200        target: &Array1<i32>,
201    ) -> Result<f64> {
202        let mut joint_counts = HashMap::new();
203        let mut feature_counts = HashMap::new();
204        let total = feature.len() as f64;
205
206        // Count joint occurrences and marginal counts
207        for i in 0..feature.len() {
208            let x_val = feature[i];
209            let y_val = target[i];
210
211            *joint_counts.entry((x_val, y_val)).or_insert(0) += 1;
212            *feature_counts.entry(x_val).or_insert(0) += 1;
213        }
214
215        let mut conditional_entropy = 0.0;
216
217        // For each value of X
218        for (&x_val, &x_count) in feature_counts.iter() {
219            if x_count == 0 {
220                continue;
221            }
222
223            let p_x = x_count as f64 / total;
224
225            // Compute H(Y | X = x_val)
226            let mut entropy_y_given_x = 0.0;
227            for (&(joint_x, _joint_y), &joint_count) in joint_counts.iter() {
228                if joint_x == x_val && joint_count > 0 {
229                    let p_y_given_x = joint_count as f64 / x_count as f64;
230                    entropy_y_given_x -= p_y_given_x * p_y_given_x.ln();
231                }
232            }
233
234            conditional_entropy += p_x * entropy_y_given_x;
235        }
236
237        Ok(conditional_entropy)
238    }
239
240    /// Compute average information gain for selected features
241    pub fn average_score(
242        &self,
243        X: ArrayView2<f64>,
244        y: ArrayView1<f64>,
245        feature_indices: &[usize],
246    ) -> Result<f64> {
247        let scores = self.compute(X, y, feature_indices)?;
248        Ok(scores.iter().sum::<f64>() / scores.len() as f64)
249    }
250}
251
252/// Chi-square-based relevance scoring
253#[derive(Debug, Clone)]
254pub struct ChiSquareScoring {
255    n_bins: usize,
256}
257
258impl ChiSquareScoring {
259    /// Create a new chi-square scorer
260    pub fn new(n_bins: usize) -> Self {
261        Self { n_bins }
262    }
263
264    /// Compute chi-square statistics for selected features
265    pub fn compute(
266        &self,
267        X: ArrayView2<f64>,
268        y: ArrayView1<f64>,
269        feature_indices: &[usize],
270    ) -> Result<Vec<f64>> {
271        if X.nrows() != y.len() {
272            return Err(RelevanceError::LengthMismatch.into());
273        }
274
275        let mut scores = Vec::with_capacity(feature_indices.len());
276
277        // Discretize target
278        let discretized_target = self.discretize_variable(y)?;
279
280        for &feature_idx in feature_indices {
281            if feature_idx >= X.ncols() {
282                return Err(RelevanceError::InvalidFeatureIndices.into());
283            }
284
285            let feature_column = X.column(feature_idx);
286            let discretized_feature = self.discretize_variable(feature_column)?;
287
288            let chi2 = self.compute_chi_square(&discretized_feature, &discretized_target)?;
289            scores.push(chi2);
290        }
291
292        Ok(scores)
293    }
294
295    /// Discretize variable for chi-square test
296    fn discretize_variable(&self, values: ArrayView1<f64>) -> Result<Array1<i32>> {
297        let min_val = values.iter().fold(f64::INFINITY, |acc, &x| acc.min(x));
298        let max_val = values.iter().fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
299
300        if (max_val - min_val).abs() < 1e-10 {
301            return Ok(Array1::zeros(values.len()));
302        }
303
304        let bin_width = (max_val - min_val) / self.n_bins as f64;
305        let mut discretized = Array1::zeros(values.len());
306
307        for (i, &value) in values.iter().enumerate() {
308            let bin = ((value - min_val) / bin_width).floor() as i32;
309            discretized[i] = bin.min((self.n_bins - 1) as i32).max(0);
310        }
311
312        Ok(discretized)
313    }
314
315    /// Compute chi-square statistic
316    fn compute_chi_square(&self, feature: &Array1<i32>, target: &Array1<i32>) -> Result<f64> {
317        // Create contingency table
318        let mut joint_counts = HashMap::new();
319        let mut feature_counts = HashMap::new();
320        let mut target_counts = HashMap::new();
321        let total = feature.len() as f64;
322
323        for i in 0..feature.len() {
324            let x = feature[i];
325            let y = target[i];
326
327            *joint_counts.entry((x, y)).or_insert(0) += 1;
328            *feature_counts.entry(x).or_insert(0) += 1;
329            *target_counts.entry(y).or_insert(0) += 1;
330        }
331
332        let mut chi_square = 0.0;
333
334        for (&(x, y), &observed) in joint_counts.iter() {
335            let x_count = *feature_counts.get(&x).unwrap_or(&0) as f64;
336            let y_count = *target_counts.get(&y).unwrap_or(&0) as f64;
337
338            let expected = (x_count * y_count) / total;
339
340            if expected > 1e-10 {
341                let diff = observed as f64 - expected;
342                chi_square += (diff * diff) / expected;
343            }
344        }
345
346        Ok(chi_square)
347    }
348
349    /// Compute average chi-square score
350    pub fn average_score(
351        &self,
352        X: ArrayView2<f64>,
353        y: ArrayView1<f64>,
354        feature_indices: &[usize],
355    ) -> Result<f64> {
356        let scores = self.compute(X, y, feature_indices)?;
357        Ok(scores.iter().sum::<f64>() / scores.len() as f64)
358    }
359}
360
361/// F-statistic-based relevance scoring
362#[derive(Debug, Clone)]
363pub struct FStatisticScoring {
364    classification: bool,
365}
366
367impl FStatisticScoring {
368    /// Create a new F-statistic scorer
369    pub fn new(classification: bool) -> Self {
370        Self { classification }
371    }
372
373    /// Compute F-statistics for selected features
374    pub fn compute(
375        &self,
376        X: ArrayView2<f64>,
377        y: ArrayView1<f64>,
378        feature_indices: &[usize],
379    ) -> Result<Vec<f64>> {
380        if X.nrows() != y.len() {
381            return Err(RelevanceError::LengthMismatch.into());
382        }
383
384        let mut scores = Vec::with_capacity(feature_indices.len());
385
386        for &feature_idx in feature_indices {
387            if feature_idx >= X.ncols() {
388                return Err(RelevanceError::InvalidFeatureIndices.into());
389            }
390
391            let feature_column = X.column(feature_idx);
392
393            let f_stat = if self.classification {
394                self.compute_f_classif(feature_column, y)?
395            } else {
396                self.compute_f_regression(feature_column, y)?
397            };
398
399            scores.push(f_stat);
400        }
401
402        Ok(scores)
403    }
404
405    /// Compute F-statistic for classification
406    fn compute_f_classif(&self, feature: ArrayView1<f64>, target: ArrayView1<f64>) -> Result<f64> {
407        // Group feature values by class
408        let mut class_groups: HashMap<i32, Vec<f64>> = HashMap::new();
409
410        for i in 0..feature.len() {
411            let class = target[i] as i32;
412            class_groups.entry(class).or_default().push(feature[i]);
413        }
414
415        if class_groups.len() < 2 {
416            return Ok(0.0); // No variation between classes
417        }
418
419        // Compute overall mean
420        let overall_mean = feature.mean().unwrap_or(0.0);
421        let total_n = feature.len() as f64;
422
423        // Compute between-group sum of squares
424        let mut ss_between = 0.0;
425        for group in class_groups.values() {
426            let group_mean = group.iter().sum::<f64>() / group.len() as f64;
427            let n_group = group.len() as f64;
428            ss_between += n_group * (group_mean - overall_mean).powi(2);
429        }
430
431        // Compute within-group sum of squares
432        let mut ss_within = 0.0;
433        for group in class_groups.values() {
434            let group_mean = group.iter().sum::<f64>() / group.len() as f64;
435            for &value in group {
436                ss_within += (value - group_mean).powi(2);
437            }
438        }
439
440        let df_between = (class_groups.len() - 1) as f64;
441        let df_within = (total_n - class_groups.len() as f64).max(1.0);
442
443        if ss_within < 1e-10 {
444            return Ok(f64::INFINITY);
445        }
446
447        let ms_between = ss_between / df_between;
448        let ms_within = ss_within / df_within;
449
450        Ok(ms_between / ms_within)
451    }
452
453    /// Compute F-statistic for regression
454    fn compute_f_regression(
455        &self,
456        feature: ArrayView1<f64>,
457        target: ArrayView1<f64>,
458    ) -> Result<f64> {
459        let n = feature.len() as f64;
460        if n < 3.0 {
461            return Ok(0.0);
462        }
463
464        // Compute correlation coefficient
465        let mean_x = feature.mean().unwrap_or(0.0);
466        let mean_y = target.mean().unwrap_or(0.0);
467
468        let mut sum_xy = 0.0;
469        let mut sum_xx = 0.0;
470        let mut sum_yy = 0.0;
471
472        for i in 0..feature.len() {
473            let dx = feature[i] - mean_x;
474            let dy = target[i] - mean_y;
475            sum_xy += dx * dy;
476            sum_xx += dx * dx;
477            sum_yy += dy * dy;
478        }
479
480        if sum_xx < 1e-10 || sum_yy < 1e-10 {
481            return Ok(0.0);
482        }
483
484        let r = sum_xy / (sum_xx * sum_yy).sqrt();
485        let r_squared = r * r;
486
487        // F-statistic for regression: F = (r²/(1-r²)) * (n-2)
488        if (1.0 - r_squared).abs() < 1e-10 {
489            return Ok(f64::INFINITY);
490        }
491
492        Ok((r_squared / (1.0 - r_squared)) * (n - 2.0))
493    }
494
495    /// Compute average F-statistic score
496    pub fn average_score(
497        &self,
498        X: ArrayView2<f64>,
499        y: ArrayView1<f64>,
500        feature_indices: &[usize],
501    ) -> Result<f64> {
502        let scores = self.compute(X, y, feature_indices)?;
503        Ok(scores.iter().sum::<f64>() / scores.len() as f64)
504    }
505}
506
507/// Correlation-based relevance scoring
508#[derive(Debug, Clone)]
509pub struct CorrelationScoring {
510    use_absolute: bool,
511}
512
513impl CorrelationScoring {
514    /// Create a new correlation scorer
515    pub fn new(use_absolute: bool) -> Self {
516        Self { use_absolute }
517    }
518
519    /// Compute correlations for selected features
520    pub fn compute(
521        &self,
522        X: ArrayView2<f64>,
523        y: ArrayView1<f64>,
524        feature_indices: &[usize],
525    ) -> Result<Vec<f64>> {
526        if X.nrows() != y.len() {
527            return Err(RelevanceError::LengthMismatch.into());
528        }
529
530        let mut scores = Vec::with_capacity(feature_indices.len());
531
532        for &feature_idx in feature_indices {
533            if feature_idx >= X.ncols() {
534                return Err(RelevanceError::InvalidFeatureIndices.into());
535            }
536
537            let feature_column = X.column(feature_idx);
538            let correlation = self.compute_correlation(feature_column, y)?;
539
540            scores.push(if self.use_absolute {
541                correlation.abs()
542            } else {
543                correlation
544            });
545        }
546
547        Ok(scores)
548    }
549
550    /// Compute Pearson correlation coefficient
551    fn compute_correlation(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> Result<f64> {
552        let n = x.len() as f64;
553        if n < 2.0 {
554            return Ok(0.0);
555        }
556
557        let mean_x = x.mean().unwrap_or(0.0);
558        let mean_y = y.mean().unwrap_or(0.0);
559
560        let mut sum_xy = 0.0;
561        let mut sum_x2 = 0.0;
562        let mut sum_y2 = 0.0;
563
564        for i in 0..x.len() {
565            let dx = x[i] - mean_x;
566            let dy = y[i] - mean_y;
567            sum_xy += dx * dy;
568            sum_x2 += dx * dx;
569            sum_y2 += dy * dy;
570        }
571
572        let denom = (sum_x2 * sum_y2).sqrt();
573        if denom < 1e-10 {
574            return Ok(0.0);
575        }
576
577        Ok(sum_xy / denom)
578    }
579
580    /// Compute average correlation score
581    pub fn average_score(
582        &self,
583        X: ArrayView2<f64>,
584        y: ArrayView1<f64>,
585        feature_indices: &[usize],
586    ) -> Result<f64> {
587        let scores = self.compute(X, y, feature_indices)?;
588        Ok(scores.iter().sum::<f64>() / scores.len() as f64)
589    }
590}
591
592/// Relief algorithm-based relevance scoring
593#[derive(Debug, Clone)]
594pub struct ReliefScoring {
595    n_iterations: usize,
596    k_neighbors: usize,
597}
598
599impl ReliefScoring {
600    /// Create a new Relief scorer
601    pub fn new(n_iterations: usize, k_neighbors: usize) -> Self {
602        Self {
603            n_iterations,
604            k_neighbors,
605        }
606    }
607
608    /// Compute Relief scores for selected features
609    pub fn compute(
610        &self,
611        X: ArrayView2<f64>,
612        y: ArrayView1<f64>,
613        feature_indices: &[usize],
614    ) -> Result<Vec<f64>> {
615        if X.nrows() != y.len() {
616            return Err(RelevanceError::LengthMismatch.into());
617        }
618
619        if X.is_empty() || y.is_empty() {
620            return Err(RelevanceError::EmptyFeatureMatrix.into());
621        }
622
623        let mut feature_weights = vec![0.0; feature_indices.len()];
624
625        // Relief algorithm iterations
626        for _ in 0..self.n_iterations {
627            // Randomly select an instance
628            let random_idx = (thread_rng().gen::<f64>() * X.nrows() as f64) as usize % X.nrows();
629
630            let random_instance = X.row(random_idx);
631            let random_class = y[random_idx];
632
633            // Find nearest hit and miss
634            let (nearest_hit_idx, nearest_miss_idx) =
635                self.find_nearest_hit_miss(X, y, random_idx, random_class)?;
636
637            if let Some(hit_idx) = nearest_hit_idx {
638                let hit_instance = X.row(hit_idx);
639                for (i, &feature_idx) in feature_indices.iter().enumerate() {
640                    if feature_idx < X.ncols() {
641                        let diff = (random_instance[feature_idx] - hit_instance[feature_idx]).abs();
642                        feature_weights[i] -= diff;
643                    }
644                }
645            }
646
647            if let Some(miss_idx) = nearest_miss_idx {
648                let miss_instance = X.row(miss_idx);
649                for (i, &feature_idx) in feature_indices.iter().enumerate() {
650                    if feature_idx < X.ncols() {
651                        let diff =
652                            (random_instance[feature_idx] - miss_instance[feature_idx]).abs();
653                        feature_weights[i] += diff;
654                    }
655                }
656            }
657        }
658
659        // Normalize weights
660        let n_iterations = self.n_iterations as f64;
661        for weight in &mut feature_weights {
662            *weight /= n_iterations;
663        }
664
665        Ok(feature_weights)
666    }
667
668    /// Find nearest hit (same class) and miss (different class)
669    fn find_nearest_hit_miss(
670        &self,
671        X: ArrayView2<f64>,
672        y: ArrayView1<f64>,
673        instance_idx: usize,
674        instance_class: f64,
675    ) -> Result<(Option<usize>, Option<usize>)> {
676        let mut nearest_hit_idx = None;
677        let mut nearest_miss_idx = None;
678        let mut min_hit_distance = f64::INFINITY;
679        let mut min_miss_distance = f64::INFINITY;
680
681        let instance = X.row(instance_idx);
682
683        for i in 0..X.nrows() {
684            if i == instance_idx {
685                continue;
686            }
687
688            let other_instance = X.row(i);
689            let other_class = y[i];
690
691            // Compute Euclidean distance
692            let mut distance = 0.0;
693            for j in 0..X.ncols() {
694                let diff = instance[j] - other_instance[j];
695                distance += diff * diff;
696            }
697            distance = distance.sqrt();
698
699            // Check if it's a hit or miss
700            if (other_class - instance_class).abs() < 1e-10 {
701                // Same class (hit)
702                if distance < min_hit_distance {
703                    min_hit_distance = distance;
704                    nearest_hit_idx = Some(i);
705                }
706            } else {
707                // Different class (miss)
708                if distance < min_miss_distance {
709                    min_miss_distance = distance;
710                    nearest_miss_idx = Some(i);
711                }
712            }
713        }
714
715        Ok((nearest_hit_idx, nearest_miss_idx))
716    }
717
718    /// Compute average Relief score
719    pub fn average_score(
720        &self,
721        X: ArrayView2<f64>,
722        y: ArrayView1<f64>,
723        feature_indices: &[usize],
724    ) -> Result<f64> {
725        let scores = self.compute(X, y, feature_indices)?;
726        Ok(scores.iter().sum::<f64>() / scores.len() as f64)
727    }
728}
729
730/// Comprehensive relevance scoring aggregator
731#[derive(Debug, Clone)]
732pub struct RelevanceScoring {
733    classification: bool,
734}
735
736impl RelevanceScoring {
737    /// Create a new relevance scorer
738    pub fn new(classification: bool) -> Self {
739        Self { classification }
740    }
741
742    /// Compute comprehensive relevance assessment
743    pub fn compute(
744        &self,
745        X: ArrayView2<f64>,
746        y: ArrayView1<f64>,
747        feature_indices: &[usize],
748    ) -> Result<RelevanceAssessment> {
749        let information_gain = InformationGainScoring::new(10, true);
750        let chi_square = ChiSquareScoring::new(10);
751        let f_statistic = FStatisticScoring::new(self.classification);
752        let correlation = CorrelationScoring::new(true);
753        let relief = ReliefScoring::new(100, 5);
754
755        let ig_scores = information_gain.compute(X, y, feature_indices)?;
756        let chi2_scores = chi_square.compute(X, y, feature_indices)?;
757        let f_scores = f_statistic.compute(X, y, feature_indices)?;
758        let corr_scores = correlation.compute(X, y, feature_indices)?;
759        let relief_scores = relief.compute(X, y, feature_indices)?;
760
761        Ok(RelevanceAssessment {
762            information_gain_scores: ig_scores,
763            chi_square_scores: chi2_scores,
764            f_statistic_scores: f_scores,
765            correlation_scores: corr_scores,
766            relief_scores,
767            feature_indices: feature_indices.to_vec(),
768            average_information_gain: information_gain.average_score(X, y, feature_indices)?,
769            average_chi_square: chi_square.average_score(X, y, feature_indices)?,
770            average_f_statistic: f_statistic.average_score(X, y, feature_indices)?,
771            average_correlation: correlation.average_score(X, y, feature_indices)?,
772            average_relief: relief.average_score(X, y, feature_indices)?,
773        })
774    }
775}
776
777/// Comprehensive relevance assessment results
778#[derive(Debug, Clone)]
779pub struct RelevanceAssessment {
780    pub information_gain_scores: Vec<f64>,
781    pub chi_square_scores: Vec<f64>,
782    pub f_statistic_scores: Vec<f64>,
783    pub correlation_scores: Vec<f64>,
784    pub relief_scores: Vec<f64>,
785    pub feature_indices: Vec<usize>,
786    pub average_information_gain: f64,
787    pub average_chi_square: f64,
788    pub average_f_statistic: f64,
789    pub average_correlation: f64,
790    pub average_relief: f64,
791}
792
793impl RelevanceAssessment {
794    /// Generate comprehensive relevance report
795    pub fn report(&self) -> String {
796        let mut report = String::new();
797
798        report.push_str("=== Feature Relevance Assessment ===\n\n");
799
800        report.push_str(&format!(
801            "Number of features analyzed: {}\n\n",
802            self.feature_indices.len()
803        ));
804
805        report.push_str(&format!(
806            "Average Information Gain: {:.4}\n",
807            self.average_information_gain
808        ));
809        report.push_str(&format!(
810            "Average Chi-Square: {:.4}\n",
811            self.average_chi_square
812        ));
813        report.push_str(&format!(
814            "Average F-Statistic: {:.4}\n",
815            self.average_f_statistic
816        ));
817        report.push_str(&format!(
818            "Average Correlation: {:.4}\n",
819            self.average_correlation
820        ));
821        report.push_str(&format!(
822            "Average Relief Score: {:.4}\n\n",
823            self.average_relief
824        ));
825
826        report.push_str("Per-Feature Relevance Scores:\n");
827        report.push_str("Feature | InfoGain | Chi2     | F-Stat   | Corr     | Relief\n");
828        report.push_str("--------|----------|----------|----------|----------|----------\n");
829
830        for i in 0..self.feature_indices.len() {
831            report.push_str(&format!(
832                "{:7} | {:8.4} | {:8.4} | {:8.4} | {:8.4} | {:8.4}\n",
833                self.feature_indices[i],
834                self.information_gain_scores[i],
835                self.chi_square_scores[i],
836                self.f_statistic_scores[i],
837                self.correlation_scores[i],
838                self.relief_scores[i]
839            ));
840        }
841
842        report.push_str(&format!(
843            "\nOverall Relevance Assessment: {}\n",
844            self.overall_assessment()
845        ));
846
847        report
848    }
849
850    fn overall_assessment(&self) -> &'static str {
851        let scores = [
852            self.average_information_gain,
853            self.average_chi_square / 10.0,  // Normalize chi2
854            self.average_f_statistic / 10.0, // Normalize F-stat
855            self.average_correlation,
856            self.average_relief,
857        ];
858
859        let average = scores.iter().sum::<f64>() / scores.len() as f64;
860
861        match average {
862            x if x >= 0.8 => "EXCELLENT: Features show very high relevance to target",
863            x if x >= 0.6 => "GOOD: Features show good relevance to target",
864            x if x >= 0.4 => "MODERATE: Features show moderate relevance to target",
865            x if x >= 0.2 => "POOR: Features show low relevance to target",
866            _ => {
867                "CRITICAL: Features show very low relevance to target - consider different features"
868            }
869        }
870    }
871
872    /// Get top N features by relevance score (using average of normalized scores)
873    pub fn get_top_features(&self, n: usize) -> Vec<(usize, f64)> {
874        let mut feature_scores: Vec<(usize, f64)> = Vec::new();
875
876        for i in 0..self.feature_indices.len() {
877            let normalized_scores = [
878                self.information_gain_scores[i],
879                self.chi_square_scores[i] / 10.0,
880                self.f_statistic_scores[i] / 10.0,
881                self.correlation_scores[i],
882                self.relief_scores[i],
883            ];
884
885            let average_score =
886                normalized_scores.iter().sum::<f64>() / normalized_scores.len() as f64;
887            feature_scores.push((self.feature_indices[i], average_score));
888        }
889
890        // Sort by score (descending)
891        feature_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
892
893        feature_scores.into_iter().take(n).collect()
894    }
895}
896
897#[allow(non_snake_case)]
898#[cfg(test)]
899mod tests {
900    use super::*;
901    use scirs2_core::ndarray::array;
902
903    #[test]
904    #[allow(non_snake_case)]
905    fn test_information_gain_scoring() {
906        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
907        let y = array![0.0, 0.0, 1.0, 1.0];
908
909        let ig_scorer = InformationGainScoring::new(3, true);
910        let scores = ig_scorer.compute(X.view(), y.view(), &[0, 1]).unwrap();
911
912        assert_eq!(scores.len(), 2);
913        for score in &scores {
914            assert!(score >= &0.0);
915        }
916    }
917
918    #[test]
919    #[allow(non_snake_case)]
920    fn test_chi_square_scoring() {
921        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
922        let y = array![0.0, 0.0, 1.0, 1.0];
923
924        let chi2_scorer = ChiSquareScoring::new(3);
925        let scores = chi2_scorer.compute(X.view(), y.view(), &[0, 1]).unwrap();
926
927        assert_eq!(scores.len(), 2);
928        for score in &scores {
929            assert!(score >= &0.0);
930        }
931    }
932
933    #[test]
934    #[allow(non_snake_case)]
935    fn test_f_statistic_scoring() {
936        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
937        let y = array![0.0, 0.0, 1.0, 1.0];
938
939        let f_scorer = FStatisticScoring::new(true);
940        let scores = f_scorer.compute(X.view(), y.view(), &[0, 1]).unwrap();
941
942        assert_eq!(scores.len(), 2);
943        for score in &scores {
944            assert!(score >= &0.0);
945        }
946    }
947
948    #[test]
949    #[allow(non_snake_case)]
950    fn test_correlation_scoring() {
951        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
952        let y = array![1.0, 2.0, 3.0, 4.0];
953
954        let corr_scorer = CorrelationScoring::new(true);
955        let scores = corr_scorer.compute(X.view(), y.view(), &[0, 1]).unwrap();
956
957        assert_eq!(scores.len(), 2);
958        for score in &scores {
959            assert!(score >= &0.0 && score <= &1.0);
960        }
961    }
962
963    #[test]
964    #[allow(non_snake_case)]
965    fn test_relief_scoring() {
966        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 1.0],];
967        let y = array![0.0, 0.0, 1.0, 1.0];
968
969        let relief_scorer = ReliefScoring::new(10, 1);
970        let scores = relief_scorer.compute(X.view(), y.view(), &[0, 1]).unwrap();
971
972        assert_eq!(scores.len(), 2);
973        // Relief scores can be negative
974    }
975
976    #[test]
977    #[allow(non_snake_case)]
978    fn test_relevance_assessment() {
979        let X = array![
980            [1.0, 2.0, 10.0],
981            [2.0, 3.0, 20.0],
982            [3.0, 4.0, 30.0],
983            [4.0, 5.0, 40.0],
984            [5.0, 6.0, 50.0],
985        ];
986        let y = array![0.0, 0.0, 1.0, 1.0, 1.0];
987
988        let relevance_scorer = RelevanceScoring::new(true);
989        let assessment = relevance_scorer
990            .compute(X.view(), y.view(), &[0, 1, 2])
991            .unwrap();
992
993        assert_eq!(assessment.information_gain_scores.len(), 3);
994        assert_eq!(assessment.chi_square_scores.len(), 3);
995        assert_eq!(assessment.f_statistic_scores.len(), 3);
996        assert_eq!(assessment.correlation_scores.len(), 3);
997        assert_eq!(assessment.relief_scores.len(), 3);
998
999        let report = assessment.report();
1000        assert!(report.contains("Relevance Assessment"));
1001
1002        let top_features = assessment.get_top_features(2);
1003        assert_eq!(top_features.len(), 2);
1004    }
1005}