1use scirs2_core::random::Random;
8use sklears_core::error::SklearsError;
9use std::collections::HashMap;
10
11#[derive(Clone, Debug)]
13pub struct ConvergenceTestConfig {
14 pub max_iterations: usize,
16 pub tolerance: f64,
18 pub min_iterations: usize,
20 pub window_size: usize,
22 pub test_monotonic: bool,
24 pub test_convergence_rate: bool,
26}
27
28impl ConvergenceTestConfig {
29 pub fn new() -> Self {
31 Self {
32 max_iterations: 1000,
33 tolerance: 1e-6,
34 min_iterations: 10,
35 window_size: 10,
36 test_monotonic: true,
37 test_convergence_rate: true,
38 }
39 }
40
41 pub fn max_iterations(mut self, max_iter: usize) -> Self {
43 self.max_iterations = max_iter;
44 self
45 }
46
47 pub fn tolerance(mut self, tol: f64) -> Self {
49 self.tolerance = tol;
50 self
51 }
52
53 pub fn min_iterations(mut self, min_iter: usize) -> Self {
55 self.min_iterations = min_iter;
56 self
57 }
58
59 pub fn window_size(mut self, window: usize) -> Self {
61 self.window_size = window;
62 self
63 }
64
65 pub fn test_monotonic(mut self, test: bool) -> Self {
67 self.test_monotonic = test;
68 self
69 }
70
71 pub fn test_convergence_rate(mut self, test: bool) -> Self {
73 self.test_convergence_rate = test;
74 self
75 }
76}
77
78#[derive(Clone, Debug)]
80pub struct ConvergenceTestResult {
81 pub converged: bool,
83 pub iterations_to_convergence: usize,
85 pub final_error: f64,
87 pub convergence_history: Vec<f64>,
89 pub is_monotonic: bool,
91 pub convergence_rate: f64,
93 pub statistics: HashMap<String, f64>,
95}
96
97impl ConvergenceTestResult {
98 pub fn new() -> Self {
100 Self {
101 converged: false,
102 iterations_to_convergence: 0,
103 final_error: f64::INFINITY,
104 convergence_history: Vec::new(),
105 is_monotonic: true,
106 convergence_rate: 0.0,
107 statistics: HashMap::new(),
108 }
109 }
110
111 pub fn meets_quality_criteria(&self, config: &ConvergenceTestConfig) -> bool {
113 self.converged
114 && self.final_error < config.tolerance
115 && (!config.test_monotonic || self.is_monotonic)
116 && self.iterations_to_convergence >= config.min_iterations
117 }
118}
119
120pub struct ConvergenceTester {
122 config: ConvergenceTestConfig,
123}
124
125impl ConvergenceTester {
126 pub fn new(config: ConvergenceTestConfig) -> Self {
128 Self { config }
129 }
130
131 pub fn test_convergence<F, S>(
133 &self,
134 mut state: S,
135 mut iteration_fn: F,
136 ) -> Result<ConvergenceTestResult, SklearsError>
137 where
138 F: FnMut(&mut S, usize) -> Result<f64, SklearsError>,
139 S: Clone,
140 {
141 let mut result = ConvergenceTestResult::new();
142 let mut prev_error = f64::INFINITY;
143
144 for iteration in 0..self.config.max_iterations {
145 let current_error = iteration_fn(&mut state, iteration)?;
147 result.convergence_history.push(current_error);
148
149 if iteration >= self.config.min_iterations {
151 let error_change = (prev_error - current_error).abs();
152 if error_change < self.config.tolerance && current_error < self.config.tolerance {
153 result.converged = true;
154 result.iterations_to_convergence = iteration + 1;
155 result.final_error = current_error;
156 break;
157 }
158 }
159
160 if self.config.test_monotonic && iteration > 0 && current_error > prev_error {
162 result.is_monotonic = false;
163 }
164
165 prev_error = current_error;
166 }
167
168 if self.config.test_convergence_rate
170 && result.convergence_history.len() > self.config.window_size
171 {
172 result.convergence_rate =
173 self.calculate_convergence_rate(&result.convergence_history)?;
174 }
175
176 self.calculate_statistics(&mut result)?;
178
179 Ok(result)
180 }
181
182 fn calculate_convergence_rate(&self, history: &[f64]) -> Result<f64, SklearsError> {
184 if history.len() < self.config.window_size {
185 return Ok(0.0);
186 }
187
188 let window_start = history.len().saturating_sub(self.config.window_size);
189 let window = &history[window_start..];
190
191 let mut total_rate = 0.0;
193 let mut count = 0;
194
195 for i in 1..window.len() {
196 if window[i - 1] > 0.0 && window[i] > 0.0 {
197 let rate = window[i] / window[i - 1];
198 total_rate += rate;
199 count += 1;
200 }
201 }
202
203 if count > 0 {
204 Ok(total_rate / count as f64)
205 } else {
206 Ok(1.0)
207 }
208 }
209
210 fn calculate_statistics(&self, result: &mut ConvergenceTestResult) -> Result<(), SklearsError> {
212 let history = &result.convergence_history;
213
214 if history.is_empty() {
215 return Ok(());
216 }
217
218 result
220 .statistics
221 .insert("initial_error".to_string(), history[0]);
222
223 let avg_error = history.iter().sum::<f64>() / history.len() as f64;
225 result
226 .statistics
227 .insert("average_error".to_string(), avg_error);
228
229 let variance = history
231 .iter()
232 .map(|&x| (x - avg_error).powi(2))
233 .sum::<f64>()
234 / history.len() as f64;
235 result
236 .statistics
237 .insert("error_variance".to_string(), variance);
238
239 let max_error = history.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
241 result.statistics.insert("max_error".to_string(), max_error);
242
243 let min_error = history.iter().cloned().fold(f64::INFINITY, f64::min);
245 result.statistics.insert("min_error".to_string(), min_error);
246
247 if history.len() > 1 && history[0] > 0.0 {
249 let reduction_ratio = (history[0] - result.final_error) / history[0];
250 result
251 .statistics
252 .insert("error_reduction_ratio".to_string(), reduction_ratio);
253 }
254
255 Ok(())
256 }
257
258 pub fn test_convergence_multiple_runs<F, G, S>(
260 &self,
261 init_fn: G,
262 iteration_fn: F,
263 num_runs: usize,
264 ) -> Result<Vec<ConvergenceTestResult>, SklearsError>
265 where
266 F: Fn(&mut S, usize) -> Result<f64, SklearsError> + Clone,
267 G: Fn() -> S,
268 S: Clone,
269 {
270 let mut results = Vec::new();
271
272 for _run in 0..num_runs {
273 let state = init_fn();
274 let result = self.test_convergence(state, iteration_fn.clone())?;
275 results.push(result);
276 }
277
278 Ok(results)
279 }
280
281 pub fn analyze_multiple_runs(
283 &self,
284 results: &[ConvergenceTestResult],
285 ) -> Result<HashMap<String, f64>, SklearsError> {
286 let mut analysis = HashMap::new();
287
288 if results.is_empty() {
289 return Ok(analysis);
290 }
291
292 let convergence_rate =
294 results.iter().filter(|r| r.converged).count() as f64 / results.len() as f64;
295 analysis.insert("convergence_rate".to_string(), convergence_rate);
296
297 let converged_results: Vec<_> = results.iter().filter(|r| r.converged).collect();
299 if !converged_results.is_empty() {
300 let avg_iterations = converged_results
301 .iter()
302 .map(|r| r.iterations_to_convergence as f64)
303 .sum::<f64>()
304 / converged_results.len() as f64;
305 analysis.insert(
306 "average_iterations_to_convergence".to_string(),
307 avg_iterations,
308 );
309
310 let avg_final_error = converged_results.iter().map(|r| r.final_error).sum::<f64>()
312 / converged_results.len() as f64;
313 analysis.insert("average_final_error".to_string(), avg_final_error);
314
315 let monotonic_rate = converged_results.iter().filter(|r| r.is_monotonic).count() as f64
317 / converged_results.len() as f64;
318 analysis.insert("monotonic_convergence_rate".to_string(), monotonic_rate);
319 }
320
321 let min_iterations = results
323 .iter()
324 .filter(|r| r.converged)
325 .map(|r| r.iterations_to_convergence)
326 .min()
327 .unwrap_or(0) as f64;
328 analysis.insert("min_iterations_to_convergence".to_string(), min_iterations);
329
330 let max_iterations = results
331 .iter()
332 .filter(|r| r.converged)
333 .map(|r| r.iterations_to_convergence)
334 .max()
335 .unwrap_or(0) as f64;
336 analysis.insert("max_iterations_to_convergence".to_string(), max_iterations);
337
338 Ok(analysis)
339 }
340}
341
342impl Default for ConvergenceTestConfig {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348impl Default for ConvergenceTestResult {
349 fn default() -> Self {
350 Self::new()
351 }
352}
353
354#[allow(non_snake_case)]
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use approx::assert_abs_diff_eq;
359 use scirs2_core::array;
360
361 #[test]
362 fn test_convergence_tester_simple() {
363 let config = ConvergenceTestConfig::new()
364 .max_iterations(200)
365 .tolerance(1e-6)
366 .min_iterations(5);
367
368 let tester = ConvergenceTester::new(config);
369
370 let mut state = 1.0;
372 let result = tester
373 .test_convergence(state, |s, _iter| {
374 *s *= 0.9;
375 Ok(*s)
376 })
377 .unwrap();
378
379 assert!(result.converged);
380 assert!(result.final_error < 1e-5);
381 assert!(result.is_monotonic);
382 assert!(result.iterations_to_convergence > 0);
383 assert!(!result.convergence_history.is_empty());
384 }
385
386 #[test]
387 fn test_convergence_tester_oscillating() {
388 let config = ConvergenceTestConfig::new()
389 .max_iterations(200)
390 .tolerance(1e-3)
391 .test_monotonic(false); let tester = ConvergenceTester::new(config);
394
395 let mut state = 1.0;
397 let result = tester
398 .test_convergence(state, |s, iter| {
399 *s *= 0.9;
400 if iter % 2 == 0 {
401 *s *= 1.01; }
403 Ok((*s as f64).abs())
404 })
405 .unwrap();
406
407 assert!(result.converged);
408 }
411
412 #[test]
413 fn test_convergence_tester_non_convergent() {
414 let config = ConvergenceTestConfig::new()
415 .max_iterations(50)
416 .tolerance(1e-6);
417
418 let tester = ConvergenceTester::new(config);
419
420 let mut state = 1.0;
422 let result = tester
423 .test_convergence(state, |s, _iter| {
424 *s *= 1.01; Ok(*s)
426 })
427 .unwrap();
428
429 assert!(!result.converged);
430 assert_eq!(result.iterations_to_convergence, 0);
431 }
432
433 #[test]
434 fn test_convergence_rate_calculation() {
435 let config = ConvergenceTestConfig::new()
436 .max_iterations(100)
437 .tolerance(1e-8)
438 .window_size(10);
439
440 let tester = ConvergenceTester::new(config);
441
442 let mut state = 1.0;
444 let result = tester
445 .test_convergence(state, |s, _iter| {
446 *s *= 0.8; Ok(*s)
448 })
449 .unwrap();
450
451 assert!(result.converged);
452 assert!(result.convergence_rate > 0.0);
453 assert!(result.convergence_rate < 1.0);
454 assert!((result.convergence_rate - 0.8).abs() < 0.1);
456 }
457
458 #[test]
459 fn test_multiple_runs_analysis() {
460 let config = ConvergenceTestConfig::new()
461 .max_iterations(100)
462 .tolerance(1e-6);
463
464 let tester = ConvergenceTester::new(config);
465
466 let results = tester
468 .test_convergence_multiple_runs(
469 || {
470 let mut rng = Random::default();
471 rng.random_range(0.0..1.0f64)
472 }, |s, _iter| {
474 *s *= 0.9;
475 Ok(*s)
476 },
477 5,
478 )
479 .unwrap();
480
481 assert_eq!(results.len(), 5);
482
483 let analysis = tester.analyze_multiple_runs(&results).unwrap();
484
485 assert!(analysis.contains_key("convergence_rate"));
486 assert!(analysis["convergence_rate"] >= 0.0);
487 assert!(analysis["convergence_rate"] <= 1.0);
488
489 if analysis["convergence_rate"] > 0.0 {
490 assert!(analysis.contains_key("average_iterations_to_convergence"));
491 assert!(analysis.contains_key("average_final_error"));
492 }
493 }
494
495 #[test]
496 fn test_convergence_statistics() {
497 let config = ConvergenceTestConfig::new()
498 .max_iterations(50)
499 .tolerance(1e-6);
500
501 let tester = ConvergenceTester::new(config);
502
503 let mut state = 1.0;
504 let result = tester
505 .test_convergence(state, |s, _iter| {
506 *s *= 0.9;
507 Ok(*s)
508 })
509 .unwrap();
510
511 assert!(result.statistics.contains_key("initial_error"));
512 assert!(result.statistics.contains_key("average_error"));
513 assert!(result.statistics.contains_key("error_variance"));
514 assert!(result.statistics.contains_key("max_error"));
515 assert!(result.statistics.contains_key("min_error"));
516
517 assert_eq!(result.statistics["initial_error"], 0.9);
518 assert!(result.statistics["average_error"] > 0.0);
519 assert!(result.statistics["max_error"] >= result.statistics["min_error"]);
520 }
521
522 #[test]
523 fn test_quality_criteria() {
524 let config = ConvergenceTestConfig::new()
525 .tolerance(1e-3)
526 .min_iterations(5);
527
528 let mut result = ConvergenceTestResult::new();
529 result.converged = true;
530 result.final_error = 1e-4;
531 result.is_monotonic = true;
532 result.iterations_to_convergence = 10;
533
534 assert!(result.meets_quality_criteria(&config));
535
536 result.converged = false;
538 assert!(!result.meets_quality_criteria(&config));
539
540 result.converged = true;
541 result.final_error = 1e-2; assert!(!result.meets_quality_criteria(&config));
543
544 result.final_error = 1e-4;
545 result.iterations_to_convergence = 3; assert!(!result.meets_quality_criteria(&config));
547 }
548
549 #[test]
550 fn test_config_builder_pattern() {
551 let config = ConvergenceTestConfig::new()
552 .max_iterations(200)
553 .tolerance(1e-8)
554 .min_iterations(20)
555 .window_size(15)
556 .test_monotonic(false)
557 .test_convergence_rate(true);
558
559 assert_eq!(config.max_iterations, 200);
560 assert_eq!(config.tolerance, 1e-8);
561 assert_eq!(config.min_iterations, 20);
562 assert_eq!(config.window_size, 15);
563 assert!(!config.test_monotonic);
564 assert!(config.test_convergence_rate);
565 }
566
567 mod property_tests {
569 use super::*;
570 use crate::graph::knn_graph;
571 use crate::label_propagation::LabelPropagation;
572 use proptest::prelude::*;
573 use scirs2_core::ndarray_ext::{Array1, Array2};
574 use sklears_core::traits::{Fit, Predict};
575
576 fn generate_test_data() -> impl Strategy<Value = (Array2<f64>, Array1<i32>)> {
578 let n_samples = 10..=50usize;
580 let n_features = 2..=10usize;
581
582 (n_samples, n_features).prop_flat_map(|(n, f)| {
583 let features = prop::collection::vec(-10.0..10.0, n * f);
584 let labels = prop::collection::vec(-1..=1i32, n);
585
586 (features, labels).prop_map(move |(feat, lab)| {
587 let X = Array2::from_shape_vec((n, f), feat).unwrap();
588 let y = Array1::from_vec(lab);
589 (X, y)
590 })
591 })
592 }
593
594 proptest! {
595 #[test]
596 fn test_label_propagation_preserves_initial_labels(
597 (X, mut y) in generate_test_data()
598 ) {
599 let n_samples = X.dim().0;
600 if n_samples < 4 { return Ok(()); }
601
602 y[0] = 0;
604 y[1] = 1;
605
606 if n_samples > 50 { return Ok(()); }
608
609 let graph = knn_graph(&X, 3, "connectivity")
610 .map_err(|_| TestCaseError::Fail("Graph construction failed".into()))?;
611
612 let mut propagator = LabelPropagation::new()
613 .max_iter(10)
614 .tol(1e-3);
615
616 let fitted = propagator.fit(&X.view(), &y.view())
617 .map_err(|_| TestCaseError::Fail("Fitting failed".into()))?;
618
619 let predictions = fitted.predict(&X.view())
620 .map_err(|_| TestCaseError::Fail("Prediction failed".into()))?;
621
622 for i in 0..n_samples {
624 if y[i] != -1 {
625 prop_assert_eq!(predictions[i], y[i],
626 "Label propagation changed initially labeled sample {} from {} to {}",
627 i, y[i], predictions[i]);
628 }
629 }
630 }
631
632 #[test]
633 fn test_label_propagation_deterministic_with_same_seed(
634 (X, mut y) in generate_test_data()
635 ) {
636 let n_samples = X.dim().0;
637 if n_samples < 4 { return Ok(()); }
638
639 y[0] = 0;
641 y[1] = 1;
642
643 if n_samples > 50 { return Ok(()); }
644
645 let graph = knn_graph(&X, 3, "connectivity")
646 .map_err(|_| TestCaseError::Fail("Graph construction failed".into()))?;
647
648 let mut propagator1 = LabelPropagation::new()
649 .max_iter(10)
650 .tol(1e-3);
651
652 let mut propagator2 = LabelPropagation::new()
653 .max_iter(10)
654 .tol(1e-3);
655
656 let fitted1 = propagator1.fit(&X.view(), &y.view())
657 .map_err(|_| TestCaseError::Fail("First fitting failed".into()))?;
658 let fitted2 = propagator2.fit(&X.view(), &y.view())
659 .map_err(|_| TestCaseError::Fail("Second fitting failed".into()))?;
660
661 let predictions1 = fitted1.predict(&X.view())
662 .map_err(|_| TestCaseError::Fail("First prediction failed".into()))?;
663 let predictions2 = fitted2.predict(&X.view())
664 .map_err(|_| TestCaseError::Fail("Second prediction failed".into()))?;
665
666 let mut agreement_count = 0;
668 for i in 0..n_samples {
669 if predictions1[i] == predictions2[i] {
670 agreement_count += 1;
671 }
672 }
673 let agreement_rate = agreement_count as f64 / n_samples as f64;
674 prop_assert!(agreement_rate >= 0.8,
675 "Consistency property violated: only {:.2}% agreement between runs", agreement_rate * 100.0);
676 }
677
678 #[test]
679 fn test_more_labeled_samples_improves_consistency(
680 (X, mut y) in generate_test_data()
681 ) {
682 let n_samples = X.dim().0;
683 if n_samples < 6 { return Ok(()); }
684
685 let mut y_few = y.clone();
687 let mut y_many = y.clone();
688
689 y_few[0] = 0;
691 y_few[1] = 1;
692 for i in 2..n_samples {
693 y_few[i] = -1;
694 }
695
696 y_many[0] = 0;
698 y_many[1] = 1;
699 if n_samples > 4 {
700 y_many[2] = 0;
701 y_many[3] = 1;
702 }
703 for i in 4..n_samples {
704 y_many[i] = -1;
705 }
706
707 if n_samples > 50 { return Ok(()); }
708
709 let graph = knn_graph(&X, 3, "connectivity")
710 .map_err(|_| TestCaseError::Fail("Graph construction failed".into()))?;
711
712 let mut propagator_few = LabelPropagation::new()
713 .max_iter(10)
714 .tol(1e-3);
715
716 let mut propagator_many = LabelPropagation::new()
717 .max_iter(10)
718 .tol(1e-3);
719
720 let fitted_few = propagator_few.fit(&X.view(), &y_few.view())
721 .map_err(|_| TestCaseError::Fail("Few labels fitting failed".into()))?;
722 let fitted_many = propagator_many.fit(&X.view(), &y_many.view())
723 .map_err(|_| TestCaseError::Fail("Many labels fitting failed".into()))?;
724
725 let pred_few = fitted_few.predict(&X.view())
726 .map_err(|_| TestCaseError::Fail("Few labels prediction failed".into()))?;
727 let pred_many = fitted_many.predict(&X.view())
728 .map_err(|_| TestCaseError::Fail("Many labels prediction failed".into()))?;
729
730 if n_samples > 4 {
733 prop_assert_eq!(pred_many[2], 0, "Additional labeled sample should be preserved");
734 prop_assert_eq!(pred_many[3], 1, "Additional labeled sample should be preserved");
735 }
736 }
737 }
738 }
739}