1use std::fmt;
11
12#[derive(Debug)]
18pub enum AdversarialError {
19 DimensionMismatch { expected: usize, got: usize },
21 InvalidEpsilon(f64),
23 InvalidStepSize(f64),
25 InvalidIterations(usize),
27 GradientComputationFailed(String),
29}
30
31impl fmt::Display for AdversarialError {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 match self {
34 AdversarialError::DimensionMismatch { expected, got } => {
35 write!(f, "dimension mismatch: expected {expected} but got {got}")
36 }
37 AdversarialError::InvalidEpsilon(e) => {
38 write!(f, "epsilon must be strictly positive, got {e}")
39 }
40 AdversarialError::InvalidStepSize(s) => {
41 write!(f, "step_size must be strictly positive, got {s}")
42 }
43 AdversarialError::InvalidIterations(n) => write!(f, "n_steps must be >= 1, got {n}"),
44 AdversarialError::GradientComputationFailed(msg) => {
45 write!(f, "gradient computation failed: {msg}")
46 }
47 }
48 }
49}
50
51impl std::error::Error for AdversarialError {}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum PerturbNorm {
60 LInf,
62 L2,
64 L1,
66}
67
68#[derive(Debug, Clone)]
74pub struct AdversarialExample {
75 pub original: Vec<f64>,
77 pub perturbed: Vec<f64>,
79 pub perturbation: Vec<f64>,
81 pub perturbation_norm: f64,
83 pub n_iterations: usize,
85}
86
87impl AdversarialExample {
88 pub fn perturbation_linf(&self) -> f64 {
90 self.perturbation
91 .iter()
92 .map(|v| v.abs())
93 .fold(0.0_f64, f64::max)
94 }
95
96 pub fn perturbation_l2(&self) -> f64 {
98 self.perturbation.iter().map(|v| v * v).sum::<f64>().sqrt()
99 }
100
101 pub fn perturbation_l1(&self) -> f64 {
103 self.perturbation.iter().map(|v| v.abs()).sum()
104 }
105}
106
107pub trait AttackLoss: Send + Sync {
116 fn loss(&self, predictions: &[f64], labels: &[f64]) -> f64;
118
119 fn grad(&self, predictions: &[f64], labels: &[f64]) -> Vec<f64>;
121}
122
123pub struct CrossEntropyAttackLoss;
133
134impl CrossEntropyAttackLoss {
135 fn softmax(logits: &[f64]) -> Vec<f64> {
136 let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
137 let exp: Vec<f64> = logits.iter().map(|&z| (z - max_val).exp()).collect();
138 let sum: f64 = exp.iter().sum();
139 if sum == 0.0 {
140 vec![1.0 / logits.len() as f64; logits.len()]
141 } else {
142 exp.iter().map(|&e| e / sum).collect()
143 }
144 }
145}
146
147impl AttackLoss for CrossEntropyAttackLoss {
148 fn loss(&self, predictions: &[f64], labels: &[f64]) -> f64 {
149 let probs = Self::softmax(predictions);
150 const EPS: f64 = 1e-12;
151 -probs
152 .iter()
153 .zip(labels.iter())
154 .map(|(&p, &y)| y * (p + EPS).ln())
155 .sum::<f64>()
156 }
157
158 fn grad(&self, predictions: &[f64], labels: &[f64]) -> Vec<f64> {
159 let probs = Self::softmax(predictions);
160 probs
161 .iter()
162 .zip(labels.iter())
163 .map(|(&p, &y)| p - y)
164 .collect()
165 }
166}
167
168pub struct MseAttackLoss;
177
178impl AttackLoss for MseAttackLoss {
179 fn loss(&self, predictions: &[f64], labels: &[f64]) -> f64 {
180 let n = predictions.len() as f64;
181 predictions
182 .iter()
183 .zip(labels.iter())
184 .map(|(&p, &y)| (p - y).powi(2))
185 .sum::<f64>()
186 / n
187 }
188
189 fn grad(&self, predictions: &[f64], labels: &[f64]) -> Vec<f64> {
190 let n = predictions.len() as f64;
191 predictions
192 .iter()
193 .zip(labels.iter())
194 .map(|(&p, &y)| 2.0 * (p - y) / n)
195 .collect()
196 }
197}
198
199pub trait AttackModel: Send + Sync {
208 fn forward(&self, input: &[f64]) -> Vec<f64>;
210
211 fn input_gradient(&self, input: &[f64], output_grad: &[f64]) -> Vec<f64> {
216 const H: f64 = 1e-5;
218 let mut grad_in = vec![0.0_f64; input.len()];
219 let mut x_plus = input.to_vec();
220 let mut x_minus = input.to_vec();
221 for i in 0..input.len() {
222 x_plus[i] = input[i] + H;
223 x_minus[i] = input[i] - H;
224 let f_plus = self.forward(&x_plus);
225 let f_minus = self.forward(&x_minus);
226 grad_in[i] = f_plus
227 .iter()
228 .zip(f_minus.iter())
229 .zip(output_grad.iter())
230 .map(|((&fp, &fm), &g)| g * (fp - fm) / (2.0 * H))
231 .sum::<f64>();
232 x_plus[i] = input[i];
233 x_minus[i] = input[i];
234 }
235 grad_in
236 }
237}
238
239pub struct LinearAttackModel {
245 pub weights: Vec<Vec<f64>>,
247 pub bias: Vec<f64>,
249}
250
251impl LinearAttackModel {
252 pub fn new(weights: Vec<Vec<f64>>, bias: Vec<f64>) -> Result<Self, AdversarialError> {
254 if weights.is_empty() || bias.is_empty() {
255 return Err(AdversarialError::DimensionMismatch {
256 expected: 1,
257 got: 0,
258 });
259 }
260 if weights.len() != bias.len() {
261 return Err(AdversarialError::DimensionMismatch {
262 expected: weights.len(),
263 got: bias.len(),
264 });
265 }
266 let n_inputs = weights[0].len();
267 for (i, row) in weights.iter().enumerate() {
268 if row.len() != n_inputs {
269 return Err(AdversarialError::DimensionMismatch {
270 expected: n_inputs,
271 got: row.len(),
272 });
273 }
274 let _ = i; }
276 Ok(Self { weights, bias })
277 }
278}
279
280impl AttackModel for LinearAttackModel {
281 fn forward(&self, input: &[f64]) -> Vec<f64> {
282 self.weights
283 .iter()
284 .zip(self.bias.iter())
285 .map(|(row, &b)| {
286 row.iter()
287 .zip(input.iter())
288 .map(|(&w, &x)| w * x)
289 .sum::<f64>()
290 + b
291 })
292 .collect()
293 }
294
295 fn input_gradient(&self, _input: &[f64], output_grad: &[f64]) -> Vec<f64> {
297 let n_inputs = self.weights[0].len();
298 let mut grad = vec![0.0_f64; n_inputs];
299 for (row, &g) in self.weights.iter().zip(output_grad.iter()) {
300 for (j, &w) in row.iter().enumerate() {
301 grad[j] += w * g;
302 }
303 }
304 grad
305 }
306}
307
308#[derive(Debug, Clone)]
314pub struct AttackConfig {
315 pub epsilon: f64,
317 pub norm: PerturbNorm,
319 pub n_steps: usize,
321 pub step_size: f64,
323 pub random_start: bool,
325 pub clip_min: f64,
327 pub clip_max: f64,
329}
330
331impl AttackConfig {
332 pub fn new(epsilon: f64) -> Result<Self, AdversarialError> {
337 if epsilon <= 0.0 || !epsilon.is_finite() {
338 return Err(AdversarialError::InvalidEpsilon(epsilon));
339 }
340 Ok(Self {
341 epsilon,
342 norm: PerturbNorm::LInf,
343 n_steps: 40,
344 step_size: epsilon / 4.0,
345 random_start: false,
346 clip_min: f64::NEG_INFINITY,
347 clip_max: f64::INFINITY,
348 })
349 }
350
351 pub fn with_norm(mut self, norm: PerturbNorm) -> Self {
353 self.norm = norm;
354 self
355 }
356
357 pub fn with_steps(mut self, n: usize) -> Result<Self, AdversarialError> {
359 if n == 0 {
360 return Err(AdversarialError::InvalidIterations(n));
361 }
362 self.n_steps = n;
363 Ok(self)
364 }
365
366 pub fn with_step_size(mut self, s: f64) -> Result<Self, AdversarialError> {
368 if s <= 0.0 || !s.is_finite() {
369 return Err(AdversarialError::InvalidStepSize(s));
370 }
371 self.step_size = s;
372 Ok(self)
373 }
374
375 pub fn with_random_start(mut self, b: bool) -> Self {
377 self.random_start = b;
378 self
379 }
380
381 pub fn with_clip(mut self, min: f64, max: f64) -> Self {
383 self.clip_min = min;
384 self.clip_max = max;
385 self
386 }
387}
388
389#[derive(Debug, Default, Clone)]
395pub struct AdversarialTrainStats {
396 pub n_samples: usize,
398 pub mean_perturbation_norm: f64,
400 pub clean_loss: f64,
402 pub adversarial_loss: f64,
404 pub combined_loss: f64,
406}
407
408pub fn project_linf(perturbation: &[f64], epsilon: f64) -> Vec<f64> {
416 perturbation
417 .iter()
418 .map(|&d| d.clamp(-epsilon, epsilon))
419 .collect()
420}
421
422pub fn project_l2(perturbation: &[f64], epsilon: f64) -> Vec<f64> {
426 let norm: f64 = perturbation.iter().map(|&d| d * d).sum::<f64>().sqrt();
427 if norm <= epsilon || norm == 0.0 {
428 perturbation.to_vec()
429 } else {
430 perturbation.iter().map(|&d| d * epsilon / norm).collect()
431 }
432}
433
434pub fn project_l1(perturbation: &[f64], epsilon: f64) -> Vec<f64> {
438 let l1: f64 = perturbation.iter().map(|&d| d.abs()).sum();
439 if l1 <= epsilon {
440 return perturbation.to_vec();
441 }
442 let n = perturbation.len();
444 let mut abs_sorted: Vec<f64> = perturbation.iter().map(|&d| d.abs()).collect();
445 abs_sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
446
447 let mut cumsum = 0.0_f64;
448 let mut rho = 0_usize;
449 for (i, &v) in abs_sorted.iter().enumerate() {
450 cumsum += v;
451 if v > (cumsum - epsilon) / (i as f64 + 1.0) {
452 rho = i;
453 }
454 }
455 let cumsum_rho: f64 = abs_sorted[..=rho].iter().sum();
456 let theta = (cumsum_rho - epsilon) / (rho as f64 + 1.0);
457
458 (0..n)
459 .map(|i| {
460 let sign = if perturbation[i] >= 0.0 { 1.0 } else { -1.0 };
461 sign * (perturbation[i].abs() - theta).max(0.0)
462 })
463 .collect()
464}
465
466fn loss_input_gradient(
472 model: &dyn AttackModel,
473 loss: &dyn AttackLoss,
474 input: &[f64],
475 labels: &[f64],
476) -> Result<Vec<f64>, AdversarialError> {
477 let predictions = model.forward(input);
478 let loss_grad = loss.grad(&predictions, labels); let input_grad = model.input_gradient(input, &loss_grad); for &g in &input_grad {
483 if !g.is_finite() {
484 return Err(AdversarialError::GradientComputationFailed(
485 "non-finite value in input gradient".to_string(),
486 ));
487 }
488 }
489 Ok(input_grad)
490}
491
492#[inline]
494fn clip_input(x: &[f64], config: &AttackConfig) -> Vec<f64> {
495 x.iter()
496 .map(|&v| v.clamp(config.clip_min, config.clip_max))
497 .collect()
498}
499
500fn project(perturbation: &[f64], config: &AttackConfig) -> Vec<f64> {
502 match config.norm {
503 PerturbNorm::LInf => project_linf(perturbation, config.epsilon),
504 PerturbNorm::L2 => project_l2(perturbation, config.epsilon),
505 PerturbNorm::L1 => project_l1(perturbation, config.epsilon),
506 }
507}
508
509fn measure_norm(perturbation: &[f64], norm: PerturbNorm) -> f64 {
511 match norm {
512 PerturbNorm::LInf => perturbation
513 .iter()
514 .map(|&d| d.abs())
515 .fold(0.0_f64, f64::max),
516 PerturbNorm::L2 => perturbation.iter().map(|&d| d * d).sum::<f64>().sqrt(),
517 PerturbNorm::L1 => perturbation.iter().map(|&d| d.abs()).sum(),
518 }
519}
520
521struct Lcg64 {
527 state: u64,
528}
529
530impl Lcg64 {
531 fn new(seed: u64) -> Self {
532 Self {
534 state: if seed == 0 { 0xdeadbeef_cafebabe } else { seed },
535 }
536 }
537
538 fn next_u64(&mut self) -> u64 {
540 self.state = self
541 .state
542 .wrapping_mul(6_364_136_223_846_793_005)
543 .wrapping_add(1_442_695_040_888_963_407);
544 self.state
545 }
546
547 fn next_f64_signed(&mut self) -> f64 {
549 let u = (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64;
551 u * 2.0 - 1.0
552 }
553}
554
555pub fn fgsm(
567 model: &dyn AttackModel,
568 loss: &dyn AttackLoss,
569 input: &[f64],
570 labels: &[f64],
571 config: &AttackConfig,
572) -> Result<AdversarialExample, AdversarialError> {
573 let grad = loss_input_gradient(model, loss, input, labels)?;
574
575 let raw_delta: Vec<f64> = match config.norm {
576 PerturbNorm::LInf => grad
577 .iter()
578 .map(|&g| {
579 if g == 0.0 {
580 0.0
581 } else {
582 config.epsilon * g.signum()
583 }
584 })
585 .collect(),
586 PerturbNorm::L2 => {
587 let norm: f64 = grad.iter().map(|&g| g * g).sum::<f64>().sqrt();
588 if norm < 1e-12 {
589 vec![0.0; grad.len()]
590 } else {
591 grad.iter().map(|&g| config.epsilon * g / norm).collect()
592 }
593 }
594 PerturbNorm::L1 => {
595 let argmax = grad
597 .iter()
598 .enumerate()
599 .max_by(|(_, a), (_, b)| {
600 a.abs()
601 .partial_cmp(&b.abs())
602 .unwrap_or(std::cmp::Ordering::Equal)
603 })
604 .map(|(i, _)| i)
605 .unwrap_or(0);
606 let mut delta = vec![0.0_f64; grad.len()];
607 delta[argmax] = config.epsilon * grad[argmax].signum();
608 delta
609 }
610 };
611
612 let perturbed_raw: Vec<f64> = input
613 .iter()
614 .zip(raw_delta.iter())
615 .map(|(&x, &d)| x + d)
616 .collect();
617 let perturbed = clip_input(&perturbed_raw, config);
618
619 let perturbation: Vec<f64> = perturbed
620 .iter()
621 .zip(input.iter())
622 .map(|(&p, &x)| p - x)
623 .collect();
624
625 let perturbation_norm = measure_norm(&perturbation, config.norm);
626
627 Ok(AdversarialExample {
628 original: input.to_vec(),
629 perturbed,
630 perturbation,
631 perturbation_norm,
632 n_iterations: 1,
633 })
634}
635
636pub fn pgd(
656 model: &dyn AttackModel,
657 loss: &dyn AttackLoss,
658 input: &[f64],
659 labels: &[f64],
660 config: &AttackConfig,
661 seed: u64,
662) -> Result<AdversarialExample, AdversarialError> {
663 let n = input.len();
664 let mut rng = Lcg64::new(seed);
665
666 let mut delta: Vec<f64> = if config.random_start {
668 let raw: Vec<f64> = (0..n)
669 .map(|_| rng.next_f64_signed() * config.epsilon)
670 .collect();
671 project(&raw, config)
672 } else {
673 vec![0.0_f64; n]
674 };
675
676 for _ in 0..config.n_steps {
677 let x_adv: Vec<f64> = input
679 .iter()
680 .zip(delta.iter())
681 .map(|(&x, &d)| x + d)
682 .collect();
683 let x_adv = clip_input(&x_adv, config);
684
685 let grad = loss_input_gradient(model, loss, &x_adv, labels)?;
686
687 let step: Vec<f64> = match config.norm {
689 PerturbNorm::LInf => grad
690 .iter()
691 .map(|&g| {
692 if g == 0.0 {
693 0.0
694 } else {
695 config.step_size * g.signum()
696 }
697 })
698 .collect(),
699 PerturbNorm::L2 => {
700 let norm: f64 = grad.iter().map(|&g| g * g).sum::<f64>().sqrt();
701 if norm < 1e-12 {
702 vec![0.0; n]
703 } else {
704 grad.iter().map(|&g| config.step_size * g / norm).collect()
705 }
706 }
707 PerturbNorm::L1 => {
708 let argmax = grad
709 .iter()
710 .enumerate()
711 .max_by(|(_, a), (_, b)| {
712 a.abs()
713 .partial_cmp(&b.abs())
714 .unwrap_or(std::cmp::Ordering::Equal)
715 })
716 .map(|(i, _)| i)
717 .unwrap_or(0);
718 let mut s = vec![0.0_f64; n];
719 s[argmax] = config.step_size * grad[argmax].signum();
720 s
721 }
722 };
723
724 let new_x_adv: Vec<f64> = input
726 .iter()
727 .zip(delta.iter())
728 .zip(step.iter())
729 .map(|((&x, &d), &s)| x + d + s)
730 .collect();
731 let new_x_adv = clip_input(&new_x_adv, config);
732
733 delta = new_x_adv
734 .iter()
735 .zip(input.iter())
736 .map(|(&xa, &x)| xa - x)
737 .collect();
738 delta = project(&delta, config);
739 }
740
741 let perturbed: Vec<f64> = input
742 .iter()
743 .zip(delta.iter())
744 .map(|(&x, &d)| (x + d).clamp(config.clip_min, config.clip_max))
745 .collect();
746
747 let perturbation: Vec<f64> = perturbed
748 .iter()
749 .zip(input.iter())
750 .map(|(&p, &x)| p - x)
751 .collect();
752
753 let perturbation_norm = measure_norm(&perturbation, config.norm);
754
755 Ok(AdversarialExample {
756 original: input.to_vec(),
757 perturbed,
758 perturbation,
759 perturbation_norm,
760 n_iterations: config.n_steps,
761 })
762}
763
764pub fn adversarial_training_loss(
778 model: &dyn AttackModel,
779 loss: &dyn AttackLoss,
780 inputs: &[Vec<f64>],
781 labels: &[Vec<f64>],
782 config: &AttackConfig,
783 alpha: f64,
784 seed: u64,
785) -> Result<(f64, AdversarialTrainStats), AdversarialError> {
786 if inputs.is_empty() {
787 return Ok((0.0, AdversarialTrainStats::default()));
788 }
789 if inputs.len() != labels.len() {
790 return Err(AdversarialError::DimensionMismatch {
791 expected: inputs.len(),
792 got: labels.len(),
793 });
794 }
795
796 let mut total_clean = 0.0_f64;
797 let mut total_adv = 0.0_f64;
798 let mut total_norm = 0.0_f64;
799 let n = inputs.len();
800
801 for (i, (x, y)) in inputs.iter().zip(labels.iter()).enumerate() {
802 let preds_clean = model.forward(x);
804 total_clean += loss.loss(&preds_clean, y);
805
806 let sample_seed = seed.wrapping_add((i as u64).wrapping_mul(0x9e3779b97f4a7c15));
808 let adv_ex = pgd(model, loss, x, y, config, sample_seed)?;
809 let preds_adv = model.forward(&adv_ex.perturbed);
810 total_adv += loss.loss(&preds_adv, y);
811 total_norm += adv_ex.perturbation_norm;
812 }
813
814 let mean_clean = total_clean / n as f64;
815 let mean_adv = total_adv / n as f64;
816 let combined = alpha * mean_clean + (1.0 - alpha) * mean_adv;
817
818 let stats = AdversarialTrainStats {
819 n_samples: n,
820 mean_perturbation_norm: total_norm / n as f64,
821 clean_loss: mean_clean,
822 adversarial_loss: mean_adv,
823 combined_loss: combined,
824 };
825
826 Ok((combined, stats))
827}
828
829pub fn robustness_eval(
843 model: &dyn AttackModel,
844 inputs: &[Vec<f64>],
845 labels: &[Vec<f64>],
846 config: &AttackConfig,
847 seed: u64,
848) -> Result<f64, AdversarialError> {
849 if inputs.is_empty() {
850 return Ok(1.0);
851 }
852 if inputs.len() != labels.len() {
853 return Err(AdversarialError::DimensionMismatch {
854 expected: inputs.len(),
855 got: labels.len(),
856 });
857 }
858
859 let mut robust_count = 0_usize;
860 let n = inputs.len();
861
862 for (i, (x, y)) in inputs.iter().zip(labels.iter()).enumerate() {
863 let clean_preds = model.forward(x);
864 let clean_argmax = argmax_vec(&clean_preds);
865 let label_argmax = argmax_vec(y);
866
867 if clean_argmax != label_argmax {
869 continue;
871 }
872
873 let sample_seed = seed.wrapping_add((i as u64).wrapping_mul(0x6c62272e07bb0142));
874 let adv_ex = pgd(model, loss_for_eval(), x, y, config, sample_seed)?;
875 let adv_preds = model.forward(&adv_ex.perturbed);
876 let adv_argmax = argmax_vec(&adv_preds);
877
878 if adv_argmax == clean_argmax {
879 robust_count += 1;
880 }
881 }
882
883 Ok(robust_count as f64 / n as f64)
884}
885
886fn loss_for_eval() -> &'static CrossEntropyAttackLoss {
888 static LOSS: CrossEntropyAttackLoss = CrossEntropyAttackLoss;
889 &LOSS
890}
891
892fn argmax_vec(v: &[f64]) -> usize {
894 v.iter()
895 .enumerate()
896 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
897 .map(|(i, _)| i)
898 .unwrap_or(0)
899}
900
901#[cfg(test)]
906mod tests {
907 use super::*;
908 use approx::assert_abs_diff_eq;
909
910 fn identity_model_2x2() -> LinearAttackModel {
914 LinearAttackModel::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])
915 .expect("valid model")
916 }
917
918 fn default_config() -> AttackConfig {
919 AttackConfig::new(0.1).expect("valid epsilon")
920 }
921
922 #[test]
925 fn test_fgsm_linf_norm_le_epsilon() {
926 let model = identity_model_2x2();
927 let loss = MseAttackLoss;
928 let input = vec![0.5, 0.5];
929 let labels = vec![1.0, 0.0];
930 let config = default_config();
931 let ex = fgsm(&model, &loss, &input, &labels, &config).expect("fgsm ok");
932 assert!(ex.perturbation_linf() <= config.epsilon + 1e-10);
933 }
934
935 #[test]
936 fn test_fgsm_changes_input_when_gradient_nonzero() {
937 let model = identity_model_2x2();
938 let loss = MseAttackLoss;
939 let input = vec![0.5, 0.3];
940 let labels = vec![1.0, 0.0]; let config = default_config();
942 let ex = fgsm(&model, &loss, &input, &labels, &config).expect("fgsm ok");
943 let norm: f64 = ex.perturbation.iter().map(|&d| d * d).sum::<f64>().sqrt();
944 assert!(norm > 1e-10, "perturbation should be non-zero");
945 }
946
947 #[test]
948 fn test_fgsm_zero_gradient_produces_zero_perturbation() {
949 let model = identity_model_2x2();
950 let loss = MseAttackLoss;
951 let input = vec![0.5, 0.5];
953 let labels = vec![0.5, 0.5];
954 let config = default_config();
955 let ex = fgsm(&model, &loss, &input, &labels, &config).expect("fgsm ok");
956 assert_abs_diff_eq!(ex.perturbation_linf(), 0.0, epsilon = 1e-12);
957 }
958
959 #[test]
960 fn test_fgsm_all_dims_within_epsilon() {
961 let model = identity_model_2x2();
962 let loss = MseAttackLoss;
963 let input = vec![0.2, 0.8];
964 let labels = vec![0.0, 1.0];
965 let config = AttackConfig::new(0.05).expect("ok");
966 let ex = fgsm(&model, &loss, &input, &labels, &config).expect("fgsm ok");
967 for &d in &ex.perturbation {
968 assert!(d.abs() <= 0.05 + 1e-10, "component {d} exceeds epsilon");
969 }
970 }
971
972 #[test]
975 fn test_pgd_linf_norm_le_epsilon() {
976 let model = identity_model_2x2();
977 let loss = MseAttackLoss;
978 let input = vec![0.5, 0.5];
979 let labels = vec![1.0, 0.0];
980 let config = default_config();
981 let ex = pgd(&model, &loss, &input, &labels, &config, 42).expect("pgd ok");
982 assert!(ex.perturbation_linf() <= config.epsilon + 1e-10);
983 }
984
985 #[test]
986 fn test_pgd_n_steps_1_matches_fgsm_linf() {
987 let model = identity_model_2x2();
988 let loss = MseAttackLoss;
989 let input = vec![0.3, 0.7];
990 let labels = vec![1.0, 0.0];
991 let eps = 0.1_f64;
992 let config_fgsm = AttackConfig::new(eps)
994 .expect("ok")
995 .with_step_size(eps)
996 .expect("ok")
997 .with_steps(1)
998 .expect("ok");
999 let config_pgd = config_fgsm.clone();
1000 let ex_fgsm = fgsm(&model, &loss, &input, &labels, &config_fgsm).expect("ok");
1001 let ex_pgd = pgd(&model, &loss, &input, &labels, &config_pgd, 0).expect("ok");
1002 for (df, dp) in ex_fgsm.perturbation.iter().zip(ex_pgd.perturbation.iter()) {
1003 assert_abs_diff_eq!(df, dp, epsilon = 1e-10);
1004 }
1005 }
1006
1007 #[test]
1008 fn test_pgd_iterates_more_than_one() {
1009 let model = identity_model_2x2();
1010 let loss = MseAttackLoss;
1011 let input = vec![0.5, 0.5];
1012 let labels = vec![1.0, 0.0];
1013 let config = AttackConfig::new(0.1)
1014 .expect("ok")
1015 .with_steps(10)
1016 .expect("ok");
1017 let ex = pgd(&model, &loss, &input, &labels, &config, 7).expect("ok");
1018 assert_eq!(ex.n_iterations, 10);
1019 }
1020
1021 #[test]
1024 fn test_project_linf_clamps_each_dim() {
1025 let delta = vec![0.2, -0.3, 0.05, -0.01];
1026 let eps = 0.1;
1027 let proj = project_linf(&delta, eps);
1028 for &d in &proj {
1029 assert!(d >= -eps - 1e-12 && d <= eps + 1e-12);
1030 }
1031 assert_abs_diff_eq!(proj[0], 0.1, epsilon = 1e-10);
1032 assert_abs_diff_eq!(proj[1], -0.1, epsilon = 1e-10);
1033 assert_abs_diff_eq!(proj[2], 0.05, epsilon = 1e-10);
1034 }
1035
1036 #[test]
1037 fn test_project_l2_result_within_epsilon() {
1038 let delta = vec![0.3, 0.4]; let eps = 0.1;
1040 let proj = project_l2(&delta, eps);
1041 let norm: f64 = proj.iter().map(|&d| d * d).sum::<f64>().sqrt();
1042 assert!(norm <= eps + 1e-10, "L2 norm {norm} exceeds epsilon {eps}");
1043 }
1044
1045 #[test]
1046 fn test_project_l2_identity_when_within_ball() {
1047 let delta = vec![0.03, 0.04]; let eps = 0.1;
1049 let proj = project_l2(&delta, eps);
1050 assert_abs_diff_eq!(proj[0], 0.03, epsilon = 1e-10);
1051 assert_abs_diff_eq!(proj[1], 0.04, epsilon = 1e-10);
1052 }
1053
1054 #[test]
1057 fn test_cross_entropy_grad_finite_difference() {
1058 let ce = CrossEntropyAttackLoss;
1059 let preds = vec![1.0, 0.5, -0.5];
1060 let labels = vec![1.0, 0.0, 0.0];
1061 let grad = ce.grad(&preds, &labels);
1062 let h = 1e-5_f64;
1063 for i in 0..preds.len() {
1064 let mut p_plus = preds.clone();
1065 let mut p_minus = preds.clone();
1066 p_plus[i] += h;
1067 p_minus[i] -= h;
1068 let fd = (ce.loss(&p_plus, &labels) - ce.loss(&p_minus, &labels)) / (2.0 * h);
1069 assert_abs_diff_eq!(grad[i], fd, epsilon = 1e-5);
1070 }
1071 }
1072
1073 #[test]
1076 fn test_mse_loss_zero_for_equal_predictions_and_labels() {
1077 let mse = MseAttackLoss;
1078 let v = vec![0.1, 0.5, -0.3];
1079 assert_abs_diff_eq!(mse.loss(&v, &v), 0.0, epsilon = 1e-12);
1080 }
1081
1082 #[test]
1083 fn test_mse_grad_zero_for_equal_predictions_and_labels() {
1084 let mse = MseAttackLoss;
1085 let v = vec![0.1, 0.5, -0.3];
1086 let grad = mse.grad(&v, &v);
1087 for &g in &grad {
1088 assert_abs_diff_eq!(g, 0.0, epsilon = 1e-12);
1089 }
1090 }
1091
1092 #[test]
1095 fn test_linear_model_forward_correct_dimension() {
1096 let model = identity_model_2x2();
1097 let preds = model.forward(&[0.3, 0.7]);
1098 assert_eq!(preds.len(), 2);
1099 }
1100
1101 #[test]
1102 fn test_linear_model_forward_correct_values() {
1103 let model = identity_model_2x2();
1104 let preds = model.forward(&[0.3, 0.7]);
1105 assert_abs_diff_eq!(preds[0], 0.3, epsilon = 1e-12);
1106 assert_abs_diff_eq!(preds[1], 0.7, epsilon = 1e-12);
1107 }
1108
1109 #[test]
1110 fn test_linear_model_input_gradient_finite_difference() {
1111 let model = LinearAttackModel::new(
1113 vec![vec![2.0, -1.0], vec![0.5, 3.0], vec![-1.0, 1.0]],
1114 vec![0.0, 0.0, 0.0],
1115 )
1116 .expect("ok");
1117 let input = vec![0.4, 0.6];
1118 let out_grad = vec![1.0, 0.0, 0.0]; let analytic = model.input_gradient(&input, &out_grad);
1120 let h = 1e-5_f64;
1122 for j in 0..input.len() {
1123 let mut ip = input.clone();
1124 let mut im = input.clone();
1125 ip[j] += h;
1126 im[j] -= h;
1127 let fd: f64 = model
1128 .forward(&ip)
1129 .iter()
1130 .zip(model.forward(&im).iter())
1131 .zip(out_grad.iter())
1132 .map(|((&fp, &fm), &g)| g * (fp - fm) / (2.0 * h))
1133 .sum();
1134 assert_abs_diff_eq!(analytic[j], fd, epsilon = 1e-6);
1135 }
1136 }
1137
1138 #[test]
1141 fn test_adversarial_example_perturbation_equals_diff() {
1142 let model = identity_model_2x2();
1143 let loss = MseAttackLoss;
1144 let input = vec![0.3, 0.7];
1145 let labels = vec![1.0, 0.0];
1146 let config = default_config();
1147 let ex = fgsm(&model, &loss, &input, &labels, &config).expect("ok");
1148 for (i, (&p, &o)) in ex.perturbed.iter().zip(ex.original.iter()).enumerate() {
1149 assert_abs_diff_eq!(ex.perturbation[i], p - o, epsilon = 1e-12);
1150 }
1151 }
1152
1153 #[test]
1154 fn test_adversarial_example_linf_le_epsilon() {
1155 let model = identity_model_2x2();
1156 let loss = MseAttackLoss;
1157 let input = vec![0.3, 0.7];
1158 let labels = vec![1.0, 0.0];
1159 let config = AttackConfig::new(0.05).expect("ok");
1160 let ex = fgsm(&model, &loss, &input, &labels, &config).expect("ok");
1161 assert!(ex.perturbation_linf() <= 0.05 + 1e-10);
1162 }
1163
1164 #[test]
1167 fn test_attack_config_negative_epsilon_is_error() {
1168 let result = AttackConfig::new(-0.1);
1169 assert!(
1170 matches!(result, Err(AdversarialError::InvalidEpsilon(_))),
1171 "expected InvalidEpsilon"
1172 );
1173 }
1174
1175 #[test]
1176 fn test_attack_config_zero_epsilon_is_error() {
1177 let result = AttackConfig::new(0.0);
1178 assert!(matches!(result, Err(AdversarialError::InvalidEpsilon(_))));
1179 }
1180
1181 #[test]
1182 fn test_attack_config_zero_steps_is_error() {
1183 let result = AttackConfig::new(0.1).expect("ok").with_steps(0);
1184 assert!(matches!(
1185 result,
1186 Err(AdversarialError::InvalidIterations(0))
1187 ));
1188 }
1189
1190 #[test]
1193 fn test_adversarial_training_loss_alpha_one_equals_clean_loss() {
1194 let model = identity_model_2x2();
1195 let loss = MseAttackLoss;
1196 let inputs = vec![vec![0.5_f64, 0.5_f64]];
1197 let labels = vec![vec![1.0_f64, 0.0_f64]];
1198 let config = AttackConfig::new(0.1)
1200 .expect("ok")
1201 .with_steps(1)
1202 .expect("ok");
1203 let (combined, stats) =
1204 adversarial_training_loss(&model, &loss, &inputs, &labels, &config, 1.0, 0)
1205 .expect("ok");
1206 assert_abs_diff_eq!(combined, stats.clean_loss, epsilon = 1e-10);
1207 }
1208
1209 #[test]
1210 fn test_adversarial_training_loss_alpha_zero_equals_adversarial_loss() {
1211 let model = identity_model_2x2();
1212 let loss = MseAttackLoss;
1213 let inputs = vec![vec![0.5_f64, 0.5_f64]];
1214 let labels = vec![vec![1.0_f64, 0.0_f64]];
1215 let config = AttackConfig::new(0.1)
1216 .expect("ok")
1217 .with_steps(1)
1218 .expect("ok");
1219 let (combined, stats) =
1220 adversarial_training_loss(&model, &loss, &inputs, &labels, &config, 0.0, 0)
1221 .expect("ok");
1222 assert_abs_diff_eq!(combined, stats.adversarial_loss, epsilon = 1e-10);
1223 }
1224
1225 #[test]
1228 fn test_robustness_eval_result_in_0_1() {
1229 let model = identity_model_2x2();
1230 let inputs = vec![
1231 vec![0.9_f64, 0.1_f64], vec![0.1_f64, 0.9_f64], ];
1234 let labels = vec![vec![1.0_f64, 0.0_f64], vec![0.0_f64, 1.0_f64]];
1235 let config = AttackConfig::new(0.05)
1236 .expect("ok")
1237 .with_steps(5)
1238 .expect("ok");
1239 let frac = robustness_eval(&model, &inputs, &labels, &config, 42).expect("ok");
1240 assert!(
1241 (0.0..=1.0).contains(&frac),
1242 "robustness fraction {frac} out of range"
1243 );
1244 }
1245
1246 #[test]
1247 fn test_robustness_eval_empty_inputs() {
1248 let model = identity_model_2x2();
1249 let config = default_config();
1250 let frac = robustness_eval(&model, &[], &[], &config, 0).expect("ok");
1251 assert_abs_diff_eq!(frac, 1.0, epsilon = 1e-12);
1252 }
1253
1254 #[test]
1257 fn test_adversarial_train_stats_n_samples() {
1258 let model = identity_model_2x2();
1259 let loss = MseAttackLoss;
1260 let inputs = vec![
1261 vec![0.5_f64, 0.5_f64],
1262 vec![0.2_f64, 0.8_f64],
1263 vec![0.7_f64, 0.3_f64],
1264 ];
1265 let labels = vec![
1266 vec![1.0_f64, 0.0_f64],
1267 vec![0.0_f64, 1.0_f64],
1268 vec![1.0_f64, 0.0_f64],
1269 ];
1270 let config = AttackConfig::new(0.1)
1271 .expect("ok")
1272 .with_steps(2)
1273 .expect("ok");
1274 let (_, stats) =
1275 adversarial_training_loss(&model, &loss, &inputs, &labels, &config, 0.5, 1)
1276 .expect("ok");
1277 assert_eq!(stats.n_samples, 3);
1278 assert!(stats.mean_perturbation_norm >= 0.0);
1279 }
1280
1281 #[test]
1282 fn test_adversarial_train_stats_combined_loss_between_clean_and_adv() {
1283 let model = identity_model_2x2();
1284 let loss = MseAttackLoss;
1285 let inputs = vec![vec![0.5_f64, 0.5_f64]];
1286 let labels = vec![vec![1.0_f64, 0.0_f64]];
1287 let config = AttackConfig::new(0.1)
1288 .expect("ok")
1289 .with_steps(3)
1290 .expect("ok");
1291 let alpha = 0.5;
1292 let (combined, stats) =
1293 adversarial_training_loss(&model, &loss, &inputs, &labels, &config, alpha, 99)
1294 .expect("ok");
1295 let expected = alpha * stats.clean_loss + (1.0 - alpha) * stats.adversarial_loss;
1296 assert_abs_diff_eq!(combined, expected, epsilon = 1e-10);
1297 }
1298
1299 #[test]
1302 fn test_pgd_random_start_stays_within_epsilon() {
1303 let model = identity_model_2x2();
1304 let loss = MseAttackLoss;
1305 let input = vec![0.5_f64, 0.5_f64];
1306 let labels = vec![1.0_f64, 0.0_f64];
1307 let config = AttackConfig::new(0.1)
1308 .expect("ok")
1309 .with_steps(5)
1310 .expect("ok")
1311 .with_random_start(true);
1312 let ex = pgd(&model, &loss, &input, &labels, &config, 12345).expect("ok");
1313 assert!(ex.perturbation_linf() <= 0.1 + 1e-10);
1314 }
1315
1316 #[test]
1317 fn test_fgsm_l2_norm_attack() {
1318 let model = identity_model_2x2();
1319 let loss = MseAttackLoss;
1320 let input = vec![0.3, 0.7];
1321 let labels = vec![0.0, 1.0];
1322 let config = AttackConfig::new(0.1)
1323 .expect("ok")
1324 .with_norm(PerturbNorm::L2);
1325 let ex = fgsm(&model, &loss, &input, &labels, &config).expect("ok");
1326 assert!(ex.perturbation_l2() <= 0.1 + 1e-10);
1327 }
1328
1329 #[test]
1330 fn test_fgsm_l1_norm_attack_single_nonzero_component() {
1331 let model = identity_model_2x2();
1332 let loss = MseAttackLoss;
1333 let input = vec![0.3, 0.7];
1334 let labels = vec![1.0, 0.0];
1335 let config = AttackConfig::new(0.1)
1336 .expect("ok")
1337 .with_norm(PerturbNorm::L1);
1338 let ex = fgsm(&model, &loss, &input, &labels, &config).expect("ok");
1339 let nonzero: Vec<f64> = ex
1341 .perturbation
1342 .iter()
1343 .cloned()
1344 .filter(|&d| d.abs() > 1e-12)
1345 .collect();
1346 assert_eq!(
1347 nonzero.len(),
1348 1,
1349 "L1 FGSM should perturb exactly one dimension"
1350 );
1351 }
1352
1353 #[test]
1354 fn test_linear_model_construction_invalid_bias_len() {
1355 let result = LinearAttackModel::new(
1356 vec![vec![1.0, 0.0], vec![0.0, 1.0]],
1357 vec![0.0], );
1359 assert!(result.is_err());
1360 }
1361}