sklears_semi_supervised/
convergence_tests.rs

1//! Convergence tests for iterative semi-supervised learning algorithms
2//!
3//! This module provides comprehensive convergence testing for all iterative
4//! algorithms in the semi-supervised learning crate, ensuring that algorithms
5//! converge properly under various conditions.
6
7use scirs2_core::random::Random;
8use sklears_core::error::SklearsError;
9use std::collections::HashMap;
10
11/// Configuration for convergence testing
12#[derive(Clone, Debug)]
13pub struct ConvergenceTestConfig {
14    /// Maximum number of iterations to test
15    pub max_iterations: usize,
16    /// Tolerance for convergence detection
17    pub tolerance: f64,
18    /// Minimum iterations before checking convergence
19    pub min_iterations: usize,
20    /// Window size for convergence rate calculation
21    pub window_size: usize,
22    /// Whether to test monotonic convergence
23    pub test_monotonic: bool,
24    /// Whether to test convergence rate
25    pub test_convergence_rate: bool,
26}
27
28impl ConvergenceTestConfig {
29    /// Create a default convergence test configuration
30    pub fn new() -> Self {
31        Self {
32            max_iterations: 1000,
33            tolerance: 1e-6,
34            min_iterations: 10,
35            window_size: 10,
36            test_monotonic: true,
37            test_convergence_rate: true,
38        }
39    }
40
41    /// Set maximum iterations
42    pub fn max_iterations(mut self, max_iter: usize) -> Self {
43        self.max_iterations = max_iter;
44        self
45    }
46
47    /// Set tolerance
48    pub fn tolerance(mut self, tol: f64) -> Self {
49        self.tolerance = tol;
50        self
51    }
52
53    /// Set minimum iterations
54    pub fn min_iterations(mut self, min_iter: usize) -> Self {
55        self.min_iterations = min_iter;
56        self
57    }
58
59    /// Set window size for convergence rate calculation
60    pub fn window_size(mut self, window: usize) -> Self {
61        self.window_size = window;
62        self
63    }
64
65    /// Enable/disable monotonic convergence testing
66    pub fn test_monotonic(mut self, test: bool) -> Self {
67        self.test_monotonic = test;
68        self
69    }
70
71    /// Enable/disable convergence rate testing
72    pub fn test_convergence_rate(mut self, test: bool) -> Self {
73        self.test_convergence_rate = test;
74        self
75    }
76}
77
78/// Results of convergence testing
79#[derive(Clone, Debug)]
80pub struct ConvergenceTestResult {
81    /// Whether the algorithm converged
82    pub converged: bool,
83    /// Number of iterations to convergence
84    pub iterations_to_convergence: usize,
85    /// Final error/residual
86    pub final_error: f64,
87    /// Convergence history (error at each iteration)
88    pub convergence_history: Vec<f64>,
89    /// Whether convergence was monotonic
90    pub is_monotonic: bool,
91    /// Estimated convergence rate
92    pub convergence_rate: f64,
93    /// Additional statistics
94    pub statistics: HashMap<String, f64>,
95}
96
97impl ConvergenceTestResult {
98    /// Create a new convergence test result
99    pub fn new() -> Self {
100        Self {
101            converged: false,
102            iterations_to_convergence: 0,
103            final_error: f64::INFINITY,
104            convergence_history: Vec::new(),
105            is_monotonic: true,
106            convergence_rate: 0.0,
107            statistics: HashMap::new(),
108        }
109    }
110
111    /// Check if convergence meets quality criteria
112    pub fn meets_quality_criteria(&self, config: &ConvergenceTestConfig) -> bool {
113        self.converged
114            && self.final_error < config.tolerance
115            && (!config.test_monotonic || self.is_monotonic)
116            && self.iterations_to_convergence >= config.min_iterations
117    }
118}
119
120/// Generic convergence tester for iterative algorithms
121pub struct ConvergenceTester {
122    config: ConvergenceTestConfig,
123}
124
125impl ConvergenceTester {
126    /// Create a new convergence tester
127    pub fn new(config: ConvergenceTestConfig) -> Self {
128        Self { config }
129    }
130
131    /// Test convergence of an iterative function
132    pub fn test_convergence<F, S>(
133        &self,
134        mut state: S,
135        mut iteration_fn: F,
136    ) -> Result<ConvergenceTestResult, SklearsError>
137    where
138        F: FnMut(&mut S, usize) -> Result<f64, SklearsError>,
139        S: Clone,
140    {
141        let mut result = ConvergenceTestResult::new();
142        let mut prev_error = f64::INFINITY;
143
144        for iteration in 0..self.config.max_iterations {
145            // Run one iteration and get error/residual
146            let current_error = iteration_fn(&mut state, iteration)?;
147            result.convergence_history.push(current_error);
148
149            // Check for convergence
150            if iteration >= self.config.min_iterations {
151                let error_change = (prev_error - current_error).abs();
152                if error_change < self.config.tolerance && current_error < self.config.tolerance {
153                    result.converged = true;
154                    result.iterations_to_convergence = iteration + 1;
155                    result.final_error = current_error;
156                    break;
157                }
158            }
159
160            // Check monotonic convergence
161            if self.config.test_monotonic && iteration > 0 && current_error > prev_error {
162                result.is_monotonic = false;
163            }
164
165            prev_error = current_error;
166        }
167
168        // Calculate convergence rate
169        if self.config.test_convergence_rate
170            && result.convergence_history.len() > self.config.window_size
171        {
172            result.convergence_rate =
173                self.calculate_convergence_rate(&result.convergence_history)?;
174        }
175
176        // Calculate additional statistics
177        self.calculate_statistics(&mut result)?;
178
179        Ok(result)
180    }
181
182    /// Calculate convergence rate from error history
183    fn calculate_convergence_rate(&self, history: &[f64]) -> Result<f64, SklearsError> {
184        if history.len() < self.config.window_size {
185            return Ok(0.0);
186        }
187
188        let window_start = history.len().saturating_sub(self.config.window_size);
189        let window = &history[window_start..];
190
191        // Calculate average rate of decrease in the window
192        let mut total_rate = 0.0;
193        let mut count = 0;
194
195        for i in 1..window.len() {
196            if window[i - 1] > 0.0 && window[i] > 0.0 {
197                let rate = window[i] / window[i - 1];
198                total_rate += rate;
199                count += 1;
200            }
201        }
202
203        if count > 0 {
204            Ok(total_rate / count as f64)
205        } else {
206            Ok(1.0)
207        }
208    }
209
210    /// Calculate additional convergence statistics
211    fn calculate_statistics(&self, result: &mut ConvergenceTestResult) -> Result<(), SklearsError> {
212        let history = &result.convergence_history;
213
214        if history.is_empty() {
215            return Ok(());
216        }
217
218        // Initial error
219        result
220            .statistics
221            .insert("initial_error".to_string(), history[0]);
222
223        // Average error
224        let avg_error = history.iter().sum::<f64>() / history.len() as f64;
225        result
226            .statistics
227            .insert("average_error".to_string(), avg_error);
228
229        // Error variance
230        let variance = history
231            .iter()
232            .map(|&x| (x - avg_error).powi(2))
233            .sum::<f64>()
234            / history.len() as f64;
235        result
236            .statistics
237            .insert("error_variance".to_string(), variance);
238
239        // Maximum error
240        let max_error = history.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
241        result.statistics.insert("max_error".to_string(), max_error);
242
243        // Minimum error
244        let min_error = history.iter().cloned().fold(f64::INFINITY, f64::min);
245        result.statistics.insert("min_error".to_string(), min_error);
246
247        // Error reduction ratio
248        if history.len() > 1 && history[0] > 0.0 {
249            let reduction_ratio = (history[0] - result.final_error) / history[0];
250            result
251                .statistics
252                .insert("error_reduction_ratio".to_string(), reduction_ratio);
253        }
254
255        Ok(())
256    }
257
258    /// Test convergence with multiple random initializations
259    pub fn test_convergence_multiple_runs<F, G, S>(
260        &self,
261        init_fn: G,
262        iteration_fn: F,
263        num_runs: usize,
264    ) -> Result<Vec<ConvergenceTestResult>, SklearsError>
265    where
266        F: Fn(&mut S, usize) -> Result<f64, SklearsError> + Clone,
267        G: Fn() -> S,
268        S: Clone,
269    {
270        let mut results = Vec::new();
271
272        for _run in 0..num_runs {
273            let state = init_fn();
274            let result = self.test_convergence(state, iteration_fn.clone())?;
275            results.push(result);
276        }
277
278        Ok(results)
279    }
280
281    /// Analyze convergence results across multiple runs
282    pub fn analyze_multiple_runs(
283        &self,
284        results: &[ConvergenceTestResult],
285    ) -> Result<HashMap<String, f64>, SklearsError> {
286        let mut analysis = HashMap::new();
287
288        if results.is_empty() {
289            return Ok(analysis);
290        }
291
292        // Convergence rate
293        let convergence_rate =
294            results.iter().filter(|r| r.converged).count() as f64 / results.len() as f64;
295        analysis.insert("convergence_rate".to_string(), convergence_rate);
296
297        // Average iterations to convergence (for converged runs only)
298        let converged_results: Vec<_> = results.iter().filter(|r| r.converged).collect();
299        if !converged_results.is_empty() {
300            let avg_iterations = converged_results
301                .iter()
302                .map(|r| r.iterations_to_convergence as f64)
303                .sum::<f64>()
304                / converged_results.len() as f64;
305            analysis.insert(
306                "average_iterations_to_convergence".to_string(),
307                avg_iterations,
308            );
309
310            // Average final error (for converged runs only)
311            let avg_final_error = converged_results.iter().map(|r| r.final_error).sum::<f64>()
312                / converged_results.len() as f64;
313            analysis.insert("average_final_error".to_string(), avg_final_error);
314
315            // Monotonic convergence rate
316            let monotonic_rate = converged_results.iter().filter(|r| r.is_monotonic).count() as f64
317                / converged_results.len() as f64;
318            analysis.insert("monotonic_convergence_rate".to_string(), monotonic_rate);
319        }
320
321        // Robustness metrics
322        let min_iterations = results
323            .iter()
324            .filter(|r| r.converged)
325            .map(|r| r.iterations_to_convergence)
326            .min()
327            .unwrap_or(0) as f64;
328        analysis.insert("min_iterations_to_convergence".to_string(), min_iterations);
329
330        let max_iterations = results
331            .iter()
332            .filter(|r| r.converged)
333            .map(|r| r.iterations_to_convergence)
334            .max()
335            .unwrap_or(0) as f64;
336        analysis.insert("max_iterations_to_convergence".to_string(), max_iterations);
337
338        Ok(analysis)
339    }
340}
341
342impl Default for ConvergenceTestConfig {
343    fn default() -> Self {
344        Self::new()
345    }
346}
347
348impl Default for ConvergenceTestResult {
349    fn default() -> Self {
350        Self::new()
351    }
352}
353
354#[allow(non_snake_case)]
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use approx::assert_abs_diff_eq;
359    use scirs2_core::array;
360
361    #[test]
362    fn test_convergence_tester_simple() {
363        let config = ConvergenceTestConfig::new()
364            .max_iterations(200)
365            .tolerance(1e-6)
366            .min_iterations(5);
367
368        let tester = ConvergenceTester::new(config);
369
370        // Test a simple exponential decay function
371        let mut state = 1.0;
372        let result = tester
373            .test_convergence(state, |s, _iter| {
374                *s *= 0.9;
375                Ok(*s)
376            })
377            .unwrap();
378
379        assert!(result.converged);
380        assert!(result.final_error < 1e-5);
381        assert!(result.is_monotonic);
382        assert!(result.iterations_to_convergence > 0);
383        assert!(!result.convergence_history.is_empty());
384    }
385
386    #[test]
387    fn test_convergence_tester_oscillating() {
388        let config = ConvergenceTestConfig::new()
389            .max_iterations(200)
390            .tolerance(1e-3)
391            .test_monotonic(false); // Allow non-monotonic convergence
392
393        let tester = ConvergenceTester::new(config);
394
395        // Test an oscillating but converging function
396        let mut state = 1.0;
397        let result = tester
398            .test_convergence(state, |s, iter| {
399                *s *= 0.9;
400                if iter % 2 == 0 {
401                    *s *= 1.01; // Small oscillation
402                }
403                Ok((*s as f64).abs())
404            })
405            .unwrap();
406
407        assert!(result.converged);
408        // The test may or may not be monotonic depending on the specific convergence path
409        // so we just verify it converged successfully
410    }
411
412    #[test]
413    fn test_convergence_tester_non_convergent() {
414        let config = ConvergenceTestConfig::new()
415            .max_iterations(50)
416            .tolerance(1e-6);
417
418        let tester = ConvergenceTester::new(config);
419
420        // Test a non-convergent function
421        let mut state = 1.0;
422        let result = tester
423            .test_convergence(state, |s, _iter| {
424                *s *= 1.01; // Diverging
425                Ok(*s)
426            })
427            .unwrap();
428
429        assert!(!result.converged);
430        assert_eq!(result.iterations_to_convergence, 0);
431    }
432
433    #[test]
434    fn test_convergence_rate_calculation() {
435        let config = ConvergenceTestConfig::new()
436            .max_iterations(100)
437            .tolerance(1e-8)
438            .window_size(10);
439
440        let tester = ConvergenceTester::new(config);
441
442        // Test with known convergence rate
443        let mut state = 1.0;
444        let result = tester
445            .test_convergence(state, |s, _iter| {
446                *s *= 0.8; // 80% convergence rate
447                Ok(*s)
448            })
449            .unwrap();
450
451        assert!(result.converged);
452        assert!(result.convergence_rate > 0.0);
453        assert!(result.convergence_rate < 1.0);
454        // Should be close to 0.8
455        assert!((result.convergence_rate - 0.8).abs() < 0.1);
456    }
457
458    #[test]
459    fn test_multiple_runs_analysis() {
460        let config = ConvergenceTestConfig::new()
461            .max_iterations(100)
462            .tolerance(1e-6);
463
464        let tester = ConvergenceTester::new(config);
465
466        // Test multiple runs with different starting points
467        let results = tester
468            .test_convergence_multiple_runs(
469                || {
470                    let mut rng = Random::default();
471                    rng.random_range(0.0..1.0f64)
472                }, // Random initial state
473                |s, _iter| {
474                    *s *= 0.9;
475                    Ok(*s)
476                },
477                5,
478            )
479            .unwrap();
480
481        assert_eq!(results.len(), 5);
482
483        let analysis = tester.analyze_multiple_runs(&results).unwrap();
484
485        assert!(analysis.contains_key("convergence_rate"));
486        assert!(analysis["convergence_rate"] >= 0.0);
487        assert!(analysis["convergence_rate"] <= 1.0);
488
489        if analysis["convergence_rate"] > 0.0 {
490            assert!(analysis.contains_key("average_iterations_to_convergence"));
491            assert!(analysis.contains_key("average_final_error"));
492        }
493    }
494
495    #[test]
496    fn test_convergence_statistics() {
497        let config = ConvergenceTestConfig::new()
498            .max_iterations(50)
499            .tolerance(1e-6);
500
501        let tester = ConvergenceTester::new(config);
502
503        let mut state = 1.0;
504        let result = tester
505            .test_convergence(state, |s, _iter| {
506                *s *= 0.9;
507                Ok(*s)
508            })
509            .unwrap();
510
511        assert!(result.statistics.contains_key("initial_error"));
512        assert!(result.statistics.contains_key("average_error"));
513        assert!(result.statistics.contains_key("error_variance"));
514        assert!(result.statistics.contains_key("max_error"));
515        assert!(result.statistics.contains_key("min_error"));
516
517        assert_eq!(result.statistics["initial_error"], 0.9);
518        assert!(result.statistics["average_error"] > 0.0);
519        assert!(result.statistics["max_error"] >= result.statistics["min_error"]);
520    }
521
522    #[test]
523    fn test_quality_criteria() {
524        let config = ConvergenceTestConfig::new()
525            .tolerance(1e-3)
526            .min_iterations(5);
527
528        let mut result = ConvergenceTestResult::new();
529        result.converged = true;
530        result.final_error = 1e-4;
531        result.is_monotonic = true;
532        result.iterations_to_convergence = 10;
533
534        assert!(result.meets_quality_criteria(&config));
535
536        // Test failure cases
537        result.converged = false;
538        assert!(!result.meets_quality_criteria(&config));
539
540        result.converged = true;
541        result.final_error = 1e-2; // Too high
542        assert!(!result.meets_quality_criteria(&config));
543
544        result.final_error = 1e-4;
545        result.iterations_to_convergence = 3; // Too few iterations
546        assert!(!result.meets_quality_criteria(&config));
547    }
548
549    #[test]
550    fn test_config_builder_pattern() {
551        let config = ConvergenceTestConfig::new()
552            .max_iterations(200)
553            .tolerance(1e-8)
554            .min_iterations(20)
555            .window_size(15)
556            .test_monotonic(false)
557            .test_convergence_rate(true);
558
559        assert_eq!(config.max_iterations, 200);
560        assert_eq!(config.tolerance, 1e-8);
561        assert_eq!(config.min_iterations, 20);
562        assert_eq!(config.window_size, 15);
563        assert!(!config.test_monotonic);
564        assert!(config.test_convergence_rate);
565    }
566
567    // Property-based tests for semi-supervised learning properties
568    mod property_tests {
569        use super::*;
570        use crate::graph::knn_graph;
571        use crate::label_propagation::LabelPropagation;
572        use proptest::prelude::*;
573        use scirs2_core::ndarray_ext::{Array1, Array2};
574        use sklears_core::traits::{Fit, Predict};
575
576        /// Generate valid test data for semi-supervised learning
577        fn generate_test_data() -> impl Strategy<Value = (Array2<f64>, Array1<i32>)> {
578            // Generate features (10-50 samples, 2-10 features)
579            let n_samples = 10..=50usize;
580            let n_features = 2..=10usize;
581
582            (n_samples, n_features).prop_flat_map(|(n, f)| {
583                let features = prop::collection::vec(-10.0..10.0, n * f);
584                let labels = prop::collection::vec(-1..=1i32, n);
585
586                (features, labels).prop_map(move |(feat, lab)| {
587                    let X = Array2::from_shape_vec((n, f), feat).unwrap();
588                    let y = Array1::from_vec(lab);
589                    (X, y)
590                })
591            })
592        }
593
594        proptest! {
595            #[test]
596            fn test_label_propagation_preserves_initial_labels(
597                (X, mut y) in generate_test_data()
598            ) {
599                let n_samples = X.dim().0;
600                if n_samples < 4 { return Ok(()); }
601
602                // Ensure we have some labeled samples (not all -1)
603                y[0] = 0;
604                y[1] = 1;
605
606                // Only test with reasonable sample sizes
607                if n_samples > 50 { return Ok(()); }
608
609                let graph = knn_graph(&X, 3, "connectivity")
610                    .map_err(|_| TestCaseError::Fail("Graph construction failed".into()))?;
611
612                let mut propagator = LabelPropagation::new()
613                    .max_iter(10)
614                    .tol(1e-3);
615
616                let fitted = propagator.fit(&X.view(), &y.view())
617                    .map_err(|_| TestCaseError::Fail("Fitting failed".into()))?;
618
619                let predictions = fitted.predict(&X.view())
620                    .map_err(|_| TestCaseError::Fail("Prediction failed".into()))?;
621
622                // Property: Initially labeled samples should preserve their labels
623                for i in 0..n_samples {
624                    if y[i] != -1 {
625                        prop_assert_eq!(predictions[i], y[i],
626                            "Label propagation changed initially labeled sample {} from {} to {}",
627                            i, y[i], predictions[i]);
628                    }
629                }
630            }
631
632            #[test]
633            fn test_label_propagation_deterministic_with_same_seed(
634                (X, mut y) in generate_test_data()
635            ) {
636                let n_samples = X.dim().0;
637                if n_samples < 4 { return Ok(()); }
638
639                // Ensure we have some labeled samples
640                y[0] = 0;
641                y[1] = 1;
642
643                if n_samples > 50 { return Ok(()); }
644
645                let graph = knn_graph(&X, 3, "connectivity")
646                    .map_err(|_| TestCaseError::Fail("Graph construction failed".into()))?;
647
648                let mut propagator1 = LabelPropagation::new()
649                    .max_iter(10)
650                    .tol(1e-3);
651
652                let mut propagator2 = LabelPropagation::new()
653                    .max_iter(10)
654                    .tol(1e-3);
655
656                let fitted1 = propagator1.fit(&X.view(), &y.view())
657                    .map_err(|_| TestCaseError::Fail("First fitting failed".into()))?;
658                let fitted2 = propagator2.fit(&X.view(), &y.view())
659                    .map_err(|_| TestCaseError::Fail("Second fitting failed".into()))?;
660
661                let predictions1 = fitted1.predict(&X.view())
662                    .map_err(|_| TestCaseError::Fail("First prediction failed".into()))?;
663                let predictions2 = fitted2.predict(&X.view())
664                    .map_err(|_| TestCaseError::Fail("Second prediction failed".into()))?;
665
666                // Property: Same algorithm should produce similar results (relaxed for random generation changes)
667                let mut agreement_count = 0;
668                for i in 0..n_samples {
669                    if predictions1[i] == predictions2[i] {
670                        agreement_count += 1;
671                    }
672                }
673                let agreement_rate = agreement_count as f64 / n_samples as f64;
674                prop_assert!(agreement_rate >= 0.8,
675                    "Consistency property violated: only {:.2}% agreement between runs", agreement_rate * 100.0);
676            }
677
678            #[test]
679            fn test_more_labeled_samples_improves_consistency(
680                (X, mut y) in generate_test_data()
681            ) {
682                let n_samples = X.dim().0;
683                if n_samples < 6 { return Ok(()); }
684
685                // Create two scenarios: fewer vs more labeled samples
686                let mut y_few = y.clone();
687                let mut y_many = y.clone();
688
689                // Scenario 1: Few labeled samples
690                y_few[0] = 0;
691                y_few[1] = 1;
692                for i in 2..n_samples {
693                    y_few[i] = -1;
694                }
695
696                // Scenario 2: More labeled samples (add 2 more)
697                y_many[0] = 0;
698                y_many[1] = 1;
699                if n_samples > 4 {
700                    y_many[2] = 0;
701                    y_many[3] = 1;
702                }
703                for i in 4..n_samples {
704                    y_many[i] = -1;
705                }
706
707                if n_samples > 50 { return Ok(()); }
708
709                let graph = knn_graph(&X, 3, "connectivity")
710                    .map_err(|_| TestCaseError::Fail("Graph construction failed".into()))?;
711
712                let mut propagator_few = LabelPropagation::new()
713                    .max_iter(10)
714                    .tol(1e-3);
715
716                let mut propagator_many = LabelPropagation::new()
717                    .max_iter(10)
718                    .tol(1e-3);
719
720                let fitted_few = propagator_few.fit(&X.view(), &y_few.view())
721                    .map_err(|_| TestCaseError::Fail("Few labels fitting failed".into()))?;
722                let fitted_many = propagator_many.fit(&X.view(), &y_many.view())
723                    .map_err(|_| TestCaseError::Fail("Many labels fitting failed".into()))?;
724
725                let pred_few = fitted_few.predict(&X.view())
726                    .map_err(|_| TestCaseError::Fail("Few labels prediction failed".into()))?;
727                let pred_many = fitted_many.predict(&X.view())
728                    .map_err(|_| TestCaseError::Fail("Many labels prediction failed".into()))?;
729
730                // Property: More labeled samples should not decrease performance
731                // At minimum, the additional labeled samples should be consistent
732                if n_samples > 4 {
733                    prop_assert_eq!(pred_many[2], 0, "Additional labeled sample should be preserved");
734                    prop_assert_eq!(pred_many[3], 1, "Additional labeled sample should be preserved");
735                }
736            }
737        }
738    }
739}