sklears_feature_selection/evaluation/
stability_measures.rs1use scirs2_core::ndarray::Array2;
8use sklears_core::error::{Result as SklResult, SklearsError};
9type Result<T> = SklResult<T>;
10use std::collections::HashSet;
11use thiserror::Error;
12
13#[derive(Debug, Error)]
14pub enum StabilityError {
15 #[error("Feature sets must have the same length")]
16 FeatureSetLengthMismatch,
17 #[error("Invalid feature index: {0}")]
18 InvalidFeatureIndex(usize),
19 #[error("Insufficient data for stability analysis")]
20 InsufficientData,
21}
22
23#[derive(Debug, Clone)]
25pub struct JaccardSimilarity;
26
27impl JaccardSimilarity {
28 pub fn compute(set1: &[usize], set2: &[usize]) -> Result<f64> {
30 let s1: HashSet<_> = set1.iter().collect();
31 let s2: HashSet<_> = set2.iter().collect();
32
33 let intersection = s1.intersection(&s2).count() as f64;
34 let union = s1.union(&s2).count() as f64;
35
36 if union == 0.0 {
37 return Ok(1.0); }
39
40 Ok(intersection / union)
41 }
42
43 pub fn average_similarity(feature_sets: &[Vec<usize>]) -> Result<f64> {
45 if feature_sets.len() < 2 {
46 return Err(SklearsError::FitError(
47 "At least two feature sets required".to_string(),
48 ));
49 }
50
51 let mut total_similarity = 0.0;
52 let mut comparisons = 0;
53
54 for i in 0..feature_sets.len() {
55 for j in (i + 1)..feature_sets.len() {
56 total_similarity += Self::compute(&feature_sets[i], &feature_sets[j])?;
57 comparisons += 1;
58 }
59 }
60
61 Ok(total_similarity / comparisons as f64)
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct DiceSimilarity;
68
69impl DiceSimilarity {
70 pub fn compute(set1: &[usize], set2: &[usize]) -> Result<f64> {
72 let s1: HashSet<_> = set1.iter().collect();
73 let s2: HashSet<_> = set2.iter().collect();
74
75 let intersection = s1.intersection(&s2).count() as f64;
76 let total_size = (s1.len() + s2.len()) as f64;
77
78 if total_size == 0.0 {
79 return Ok(1.0); }
81
82 Ok(2.0 * intersection / total_size)
83 }
84
85 pub fn average_similarity(feature_sets: &[Vec<usize>]) -> Result<f64> {
87 if feature_sets.len() < 2 {
88 return Err(SklearsError::FitError(
89 "At least two feature sets required".to_string(),
90 ));
91 }
92
93 let mut total_similarity = 0.0;
94 let mut comparisons = 0;
95
96 for i in 0..feature_sets.len() {
97 for j in (i + 1)..feature_sets.len() {
98 total_similarity += Self::compute(&feature_sets[i], &feature_sets[j])?;
99 comparisons += 1;
100 }
101 }
102
103 Ok(total_similarity / comparisons as f64)
104 }
105}
106
107#[derive(Debug, Clone)]
109pub struct OverlapCoefficient;
110
111impl OverlapCoefficient {
112 pub fn compute(set1: &[usize], set2: &[usize]) -> Result<f64> {
114 let s1: HashSet<_> = set1.iter().collect();
115 let s2: HashSet<_> = set2.iter().collect();
116
117 let intersection = s1.intersection(&s2).count() as f64;
118 let min_size = std::cmp::min(s1.len(), s2.len()) as f64;
119
120 if min_size == 0.0 {
121 return Ok(1.0); }
123
124 Ok(intersection / min_size)
125 }
126
127 pub fn average_coefficient(feature_sets: &[Vec<usize>]) -> Result<f64> {
129 if feature_sets.len() < 2 {
130 return Err(SklearsError::FitError(
131 "At least two feature sets required".to_string(),
132 ));
133 }
134
135 let mut total_coefficient = 0.0;
136 let mut comparisons = 0;
137
138 for i in 0..feature_sets.len() {
139 for j in (i + 1)..feature_sets.len() {
140 total_coefficient += Self::compute(&feature_sets[i], &feature_sets[j])?;
141 comparisons += 1;
142 }
143 }
144
145 Ok(total_coefficient / comparisons as f64)
146 }
147}
148
149#[derive(Debug, Clone)]
151pub struct ConsistencyIndex;
152
153impl ConsistencyIndex {
154 pub fn compute(set1: &[usize], set2: &[usize], total_features: usize) -> Result<f64> {
159 if total_features == 0 {
160 return Err(SklearsError::FitError(
161 "Total features must be positive".to_string(),
162 ));
163 }
164
165 let s1: HashSet<_> = set1.iter().collect();
166 let s2: HashSet<_> = set2.iter().collect();
167
168 let k1 = s1.len() as f64;
169 let k2 = s2.len() as f64;
170 let r = s1.intersection(&s2).count() as f64;
171 let n = total_features as f64;
172
173 let expected_overlap = (k1 * k2) / n;
174 let numerator = r - expected_overlap;
175 let denominator = std::cmp::min(s1.len(), s2.len()) as f64 - expected_overlap;
176
177 if denominator.abs() < 1e-10 {
178 return Ok(1.0); }
180
181 Ok(numerator / denominator)
182 }
183
184 pub fn average_consistency(feature_sets: &[Vec<usize>], total_features: usize) -> Result<f64> {
186 if feature_sets.len() < 2 {
187 return Err(SklearsError::FitError(
188 "At least two feature sets required".to_string(),
189 ));
190 }
191
192 let mut total_consistency = 0.0;
193 let mut comparisons = 0;
194
195 for i in 0..feature_sets.len() {
196 for j in (i + 1)..feature_sets.len() {
197 total_consistency +=
198 Self::compute(&feature_sets[i], &feature_sets[j], total_features)?;
199 comparisons += 1;
200 }
201 }
202
203 Ok(total_consistency / comparisons as f64)
204 }
205}
206
207#[derive(Debug, Clone)]
209pub struct StabilityMeasures {
210 pub jaccard_similarity: f64,
211 pub dice_similarity: f64,
212 pub overlap_coefficient: f64,
213 pub consistency_index: f64,
214 pub pairwise_stability: f64,
215 pub relative_stability_index: f64,
216}
217
218impl StabilityMeasures {
219 pub fn compute(feature_sets: &[Vec<usize>], total_features: usize) -> Result<Self> {
221 if feature_sets.len() < 2 {
222 return Err(SklearsError::FitError(
223 "At least two feature sets required".to_string(),
224 ));
225 }
226
227 let jaccard_similarity = JaccardSimilarity::average_similarity(feature_sets)?;
228 let dice_similarity = DiceSimilarity::average_similarity(feature_sets)?;
229 let overlap_coefficient = OverlapCoefficient::average_coefficient(feature_sets)?;
230 let consistency_index =
231 ConsistencyIndex::average_consistency(feature_sets, total_features)?;
232 let pairwise_stability = Self::compute_pairwise_stability(feature_sets)?;
233 let relative_stability_index =
234 Self::compute_relative_stability_index(feature_sets, total_features)?;
235
236 Ok(Self {
237 jaccard_similarity,
238 dice_similarity,
239 overlap_coefficient,
240 consistency_index,
241 pairwise_stability,
242 relative_stability_index,
243 })
244 }
245
246 fn compute_pairwise_stability(feature_sets: &[Vec<usize>]) -> Result<f64> {
248 if feature_sets.len() < 2 {
249 return Ok(0.0);
250 }
251
252 let n_sets = feature_sets.len();
253 let mut stability_matrix = Array2::zeros((n_sets, n_sets));
254
255 for i in 0..n_sets {
256 for j in 0..n_sets {
257 if i != j {
258 stability_matrix[[i, j]] =
259 JaccardSimilarity::compute(&feature_sets[i], &feature_sets[j])?;
260 } else {
261 stability_matrix[[i, j]] = 1.0;
262 }
263 }
264 }
265
266 let mut sum = 0.0;
268 let mut count = 0;
269
270 for i in 0..n_sets {
271 for j in 0..n_sets {
272 if i != j {
273 sum += stability_matrix[[i, j]];
274 count += 1;
275 }
276 }
277 }
278
279 Ok(sum / count as f64)
280 }
281
282 fn compute_relative_stability_index(
284 feature_sets: &[Vec<usize>],
285 total_features: usize,
286 ) -> Result<f64> {
287 let average_set_size =
288 feature_sets.iter().map(|s| s.len()).sum::<usize>() as f64 / feature_sets.len() as f64;
289 let max_possible_overlap = average_set_size.min(total_features as f64 - average_set_size);
290
291 if max_possible_overlap < 1e-10 {
292 return Ok(0.0);
293 }
294
295 let observed_overlap =
296 JaccardSimilarity::average_similarity(feature_sets)? * average_set_size;
297
298 Ok(observed_overlap / max_possible_overlap)
299 }
300
301 pub fn report(&self) -> String {
303 let mut report = String::new();
304
305 report.push_str("=== Feature Selection Stability Report ===\n\n");
306
307 report.push_str(&format!(
308 "Jaccard Similarity: {:.4}\n",
309 self.jaccard_similarity
310 ));
311 report.push_str(&format!(" Interpretation: {}\n", self.interpret_jaccard()));
312
313 report.push_str(&format!("\nDice Similarity: {:.4}\n", self.dice_similarity));
314 report.push_str(&format!(" Interpretation: {}\n", self.interpret_dice()));
315
316 report.push_str(&format!(
317 "\nOverlap Coefficient: {:.4}\n",
318 self.overlap_coefficient
319 ));
320 report.push_str(&format!(" Interpretation: {}\n", self.interpret_overlap()));
321
322 report.push_str(&format!(
323 "\nConsistency Index: {:.4}\n",
324 self.consistency_index
325 ));
326 report.push_str(&format!(
327 " Interpretation: {}\n",
328 self.interpret_consistency()
329 ));
330
331 report.push_str(&format!(
332 "\nPairwise Stability: {:.4}\n",
333 self.pairwise_stability
334 ));
335 report.push_str(&format!(
336 " Interpretation: {}\n",
337 self.interpret_pairwise()
338 ));
339
340 report.push_str(&format!(
341 "\nRelative Stability Index: {:.4}\n",
342 self.relative_stability_index
343 ));
344 report.push_str(&format!(
345 " Interpretation: {}\n",
346 self.interpret_relative()
347 ));
348
349 report.push_str(&format!(
350 "\nOverall Assessment: {}\n",
351 self.overall_assessment()
352 ));
353
354 report
355 }
356
357 fn interpret_jaccard(&self) -> &'static str {
358 match self.jaccard_similarity {
359 x if x >= 0.8 => "Excellent stability - feature sets are highly consistent",
360 x if x >= 0.6 => "Good stability - reasonable consistency in feature selection",
361 x if x >= 0.4 => "Moderate stability - some variability in feature selection",
362 x if x >= 0.2 => "Poor stability - high variability in feature selection",
363 _ => "Very poor stability - feature selection is highly inconsistent",
364 }
365 }
366
367 fn interpret_dice(&self) -> &'static str {
368 match self.dice_similarity {
369 x if x >= 0.8 => "Excellent overlap between feature sets",
370 x if x >= 0.6 => "Good overlap between feature sets",
371 x if x >= 0.4 => "Moderate overlap between feature sets",
372 x if x >= 0.2 => "Poor overlap between feature sets",
373 _ => "Very poor overlap between feature sets",
374 }
375 }
376
377 fn interpret_overlap(&self) -> &'static str {
378 match self.overlap_coefficient {
379 x if x >= 0.8 => "High subset consistency - smaller sets are well contained",
380 x if x >= 0.6 => "Good subset consistency",
381 x if x >= 0.4 => "Moderate subset consistency",
382 x if x >= 0.2 => "Poor subset consistency",
383 _ => "Very poor subset consistency",
384 }
385 }
386
387 fn interpret_consistency(&self) -> &'static str {
388 match self.consistency_index {
389 x if x >= 0.8 => "Excellent consistency above random chance",
390 x if x >= 0.6 => "Good consistency above random chance",
391 x if x >= 0.4 => "Moderate consistency above random chance",
392 x if x >= 0.0 => "Some consistency above random chance",
393 _ => "Consistency below random chance - concerning",
394 }
395 }
396
397 fn interpret_pairwise(&self) -> &'static str {
398 match self.pairwise_stability {
399 x if x >= 0.8 => "Excellent pairwise stability across all comparisons",
400 x if x >= 0.6 => "Good pairwise stability",
401 x if x >= 0.4 => "Moderate pairwise stability",
402 x if x >= 0.2 => "Poor pairwise stability",
403 _ => "Very poor pairwise stability",
404 }
405 }
406
407 fn interpret_relative(&self) -> &'static str {
408 match self.relative_stability_index {
409 x if x >= 0.8 => "Excellent relative stability considering set sizes",
410 x if x >= 0.6 => "Good relative stability",
411 x if x >= 0.4 => "Moderate relative stability",
412 x if x >= 0.2 => "Poor relative stability",
413 _ => "Very poor relative stability",
414 }
415 }
416
417 fn overall_assessment(&self) -> &'static str {
418 let average = (self.jaccard_similarity
419 + self.dice_similarity
420 + self.overlap_coefficient
421 + self.consistency_index
422 + self.pairwise_stability
423 + self.relative_stability_index)
424 / 6.0;
425
426 match average {
427 x if x >= 0.8 => "EXCELLENT: Feature selection is highly stable and reliable",
428 x if x >= 0.6 => "GOOD: Feature selection shows good stability",
429 x if x >= 0.4 => {
430 "MODERATE: Feature selection has moderate stability - consider parameter tuning"
431 }
432 x if x >= 0.2 => "POOR: Feature selection is unstable - review methodology",
433 _ => "CRITICAL: Feature selection is highly unstable - major concerns",
434 }
435 }
436}
437
438#[allow(non_snake_case)]
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[test]
444 fn test_jaccard_similarity() {
445 let set1 = vec![0, 1, 2];
446 let set2 = vec![1, 2, 3];
447 let similarity = JaccardSimilarity::compute(&set1, &set2).unwrap();
448 assert!((similarity - 0.5).abs() < 1e-10);
449
450 let similarity = JaccardSimilarity::compute(&set1, &set1).unwrap();
452 assert!((similarity - 1.0).abs() < 1e-10);
453
454 let empty1: Vec<usize> = vec![];
456 let empty2: Vec<usize> = vec![];
457 let similarity = JaccardSimilarity::compute(&empty1, &empty2).unwrap();
458 assert!((similarity - 1.0).abs() < 1e-10);
459 }
460
461 #[test]
462 fn test_dice_similarity() {
463 let set1 = vec![0, 1, 2];
464 let set2 = vec![1, 2, 3];
465 let similarity = DiceSimilarity::compute(&set1, &set2).unwrap();
466 assert!((similarity - 2.0 / 3.0).abs() < 1e-10);
467 }
468
469 #[test]
470 fn test_overlap_coefficient() {
471 let set1 = vec![0, 1, 2];
472 let set2 = vec![1, 2, 3, 4];
473 let coefficient = OverlapCoefficient::compute(&set1, &set2).unwrap();
474 assert!((coefficient - 2.0 / 3.0).abs() < 1e-10);
475 }
476
477 #[test]
478 fn test_consistency_index() {
479 let set1 = vec![0, 1, 2];
480 let set2 = vec![1, 2, 3];
481 let total_features = 10;
482 let consistency = ConsistencyIndex::compute(&set1, &set2, total_features).unwrap();
483 assert!(consistency > -1.0 && consistency <= 1.0);
484 }
485
486 #[test]
487 fn test_stability_measures() {
488 let feature_sets = vec![vec![0, 1, 2], vec![1, 2, 3], vec![0, 2, 4]];
489 let total_features = 10;
490
491 let measures = StabilityMeasures::compute(&feature_sets, total_features).unwrap();
492
493 assert!(measures.jaccard_similarity >= 0.0 && measures.jaccard_similarity <= 1.0);
494 assert!(measures.dice_similarity >= 0.0 && measures.dice_similarity <= 1.0);
495 assert!(measures.overlap_coefficient >= 0.0 && measures.overlap_coefficient <= 1.0);
496 assert!(measures.pairwise_stability >= 0.0 && measures.pairwise_stability <= 1.0);
497
498 let report = measures.report();
499 assert!(report.contains("Stability Report"));
500 assert!(report.contains("Overall Assessment"));
501 }
502
503 #[test]
504 fn test_average_similarities() {
505 let feature_sets = vec![vec![0, 1, 2], vec![1, 2, 3], vec![2, 3, 4]];
506
507 let jaccard = JaccardSimilarity::average_similarity(&feature_sets).unwrap();
508 let dice = DiceSimilarity::average_similarity(&feature_sets).unwrap();
509 let overlap = OverlapCoefficient::average_coefficient(&feature_sets).unwrap();
510
511 assert!(jaccard >= 0.0 && jaccard <= 1.0);
512 assert!(dice >= 0.0 && dice <= 1.0);
513 assert!(overlap >= 0.0 && overlap <= 1.0);
514 }
515}