sklears_inspection/
testing.rs

1//! Comprehensive Testing Framework
2//!
3//! This module provides comprehensive testing capabilities including property-based tests,
4//! fidelity tests, consistency tests, and robustness validation for explanation methods.
5
6use crate::{Float, SklResult};
7// ✅ SciRS2 Policy Compliant Import
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
9use scirs2_core::random::{Rng, SeedableRng};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Configuration for property-based testing
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct PropertyTestConfig {
16    /// Number of test cases to generate
17    pub num_test_cases: usize,
18    /// Random seed for reproducibility
19    pub seed: Option<u64>,
20    /// Tolerance for floating point comparisons
21    pub tolerance: Float,
22    /// Maximum number of features to test
23    pub max_features: usize,
24    /// Maximum number of samples to test
25    pub max_samples: usize,
26    /// Enable verbose logging
27    pub verbose: bool,
28}
29
30impl Default for PropertyTestConfig {
31    fn default() -> Self {
32        Self {
33            num_test_cases: 100,
34            seed: Some(42),
35            tolerance: 1e-6,
36            max_features: 100,
37            max_samples: 1000,
38            verbose: false,
39        }
40    }
41}
42
43/// Test result for property-based tests
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct PropertyTestResult {
46    /// Test name
47    pub test_name: String,
48    /// Number of test cases run
49    pub cases_run: usize,
50    /// Number of test cases passed
51    pub cases_passed: usize,
52    /// Test passed overall
53    pub passed: bool,
54    /// Failure messages if any
55    pub failure_messages: Vec<String>,
56    /// Property violations found
57    pub violations: Vec<PropertyViolation>,
58    /// Execution time in milliseconds
59    pub execution_time_ms: f64,
60}
61
62/// Property violation details
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct PropertyViolation {
65    /// Property name that was violated
66    pub property: String,
67    /// Description of the violation
68    pub description: String,
69    /// Input data that caused the violation
70    pub input_data: String,
71    /// Expected behavior
72    pub expected: String,
73    /// Actual behavior
74    pub actual: String,
75}
76
77/// Fidelity test configuration
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct FidelityTestConfig {
80    /// Minimum required fidelity score
81    pub min_fidelity: Float,
82    /// Number of samples for fidelity testing
83    pub num_samples: usize,
84    /// Perturbation magnitude for fidelity testing
85    pub perturbation_magnitude: Float,
86    /// Random seed
87    pub seed: Option<u64>,
88}
89
90impl Default for FidelityTestConfig {
91    fn default() -> Self {
92        Self {
93            min_fidelity: 0.8,
94            num_samples: 100,
95            perturbation_magnitude: 0.1,
96            seed: Some(42),
97        }
98    }
99}
100
101/// Consistency test configuration
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct ConsistencyTestConfig {
104    /// Methods to compare for consistency
105    pub methods: Vec<String>,
106    /// Tolerance for consistency checks
107    pub tolerance: Float,
108    /// Number of test cases
109    pub num_test_cases: usize,
110    /// Random seed
111    pub seed: Option<u64>,
112}
113
114impl Default for ConsistencyTestConfig {
115    fn default() -> Self {
116        Self {
117            methods: vec!["permutation".to_string(), "shap".to_string()],
118            tolerance: 0.2,
119            num_test_cases: 50,
120            seed: Some(42),
121        }
122    }
123}
124
125/// Robustness test configuration
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct RobustnessTestConfig {
128    /// Noise levels to test
129    pub noise_levels: Vec<Float>,
130    /// Number of noise perturbations per level
131    pub perturbations_per_level: usize,
132    /// Maximum acceptable explanation change
133    pub max_explanation_change: Float,
134    /// Random seed
135    pub seed: Option<u64>,
136}
137
138impl Default for RobustnessTestConfig {
139    fn default() -> Self {
140        Self {
141            noise_levels: vec![0.01, 0.05, 0.1, 0.2],
142            perturbations_per_level: 10,
143            max_explanation_change: 0.3,
144            seed: Some(42),
145        }
146    }
147}
148
149/// Comprehensive testing suite
150pub struct TestingSuite {
151    /// Property test configuration
152    property_config: PropertyTestConfig,
153    /// Fidelity test configuration
154    fidelity_config: FidelityTestConfig,
155    /// Consistency test configuration
156    consistency_config: ConsistencyTestConfig,
157    /// Robustness test configuration
158    robustness_config: RobustnessTestConfig,
159}
160
161impl TestingSuite {
162    /// Create a new testing suite
163    ///
164    /// # Examples
165    ///
166    /// ```rust
167    /// use sklears_inspection::testing::TestingSuite;
168    ///
169    /// let suite = TestingSuite::new();
170    /// ```
171    pub fn new() -> Self {
172        Self {
173            property_config: PropertyTestConfig::default(),
174            fidelity_config: FidelityTestConfig::default(),
175            consistency_config: ConsistencyTestConfig::default(),
176            robustness_config: RobustnessTestConfig::default(),
177        }
178    }
179
180    /// Create a new testing suite with custom configurations
181    pub fn with_configs(
182        property_config: PropertyTestConfig,
183        fidelity_config: FidelityTestConfig,
184        consistency_config: ConsistencyTestConfig,
185        robustness_config: RobustnessTestConfig,
186    ) -> Self {
187        Self {
188            property_config,
189            fidelity_config,
190            consistency_config,
191            robustness_config,
192        }
193    }
194
195    /// Run property-based tests for explanation properties
196    pub fn test_explanation_properties<F>(
197        &self,
198        explanation_fn: F,
199        test_name: &str,
200    ) -> SklResult<PropertyTestResult>
201    where
202        F: Fn(&ArrayView2<Float>) -> SklResult<Array1<Float>>,
203    {
204        let start_time = std::time::Instant::now();
205        let mut cases_passed = 0;
206        let mut violations = Vec::new();
207        let mut failure_messages = Vec::new();
208
209        // Set up random number generator
210        let mut rng = if let Some(seed) = self.property_config.seed {
211            scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
212        } else {
213            scirs2_core::random::rngs::StdRng::from_rng(&mut scirs2_core::random::thread_rng())
214        };
215
216        for case_idx in 0..self.property_config.num_test_cases {
217            // Generate random test data
218            let test_data = self.generate_test_data(&mut rng)?;
219
220            // Run explanation function
221            match explanation_fn(&test_data.view()) {
222                Ok(explanation) => {
223                    // Test explanation properties
224                    let property_results = self.check_explanation_properties(
225                        &test_data.view(),
226                        &explanation.view(),
227                        case_idx,
228                    );
229
230                    if property_results.is_empty() {
231                        cases_passed += 1;
232                    } else {
233                        violations.extend(property_results);
234                    }
235                }
236                Err(e) => {
237                    failure_messages.push(format!("Case {}: {}", case_idx, e));
238                }
239            }
240        }
241
242        let execution_time = start_time.elapsed().as_millis() as f64;
243        let passed = violations.is_empty() && failure_messages.is_empty();
244
245        Ok(PropertyTestResult {
246            test_name: test_name.to_string(),
247            cases_run: self.property_config.num_test_cases,
248            cases_passed,
249            passed,
250            failure_messages,
251            violations,
252            execution_time_ms: execution_time,
253        })
254    }
255
256    /// Test fidelity of local explanations
257    pub fn test_explanation_fidelity<F, M>(
258        &self,
259        model_fn: M,
260        explanation_fn: F,
261        test_data: &ArrayView2<Float>,
262    ) -> SklResult<Float>
263    where
264        M: Fn(&ArrayView2<Float>) -> SklResult<Array1<Float>>,
265        F: Fn(&ArrayView2<Float>) -> SklResult<Array1<Float>>,
266    {
267        use scirs2_core::random::Rng;
268
269        let mut rng = if let Some(seed) = self.fidelity_config.seed {
270            scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
271        } else {
272            scirs2_core::random::rngs::StdRng::from_rng(&mut scirs2_core::random::thread_rng())
273        };
274
275        let mut total_fidelity = 0.0;
276        let n_features = test_data.ncols();
277
278        for i in 0..self.fidelity_config.num_samples.min(test_data.nrows()) {
279            let instance = test_data.row(i);
280            let instance_2d = instance.insert_axis(Axis(0));
281            let original_prediction = model_fn(&instance_2d.view())?;
282            let explanation = explanation_fn(&instance_2d.view())?;
283
284            // Create perturbed instances based on explanation
285            let mut correct_predictions = 0;
286            let mut total_predictions = 0;
287
288            for _ in 0..10 {
289                let mut perturbed_instance = instance.to_owned();
290
291                // Perturb features based on explanation importance
292                for j in 0..n_features {
293                    if rng.gen::<Float>() < self.fidelity_config.perturbation_magnitude {
294                        let importance = explanation[j].abs();
295                        let perturbation = rng.gen_range(-importance..importance);
296                        perturbed_instance[j] += perturbation;
297                    }
298                }
299
300                let perturbed_2d = perturbed_instance.view().insert_axis(Axis(0));
301                let perturbed_prediction = model_fn(&perturbed_2d)?;
302
303                // Check if explanation correctly predicts direction of change
304                let prediction_change = perturbed_prediction[0] - original_prediction[0];
305                let expected_change = explanation
306                    .iter()
307                    .zip(perturbed_instance.iter().zip(instance.iter()))
308                    .map(|(imp, (new_val, old_val))| imp * (new_val - old_val))
309                    .sum::<Float>();
310
311                if prediction_change.signum() == expected_change.signum() {
312                    correct_predictions += 1;
313                }
314                total_predictions += 1;
315            }
316
317            total_fidelity += correct_predictions as Float / total_predictions as Float;
318        }
319
320        Ok(total_fidelity / self.fidelity_config.num_samples.min(test_data.nrows()) as Float)
321    }
322
323    /// Test consistency across different explanation methods
324    pub fn test_method_consistency<F1, F2>(
325        &self,
326        method1: F1,
327        method2: F2,
328        test_data: &ArrayView2<Float>,
329        method1_name: &str,
330        method2_name: &str,
331    ) -> SklResult<Float>
332    where
333        F1: Fn(&ArrayView2<Float>) -> SklResult<Array1<Float>>,
334        F2: Fn(&ArrayView2<Float>) -> SklResult<Array1<Float>>,
335    {
336        let mut total_correlation = 0.0;
337        let mut valid_comparisons = 0;
338
339        for i in 0..self
340            .consistency_config
341            .num_test_cases
342            .min(test_data.nrows())
343        {
344            let instance = test_data.row(i).insert_axis(Axis(0));
345
346            let explanation1 = method1(&instance.view())?;
347            let explanation2 = method2(&instance.view())?;
348
349            // Calculate correlation between explanations
350            let correlation =
351                self.calculate_correlation(&explanation1.view(), &explanation2.view());
352
353            if !correlation.is_nan() {
354                total_correlation += correlation;
355                valid_comparisons += 1;
356            }
357        }
358
359        if valid_comparisons > 0 {
360            Ok(total_correlation / valid_comparisons as Float)
361        } else {
362            Ok(0.0)
363        }
364    }
365
366    /// Test robustness to input noise
367    pub fn test_robustness<F>(
368        &self,
369        explanation_fn: F,
370        test_data: &ArrayView2<Float>,
371    ) -> SklResult<HashMap<String, Float>>
372    where
373        F: Fn(&ArrayView2<Float>) -> SklResult<Array1<Float>>,
374    {
375        use scirs2_core::random::Rng;
376
377        let mut rng = if let Some(seed) = self.robustness_config.seed {
378            scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
379        } else {
380            scirs2_core::random::rngs::StdRng::from_rng(&mut scirs2_core::random::thread_rng())
381        };
382
383        let mut results = HashMap::new();
384
385        for &noise_level in &self.robustness_config.noise_levels {
386            let mut total_stability = 0.0;
387            let mut valid_tests = 0;
388
389            for i in 0..test_data.nrows() {
390                let original_instance = test_data.row(i);
391                let original_2d = original_instance.insert_axis(Axis(0));
392                let original_explanation = explanation_fn(&original_2d.view())?;
393
394                let mut perturbation_stabilities = Vec::new();
395
396                for _ in 0..self.robustness_config.perturbations_per_level {
397                    // Add noise to the instance
398                    let mut noisy_instance = original_instance.to_owned();
399                    for j in 0..noisy_instance.len() {
400                        let noise = rng.gen_range(-noise_level..noise_level);
401                        noisy_instance[j] += noise;
402                    }
403
404                    let noisy_2d = noisy_instance.view().insert_axis(Axis(0));
405                    let noisy_explanation = explanation_fn(&noisy_2d)?;
406
407                    // Calculate stability (1 - relative change)
408                    let stability = self.calculate_explanation_stability(
409                        &original_explanation.view(),
410                        &noisy_explanation.view(),
411                    );
412
413                    if !stability.is_nan() {
414                        perturbation_stabilities.push(stability);
415                    }
416                }
417
418                if !perturbation_stabilities.is_empty() {
419                    let avg_stability = perturbation_stabilities.iter().sum::<Float>()
420                        / perturbation_stabilities.len() as Float;
421                    total_stability += avg_stability;
422                    valid_tests += 1;
423                }
424            }
425
426            if valid_tests > 0 {
427                results.insert(
428                    format!("noise_{:.3}", noise_level),
429                    total_stability / valid_tests as Float,
430                );
431            }
432        }
433
434        Ok(results)
435    }
436
437    /// Run comprehensive test suite
438    pub fn run_comprehensive_tests<F, M>(
439        &self,
440        model_fn: M,
441        explanation_fn: F,
442        test_data: &ArrayView2<Float>,
443        test_name: &str,
444    ) -> SklResult<HashMap<String, serde_json::Value>>
445    where
446        F: Fn(&ArrayView2<Float>) -> SklResult<Array1<Float>> + Clone,
447        M: Fn(&ArrayView2<Float>) -> SklResult<Array1<Float>>,
448    {
449        let mut results = HashMap::new();
450
451        // Property-based tests
452        let property_result =
453            self.test_explanation_properties(explanation_fn.clone(), test_name)?;
454        results.insert(
455            "property_tests".to_string(),
456            serde_json::to_value(property_result).map_err(|e| {
457                crate::SklearsError::InvalidInput(format!(
458                    "Failed to serialize property test results: {}",
459                    e
460                ))
461            })?,
462        );
463
464        // Robustness tests
465        let robustness_result = self.test_robustness(explanation_fn, test_data)?;
466        results.insert(
467            "robustness_tests".to_string(),
468            serde_json::to_value(robustness_result).map_err(|e| {
469                crate::SklearsError::InvalidInput(format!(
470                    "Failed to serialize robustness test results: {}",
471                    e
472                ))
473            })?,
474        );
475
476        Ok(results)
477    }
478
479    // Helper methods
480    fn generate_test_data<R: scirs2_core::random::Rng>(
481        &self,
482        rng: &mut R,
483    ) -> SklResult<Array2<Float>> {
484        use scirs2_core::random::Rng;
485
486        let n_samples = rng.gen_range(10..self.property_config.max_samples.min(100 + 1));
487        let n_features = rng.gen_range(5..self.property_config.max_features.min(20 + 1));
488
489        let mut data = Array2::zeros((n_samples, n_features));
490
491        for i in 0..n_samples {
492            for j in 0..n_features {
493                data[[i, j]] = rng.gen_range(-2.0..2.0);
494            }
495        }
496
497        Ok(data)
498    }
499
500    fn check_explanation_properties(
501        &self,
502        _data: &ArrayView2<Float>,
503        explanation: &ArrayView1<Float>,
504        case_idx: usize,
505    ) -> Vec<PropertyViolation> {
506        let mut violations = Vec::new();
507
508        // Property 1: Explanation should not contain NaN or infinite values
509        for (i, &value) in explanation.iter().enumerate() {
510            if value.is_nan() {
511                violations.push(PropertyViolation {
512                    property: "no_nan_values".to_string(),
513                    description: format!("Explanation contains NaN value at index {}", i),
514                    input_data: format!("Case {}", case_idx),
515                    expected: "Finite numeric value".to_string(),
516                    actual: "NaN".to_string(),
517                });
518            }
519            if value.is_infinite() {
520                violations.push(PropertyViolation {
521                    property: "no_infinite_values".to_string(),
522                    description: format!("Explanation contains infinite value at index {}", i),
523                    input_data: format!("Case {}", case_idx),
524                    expected: "Finite numeric value".to_string(),
525                    actual: "Infinite".to_string(),
526                });
527            }
528        }
529
530        // Property 2: Explanation should not be all zeros (unless intended)
531        let sum_abs = explanation.iter().map(|x| x.abs()).sum::<Float>();
532        if sum_abs < self.property_config.tolerance {
533            violations.push(PropertyViolation {
534                property: "non_trivial_explanation".to_string(),
535                description: "Explanation is all zeros or nearly zeros".to_string(),
536                input_data: format!("Case {}", case_idx),
537                expected: "Non-zero explanation values".to_string(),
538                actual: format!("Sum of absolute values: {}", sum_abs),
539            });
540        }
541
542        violations
543    }
544
545    fn calculate_correlation(&self, a: &ArrayView1<Float>, b: &ArrayView1<Float>) -> Float {
546        if a.len() != b.len() {
547            return Float::NAN;
548        }
549
550        let n = a.len() as Float;
551        let mean_a = a.iter().sum::<Float>() / n;
552        let mean_b = b.iter().sum::<Float>() / n;
553
554        let numerator: Float = a
555            .iter()
556            .zip(b.iter())
557            .map(|(ai, bi)| (ai - mean_a) * (bi - mean_b))
558            .sum();
559
560        let sum_sq_a: Float = a.iter().map(|ai| (ai - mean_a).powi(2)).sum();
561        let sum_sq_b: Float = b.iter().map(|bi| (bi - mean_b).powi(2)).sum();
562
563        let denominator = (sum_sq_a * sum_sq_b).sqrt();
564
565        if denominator == 0.0 {
566            Float::NAN
567        } else {
568            numerator / denominator
569        }
570    }
571
572    fn calculate_explanation_stability(
573        &self,
574        original: &ArrayView1<Float>,
575        perturbed: &ArrayView1<Float>,
576    ) -> Float {
577        if original.len() != perturbed.len() {
578            return Float::NAN;
579        }
580
581        let relative_changes: Vec<Float> = original
582            .iter()
583            .zip(perturbed.iter())
584            .map(|(orig, pert)| {
585                if orig.abs() < self.property_config.tolerance {
586                    pert.abs()
587                } else {
588                    (pert - orig).abs() / orig.abs()
589                }
590            })
591            .collect();
592
593        let max_relative_change = relative_changes.iter().copied().fold(0.0f64, f64::max);
594
595        // Stability is 1 - normalized change (clamped to [0, 1])
596        (1.0 - max_relative_change).max(0.0).min(1.0)
597    }
598}
599
600impl Default for TestingSuite {
601    fn default() -> Self {
602        Self::new()
603    }
604}
605
606/// Validate explanation output against expected properties
607pub fn validate_explanation_output(
608    explanation: &ArrayView1<Float>,
609    expected_properties: &ExplanationProperties,
610) -> SklResult<ValidationResult> {
611    let mut violations = Vec::new();
612    let mut passed_checks = 0;
613    let total_checks = 5; // Adjust based on number of checks
614
615    // Check for finite values
616    let has_finite_values = explanation.iter().all(|x| x.is_finite());
617    if has_finite_values {
618        passed_checks += 1;
619    } else {
620        violations.push("Explanation contains non-finite values".to_string());
621    }
622
623    // Check sum constraint if specified
624    if let Some(expected_sum) = expected_properties.expected_sum {
625        let actual_sum = explanation.sum();
626        if (actual_sum - expected_sum).abs() < expected_properties.tolerance {
627            passed_checks += 1;
628        } else {
629            violations.push(format!(
630                "Sum constraint violated: expected {}, got {}",
631                expected_sum, actual_sum
632            ));
633        }
634    } else {
635        passed_checks += 1; // Skip this check
636    }
637
638    // Check non-negativity if required
639    if expected_properties.non_negative {
640        let is_non_negative = explanation.iter().all(|x| *x >= 0.0);
641        if is_non_negative {
642            passed_checks += 1;
643        } else {
644            violations.push(
645                "Explanation contains negative values when non-negativity is required".to_string(),
646            );
647        }
648    } else {
649        passed_checks += 1; // Skip this check
650    }
651
652    // Check magnitude bounds
653    let max_magnitude = explanation.iter().map(|x| x.abs()).fold(0.0, f64::max);
654    if max_magnitude <= expected_properties.max_magnitude {
655        passed_checks += 1;
656    } else {
657        violations.push(format!(
658            "Magnitude bound violated: max magnitude {} exceeds limit {}",
659            max_magnitude, expected_properties.max_magnitude
660        ));
661    }
662
663    // Check sparsity if required
664    if let Some(max_non_zero) = expected_properties.max_non_zero_features {
665        let non_zero_count = explanation
666            .iter()
667            .filter(|x| x.abs() > expected_properties.tolerance)
668            .count();
669        if non_zero_count <= max_non_zero {
670            passed_checks += 1;
671        } else {
672            violations.push(format!(
673                "Sparsity constraint violated: {} non-zero features exceeds limit {}",
674                non_zero_count, max_non_zero
675            ));
676        }
677    } else {
678        passed_checks += 1; // Skip this check
679    }
680
681    Ok(ValidationResult {
682        passed: violations.is_empty(),
683        passed_checks,
684        total_checks,
685        violations,
686        score: passed_checks as Float / total_checks as Float,
687    })
688}
689
690/// Expected properties for explanation validation
691#[derive(Debug, Clone)]
692pub struct ExplanationProperties {
693    /// Expected sum of explanation values
694    pub expected_sum: Option<Float>,
695    /// Whether values should be non-negative
696    pub non_negative: bool,
697    /// Maximum allowed magnitude
698    pub max_magnitude: Float,
699    /// Maximum number of non-zero features
700    pub max_non_zero_features: Option<usize>,
701    /// Tolerance for floating point comparisons
702    pub tolerance: Float,
703}
704
705impl Default for ExplanationProperties {
706    fn default() -> Self {
707        Self {
708            expected_sum: None,
709            non_negative: false,
710            max_magnitude: 10.0,
711            max_non_zero_features: None,
712            tolerance: 1e-6,
713        }
714    }
715}
716
717/// Validation result
718#[derive(Debug, Clone)]
719pub struct ValidationResult {
720    /// Whether all checks passed
721    pub passed: bool,
722    /// Number of checks that passed
723    pub passed_checks: usize,
724    /// Total number of checks
725    pub total_checks: usize,
726    /// Violation messages
727    pub violations: Vec<String>,
728    /// Overall validation score (0-1)
729    pub score: Float,
730}
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735    // ✅ SciRS2 Policy Compliant Import
736    use scirs2_core::ndarray::array;
737
738    #[test]
739    fn test_testing_suite_creation() {
740        let suite = TestingSuite::new();
741        assert_eq!(suite.property_config.num_test_cases, 100);
742        assert_eq!(suite.property_config.seed, Some(42));
743    }
744
745    #[test]
746    fn test_property_test_config_default() {
747        let config = PropertyTestConfig::default();
748        assert_eq!(config.num_test_cases, 100);
749        assert_eq!(config.tolerance, 1e-6);
750        assert_eq!(config.max_features, 100);
751    }
752
753    #[test]
754    fn test_explanation_property_validation() {
755        let explanation = array![0.3, 0.5, -0.2, 0.1];
756        let properties = ExplanationProperties::default();
757
758        let result = validate_explanation_output(&explanation.view(), &properties).unwrap();
759        assert!(result.passed);
760        assert!(result.score > 0.8);
761    }
762
763    #[test]
764    fn test_explanation_with_nan_validation() {
765        let explanation = array![0.3, Float::NAN, -0.2, 0.1];
766        let properties = ExplanationProperties::default();
767
768        let result = validate_explanation_output(&explanation.view(), &properties).unwrap();
769        assert!(!result.passed);
770        assert!(!result.violations.is_empty());
771    }
772
773    #[test]
774    fn test_non_negative_constraint() {
775        let explanation = array![0.3, 0.5, -0.2, 0.1];
776        let properties = ExplanationProperties {
777            non_negative: true,
778            ..Default::default()
779        };
780
781        let result = validate_explanation_output(&explanation.view(), &properties).unwrap();
782        assert!(!result.passed);
783        assert!(result
784            .violations
785            .iter()
786            .any(|v| v.contains("negative values")));
787    }
788
789    #[test]
790    fn test_sum_constraint() {
791        let explanation = array![0.3, 0.5, 0.2, 0.0];
792        let properties = ExplanationProperties {
793            expected_sum: Some(1.0),
794            tolerance: 1e-6,
795            ..Default::default()
796        };
797
798        let result = validate_explanation_output(&explanation.view(), &properties).unwrap();
799        assert!(result.passed);
800    }
801
802    #[test]
803    fn test_magnitude_constraint() {
804        let explanation = array![0.3, 15.0, 0.2, 0.1]; // 15.0 exceeds default max_magnitude of 10.0
805        let properties = ExplanationProperties::default();
806
807        let result = validate_explanation_output(&explanation.view(), &properties).unwrap();
808        assert!(!result.passed);
809        assert!(result
810            .violations
811            .iter()
812            .any(|v| v.contains("Magnitude bound violated")));
813    }
814
815    #[test]
816    fn test_sparsity_constraint() {
817        let explanation = array![0.3, 0.5, 0.2, 0.1, 0.05];
818        let properties = ExplanationProperties {
819            max_non_zero_features: Some(3),
820            tolerance: 1e-6,
821            ..Default::default()
822        };
823
824        let result = validate_explanation_output(&explanation.view(), &properties).unwrap();
825        assert!(!result.passed);
826        assert!(result
827            .violations
828            .iter()
829            .any(|v| v.contains("Sparsity constraint violated")));
830    }
831
832    #[test]
833    fn test_correlation_calculation() {
834        let suite = TestingSuite::new();
835        let a = array![1.0, 2.0, 3.0, 4.0];
836        let b = array![2.0, 4.0, 6.0, 8.0]; // Perfect positive correlation
837
838        let correlation = suite.calculate_correlation(&a.view(), &b.view());
839        assert!((correlation - 1.0).abs() < 1e-6);
840    }
841
842    #[test]
843    fn test_explanation_stability() {
844        let suite = TestingSuite::new();
845        let original = array![0.3, 0.5, 0.2];
846        let similar = array![0.31, 0.49, 0.21]; // Small changes
847
848        let stability = suite.calculate_explanation_stability(&original.view(), &similar.view());
849        assert!(stability > 0.8); // Should be quite stable
850    }
851
852    #[test]
853    fn test_generate_test_data() {
854        let config = PropertyTestConfig {
855            max_samples: 50,
856            max_features: 10,
857            seed: Some(42),
858            ..Default::default()
859        };
860        let suite = TestingSuite {
861            property_config: config,
862            ..Default::default()
863        };
864
865        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
866        let data = suite.generate_test_data(&mut rng).unwrap();
867
868        assert!(data.nrows() >= 10 && data.nrows() <= 50);
869        assert!(data.ncols() >= 5 && data.ncols() <= 10);
870    }
871
872    #[test]
873    fn test_property_violation_creation() {
874        let violation = PropertyViolation {
875            property: "test_property".to_string(),
876            description: "Test violation".to_string(),
877            input_data: "test_data".to_string(),
878            expected: "expected_behavior".to_string(),
879            actual: "actual_behavior".to_string(),
880        };
881
882        assert_eq!(violation.property, "test_property");
883        assert_eq!(violation.description, "Test violation");
884    }
885
886    #[test]
887    fn test_fidelity_config_default() {
888        let config = FidelityTestConfig::default();
889        assert_eq!(config.min_fidelity, 0.8);
890        assert_eq!(config.num_samples, 100);
891        assert_eq!(config.perturbation_magnitude, 0.1);
892    }
893
894    #[test]
895    fn test_consistency_config_default() {
896        let config = ConsistencyTestConfig::default();
897        assert_eq!(config.methods.len(), 2);
898        assert_eq!(config.tolerance, 0.2);
899        assert_eq!(config.num_test_cases, 50);
900    }
901
902    #[test]
903    fn test_robustness_config_default() {
904        let config = RobustnessTestConfig::default();
905        assert_eq!(config.noise_levels.len(), 4);
906        assert_eq!(config.perturbations_per_level, 10);
907        assert_eq!(config.max_explanation_change, 0.3);
908    }
909}