1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::rand_prelude::*;
9use scirs2_core::random::Random;
10use sklears_core::error::SklearsError;
11
12#[derive(Clone)]
14pub struct AdversarialGraphLearning {
15 pub k_neighbors: usize,
17 pub robustness_lambda: f64,
19 pub max_perturbation: f64,
21 pub adversarial_steps: usize,
23 pub adversarial_lr: f64,
25 pub defense_strategy: String,
27 pub consensus_threshold: f64,
29 pub max_iter: usize,
31 pub tolerance: f64,
33 pub random_state: Option<u64>,
35}
36
37#[derive(Clone, Debug)]
39pub struct AdversarialAttack {
40 pub attack_type: String,
42 pub attack_strength: f64,
44 pub target_nodes: usize,
46 pub perturbation_strategy: String,
48}
49
50impl AdversarialGraphLearning {
51 pub fn new() -> Self {
53 Self {
54 k_neighbors: 5,
55 robustness_lambda: 0.1,
56 max_perturbation: 0.1,
57 adversarial_steps: 10,
58 adversarial_lr: 0.01,
59 defense_strategy: "spectral".to_string(),
60 consensus_threshold: 0.7,
61 max_iter: 100,
62 tolerance: 1e-6,
63 random_state: None,
64 }
65 }
66
67 pub fn k_neighbors(mut self, k: usize) -> Self {
69 self.k_neighbors = k;
70 self
71 }
72
73 pub fn robustness_lambda(mut self, lambda: f64) -> Self {
75 self.robustness_lambda = lambda;
76 self
77 }
78
79 pub fn max_perturbation(mut self, max_pert: f64) -> Self {
81 self.max_perturbation = max_pert;
82 self
83 }
84
85 pub fn adversarial_steps(mut self, steps: usize) -> Self {
87 self.adversarial_steps = steps;
88 self
89 }
90
91 pub fn adversarial_lr(mut self, lr: f64) -> Self {
93 self.adversarial_lr = lr;
94 self
95 }
96
97 pub fn defense_strategy(mut self, strategy: String) -> Self {
99 self.defense_strategy = strategy;
100 self
101 }
102
103 pub fn consensus_threshold(mut self, threshold: f64) -> Self {
105 self.consensus_threshold = threshold;
106 self
107 }
108
109 pub fn max_iter(mut self, max_iter: usize) -> Self {
111 self.max_iter = max_iter;
112 self
113 }
114
115 pub fn tolerance(mut self, tol: f64) -> Self {
117 self.tolerance = tol;
118 self
119 }
120
121 pub fn random_state(mut self, seed: u64) -> Self {
123 self.random_state = Some(seed);
124 self
125 }
126
127 pub fn fit_robust(
129 &self,
130 features: ArrayView2<f64>,
131 labels: Option<ArrayView1<i32>>,
132 ) -> Result<Array2<f64>, SklearsError> {
133 let n_samples = features.nrows();
134
135 if n_samples == 0 {
136 return Err(SklearsError::InvalidInput(
137 "No samples provided".to_string(),
138 ));
139 }
140
141 match self.defense_strategy.as_str() {
142 "spectral" => self.spectral_defense(features, labels),
143 "robust_pca" => self.robust_pca_defense(features, labels),
144 "consensus" => self.consensus_defense(features, labels),
145 "adaptive" => self.adaptive_defense(features, labels),
146 _ => Err(SklearsError::InvalidInput(format!(
147 "Unknown defense strategy: {}",
148 self.defense_strategy
149 ))),
150 }
151 }
152
153 fn spectral_defense(
155 &self,
156 features: ArrayView2<f64>,
157 _labels: Option<ArrayView1<i32>>,
158 ) -> Result<Array2<f64>, SklearsError> {
159 let n_samples = features.nrows();
160 let mut adjacency = Array2::zeros((n_samples, n_samples));
161
162 for i in 0..n_samples {
164 let mut distances: Vec<(usize, f64)> = Vec::new();
165
166 for j in 0..n_samples {
167 if i != j {
168 let dist = self.compute_robust_distance(features.row(i), features.row(j));
169 distances.push((j, dist));
170 }
171 }
172
173 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
174 for &(neighbor, dist) in distances.iter().take(self.k_neighbors) {
175 let weight = (-dist / (2.0 * self.max_perturbation.powi(2))).exp();
176 adjacency[[i, neighbor]] = weight;
177 adjacency[[neighbor, i]] = weight;
178 }
179 }
180
181 self.apply_spectral_regularization(&mut adjacency)?;
183
184 Ok(adjacency)
185 }
186
187 fn robust_pca_defense(
189 &self,
190 features: ArrayView2<f64>,
191 _labels: Option<ArrayView1<i32>>,
192 ) -> Result<Array2<f64>, SklearsError> {
193 let n_samples = features.nrows();
194 let n_features = features.ncols();
195
196 let robust_mean = self.compute_robust_mean(features)?;
198 let robust_cov = self.compute_robust_covariance(features, &robust_mean)?;
199
200 let robust_features = self.robust_pca_projection(features, &robust_mean, &robust_cov)?;
202
203 let mut adjacency = Array2::zeros((n_samples, n_samples));
205
206 for i in 0..n_samples {
207 let mut distances: Vec<(usize, f64)> = Vec::new();
208
209 for j in 0..n_samples {
210 if i != j {
211 let dist = self.mahalanobis_distance(
212 robust_features.row(i),
213 robust_features.row(j),
214 &robust_cov,
215 )?;
216 distances.push((j, dist));
217 }
218 }
219
220 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
221 for &(neighbor, dist) in distances.iter().take(self.k_neighbors) {
222 let weight = (-dist).exp();
223 adjacency[[i, neighbor]] = weight;
224 adjacency[[neighbor, i]] = weight;
225 }
226 }
227
228 Ok(adjacency)
229 }
230
231 fn consensus_defense(
233 &self,
234 features: ArrayView2<f64>,
235 labels: Option<ArrayView1<i32>>,
236 ) -> Result<Array2<f64>, SklearsError> {
237 let n_samples = features.nrows();
238 let num_graphs = 5; let mut consensus_adjacency = Array2::zeros((n_samples, n_samples));
241 let mut rng = if let Some(seed) = self.random_state {
242 Random::seed(seed)
243 } else {
244 Random::seed(42)
245 };
246
247 for graph_idx in 0..num_graphs {
249 let mut perturbed_features = features.to_owned();
250
251 for i in 0..n_samples {
253 for j in 0..features.ncols() {
254 let noise = rng.random_range(-self.max_perturbation..self.max_perturbation);
255 perturbed_features[[i, j]] += noise;
256 }
257 }
258
259 let graph = self.build_knn_graph(perturbed_features.view())?;
261
262 consensus_adjacency = consensus_adjacency + graph;
264 }
265
266 consensus_adjacency /= num_graphs as f64;
268
269 consensus_adjacency.mapv_inplace(|x| {
271 if x >= self.consensus_threshold {
272 x
273 } else {
274 0.0
275 }
276 });
277
278 Ok(consensus_adjacency)
279 }
280
281 fn adaptive_defense(
283 &self,
284 features: ArrayView2<f64>,
285 labels: Option<ArrayView1<i32>>,
286 ) -> Result<Array2<f64>, SklearsError> {
287 let spectral_graph = self.spectral_defense(features, labels)?;
289 let consensus_graph = self.consensus_defense(features, labels)?;
290
291 let n_samples = features.nrows();
292 let mut adaptive_graph = Array2::zeros((n_samples, n_samples));
293
294 for i in 0..n_samples {
296 for j in 0..n_samples {
297 if i != j {
298 let spectral_weight = spectral_graph[[i, j]];
299 let consensus_weight = consensus_graph[[i, j]];
300
301 let agreement = (spectral_weight - consensus_weight).abs();
303 let confidence = (-agreement / self.max_perturbation).exp();
304
305 adaptive_graph[[i, j]] =
306 confidence * spectral_weight + (1.0 - confidence) * consensus_weight;
307 }
308 }
309 }
310
311 Ok(adaptive_graph)
312 }
313
314 pub fn apply_attack(
316 &self,
317 features: ArrayView2<f64>,
318 attack: &AdversarialAttack,
319 ) -> Result<Array2<f64>, SklearsError> {
320 let mut attacked_features = features.to_owned();
321 let n_samples = features.nrows();
322
323 let mut rng = if let Some(seed) = self.random_state {
324 Random::seed(seed)
325 } else {
326 Random::seed(42)
327 };
328
329 match attack.attack_type.as_str() {
330 "feature_perturbation" => {
331 let num_target_nodes = attack.target_nodes.min(n_samples);
332 let target_indices: Vec<usize> = (0..n_samples)
333 .choose_multiple(&mut rng, num_target_nodes)
334 .into_iter()
335 .collect();
336
337 for &node_idx in &target_indices {
338 for feature_idx in 0..features.ncols() {
339 let perturbation = match attack.perturbation_strategy.as_str() {
340 "random" => {
341 rng.random_range(-attack.attack_strength..attack.attack_strength)
342 }
343 "gradient" => {
344 self.compute_gradient_perturbation(features, node_idx, feature_idx)?
345 * attack.attack_strength
346 }
347 "targeted" => {
348 self.compute_targeted_perturbation(features, node_idx, feature_idx)?
349 * attack.attack_strength
350 }
351 _ => rng.random_range(-attack.attack_strength..attack.attack_strength),
352 };
353
354 attacked_features[[node_idx, feature_idx]] += perturbation;
355 }
356 }
357 }
358 "node_injection" => {
359 return Err(SklearsError::InvalidInput(
361 "Node injection not implemented in this context".to_string(),
362 ));
363 }
364 "edge_manipulation" => {
365 return Err(SklearsError::InvalidInput(
367 "Edge manipulation should be applied to adjacency matrix".to_string(),
368 ));
369 }
370 _ => {
371 return Err(SklearsError::InvalidInput(format!(
372 "Unknown attack type: {}",
373 attack.attack_type
374 )));
375 }
376 }
377
378 Ok(attacked_features)
379 }
380
381 fn compute_robust_distance(&self, feat1: ArrayView1<f64>, feat2: ArrayView1<f64>) -> f64 {
383 let delta = self.max_perturbation;
385
386 feat1
387 .iter()
388 .zip(feat2.iter())
389 .map(|(&a, &b)| {
390 let diff = (a - b).abs();
391 if diff <= delta {
392 0.5 * diff * diff
393 } else {
394 delta * (diff - 0.5 * delta)
395 }
396 })
397 .sum::<f64>()
398 .sqrt()
399 }
400
401 fn apply_spectral_regularization(
403 &self,
404 adjacency: &mut Array2<f64>,
405 ) -> Result<(), SklearsError> {
406 let n = adjacency.nrows();
407
408 let mut degree = Array1::zeros(n);
410 for i in 0..n {
411 degree[i] = adjacency.row(i).sum();
412 }
413
414 for i in 0..n {
416 for j in 0..n {
417 if i != j && adjacency[[i, j]] > 0.0 {
418 let degree_penalty =
420 (degree[i] - degree[j]).abs() / (degree[i] + degree[j] + 1e-8);
421 adjacency[[i, j]] *= 1.0 - self.robustness_lambda * degree_penalty;
422 }
423 }
424 }
425
426 Ok(())
427 }
428
429 fn compute_robust_mean(&self, features: ArrayView2<f64>) -> Result<Array1<f64>, SklearsError> {
431 let n_features = features.ncols();
432 let mut robust_mean = Array1::zeros(n_features);
433
434 for j in 0..n_features {
435 let mut column: Vec<f64> = features.column(j).to_vec();
436 column.sort_by(|a, b| a.partial_cmp(b).unwrap());
437
438 let median_idx = column.len() / 2;
439 robust_mean[j] = if column.len() % 2 == 0 {
440 (column[median_idx - 1] + column[median_idx]) / 2.0
441 } else {
442 column[median_idx]
443 };
444 }
445
446 Ok(robust_mean)
447 }
448
449 fn compute_robust_covariance(
451 &self,
452 features: ArrayView2<f64>,
453 robust_mean: &Array1<f64>,
454 ) -> Result<Array2<f64>, SklearsError> {
455 let n_features = features.ncols();
456 let mut robust_cov = Array2::eye(n_features);
457
458 for j in 0..n_features {
459 let mut deviations: Vec<f64> = features
460 .column(j)
461 .iter()
462 .map(|&x| (x - robust_mean[j]).abs())
463 .collect();
464
465 deviations.sort_by(|a, b| a.partial_cmp(b).unwrap());
466 let mad = deviations[deviations.len() / 2] * 1.4826; robust_cov[[j, j]] = mad * mad;
469 }
470
471 Ok(robust_cov)
472 }
473
474 fn robust_pca_projection(
476 &self,
477 features: ArrayView2<f64>,
478 robust_mean: &Array1<f64>,
479 robust_cov: &Array2<f64>,
480 ) -> Result<Array2<f64>, SklearsError> {
481 let n_samples = features.nrows();
482 let n_features = features.ncols();
483
484 let mut projected = Array2::zeros((n_samples, n_features));
486
487 for i in 0..n_samples {
488 for j in 0..n_features {
489 projected[[i, j]] = (features[[i, j]] - robust_mean[j]) / robust_cov[[j, j]].sqrt();
490 }
491 }
492
493 Ok(projected)
494 }
495
496 fn mahalanobis_distance(
498 &self,
499 feat1: ArrayView1<f64>,
500 feat2: ArrayView1<f64>,
501 cov: &Array2<f64>,
502 ) -> Result<f64, SklearsError> {
503 let diff: Array1<f64> = &feat1.to_owned() - &feat2.to_owned();
504
505 let mut distance = 0.0;
507 for (i, &d) in diff.iter().enumerate() {
508 distance += d * d / cov[[i, i]];
509 }
510
511 Ok(distance.sqrt())
512 }
513
514 fn build_knn_graph(&self, features: ArrayView2<f64>) -> Result<Array2<f64>, SklearsError> {
516 let n_samples = features.nrows();
517 let mut adjacency = Array2::zeros((n_samples, n_samples));
518
519 for i in 0..n_samples {
520 let mut distances: Vec<(usize, f64)> = Vec::new();
521
522 for j in 0..n_samples {
523 if i != j {
524 let dist = self.compute_robust_distance(features.row(i), features.row(j));
525 distances.push((j, dist));
526 }
527 }
528
529 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
530 for &(neighbor, dist) in distances.iter().take(self.k_neighbors) {
531 let weight = (-dist).exp();
532 adjacency[[i, neighbor]] = weight;
533 adjacency[[neighbor, i]] = weight;
534 }
535 }
536
537 Ok(adjacency)
538 }
539
540 fn compute_gradient_perturbation(
542 &self,
543 _features: ArrayView2<f64>,
544 _node_idx: usize,
545 _feature_idx: usize,
546 ) -> Result<f64, SklearsError> {
547 Ok(0.1) }
551
552 fn compute_targeted_perturbation(
554 &self,
555 _features: ArrayView2<f64>,
556 _node_idx: usize,
557 _feature_idx: usize,
558 ) -> Result<f64, SklearsError> {
559 Ok(0.05) }
563
564 pub fn evaluate_robustness(
566 &self,
567 original_graph: &Array2<f64>,
568 attacked_graph: &Array2<f64>,
569 ) -> Result<f64, SklearsError> {
570 if original_graph.dim() != attacked_graph.dim() {
571 return Err(SklearsError::ShapeMismatch {
572 expected: format!("{:?}", original_graph.dim()),
573 actual: format!("{:?}", attacked_graph.dim()),
574 });
575 }
576
577 let diff = original_graph - attacked_graph;
579 let frobenius_norm = diff.iter().map(|&x| x * x).sum::<f64>().sqrt();
580
581 let original_norm = original_graph.iter().map(|&x| x * x).sum::<f64>().sqrt();
583
584 if original_norm > 0.0 {
585 Ok(frobenius_norm / original_norm)
586 } else {
587 Ok(0.0)
588 }
589 }
590}
591
592impl Default for AdversarialGraphLearning {
593 fn default() -> Self {
594 Self::new()
595 }
596}
597
598#[allow(non_snake_case)]
599#[cfg(test)]
600mod tests {
601 use super::*;
602 use approx::assert_abs_diff_eq;
603 use scirs2_core::array;
604
605 #[test]
606 fn test_adversarial_graph_learning_spectral() {
607 let agl = AdversarialGraphLearning::new()
608 .k_neighbors(2)
609 .defense_strategy("spectral".to_string())
610 .robustness_lambda(0.1);
611
612 let features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
613
614 let result = agl.fit_robust(features.view(), None);
615 assert!(result.is_ok());
616
617 let graph = result.unwrap();
618 assert_eq!(graph.dim(), (3, 3));
619
620 for i in 0..3 {
622 assert_eq!(graph[[i, i]], 0.0);
623 }
624 }
625
626 #[test]
627 fn test_adversarial_graph_learning_consensus() {
628 let agl = AdversarialGraphLearning::new()
629 .k_neighbors(2)
630 .defense_strategy("consensus".to_string())
631 .consensus_threshold(0.5)
632 .random_state(42);
633
634 let features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
635
636 let result = agl.fit_robust(features.view(), None);
637 assert!(result.is_ok());
638
639 let graph = result.unwrap();
640 assert_eq!(graph.dim(), (3, 3));
641 }
642
643 #[test]
644 fn test_feature_perturbation_attack() {
645 let agl = AdversarialGraphLearning::new().random_state(42);
646
647 let features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
648
649 let attack = AdversarialAttack {
650 attack_type: "feature_perturbation".to_string(),
651 attack_strength: 0.1,
652 target_nodes: 2,
653 perturbation_strategy: "random".to_string(),
654 };
655
656 let result = agl.apply_attack(features.view(), &attack);
657 assert!(result.is_ok());
658
659 let attacked_features = result.unwrap();
660 assert_eq!(attacked_features.dim(), features.dim());
661
662 let mut different = false;
664 for i in 0..features.nrows() {
665 for j in 0..features.ncols() {
666 if (features[[i, j]] - attacked_features[[i, j]]).abs() > 1e-10 {
667 different = true;
668 break;
669 }
670 }
671 }
672 assert!(different);
673 }
674
675 #[test]
676 fn test_robust_distance() {
677 let agl = AdversarialGraphLearning::new().max_perturbation(0.1);
678
679 let feat1 = array![1.0, 2.0];
680 let feat2 = array![1.1, 2.1];
681
682 let distance = agl.compute_robust_distance(feat1.view(), feat2.view());
683 assert!(distance > 0.0);
684
685 let feat3 = array![10.0, 20.0];
687 let robust_distance = agl.compute_robust_distance(feat1.view(), feat3.view());
688 let euclidean_distance =
689 ((1.0_f64 - 10.0_f64).powi(2) + (2.0_f64 - 20.0_f64).powi(2)).sqrt();
690
691 assert!(robust_distance < euclidean_distance);
693 }
694
695 #[test]
696 fn test_robust_mean_computation() {
697 let agl = AdversarialGraphLearning::new();
698
699 let features = array![
700 [1.0, 2.0],
701 [2.0, 3.0],
702 [3.0, 4.0],
703 [100.0, 200.0] ];
705
706 let robust_mean = agl.compute_robust_mean(features.view()).unwrap();
707
708 assert!(robust_mean[0] < 10.0); assert!(robust_mean[1] < 20.0);
711 }
712
713 #[test]
714 fn test_robustness_evaluation() {
715 let agl = AdversarialGraphLearning::new();
716
717 let original_graph = array![[0.0, 1.0, 0.5], [1.0, 0.0, 0.8], [0.5, 0.8, 0.0]];
718
719 let attacked_graph = array![[0.0, 0.9, 0.4], [0.9, 0.0, 0.7], [0.4, 0.7, 0.0]];
720
721 let robustness = agl
722 .evaluate_robustness(&original_graph, &attacked_graph)
723 .unwrap();
724 assert!(robustness > 0.0);
725 assert!(robustness < 1.0);
726 }
727
728 #[test]
729 fn test_adaptive_defense() {
730 let agl = AdversarialGraphLearning::new()
731 .k_neighbors(2)
732 .defense_strategy("adaptive".to_string())
733 .random_state(42);
734
735 let features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
736
737 let result = agl.fit_robust(features.view(), None);
738 assert!(result.is_ok());
739
740 let graph = result.unwrap();
741 assert_eq!(graph.dim(), (3, 3));
742 }
743
744 #[test]
745 fn test_error_cases() {
746 let agl = AdversarialGraphLearning::new();
747
748 let empty_features = Array2::<f64>::zeros((0, 2));
750 let result = agl.fit_robust(empty_features.view(), None);
751 assert!(result.is_err());
752
753 let agl_invalid =
755 AdversarialGraphLearning::new().defense_strategy("invalid_strategy".to_string());
756
757 let features = array![[1.0, 2.0]];
758 let result = agl_invalid.fit_robust(features.view(), None);
759 assert!(result.is_err());
760
761 let graph1 = Array2::<f64>::zeros((2, 2));
763 let graph2 = Array2::<f64>::zeros((3, 3));
764 let result = agl.evaluate_robustness(&graph1, &graph2);
765 assert!(result.is_err());
766 }
767
768 #[test]
769 fn test_invalid_attack_types() {
770 let agl = AdversarialGraphLearning::new();
771
772 let features = array![[1.0, 2.0]];
773
774 let invalid_attack = AdversarialAttack {
775 attack_type: "invalid_attack".to_string(),
776 attack_strength: 0.1,
777 target_nodes: 1,
778 perturbation_strategy: "random".to_string(),
779 };
780
781 let result = agl.apply_attack(features.view(), &invalid_attack);
782 assert!(result.is_err());
783 }
784}