sklears_semi_supervised/
adversarial_graph_learning.rs

1//! Adversarial graph learning for robust semi-supervised scenarios
2//!
3//! This module provides adversarial graph learning algorithms that can handle
4//! adversarial perturbations, malicious nodes, and robust graph construction
5//! in adversarial environments.
6
7use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::rand_prelude::*;
9use scirs2_core::random::Random;
10use sklears_core::error::SklearsError;
11
12/// Adversarial graph learning for robust semi-supervised learning
13#[derive(Clone)]
14pub struct AdversarialGraphLearning {
15    /// Number of neighbors for graph construction
16    pub k_neighbors: usize,
17    /// Robustness parameter for adversarial defense
18    pub robustness_lambda: f64,
19    /// Maximum perturbation magnitude allowed
20    pub max_perturbation: f64,
21    /// Number of adversarial iterations
22    pub adversarial_steps: usize,
23    /// Learning rate for adversarial updates
24    pub adversarial_lr: f64,
25    /// Defense strategy: "spectral", "robust_pca", "consensus", "adaptive"
26    pub defense_strategy: String,
27    /// Consensus threshold for agreement-based defense
28    pub consensus_threshold: f64,
29    /// Maximum iterations for optimization
30    pub max_iter: usize,
31    /// Convergence tolerance
32    pub tolerance: f64,
33    /// Random state for reproducibility
34    pub random_state: Option<u64>,
35}
36
37/// Adversarial attack configuration
38#[derive(Clone, Debug)]
39pub struct AdversarialAttack {
40    /// Attack type: "node_injection", "edge_manipulation", "feature_perturbation"
41    pub attack_type: String,
42    /// Attack strength (0.0 to 1.0)
43    pub attack_strength: f64,
44    /// Number of nodes to attack
45    pub target_nodes: usize,
46    /// Perturbation strategy: "random", "gradient", "targeted"
47    pub perturbation_strategy: String,
48}
49
50impl AdversarialGraphLearning {
51    /// Create a new adversarial graph learning instance
52    pub fn new() -> Self {
53        Self {
54            k_neighbors: 5,
55            robustness_lambda: 0.1,
56            max_perturbation: 0.1,
57            adversarial_steps: 10,
58            adversarial_lr: 0.01,
59            defense_strategy: "spectral".to_string(),
60            consensus_threshold: 0.7,
61            max_iter: 100,
62            tolerance: 1e-6,
63            random_state: None,
64        }
65    }
66
67    /// Set the number of neighbors for graph construction
68    pub fn k_neighbors(mut self, k: usize) -> Self {
69        self.k_neighbors = k;
70        self
71    }
72
73    /// Set the robustness parameter
74    pub fn robustness_lambda(mut self, lambda: f64) -> Self {
75        self.robustness_lambda = lambda;
76        self
77    }
78
79    /// Set the maximum perturbation magnitude
80    pub fn max_perturbation(mut self, max_pert: f64) -> Self {
81        self.max_perturbation = max_pert;
82        self
83    }
84
85    /// Set the number of adversarial steps
86    pub fn adversarial_steps(mut self, steps: usize) -> Self {
87        self.adversarial_steps = steps;
88        self
89    }
90
91    /// Set the adversarial learning rate
92    pub fn adversarial_lr(mut self, lr: f64) -> Self {
93        self.adversarial_lr = lr;
94        self
95    }
96
97    /// Set the defense strategy
98    pub fn defense_strategy(mut self, strategy: String) -> Self {
99        self.defense_strategy = strategy;
100        self
101    }
102
103    /// Set the consensus threshold
104    pub fn consensus_threshold(mut self, threshold: f64) -> Self {
105        self.consensus_threshold = threshold;
106        self
107    }
108
109    /// Set the maximum iterations
110    pub fn max_iter(mut self, max_iter: usize) -> Self {
111        self.max_iter = max_iter;
112        self
113    }
114
115    /// Set the tolerance
116    pub fn tolerance(mut self, tol: f64) -> Self {
117        self.tolerance = tol;
118        self
119    }
120
121    /// Set the random state
122    pub fn random_state(mut self, seed: u64) -> Self {
123        self.random_state = Some(seed);
124        self
125    }
126
127    /// Learn a robust graph in the presence of adversarial perturbations
128    pub fn fit_robust(
129        &self,
130        features: ArrayView2<f64>,
131        labels: Option<ArrayView1<i32>>,
132    ) -> Result<Array2<f64>, SklearsError> {
133        let n_samples = features.nrows();
134
135        if n_samples == 0 {
136            return Err(SklearsError::InvalidInput(
137                "No samples provided".to_string(),
138            ));
139        }
140
141        match self.defense_strategy.as_str() {
142            "spectral" => self.spectral_defense(features, labels),
143            "robust_pca" => self.robust_pca_defense(features, labels),
144            "consensus" => self.consensus_defense(features, labels),
145            "adaptive" => self.adaptive_defense(features, labels),
146            _ => Err(SklearsError::InvalidInput(format!(
147                "Unknown defense strategy: {}",
148                self.defense_strategy
149            ))),
150        }
151    }
152
153    /// Spectral defense using eigenvalue decomposition for robustness
154    fn spectral_defense(
155        &self,
156        features: ArrayView2<f64>,
157        _labels: Option<ArrayView1<i32>>,
158    ) -> Result<Array2<f64>, SklearsError> {
159        let n_samples = features.nrows();
160        let mut adjacency = Array2::zeros((n_samples, n_samples));
161
162        // Build initial graph
163        for i in 0..n_samples {
164            let mut distances: Vec<(usize, f64)> = Vec::new();
165
166            for j in 0..n_samples {
167                if i != j {
168                    let dist = self.compute_robust_distance(features.row(i), features.row(j));
169                    distances.push((j, dist));
170                }
171            }
172
173            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
174            for &(neighbor, dist) in distances.iter().take(self.k_neighbors) {
175                let weight = (-dist / (2.0 * self.max_perturbation.powi(2))).exp();
176                adjacency[[i, neighbor]] = weight;
177                adjacency[[neighbor, i]] = weight;
178            }
179        }
180
181        // Apply spectral regularization for robustness
182        self.apply_spectral_regularization(&mut adjacency)?;
183
184        Ok(adjacency)
185    }
186
187    /// Robust PCA defense using outlier-resistant principal components
188    fn robust_pca_defense(
189        &self,
190        features: ArrayView2<f64>,
191        _labels: Option<ArrayView1<i32>>,
192    ) -> Result<Array2<f64>, SklearsError> {
193        let n_samples = features.nrows();
194        let n_features = features.ncols();
195
196        // Estimate robust mean and covariance
197        let robust_mean = self.compute_robust_mean(features)?;
198        let robust_cov = self.compute_robust_covariance(features, &robust_mean)?;
199
200        // Project data using robust PCA
201        let robust_features = self.robust_pca_projection(features, &robust_mean, &robust_cov)?;
202
203        // Build graph using robust features
204        let mut adjacency = Array2::zeros((n_samples, n_samples));
205
206        for i in 0..n_samples {
207            let mut distances: Vec<(usize, f64)> = Vec::new();
208
209            for j in 0..n_samples {
210                if i != j {
211                    let dist = self.mahalanobis_distance(
212                        robust_features.row(i),
213                        robust_features.row(j),
214                        &robust_cov,
215                    )?;
216                    distances.push((j, dist));
217                }
218            }
219
220            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
221            for &(neighbor, dist) in distances.iter().take(self.k_neighbors) {
222                let weight = (-dist).exp();
223                adjacency[[i, neighbor]] = weight;
224                adjacency[[neighbor, i]] = weight;
225            }
226        }
227
228        Ok(adjacency)
229    }
230
231    /// Consensus defense using multiple graph constructions
232    fn consensus_defense(
233        &self,
234        features: ArrayView2<f64>,
235        labels: Option<ArrayView1<i32>>,
236    ) -> Result<Array2<f64>, SklearsError> {
237        let n_samples = features.nrows();
238        let num_graphs = 5; // Number of consensus graphs
239
240        let mut consensus_adjacency = Array2::zeros((n_samples, n_samples));
241        let mut rng = if let Some(seed) = self.random_state {
242            Random::seed(seed)
243        } else {
244            Random::seed(42)
245        };
246
247        // Generate multiple graphs with different perturbations
248        for graph_idx in 0..num_graphs {
249            let mut perturbed_features = features.to_owned();
250
251            // Add small random perturbations
252            for i in 0..n_samples {
253                for j in 0..features.ncols() {
254                    let noise = rng.random_range(-self.max_perturbation..self.max_perturbation);
255                    perturbed_features[[i, j]] += noise;
256                }
257            }
258
259            // Build graph from perturbed features
260            let graph = self.build_knn_graph(perturbed_features.view())?;
261
262            // Add to consensus
263            consensus_adjacency = consensus_adjacency + graph;
264        }
265
266        // Normalize by number of graphs
267        consensus_adjacency /= num_graphs as f64;
268
269        // Apply consensus threshold
270        consensus_adjacency.mapv_inplace(|x| {
271            if x >= self.consensus_threshold {
272                x
273            } else {
274                0.0
275            }
276        });
277
278        Ok(consensus_adjacency)
279    }
280
281    /// Adaptive defense that combines multiple strategies
282    fn adaptive_defense(
283        &self,
284        features: ArrayView2<f64>,
285        labels: Option<ArrayView1<i32>>,
286    ) -> Result<Array2<f64>, SklearsError> {
287        // Combine spectral and consensus defenses
288        let spectral_graph = self.spectral_defense(features, labels)?;
289        let consensus_graph = self.consensus_defense(features, labels)?;
290
291        let n_samples = features.nrows();
292        let mut adaptive_graph = Array2::zeros((n_samples, n_samples));
293
294        // Adaptive weighting based on local graph properties
295        for i in 0..n_samples {
296            for j in 0..n_samples {
297                if i != j {
298                    let spectral_weight = spectral_graph[[i, j]];
299                    let consensus_weight = consensus_graph[[i, j]];
300
301                    // Adaptive combination based on weight agreement
302                    let agreement = (spectral_weight - consensus_weight).abs();
303                    let confidence = (-agreement / self.max_perturbation).exp();
304
305                    adaptive_graph[[i, j]] =
306                        confidence * spectral_weight + (1.0 - confidence) * consensus_weight;
307                }
308            }
309        }
310
311        Ok(adaptive_graph)
312    }
313
314    /// Apply adversarial attack to test robustness
315    pub fn apply_attack(
316        &self,
317        features: ArrayView2<f64>,
318        attack: &AdversarialAttack,
319    ) -> Result<Array2<f64>, SklearsError> {
320        let mut attacked_features = features.to_owned();
321        let n_samples = features.nrows();
322
323        let mut rng = if let Some(seed) = self.random_state {
324            Random::seed(seed)
325        } else {
326            Random::seed(42)
327        };
328
329        match attack.attack_type.as_str() {
330            "feature_perturbation" => {
331                let num_target_nodes = attack.target_nodes.min(n_samples);
332                let target_indices: Vec<usize> = (0..n_samples)
333                    .choose_multiple(&mut rng, num_target_nodes)
334                    .into_iter()
335                    .collect();
336
337                for &node_idx in &target_indices {
338                    for feature_idx in 0..features.ncols() {
339                        let perturbation = match attack.perturbation_strategy.as_str() {
340                            "random" => {
341                                rng.random_range(-attack.attack_strength..attack.attack_strength)
342                            }
343                            "gradient" => {
344                                self.compute_gradient_perturbation(features, node_idx, feature_idx)?
345                                    * attack.attack_strength
346                            }
347                            "targeted" => {
348                                self.compute_targeted_perturbation(features, node_idx, feature_idx)?
349                                    * attack.attack_strength
350                            }
351                            _ => rng.random_range(-attack.attack_strength..attack.attack_strength),
352                        };
353
354                        attacked_features[[node_idx, feature_idx]] += perturbation;
355                    }
356                }
357            }
358            "node_injection" => {
359                // This would require extending the feature matrix
360                return Err(SklearsError::InvalidInput(
361                    "Node injection not implemented in this context".to_string(),
362                ));
363            }
364            "edge_manipulation" => {
365                // This would be applied to the adjacency matrix after construction
366                return Err(SklearsError::InvalidInput(
367                    "Edge manipulation should be applied to adjacency matrix".to_string(),
368                ));
369            }
370            _ => {
371                return Err(SklearsError::InvalidInput(format!(
372                    "Unknown attack type: {}",
373                    attack.attack_type
374                )));
375            }
376        }
377
378        Ok(attacked_features)
379    }
380
381    /// Compute robust distance metric resistant to outliers
382    fn compute_robust_distance(&self, feat1: ArrayView1<f64>, feat2: ArrayView1<f64>) -> f64 {
383        // Use Huber loss-based distance for robustness
384        let delta = self.max_perturbation;
385
386        feat1
387            .iter()
388            .zip(feat2.iter())
389            .map(|(&a, &b)| {
390                let diff = (a - b).abs();
391                if diff <= delta {
392                    0.5 * diff * diff
393                } else {
394                    delta * (diff - 0.5 * delta)
395                }
396            })
397            .sum::<f64>()
398            .sqrt()
399    }
400
401    /// Apply spectral regularization to improve robustness
402    fn apply_spectral_regularization(
403        &self,
404        adjacency: &mut Array2<f64>,
405    ) -> Result<(), SklearsError> {
406        let n = adjacency.nrows();
407
408        // Compute degree matrix
409        let mut degree = Array1::zeros(n);
410        for i in 0..n {
411            degree[i] = adjacency.row(i).sum();
412        }
413
414        // Apply regularization to improve spectral properties
415        for i in 0..n {
416            for j in 0..n {
417                if i != j && adjacency[[i, j]] > 0.0 {
418                    // Regularize edge weights based on degree difference
419                    let degree_penalty =
420                        (degree[i] - degree[j]).abs() / (degree[i] + degree[j] + 1e-8);
421                    adjacency[[i, j]] *= 1.0 - self.robustness_lambda * degree_penalty;
422                }
423            }
424        }
425
426        Ok(())
427    }
428
429    /// Compute robust mean using median
430    fn compute_robust_mean(&self, features: ArrayView2<f64>) -> Result<Array1<f64>, SklearsError> {
431        let n_features = features.ncols();
432        let mut robust_mean = Array1::zeros(n_features);
433
434        for j in 0..n_features {
435            let mut column: Vec<f64> = features.column(j).to_vec();
436            column.sort_by(|a, b| a.partial_cmp(b).unwrap());
437
438            let median_idx = column.len() / 2;
439            robust_mean[j] = if column.len() % 2 == 0 {
440                (column[median_idx - 1] + column[median_idx]) / 2.0
441            } else {
442                column[median_idx]
443            };
444        }
445
446        Ok(robust_mean)
447    }
448
449    /// Compute robust covariance using MAD (Median Absolute Deviation)
450    fn compute_robust_covariance(
451        &self,
452        features: ArrayView2<f64>,
453        robust_mean: &Array1<f64>,
454    ) -> Result<Array2<f64>, SklearsError> {
455        let n_features = features.ncols();
456        let mut robust_cov = Array2::eye(n_features);
457
458        for j in 0..n_features {
459            let mut deviations: Vec<f64> = features
460                .column(j)
461                .iter()
462                .map(|&x| (x - robust_mean[j]).abs())
463                .collect();
464
465            deviations.sort_by(|a, b| a.partial_cmp(b).unwrap());
466            let mad = deviations[deviations.len() / 2] * 1.4826; // Scale factor for normal distribution
467
468            robust_cov[[j, j]] = mad * mad;
469        }
470
471        Ok(robust_cov)
472    }
473
474    /// Project features using robust PCA
475    fn robust_pca_projection(
476        &self,
477        features: ArrayView2<f64>,
478        robust_mean: &Array1<f64>,
479        robust_cov: &Array2<f64>,
480    ) -> Result<Array2<f64>, SklearsError> {
481        let n_samples = features.nrows();
482        let n_features = features.ncols();
483
484        // Simple robust projection (in practice, you'd use proper robust PCA)
485        let mut projected = Array2::zeros((n_samples, n_features));
486
487        for i in 0..n_samples {
488            for j in 0..n_features {
489                projected[[i, j]] = (features[[i, j]] - robust_mean[j]) / robust_cov[[j, j]].sqrt();
490            }
491        }
492
493        Ok(projected)
494    }
495
496    /// Compute Mahalanobis distance
497    fn mahalanobis_distance(
498        &self,
499        feat1: ArrayView1<f64>,
500        feat2: ArrayView1<f64>,
501        cov: &Array2<f64>,
502    ) -> Result<f64, SklearsError> {
503        let diff: Array1<f64> = &feat1.to_owned() - &feat2.to_owned();
504
505        // Simplified Mahalanobis distance (assuming diagonal covariance)
506        let mut distance = 0.0;
507        for (i, &d) in diff.iter().enumerate() {
508            distance += d * d / cov[[i, i]];
509        }
510
511        Ok(distance.sqrt())
512    }
513
514    /// Build k-NN graph from features
515    fn build_knn_graph(&self, features: ArrayView2<f64>) -> Result<Array2<f64>, SklearsError> {
516        let n_samples = features.nrows();
517        let mut adjacency = Array2::zeros((n_samples, n_samples));
518
519        for i in 0..n_samples {
520            let mut distances: Vec<(usize, f64)> = Vec::new();
521
522            for j in 0..n_samples {
523                if i != j {
524                    let dist = self.compute_robust_distance(features.row(i), features.row(j));
525                    distances.push((j, dist));
526                }
527            }
528
529            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
530            for &(neighbor, dist) in distances.iter().take(self.k_neighbors) {
531                let weight = (-dist).exp();
532                adjacency[[i, neighbor]] = weight;
533                adjacency[[neighbor, i]] = weight;
534            }
535        }
536
537        Ok(adjacency)
538    }
539
540    /// Compute gradient-based perturbation (simplified)
541    fn compute_gradient_perturbation(
542        &self,
543        _features: ArrayView2<f64>,
544        _node_idx: usize,
545        _feature_idx: usize,
546    ) -> Result<f64, SklearsError> {
547        // Simplified gradient computation
548        // In practice, this would involve computing gradients of the loss function
549        Ok(0.1) // Placeholder
550    }
551
552    /// Compute targeted perturbation
553    fn compute_targeted_perturbation(
554        &self,
555        _features: ArrayView2<f64>,
556        _node_idx: usize,
557        _feature_idx: usize,
558    ) -> Result<f64, SklearsError> {
559        // Simplified targeted perturbation
560        // In practice, this would target specific nodes or classes
561        Ok(0.05) // Placeholder
562    }
563
564    /// Evaluate robustness against attack
565    pub fn evaluate_robustness(
566        &self,
567        original_graph: &Array2<f64>,
568        attacked_graph: &Array2<f64>,
569    ) -> Result<f64, SklearsError> {
570        if original_graph.dim() != attacked_graph.dim() {
571            return Err(SklearsError::ShapeMismatch {
572                expected: format!("{:?}", original_graph.dim()),
573                actual: format!("{:?}", attacked_graph.dim()),
574            });
575        }
576
577        // Compute Frobenius norm of the difference
578        let diff = original_graph - attacked_graph;
579        let frobenius_norm = diff.iter().map(|&x| x * x).sum::<f64>().sqrt();
580
581        // Normalize by original graph norm
582        let original_norm = original_graph.iter().map(|&x| x * x).sum::<f64>().sqrt();
583
584        if original_norm > 0.0 {
585            Ok(frobenius_norm / original_norm)
586        } else {
587            Ok(0.0)
588        }
589    }
590}
591
592impl Default for AdversarialGraphLearning {
593    fn default() -> Self {
594        Self::new()
595    }
596}
597
598#[allow(non_snake_case)]
599#[cfg(test)]
600mod tests {
601    use super::*;
602    use approx::assert_abs_diff_eq;
603    use scirs2_core::array;
604
605    #[test]
606    fn test_adversarial_graph_learning_spectral() {
607        let agl = AdversarialGraphLearning::new()
608            .k_neighbors(2)
609            .defense_strategy("spectral".to_string())
610            .robustness_lambda(0.1);
611
612        let features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
613
614        let result = agl.fit_robust(features.view(), None);
615        assert!(result.is_ok());
616
617        let graph = result.unwrap();
618        assert_eq!(graph.dim(), (3, 3));
619
620        // Check that diagonal is zero
621        for i in 0..3 {
622            assert_eq!(graph[[i, i]], 0.0);
623        }
624    }
625
626    #[test]
627    fn test_adversarial_graph_learning_consensus() {
628        let agl = AdversarialGraphLearning::new()
629            .k_neighbors(2)
630            .defense_strategy("consensus".to_string())
631            .consensus_threshold(0.5)
632            .random_state(42);
633
634        let features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
635
636        let result = agl.fit_robust(features.view(), None);
637        assert!(result.is_ok());
638
639        let graph = result.unwrap();
640        assert_eq!(graph.dim(), (3, 3));
641    }
642
643    #[test]
644    fn test_feature_perturbation_attack() {
645        let agl = AdversarialGraphLearning::new().random_state(42);
646
647        let features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
648
649        let attack = AdversarialAttack {
650            attack_type: "feature_perturbation".to_string(),
651            attack_strength: 0.1,
652            target_nodes: 2,
653            perturbation_strategy: "random".to_string(),
654        };
655
656        let result = agl.apply_attack(features.view(), &attack);
657        assert!(result.is_ok());
658
659        let attacked_features = result.unwrap();
660        assert_eq!(attacked_features.dim(), features.dim());
661
662        // Check that features have been perturbed
663        let mut different = false;
664        for i in 0..features.nrows() {
665            for j in 0..features.ncols() {
666                if (features[[i, j]] - attacked_features[[i, j]]).abs() > 1e-10 {
667                    different = true;
668                    break;
669                }
670            }
671        }
672        assert!(different);
673    }
674
675    #[test]
676    fn test_robust_distance() {
677        let agl = AdversarialGraphLearning::new().max_perturbation(0.1);
678
679        let feat1 = array![1.0, 2.0];
680        let feat2 = array![1.1, 2.1];
681
682        let distance = agl.compute_robust_distance(feat1.view(), feat2.view());
683        assert!(distance > 0.0);
684
685        // Test with larger difference (should be robust)
686        let feat3 = array![10.0, 20.0];
687        let robust_distance = agl.compute_robust_distance(feat1.view(), feat3.view());
688        let euclidean_distance =
689            ((1.0_f64 - 10.0_f64).powi(2) + (2.0_f64 - 20.0_f64).powi(2)).sqrt();
690
691        // Robust distance should be less than Euclidean for outliers
692        assert!(robust_distance < euclidean_distance);
693    }
694
695    #[test]
696    fn test_robust_mean_computation() {
697        let agl = AdversarialGraphLearning::new();
698
699        let features = array![
700            [1.0, 2.0],
701            [2.0, 3.0],
702            [3.0, 4.0],
703            [100.0, 200.0] // Outlier
704        ];
705
706        let robust_mean = agl.compute_robust_mean(features.view()).unwrap();
707
708        // Robust mean should be closer to median than arithmetic mean
709        assert!(robust_mean[0] < 10.0); // Should not be heavily influenced by outlier
710        assert!(robust_mean[1] < 20.0);
711    }
712
713    #[test]
714    fn test_robustness_evaluation() {
715        let agl = AdversarialGraphLearning::new();
716
717        let original_graph = array![[0.0, 1.0, 0.5], [1.0, 0.0, 0.8], [0.5, 0.8, 0.0]];
718
719        let attacked_graph = array![[0.0, 0.9, 0.4], [0.9, 0.0, 0.7], [0.4, 0.7, 0.0]];
720
721        let robustness = agl
722            .evaluate_robustness(&original_graph, &attacked_graph)
723            .unwrap();
724        assert!(robustness > 0.0);
725        assert!(robustness < 1.0);
726    }
727
728    #[test]
729    fn test_adaptive_defense() {
730        let agl = AdversarialGraphLearning::new()
731            .k_neighbors(2)
732            .defense_strategy("adaptive".to_string())
733            .random_state(42);
734
735        let features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
736
737        let result = agl.fit_robust(features.view(), None);
738        assert!(result.is_ok());
739
740        let graph = result.unwrap();
741        assert_eq!(graph.dim(), (3, 3));
742    }
743
744    #[test]
745    fn test_error_cases() {
746        let agl = AdversarialGraphLearning::new();
747
748        // Test with empty features
749        let empty_features = Array2::<f64>::zeros((0, 2));
750        let result = agl.fit_robust(empty_features.view(), None);
751        assert!(result.is_err());
752
753        // Test with invalid defense strategy
754        let agl_invalid =
755            AdversarialGraphLearning::new().defense_strategy("invalid_strategy".to_string());
756
757        let features = array![[1.0, 2.0]];
758        let result = agl_invalid.fit_robust(features.view(), None);
759        assert!(result.is_err());
760
761        // Test robustness evaluation with mismatched dimensions
762        let graph1 = Array2::<f64>::zeros((2, 2));
763        let graph2 = Array2::<f64>::zeros((3, 3));
764        let result = agl.evaluate_robustness(&graph1, &graph2);
765        assert!(result.is_err());
766    }
767
768    #[test]
769    fn test_invalid_attack_types() {
770        let agl = AdversarialGraphLearning::new();
771
772        let features = array![[1.0, 2.0]];
773
774        let invalid_attack = AdversarialAttack {
775            attack_type: "invalid_attack".to_string(),
776            attack_strength: 0.1,
777            target_nodes: 1,
778            perturbation_strategy: "random".to_string(),
779        };
780
781        let result = agl.apply_attack(features.view(), &invalid_attack);
782        assert!(result.is_err());
783    }
784}