sklears_semi_supervised/
robust_graph_methods.rs

1//! Robust graph learning methods for semi-supervised learning
2//!
3//! This module provides robust graph construction algorithms that are resistant
4//! to outliers, noise, and adversarial examples in semi-supervised learning.
5
6use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::random::rand_prelude::*;
8use scirs2_core::random::Random;
9use sklears_core::error::SklearsError;
10use std::collections::HashMap;
11
12/// Robust graph construction using M-estimators and outlier detection
13#[derive(Clone)]
14pub struct RobustGraphConstruction {
15    /// Number of neighbors for k-NN graph construction
16    pub k_neighbors: usize,
17    /// Robust distance metric: "huber", "tukey", "cauchy", "welsch"
18    pub robust_metric: String,
19    /// Robustness parameter for M-estimators
20    pub robustness_param: f64,
21    /// Outlier detection threshold
22    pub outlier_threshold: f64,
23    /// Graph construction method: "knn", "epsilon", "adaptive"
24    pub construction_method: String,
25    /// Epsilon parameter for epsilon-neighborhood graphs
26    pub epsilon: f64,
27    /// Random state for reproducibility
28    pub random_state: Option<u64>,
29}
30
31impl RobustGraphConstruction {
32    /// Create a new robust graph construction instance
33    pub fn new() -> Self {
34        Self {
35            k_neighbors: 5,
36            robust_metric: "huber".to_string(),
37            robustness_param: 1.345,
38            outlier_threshold: 3.0,
39            construction_method: "knn".to_string(),
40            epsilon: 1.0,
41            random_state: None,
42        }
43    }
44
45    /// Set the number of neighbors
46    pub fn k_neighbors(mut self, k: usize) -> Self {
47        self.k_neighbors = k;
48        self
49    }
50
51    /// Set the robust distance metric
52    pub fn robust_metric(mut self, metric: String) -> Self {
53        self.robust_metric = metric;
54        self
55    }
56
57    /// Set the robustness parameter
58    pub fn robustness_param(mut self, param: f64) -> Self {
59        self.robustness_param = param;
60        self
61    }
62
63    /// Set the outlier detection threshold
64    pub fn outlier_threshold(mut self, threshold: f64) -> Self {
65        self.outlier_threshold = threshold;
66        self
67    }
68
69    /// Set the graph construction method
70    pub fn construction_method(mut self, method: String) -> Self {
71        self.construction_method = method;
72        self
73    }
74
75    /// Set the epsilon parameter
76    pub fn epsilon(mut self, eps: f64) -> Self {
77        self.epsilon = eps;
78        self
79    }
80
81    /// Set the random state
82    pub fn random_state(mut self, seed: u64) -> Self {
83        self.random_state = Some(seed);
84        self
85    }
86
87    /// Construct a robust graph from data
88    pub fn fit(&self, X: &ArrayView2<f64>) -> Result<Array2<f64>, SklearsError> {
89        let n_samples = X.nrows();
90
91        // Detect outliers
92        let outlier_mask = self.detect_outliers(X)?;
93
94        // Construct graph with robust distances
95        let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
96
97        match self.construction_method.as_str() {
98            "knn" => {
99                graph = self.construct_robust_knn_graph(X, &outlier_mask)?;
100            }
101            "epsilon" => {
102                graph = self.construct_robust_epsilon_graph(X, &outlier_mask)?;
103            }
104            "adaptive" => {
105                graph = self.construct_adaptive_robust_graph(X, &outlier_mask)?;
106            }
107            _ => {
108                return Err(SklearsError::InvalidInput(format!(
109                    "Unknown construction method: {}",
110                    self.construction_method
111                )));
112            }
113        }
114
115        Ok(graph)
116    }
117
118    /// Detect outliers using robust statistical methods
119    fn detect_outliers(&self, X: &ArrayView2<f64>) -> Result<Array1<bool>, SklearsError> {
120        let n_samples = X.nrows();
121        let mut outlier_mask = Array1::from_elem(n_samples, false);
122
123        // Compute robust center and scale estimates
124        let robust_center = self.compute_robust_center(X)?;
125        let robust_scale = self.compute_robust_scale(X, &robust_center)?;
126
127        // Identify outliers based on Mahalanobis distance
128        for i in 0..n_samples {
129            let distance = self.mahalanobis_distance(&X.row(i), &robust_center, robust_scale);
130            if distance > self.outlier_threshold {
131                outlier_mask[i] = true;
132            }
133        }
134
135        Ok(outlier_mask)
136    }
137
138    /// Compute robust center using median
139    fn compute_robust_center(&self, X: &ArrayView2<f64>) -> Result<Array1<f64>, SklearsError> {
140        let n_features = X.ncols();
141        let mut center = Array1::zeros(n_features);
142
143        for j in 0..n_features {
144            let mut feature_values: Vec<f64> = X.column(j).to_vec();
145            feature_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
146
147            let median = if feature_values.len() % 2 == 0 {
148                let mid = feature_values.len() / 2;
149                (feature_values[mid - 1] + feature_values[mid]) / 2.0
150            } else {
151                feature_values[feature_values.len() / 2]
152            };
153
154            center[j] = median;
155        }
156
157        Ok(center)
158    }
159
160    /// Compute robust scale using MAD (Median Absolute Deviation)
161    fn compute_robust_scale(
162        &self,
163        X: &ArrayView2<f64>,
164        center: &Array1<f64>,
165    ) -> Result<f64, SklearsError> {
166        let mut deviations = Vec::new();
167
168        for i in 0..X.nrows() {
169            let distance = self.euclidean_distance(&X.row(i), &center.view());
170            deviations.push(distance);
171        }
172
173        deviations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
174
175        let mad = if deviations.len() % 2 == 0 {
176            let mid = deviations.len() / 2;
177            (deviations[mid - 1] + deviations[mid]) / 2.0
178        } else {
179            deviations[deviations.len() / 2]
180        };
181
182        // MAD to standard deviation conversion factor
183        Ok(mad * 1.4826)
184    }
185
186    /// Compute Mahalanobis distance (simplified version using robust scale)
187    fn mahalanobis_distance(&self, x: &ArrayView1<f64>, center: &Array1<f64>, scale: f64) -> f64 {
188        let distance = self.euclidean_distance(x, &center.view());
189        if scale > 0.0 {
190            distance / scale
191        } else {
192            distance
193        }
194    }
195
196    /// Construct robust k-NN graph
197    fn construct_robust_knn_graph(
198        &self,
199        X: &ArrayView2<f64>,
200        outlier_mask: &Array1<bool>,
201    ) -> Result<Array2<f64>, SklearsError> {
202        let n_samples = X.nrows();
203        let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
204
205        for i in 0..n_samples {
206            if outlier_mask[i] {
207                continue; // Skip outliers
208            }
209
210            let mut distances: Vec<(f64, usize)> = Vec::new();
211
212            for j in 0..n_samples {
213                if i != j && !outlier_mask[j] {
214                    let dist = self.robust_distance(&X.row(i), &X.row(j));
215                    distances.push((dist, j));
216                }
217            }
218
219            // Sort by distance and take k nearest neighbors
220            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
221
222            for (dist, j) in distances.iter().take(self.k_neighbors.min(distances.len())) {
223                let weight = self.robust_weight(*dist);
224                graph[[i, *j]] = weight;
225            }
226        }
227
228        // Make graph symmetric
229        for i in 0..n_samples {
230            for j in i + 1..n_samples {
231                let avg_weight = (graph[[i, j]] + graph[[j, i]]) / 2.0;
232                graph[[i, j]] = avg_weight;
233                graph[[j, i]] = avg_weight;
234            }
235        }
236
237        Ok(graph)
238    }
239
240    /// Construct robust epsilon-neighborhood graph
241    fn construct_robust_epsilon_graph(
242        &self,
243        X: &ArrayView2<f64>,
244        outlier_mask: &Array1<bool>,
245    ) -> Result<Array2<f64>, SklearsError> {
246        let n_samples = X.nrows();
247        let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
248
249        for i in 0..n_samples {
250            if outlier_mask[i] {
251                continue; // Skip outliers
252            }
253
254            for j in i + 1..n_samples {
255                if !outlier_mask[j] {
256                    let dist = self.robust_distance(&X.row(i), &X.row(j));
257
258                    if dist <= self.epsilon {
259                        let weight = self.robust_weight(dist);
260                        graph[[i, j]] = weight;
261                        graph[[j, i]] = weight;
262                    }
263                }
264            }
265        }
266
267        Ok(graph)
268    }
269
270    /// Construct adaptive robust graph
271    fn construct_adaptive_robust_graph(
272        &self,
273        X: &ArrayView2<f64>,
274        outlier_mask: &Array1<bool>,
275    ) -> Result<Array2<f64>, SklearsError> {
276        let n_samples = X.nrows();
277        let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
278
279        // Compute adaptive epsilon for each point
280        let adaptive_epsilons = self.compute_adaptive_epsilons(X, outlier_mask)?;
281
282        for i in 0..n_samples {
283            if outlier_mask[i] {
284                continue; // Skip outliers
285            }
286
287            for j in i + 1..n_samples {
288                if !outlier_mask[j] {
289                    let dist = self.robust_distance(&X.row(i), &X.row(j));
290                    let epsilon_ij = (adaptive_epsilons[i] + adaptive_epsilons[j]) / 2.0;
291
292                    if dist <= epsilon_ij {
293                        let weight = self.robust_weight(dist);
294                        graph[[i, j]] = weight;
295                        graph[[j, i]] = weight;
296                    }
297                }
298            }
299        }
300
301        Ok(graph)
302    }
303
304    /// Compute adaptive epsilon values for each point
305    fn compute_adaptive_epsilons(
306        &self,
307        X: &ArrayView2<f64>,
308        outlier_mask: &Array1<bool>,
309    ) -> Result<Array1<f64>, SklearsError> {
310        let n_samples = X.nrows();
311        let mut epsilons = Array1::zeros(n_samples);
312
313        for i in 0..n_samples {
314            if outlier_mask[i] {
315                epsilons[i] = f64::INFINITY; // Outliers get infinite epsilon
316                continue;
317            }
318
319            let mut distances = Vec::new();
320            for j in 0..n_samples {
321                if i != j && !outlier_mask[j] {
322                    let dist = self.robust_distance(&X.row(i), &X.row(j));
323                    distances.push(dist);
324                }
325            }
326
327            if distances.is_empty() {
328                epsilons[i] = 1.0; // Default epsilon
329                continue;
330            }
331
332            distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
333
334            // Use k-th nearest neighbor distance as adaptive epsilon
335            let k_index = self.k_neighbors.min(distances.len()) - 1;
336            epsilons[i] = distances[k_index];
337        }
338
339        Ok(epsilons)
340    }
341
342    /// Compute robust distance using M-estimators
343    fn robust_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
344        let euclidean_dist = self.euclidean_distance(x1, x2);
345
346        match self.robust_metric.as_str() {
347            "huber" => self.huber_distance(euclidean_dist),
348            "tukey" => self.tukey_distance(euclidean_dist),
349            "cauchy" => self.cauchy_distance(euclidean_dist),
350            "welsch" => self.welsch_distance(euclidean_dist),
351            _ => euclidean_dist, // Default to Euclidean
352        }
353    }
354
355    /// Huber robust distance
356    fn huber_distance(&self, dist: f64) -> f64 {
357        let c = self.robustness_param;
358        if dist <= c {
359            0.5 * dist.powi(2)
360        } else {
361            c * dist - 0.5 * c.powi(2)
362        }
363    }
364
365    /// Tukey biweight robust distance
366    fn tukey_distance(&self, dist: f64) -> f64 {
367        let c = self.robustness_param;
368        if dist <= c {
369            let ratio = dist / c;
370            (c.powi(2) / 6.0) * (1.0 - (1.0 - ratio.powi(2)).powi(3))
371        } else {
372            c.powi(2) / 6.0
373        }
374    }
375
376    /// Cauchy robust distance
377    fn cauchy_distance(&self, dist: f64) -> f64 {
378        let c = self.robustness_param;
379        (c.powi(2) / 2.0) * ((1.0 + (dist / c).powi(2)).ln())
380    }
381
382    /// Welsch robust distance
383    fn welsch_distance(&self, dist: f64) -> f64 {
384        let c = self.robustness_param;
385        (c.powi(2) / 2.0) * (1.0 - (-(dist / c).powi(2)).exp())
386    }
387
388    /// Compute robust weight from distance
389    fn robust_weight(&self, dist: f64) -> f64 {
390        match self.robust_metric.as_str() {
391            "huber" => self.huber_weight(dist),
392            "tukey" => self.tukey_weight(dist),
393            "cauchy" => self.cauchy_weight(dist),
394            "welsch" => self.welsch_weight(dist),
395            _ => (-dist.powi(2) / 2.0).exp(), // Default RBF kernel
396        }
397    }
398
399    /// Huber weight function
400    fn huber_weight(&self, dist: f64) -> f64 {
401        let c = self.robustness_param;
402        if dist <= c {
403            1.0
404        } else {
405            c / dist
406        }
407    }
408
409    /// Tukey biweight weight function
410    fn tukey_weight(&self, dist: f64) -> f64 {
411        let c = self.robustness_param;
412        if dist <= c {
413            let ratio = dist / c;
414            (1.0 - ratio.powi(2)).powi(2)
415        } else {
416            0.0
417        }
418    }
419
420    /// Cauchy weight function
421    fn cauchy_weight(&self, dist: f64) -> f64 {
422        let c = self.robustness_param;
423        1.0 / (1.0 + (dist / c).powi(2))
424    }
425
426    /// Welsch weight function
427    fn welsch_weight(&self, dist: f64) -> f64 {
428        let c = self.robustness_param;
429        (-(dist / c).powi(2)).exp()
430    }
431
432    /// Compute Euclidean distance between two vectors
433    fn euclidean_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
434        x1.iter()
435            .zip(x2.iter())
436            .map(|(a, b)| (a - b).powi(2))
437            .sum::<f64>()
438            .sqrt()
439    }
440}
441
442impl Default for RobustGraphConstruction {
443    fn default() -> Self {
444        Self::new()
445    }
446}
447
448/// Noise-robust label propagation for semi-supervised learning
449#[derive(Clone)]
450pub struct NoiseRobustPropagation {
451    /// Number of neighbors for graph construction
452    pub k_neighbors: usize,
453    /// Noise level estimation method: "mad", "iqr", "adaptive"
454    pub noise_estimation: String,
455    /// Robustness parameter for propagation
456    pub robustness_param: f64,
457    /// Maximum iterations for label propagation
458    pub max_iter: usize,
459    /// Convergence tolerance
460    pub tolerance: f64,
461    /// Alpha parameter for label spreading
462    pub alpha: f64,
463    /// Random state for reproducibility
464    pub random_state: Option<u64>,
465}
466
467impl NoiseRobustPropagation {
468    /// Create a new noise-robust propagation instance
469    pub fn new() -> Self {
470        Self {
471            k_neighbors: 5,
472            noise_estimation: "mad".to_string(),
473            robustness_param: 1.345,
474            max_iter: 1000,
475            tolerance: 1e-6,
476            alpha: 0.2,
477            random_state: None,
478        }
479    }
480
481    /// Set the number of neighbors
482    pub fn k_neighbors(mut self, k: usize) -> Self {
483        self.k_neighbors = k;
484        self
485    }
486
487    /// Set the noise estimation method
488    pub fn noise_estimation(mut self, method: String) -> Self {
489        self.noise_estimation = method;
490        self
491    }
492
493    /// Set the robustness parameter
494    pub fn robustness_param(mut self, param: f64) -> Self {
495        self.robustness_param = param;
496        self
497    }
498
499    /// Set the maximum iterations
500    pub fn max_iter(mut self, max_iter: usize) -> Self {
501        self.max_iter = max_iter;
502        self
503    }
504
505    /// Set the convergence tolerance
506    pub fn tolerance(mut self, tol: f64) -> Self {
507        self.tolerance = tol;
508        self
509    }
510
511    /// Set the alpha parameter
512    pub fn alpha(mut self, alpha: f64) -> Self {
513        self.alpha = alpha;
514        self
515    }
516
517    /// Set the random state
518    pub fn random_state(mut self, seed: u64) -> Self {
519        self.random_state = Some(seed);
520        self
521    }
522
523    /// Perform noise-robust label propagation
524    pub fn fit(
525        &self,
526        X: &ArrayView2<f64>,
527        y: &ArrayView1<i32>,
528    ) -> Result<Array1<i32>, SklearsError> {
529        let n_samples = X.nrows();
530
531        if y.len() != n_samples {
532            return Err(SklearsError::ShapeMismatch {
533                expected: format!("X and y should have same number of samples: {}", X.nrows()),
534                actual: format!("X has {} samples, y has {} samples", X.nrows(), y.len()),
535            });
536        }
537
538        // Estimate noise level
539        let noise_level = self.estimate_noise_level(X)?;
540
541        // Construct robust graph
542        let robust_graph_builder = RobustGraphConstruction::new()
543            .k_neighbors(self.k_neighbors)
544            .robustness_param(self.robustness_param)
545            .outlier_threshold(noise_level * 3.0);
546
547        let graph = robust_graph_builder.fit(X)?;
548
549        // Perform robust label propagation
550        let labels = self.robust_propagate_labels(&graph, y)?;
551
552        Ok(labels)
553    }
554
555    /// Estimate noise level in the data
556    fn estimate_noise_level(&self, X: &ArrayView2<f64>) -> Result<f64, SklearsError> {
557        match self.noise_estimation.as_str() {
558            "mad" => self.estimate_noise_mad(X),
559            "iqr" => self.estimate_noise_iqr(X),
560            "adaptive" => self.estimate_noise_adaptive(X),
561            _ => Ok(1.0), // Default noise level
562        }
563    }
564
565    /// Estimate noise using Median Absolute Deviation
566    fn estimate_noise_mad(&self, X: &ArrayView2<f64>) -> Result<f64, SklearsError> {
567        let n_samples = X.nrows();
568        let mut distances = Vec::new();
569
570        // Compute pairwise distances
571        for i in 0..n_samples {
572            for j in i + 1..n_samples {
573                let dist = self.euclidean_distance(&X.row(i), &X.row(j));
574                distances.push(dist);
575            }
576        }
577
578        if distances.is_empty() {
579            return Ok(1.0);
580        }
581
582        distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
583
584        let median = if distances.len() % 2 == 0 {
585            let mid = distances.len() / 2;
586            (distances[mid - 1] + distances[mid]) / 2.0
587        } else {
588            distances[distances.len() / 2]
589        };
590
591        // Compute MAD
592        let mut deviations: Vec<f64> = distances.iter().map(|&d| (d - median).abs()).collect();
593        deviations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
594
595        let mad = if deviations.len() % 2 == 0 {
596            let mid = deviations.len() / 2;
597            (deviations[mid - 1] + deviations[mid]) / 2.0
598        } else {
599            deviations[deviations.len() / 2]
600        };
601
602        Ok(mad * 1.4826) // MAD to std conversion
603    }
604
605    /// Estimate noise using Interquartile Range
606    fn estimate_noise_iqr(&self, X: &ArrayView2<f64>) -> Result<f64, SklearsError> {
607        let n_samples = X.nrows();
608        let mut distances = Vec::new();
609
610        // Compute pairwise distances
611        for i in 0..n_samples {
612            for j in i + 1..n_samples {
613                let dist = self.euclidean_distance(&X.row(i), &X.row(j));
614                distances.push(dist);
615            }
616        }
617
618        if distances.is_empty() {
619            return Ok(1.0);
620        }
621
622        distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
623
624        let q1_idx = distances.len() / 4;
625        let q3_idx = 3 * distances.len() / 4;
626
627        let iqr = distances[q3_idx] - distances[q1_idx];
628
629        Ok(iqr / 1.349) // IQR to std conversion
630    }
631
632    /// Estimate noise adaptively
633    fn estimate_noise_adaptive(&self, X: &ArrayView2<f64>) -> Result<f64, SklearsError> {
634        // Combine MAD and IQR estimates
635        let mad_estimate = self.estimate_noise_mad(X)?;
636        let iqr_estimate = self.estimate_noise_iqr(X)?;
637
638        // Use the minimum to be more conservative
639        Ok(mad_estimate.min(iqr_estimate))
640    }
641
642    /// Perform robust label propagation
643    #[allow(non_snake_case)]
644    fn robust_propagate_labels(
645        &self,
646        graph: &Array2<f64>,
647        y: &ArrayView1<i32>,
648    ) -> Result<Array1<i32>, SklearsError> {
649        let n_samples = graph.nrows();
650
651        // Identify labeled and unlabeled samples
652        let labeled_mask: Array1<bool> = y.iter().map(|&label| label != -1).collect();
653        let unique_labels: Vec<i32> = y
654            .iter()
655            .filter(|&&label| label != -1)
656            .cloned()
657            .collect::<std::collections::HashSet<_>>()
658            .into_iter()
659            .collect();
660
661        if unique_labels.is_empty() {
662            return Err(SklearsError::InvalidInput(
663                "No labeled samples found".to_string(),
664            ));
665        }
666
667        let n_classes = unique_labels.len();
668
669        // Initialize label probability matrix
670        let mut F = Array2::<f64>::zeros((n_samples, n_classes));
671
672        // Set initial labels for labeled samples
673        for i in 0..n_samples {
674            if labeled_mask[i] {
675                if let Some(class_idx) = unique_labels.iter().position(|&x| x == y[i]) {
676                    F[[i, class_idx]] = 1.0;
677                }
678            }
679        }
680
681        // Normalize graph to get transition matrix
682        let P = self.normalize_graph(graph)?;
683
684        // Iterative label propagation with robustness
685        for _iter in 0..self.max_iter {
686            let F_old = F.clone();
687
688            // Propagate labels: F = α * P * F + (1-α) * Y
689            let propagated = P.dot(&F);
690            F = &propagated * self.alpha;
691
692            // Reset labeled samples
693            for i in 0..n_samples {
694                if labeled_mask[i] {
695                    F.row_mut(i).fill(0.0);
696                    if let Some(class_idx) = unique_labels.iter().position(|&x| x == y[i]) {
697                        F[[i, class_idx]] = 1.0;
698                    }
699                }
700            }
701
702            // Check convergence
703            let change = (&F - &F_old).iter().map(|x| x.abs()).sum::<f64>();
704            if change < self.tolerance {
705                break;
706            }
707        }
708
709        // Convert probabilities to labels
710        let mut labels = Array1::zeros(n_samples);
711        for i in 0..n_samples {
712            let mut max_prob = 0.0;
713            let mut max_class = 0;
714
715            for j in 0..n_classes {
716                if F[[i, j]] > max_prob {
717                    max_prob = F[[i, j]];
718                    max_class = j;
719                }
720            }
721
722            labels[i] = unique_labels[max_class];
723        }
724
725        Ok(labels)
726    }
727
728    /// Normalize graph to get transition matrix
729    fn normalize_graph(&self, graph: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
730        let n_samples = graph.nrows();
731        let mut P = graph.clone();
732
733        for i in 0..n_samples {
734            let row_sum: f64 = P.row(i).sum();
735            if row_sum > 0.0 {
736                for j in 0..n_samples {
737                    P[[i, j]] /= row_sum;
738                }
739            }
740        }
741
742        Ok(P)
743    }
744
745    /// Compute Euclidean distance between two vectors
746    fn euclidean_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
747        x1.iter()
748            .zip(x2.iter())
749            .map(|(a, b)| (a - b).powi(2))
750            .sum::<f64>()
751            .sqrt()
752    }
753}
754
755impl Default for NoiseRobustPropagation {
756    fn default() -> Self {
757        Self::new()
758    }
759}
760
761/// Breakdown point analysis for robust graph learning methods
762#[derive(Clone)]
763pub struct BreakdownPointAnalysis {
764    /// Robust estimators to analyze: "median", "huber", "tukey", "trimmed_mean"
765    pub estimators: Vec<String>,
766    /// Contamination levels to test (0.0 to 0.5)
767    pub contamination_levels: Vec<f64>,
768    /// Number of Monte Carlo simulations
769    pub n_simulations: usize,
770    /// Breakdown threshold (relative change in estimate)
771    pub breakdown_threshold: f64,
772    /// Random state for reproducibility
773    pub random_state: Option<u64>,
774}
775
776impl BreakdownPointAnalysis {
777    /// Create a new breakdown point analysis instance
778    pub fn new() -> Self {
779        Self {
780            estimators: vec![
781                "median".to_string(),
782                "huber".to_string(),
783                "tukey".to_string(),
784                "trimmed_mean".to_string(),
785            ],
786            contamination_levels: (1..=25).map(|x| x as f64 / 100.0).collect(), // 1% to 25%
787            n_simulations: 100,
788            breakdown_threshold: 10.0, // 10x change indicates breakdown
789            random_state: None,
790        }
791    }
792
793    /// Set the estimators to analyze
794    pub fn estimators(mut self, estimators: Vec<String>) -> Self {
795        self.estimators = estimators;
796        self
797    }
798
799    /// Set the contamination levels to test
800    pub fn contamination_levels(mut self, levels: Vec<f64>) -> Self {
801        self.contamination_levels = levels;
802        self
803    }
804
805    /// Set the number of simulations
806    pub fn n_simulations(mut self, n: usize) -> Self {
807        self.n_simulations = n;
808        self
809    }
810
811    /// Set the breakdown threshold
812    pub fn breakdown_threshold(mut self, threshold: f64) -> Self {
813        self.breakdown_threshold = threshold;
814        self
815    }
816
817    /// Set the random state
818    pub fn random_state(mut self, seed: u64) -> Self {
819        self.random_state = Some(seed);
820        self
821    }
822
823    /// Analyze breakdown points for robust graph construction
824    pub fn analyze_graph_breakdown(
825        &self,
826        X: &ArrayView2<f64>,
827    ) -> Result<HashMap<String, BreakdownResult>, SklearsError> {
828        let mut results = HashMap::new();
829        let mut rng = Random::default();
830
831        for estimator in &self.estimators {
832            let breakdown_result = self.estimate_breakdown_point(X, estimator, &mut rng)?;
833            results.insert(estimator.clone(), breakdown_result);
834        }
835
836        Ok(results)
837    }
838
839    /// Estimate breakdown point for a specific estimator
840    fn estimate_breakdown_point(
841        &self,
842        X: &ArrayView2<f64>,
843        estimator: &str,
844        rng: &mut Random,
845    ) -> Result<BreakdownResult, SklearsError> {
846        let n_samples = X.nrows();
847
848        // Compute clean estimate (no contamination)
849        let clean_estimate = self.compute_robust_estimate(X, estimator, 0.0, rng)?;
850
851        let mut breakdown_rates = Vec::new();
852        let mut first_breakdown = None;
853
854        for &contamination_level in &self.contamination_levels {
855            let mut breakdown_count = 0;
856
857            for _sim in 0..self.n_simulations {
858                // Create contaminated data
859                let contaminated_X = self.contaminate_data(X, contamination_level, rng)?;
860
861                // Compute estimate on contaminated data
862                let contaminated_estimate =
863                    self.compute_robust_estimate(&contaminated_X.view(), estimator, 0.0, rng)?;
864
865                // Check if breakdown occurred
866                let relative_change =
867                    self.compute_relative_change(clean_estimate, contaminated_estimate);
868
869                if relative_change > self.breakdown_threshold {
870                    breakdown_count += 1;
871                }
872            }
873
874            let breakdown_rate = breakdown_count as f64 / self.n_simulations as f64;
875            breakdown_rates.push(breakdown_rate);
876
877            // Record first significant breakdown
878            if first_breakdown.is_none() && breakdown_rate > 0.5 {
879                first_breakdown = Some(contamination_level);
880            }
881        }
882
883        Ok(BreakdownResult {
884            estimator: estimator.to_string(),
885            theoretical_breakdown_point: self.theoretical_breakdown_point(estimator),
886            empirical_breakdown_point: first_breakdown.unwrap_or(0.5),
887            contamination_levels: self.contamination_levels.clone(),
888            breakdown_rates,
889            clean_estimate,
890        })
891    }
892
893    /// Contaminate data by replacing a fraction with outliers
894    fn contaminate_data<R>(
895        &self,
896        X: &ArrayView2<f64>,
897        contamination_level: f64,
898        rng: &mut Random<R>,
899    ) -> Result<Array2<f64>, SklearsError>
900    where
901        R: Rng,
902    {
903        let n_samples = X.nrows();
904        let n_features = X.ncols();
905        let n_outliers = (n_samples as f64 * contamination_level).round() as usize;
906
907        let mut contaminated_X = X.to_owned();
908
909        if n_outliers == 0 {
910            return Ok(contaminated_X);
911        }
912
913        // Select random samples to contaminate
914        let outlier_indices: Vec<usize> = (0..n_samples)
915            .choose_multiple(rng, n_outliers)
916            .into_iter()
917            .collect();
918
919        // Compute data range for generating outliers
920        let mut feature_ranges = Vec::new();
921        for j in 0..n_features {
922            let column = X.column(j);
923            let min_val = column.iter().fold(f64::INFINITY, |a, &b| a.min(b));
924            let max_val = column.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
925            feature_ranges.push((min_val, max_val));
926        }
927
928        // Replace selected samples with outliers
929        for &idx in &outlier_indices {
930            for j in 0..n_features {
931                let (min_val, max_val) = feature_ranges[j];
932                let range = max_val - min_val;
933
934                // Generate outlier far from the data range
935                let outlier_multiplier = rng.random_range(5.0..10.0);
936                let outlier_value = if rng.gen_bool(0.5) {
937                    min_val - outlier_multiplier * range
938                } else {
939                    max_val + outlier_multiplier * range
940                };
941
942                contaminated_X[[idx, j]] = outlier_value;
943            }
944        }
945
946        Ok(contaminated_X)
947    }
948
949    /// Compute robust estimate for graph properties
950    fn compute_robust_estimate<R>(
951        &self,
952        X: &ArrayView2<f64>,
953        estimator: &str,
954        _contamination: f64,
955        _rng: &mut Random<R>,
956    ) -> Result<f64, SklearsError>
957    where
958        R: Rng,
959    {
960        match estimator {
961            "median" => self.compute_median_graph_property(X),
962            "huber" => self.compute_huber_graph_property(X),
963            "tukey" => self.compute_tukey_graph_property(X),
964            "trimmed_mean" => self.compute_trimmed_mean_graph_property(X),
965            _ => Err(SklearsError::InvalidInput(format!(
966                "Unknown estimator: {}",
967                estimator
968            ))),
969        }
970    }
971
972    /// Compute median-based graph property (median edge weight)
973    fn compute_median_graph_property(&self, X: &ArrayView2<f64>) -> Result<f64, SklearsError> {
974        let n_samples = X.nrows();
975        let mut edge_weights = Vec::new();
976
977        for i in 0..n_samples {
978            for j in i + 1..n_samples {
979                let dist = self.euclidean_distance(&X.row(i), &X.row(j));
980                let weight = (-dist.powi(2) / 2.0).exp(); // RBF kernel
981                edge_weights.push(weight);
982            }
983        }
984
985        if edge_weights.is_empty() {
986            return Ok(0.0);
987        }
988
989        edge_weights.sort_by(|a, b| a.partial_cmp(b).unwrap());
990        let median_idx = edge_weights.len() / 2;
991        Ok(if edge_weights.len() % 2 == 0 {
992            (edge_weights[median_idx - 1] + edge_weights[median_idx]) / 2.0
993        } else {
994            edge_weights[median_idx]
995        })
996    }
997
998    /// Compute Huber-based graph property
999    fn compute_huber_graph_property(&self, X: &ArrayView2<f64>) -> Result<f64, SklearsError> {
1000        let n_samples = X.nrows();
1001        let mut huber_weights = Vec::new();
1002        let c = 1.345; // Standard Huber parameter
1003
1004        for i in 0..n_samples {
1005            for j in i + 1..n_samples {
1006                let dist = self.euclidean_distance(&X.row(i), &X.row(j));
1007                let huber_dist = if dist <= c {
1008                    0.5 * dist.powi(2)
1009                } else {
1010                    c * dist - 0.5 * c.powi(2)
1011                };
1012                huber_weights.push((-huber_dist).exp());
1013            }
1014        }
1015
1016        Ok(huber_weights.iter().sum::<f64>() / huber_weights.len() as f64)
1017    }
1018
1019    /// Compute Tukey-based graph property
1020    fn compute_tukey_graph_property(&self, X: &ArrayView2<f64>) -> Result<f64, SklearsError> {
1021        let n_samples = X.nrows();
1022        let mut tukey_weights = Vec::new();
1023        let c = 4.685; // Standard Tukey parameter
1024
1025        for i in 0..n_samples {
1026            for j in i + 1..n_samples {
1027                let dist = self.euclidean_distance(&X.row(i), &X.row(j));
1028                let tukey_weight = if dist <= c {
1029                    let ratio = dist / c;
1030                    (1.0 - ratio.powi(2)).powi(2)
1031                } else {
1032                    0.0
1033                };
1034                tukey_weights.push(tukey_weight);
1035            }
1036        }
1037
1038        Ok(tukey_weights.iter().sum::<f64>() / tukey_weights.len() as f64)
1039    }
1040
1041    /// Compute trimmed mean-based graph property
1042    fn compute_trimmed_mean_graph_property(
1043        &self,
1044        X: &ArrayView2<f64>,
1045    ) -> Result<f64, SklearsError> {
1046        let n_samples = X.nrows();
1047        let mut edge_weights = Vec::new();
1048
1049        for i in 0..n_samples {
1050            for j in i + 1..n_samples {
1051                let dist = self.euclidean_distance(&X.row(i), &X.row(j));
1052                let weight = (-dist.powi(2) / 2.0).exp(); // RBF kernel
1053                edge_weights.push(weight);
1054            }
1055        }
1056
1057        if edge_weights.is_empty() {
1058            return Ok(0.0);
1059        }
1060
1061        edge_weights.sort_by(|a, b| a.partial_cmp(b).unwrap());
1062
1063        // Trim 20% from each end
1064        let trim_size = (edge_weights.len() as f64 * 0.2) as usize;
1065        let trimmed_weights = &edge_weights[trim_size..edge_weights.len() - trim_size];
1066
1067        Ok(trimmed_weights.iter().sum::<f64>() / trimmed_weights.len() as f64)
1068    }
1069
1070    /// Compute relative change between estimates
1071    fn compute_relative_change(&self, clean_estimate: f64, contaminated_estimate: f64) -> f64 {
1072        if clean_estimate.abs() < 1e-10 {
1073            if contaminated_estimate.abs() < 1e-10 {
1074                0.0
1075            } else {
1076                f64::INFINITY
1077            }
1078        } else {
1079            (contaminated_estimate - clean_estimate).abs() / clean_estimate.abs()
1080        }
1081    }
1082
1083    /// Get theoretical breakdown point for different estimators
1084    fn theoretical_breakdown_point(&self, estimator: &str) -> f64 {
1085        match estimator {
1086            "median" => 0.5,
1087            "huber" => 0.5,
1088            "tukey" => 0.5,
1089            "trimmed_mean" => 0.2, // 20% trimming
1090            _ => 0.0,
1091        }
1092    }
1093
1094    /// Compute Euclidean distance between two vectors
1095    fn euclidean_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
1096        x1.iter()
1097            .zip(x2.iter())
1098            .map(|(a, b)| (a - b).powi(2))
1099            .sum::<f64>()
1100            .sqrt()
1101    }
1102}
1103
1104impl Default for BreakdownPointAnalysis {
1105    fn default() -> Self {
1106        Self::new()
1107    }
1108}
1109
1110/// Result of breakdown point analysis
1111#[derive(Clone, Debug)]
1112pub struct BreakdownResult {
1113    /// Name of the estimator
1114    pub estimator: String,
1115    /// Theoretical breakdown point
1116    pub theoretical_breakdown_point: f64,
1117    /// Empirically observed breakdown point
1118    pub empirical_breakdown_point: f64,
1119    /// Contamination levels tested
1120    pub contamination_levels: Vec<f64>,
1121    /// Breakdown rates at each contamination level
1122    pub breakdown_rates: Vec<f64>,
1123    /// Clean estimate (no contamination)
1124    pub clean_estimate: f64,
1125}
1126
1127impl BreakdownResult {
1128    /// Get the efficiency of the estimator (1 - empirical_breakdown_point)
1129    pub fn efficiency(&self) -> f64 {
1130        1.0 - self.empirical_breakdown_point
1131    }
1132
1133    /// Check if the estimator meets theoretical expectations
1134    pub fn meets_theory(&self) -> bool {
1135        self.empirical_breakdown_point >= self.theoretical_breakdown_point * 0.9
1136    }
1137
1138    /// Get summary statistics
1139    pub fn summary(&self) -> String {
1140        format!(
1141            "Estimator: {}\nTheoretical BP: {:.3}\nEmpirical BP: {:.3}\nEfficiency: {:.3}\nMeets Theory: {}",
1142            self.estimator,
1143            self.theoretical_breakdown_point,
1144            self.empirical_breakdown_point,
1145            self.efficiency(),
1146            self.meets_theory()
1147        )
1148    }
1149}
1150
1151#[allow(non_snake_case)]
1152#[cfg(test)]
1153mod tests {
1154    use super::*;
1155    use approx::assert_abs_diff_eq;
1156    use scirs2_core::array;
1157
1158    #[test]
1159    #[allow(non_snake_case)]
1160    fn test_robust_graph_construction() {
1161        let X = array![
1162            [1.0, 2.0],
1163            [2.0, 3.0],
1164            [3.0, 4.0],
1165            [10.0, 20.0] // Outlier
1166        ];
1167
1168        let rgc = RobustGraphConstruction::new()
1169            .k_neighbors(2)
1170            .robust_metric("huber".to_string())
1171            .outlier_threshold(2.0);
1172
1173        let result = rgc.fit(&X.view());
1174        assert!(result.is_ok());
1175
1176        let graph = result.unwrap();
1177        assert_eq!(graph.dim(), (4, 4));
1178
1179        // Check that diagonal is zero
1180        for i in 0..4 {
1181            assert_eq!(graph[[i, i]], 0.0);
1182        }
1183
1184        // Check symmetry
1185        for i in 0..4 {
1186            for j in 0..4 {
1187                assert_abs_diff_eq!(graph[[i, j]], graph[[j, i]], epsilon = 1e-10);
1188            }
1189        }
1190    }
1191
1192    #[test]
1193    #[allow(non_snake_case)]
1194    fn test_robust_metrics() {
1195        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1196
1197        let metrics = vec!["huber", "tukey", "cauchy", "welsch"];
1198
1199        for metric in metrics {
1200            let rgc = RobustGraphConstruction::new()
1201                .k_neighbors(2)
1202                .robust_metric(metric.to_string());
1203
1204            let result = rgc.fit(&X.view());
1205            assert!(result.is_ok());
1206
1207            let graph = result.unwrap();
1208            assert_eq!(graph.dim(), (3, 3));
1209        }
1210    }
1211
1212    #[test]
1213    #[allow(non_snake_case)]
1214    fn test_robust_construction_methods() {
1215        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1216
1217        let methods = vec!["knn", "epsilon", "adaptive"];
1218
1219        for method in methods {
1220            let rgc = RobustGraphConstruction::new()
1221                .construction_method(method.to_string())
1222                .k_neighbors(2)
1223                .epsilon(2.0);
1224
1225            let result = rgc.fit(&X.view());
1226            assert!(result.is_ok());
1227
1228            let graph = result.unwrap();
1229            assert_eq!(graph.dim(), (3, 3));
1230        }
1231    }
1232
1233    #[test]
1234    #[allow(non_snake_case)]
1235    fn test_noise_robust_propagation() {
1236        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1237        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
1238
1239        let nrp = NoiseRobustPropagation::new()
1240            .k_neighbors(2)
1241            .noise_estimation("mad".to_string())
1242            .max_iter(100)
1243            .alpha(0.2);
1244
1245        let result = nrp.fit(&X.view(), &y.view());
1246        assert!(result.is_ok());
1247
1248        let labels = result.unwrap();
1249        assert_eq!(labels.len(), 4);
1250
1251        // Check that labeled samples retain their labels
1252        assert_eq!(labels[0], 0);
1253        assert_eq!(labels[1], 1);
1254    }
1255
1256    #[test]
1257    #[allow(non_snake_case)]
1258    fn test_noise_estimation_methods() {
1259        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1260        let y = array![0, 1, -1];
1261
1262        let methods = vec!["mad", "iqr", "adaptive"];
1263
1264        for method in methods {
1265            let nrp = NoiseRobustPropagation::new()
1266                .noise_estimation(method.to_string())
1267                .k_neighbors(2)
1268                .max_iter(50);
1269
1270            let result = nrp.fit(&X.view(), &y.view());
1271            assert!(result.is_ok());
1272
1273            let labels = result.unwrap();
1274            assert_eq!(labels.len(), 3);
1275        }
1276    }
1277
1278    #[test]
1279    #[allow(non_snake_case)]
1280    fn test_robust_graph_error_cases() {
1281        let rgc = RobustGraphConstruction::new().construction_method("invalid".to_string());
1282
1283        let X = array![[1.0, 2.0], [2.0, 3.0]];
1284        let result = rgc.fit(&X.view());
1285        assert!(result.is_err());
1286    }
1287
1288    #[test]
1289    #[allow(non_snake_case)]
1290    fn test_noise_robust_propagation_error_cases() {
1291        let nrp = NoiseRobustPropagation::new();
1292
1293        // Test with mismatched dimensions
1294        let X = array![[1.0, 2.0], [2.0, 3.0]];
1295        let y = array![0]; // Wrong size
1296
1297        let result = nrp.fit(&X.view(), &y.view());
1298        assert!(result.is_err());
1299
1300        // Test with no labeled samples
1301        let y_unlabeled = array![-1, -1];
1302        let result = nrp.fit(&X.view(), &y_unlabeled.view());
1303        assert!(result.is_err());
1304    }
1305
1306    #[test]
1307    #[allow(non_snake_case)]
1308    fn test_breakdown_point_analysis() {
1309        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1310
1311        let bpa = BreakdownPointAnalysis::new()
1312            .n_simulations(20) // Reduced for faster testing
1313            .contamination_levels(vec![0.1, 0.2, 0.3])
1314            .random_state(42);
1315
1316        let result = bpa.analyze_graph_breakdown(&X.view());
1317        assert!(result.is_ok());
1318
1319        let results = result.unwrap();
1320        assert!(!results.is_empty());
1321
1322        // Check that we have results for each estimator
1323        for estimator in &["median", "huber", "tukey", "trimmed_mean"] {
1324            assert!(results.contains_key(*estimator));
1325            let breakdown_result = &results[*estimator];
1326
1327            // Check that breakdown point is reasonable
1328            assert!(breakdown_result.empirical_breakdown_point >= 0.0);
1329            assert!(breakdown_result.empirical_breakdown_point <= 0.5);
1330
1331            // Check that theoretical breakdown point is set correctly
1332            assert!(breakdown_result.theoretical_breakdown_point > 0.0);
1333        }
1334    }
1335
1336    #[test]
1337    fn test_breakdown_result_methods() {
1338        let breakdown_result = BreakdownResult {
1339            estimator: "median".to_string(),
1340            theoretical_breakdown_point: 0.5,
1341            empirical_breakdown_point: 0.45, // Changed to meet theory threshold
1342            contamination_levels: vec![0.1, 0.2, 0.3],
1343            breakdown_rates: vec![0.0, 0.1, 0.8],
1344            clean_estimate: 1.0,
1345        };
1346
1347        // Test efficiency calculation
1348        assert_abs_diff_eq!(breakdown_result.efficiency(), 0.55, epsilon = 1e-10);
1349
1350        // Test meets_theory check
1351        assert!(breakdown_result.meets_theory());
1352
1353        // Test summary generation
1354        let summary = breakdown_result.summary();
1355        assert!(summary.contains("median"));
1356        assert!(summary.contains("0.500"));
1357        assert!(summary.contains("0.450"));
1358    }
1359
1360    #[test]
1361    #[allow(non_snake_case)]
1362    fn test_contaminate_data() {
1363        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1364
1365        let bpa = BreakdownPointAnalysis::new().random_state(42);
1366        let mut rng = Random::seed(42);
1367
1368        // Test with 25% contamination
1369        let contaminated = bpa.contaminate_data(&X.view(), 0.25, &mut rng).unwrap();
1370        assert_eq!(contaminated.dim(), X.dim());
1371
1372        // Check that at least one sample was contaminated
1373        let mut different = false;
1374        for i in 0..X.nrows() {
1375            for j in 0..X.ncols() {
1376                if (X[[i, j]] - contaminated[[i, j]]).abs() > 1e-10 {
1377                    different = true;
1378                    break;
1379                }
1380            }
1381        }
1382        assert!(different);
1383
1384        // Test with 0% contamination
1385        let no_contamination = bpa.contaminate_data(&X.view(), 0.0, &mut rng).unwrap();
1386        for i in 0..X.nrows() {
1387            for j in 0..X.ncols() {
1388                assert_abs_diff_eq!(X[[i, j]], no_contamination[[i, j]], epsilon = 1e-10);
1389            }
1390        }
1391    }
1392
1393    #[test]
1394    #[allow(non_snake_case)]
1395    fn test_robust_graph_property_estimators() {
1396        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1397
1398        let bpa = BreakdownPointAnalysis::new();
1399
1400        // Test all estimators
1401        let estimators = vec!["median", "huber", "tukey", "trimmed_mean"];
1402        for estimator in estimators {
1403            let mut rng = Random::seed(42);
1404            let result = bpa.compute_robust_estimate(&X.view(), estimator, 0.0, &mut rng);
1405            assert!(result.is_ok());
1406
1407            let estimate = result.unwrap();
1408            assert!(estimate >= 0.0);
1409            assert!(estimate <= 1.0); // Should be normalized
1410        }
1411
1412        // Test invalid estimator
1413        let mut rng = Random::seed(42);
1414        let result = bpa.compute_robust_estimate(&X.view(), "invalid", 0.0, &mut rng);
1415        assert!(result.is_err());
1416    }
1417
1418    #[test]
1419    fn test_theoretical_breakdown_points() {
1420        let bpa = BreakdownPointAnalysis::new();
1421
1422        assert_eq!(bpa.theoretical_breakdown_point("median"), 0.5);
1423        assert_eq!(bpa.theoretical_breakdown_point("huber"), 0.5);
1424        assert_eq!(bpa.theoretical_breakdown_point("tukey"), 0.5);
1425        assert_eq!(bpa.theoretical_breakdown_point("trimmed_mean"), 0.2);
1426        assert_eq!(bpa.theoretical_breakdown_point("unknown"), 0.0);
1427    }
1428
1429    #[test]
1430    fn test_relative_change_computation() {
1431        let bpa = BreakdownPointAnalysis::new();
1432
1433        // Normal case
1434        assert_abs_diff_eq!(bpa.compute_relative_change(1.0, 1.5), 0.5, epsilon = 1e-10);
1435        assert_abs_diff_eq!(bpa.compute_relative_change(2.0, 1.0), 0.5, epsilon = 1e-10);
1436
1437        // Zero clean estimate case
1438        assert_eq!(bpa.compute_relative_change(0.0, 0.0), 0.0);
1439        assert_eq!(bpa.compute_relative_change(0.0, 1.0), f64::INFINITY);
1440
1441        // Very small clean estimate
1442        let small_val = 1e-12;
1443        assert_eq!(bpa.compute_relative_change(small_val, 1.0), f64::INFINITY);
1444    }
1445
1446    #[test]
1447    fn test_breakdown_point_analysis_builder() {
1448        let custom_levels = vec![0.05, 0.15, 0.25];
1449        let custom_estimators = vec!["median".to_string(), "huber".to_string()];
1450
1451        let bpa = BreakdownPointAnalysis::new()
1452            .estimators(custom_estimators.clone())
1453            .contamination_levels(custom_levels.clone())
1454            .n_simulations(50)
1455            .breakdown_threshold(5.0)
1456            .random_state(123);
1457
1458        assert_eq!(bpa.estimators, custom_estimators);
1459        assert_eq!(bpa.contamination_levels, custom_levels);
1460        assert_eq!(bpa.n_simulations, 50);
1461        assert_eq!(bpa.breakdown_threshold, 5.0);
1462        assert_eq!(bpa.random_state, Some(123));
1463    }
1464}