Skip to main content

tensorlogic_train/
adversarial.rs

1//! Adversarial training utilities for TensorLogic.
2//!
3//! Provides FGSM (Fast Gradient Sign Method), PGD (Projected Gradient Descent),
4//! adversarial example generation, adversarial training loss, and robustness evaluation.
5//!
6//! # References
7//! - Goodfellow et al. (2014): "Explaining and Harnessing Adversarial Examples" (FGSM)
8//! - Madry et al. (2017): "Towards Deep Learning Models Resistant to Adversarial Attacks" (PGD)
9
10use std::fmt;
11
12// ─────────────────────────────────────────────────────────────────────────────
13// Error type
14// ─────────────────────────────────────────────────────────────────────────────
15
16/// Errors that can arise during adversarial attack construction or execution.
17#[derive(Debug)]
18pub enum AdversarialError {
19    /// Input and label dimensions did not match what the model expects.
20    DimensionMismatch { expected: usize, got: usize },
21    /// The epsilon (perturbation budget) is not strictly positive.
22    InvalidEpsilon(f64),
23    /// The per-step step-size is not strictly positive.
24    InvalidStepSize(f64),
25    /// The number of PGD iterations must be at least 1.
26    InvalidIterations(usize),
27    /// Gradient computation produced a non-finite value or other failure.
28    GradientComputationFailed(String),
29}
30
31impl fmt::Display for AdversarialError {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        match self {
34            AdversarialError::DimensionMismatch { expected, got } => {
35                write!(f, "dimension mismatch: expected {expected} but got {got}")
36            }
37            AdversarialError::InvalidEpsilon(e) => {
38                write!(f, "epsilon must be strictly positive, got {e}")
39            }
40            AdversarialError::InvalidStepSize(s) => {
41                write!(f, "step_size must be strictly positive, got {s}")
42            }
43            AdversarialError::InvalidIterations(n) => write!(f, "n_steps must be >= 1, got {n}"),
44            AdversarialError::GradientComputationFailed(msg) => {
45                write!(f, "gradient computation failed: {msg}")
46            }
47        }
48    }
49}
50
51impl std::error::Error for AdversarialError {}
52
53// ─────────────────────────────────────────────────────────────────────────────
54// Norm type
55// ─────────────────────────────────────────────────────────────────────────────
56
57/// The norm used to measure and project the adversarial perturbation.
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum PerturbNorm {
60    /// L∞ constraint: max |δᵢ| ≤ ε.
61    LInf,
62    /// L2 constraint: ‖δ‖₂ ≤ ε.
63    L2,
64    /// L1 constraint: ‖δ‖₁ ≤ ε.
65    L1,
66}
67
68// ─────────────────────────────────────────────────────────────────────────────
69// AdversarialExample
70// ─────────────────────────────────────────────────────────────────────────────
71
72/// The result of running an adversarial attack on a single input.
73#[derive(Debug, Clone)]
74pub struct AdversarialExample {
75    /// The clean (unperturbed) input.
76    pub original: Vec<f64>,
77    /// The perturbed input `original + perturbation`.
78    pub perturbed: Vec<f64>,
79    /// The additive perturbation δ = perturbed − original.
80    pub perturbation: Vec<f64>,
81    /// The actual norm of the perturbation (measured in the configured norm).
82    pub perturbation_norm: f64,
83    /// Number of attack iterations performed (1 for FGSM).
84    pub n_iterations: usize,
85}
86
87impl AdversarialExample {
88    /// L∞ norm of the perturbation: max |δᵢ|.
89    pub fn perturbation_linf(&self) -> f64 {
90        self.perturbation
91            .iter()
92            .map(|v| v.abs())
93            .fold(0.0_f64, f64::max)
94    }
95
96    /// L2 norm of the perturbation: √(Σ δᵢ²).
97    pub fn perturbation_l2(&self) -> f64 {
98        self.perturbation.iter().map(|v| v * v).sum::<f64>().sqrt()
99    }
100
101    /// L1 norm of the perturbation: Σ |δᵢ|.
102    pub fn perturbation_l1(&self) -> f64 {
103        self.perturbation.iter().map(|v| v.abs()).sum()
104    }
105}
106
107// ─────────────────────────────────────────────────────────────────────────────
108// AttackLoss trait
109// ─────────────────────────────────────────────────────────────────────────────
110
111/// A differentiable loss function used by attack algorithms.
112///
113/// Both `loss` and `grad` receive raw model outputs (logits or probabilities)
114/// and target labels, and must be thread-safe.
115pub trait AttackLoss: Send + Sync {
116    /// Compute the scalar loss value.
117    fn loss(&self, predictions: &[f64], labels: &[f64]) -> f64;
118
119    /// Compute the gradient of the loss with respect to `predictions`.
120    fn grad(&self, predictions: &[f64], labels: &[f64]) -> Vec<f64>;
121}
122
123// ─────────────────────────────────────────────────────────────────────────────
124// CrossEntropyAttackLoss
125// ─────────────────────────────────────────────────────────────────────────────
126
127/// Cross-entropy loss for multi-class classification attacks.
128///
129/// Applies softmax internally:
130/// - loss = −Σ yᵢ · log(softmax(zᵢ) + ε)
131/// - grad = softmax(z) − y
132pub struct CrossEntropyAttackLoss;
133
134impl CrossEntropyAttackLoss {
135    fn softmax(logits: &[f64]) -> Vec<f64> {
136        let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
137        let exp: Vec<f64> = logits.iter().map(|&z| (z - max_val).exp()).collect();
138        let sum: f64 = exp.iter().sum();
139        if sum == 0.0 {
140            vec![1.0 / logits.len() as f64; logits.len()]
141        } else {
142            exp.iter().map(|&e| e / sum).collect()
143        }
144    }
145}
146
147impl AttackLoss for CrossEntropyAttackLoss {
148    fn loss(&self, predictions: &[f64], labels: &[f64]) -> f64 {
149        let probs = Self::softmax(predictions);
150        const EPS: f64 = 1e-12;
151        -probs
152            .iter()
153            .zip(labels.iter())
154            .map(|(&p, &y)| y * (p + EPS).ln())
155            .sum::<f64>()
156    }
157
158    fn grad(&self, predictions: &[f64], labels: &[f64]) -> Vec<f64> {
159        let probs = Self::softmax(predictions);
160        probs
161            .iter()
162            .zip(labels.iter())
163            .map(|(&p, &y)| p - y)
164            .collect()
165    }
166}
167
168// ─────────────────────────────────────────────────────────────────────────────
169// MseAttackLoss
170// ─────────────────────────────────────────────────────────────────────────────
171
172/// Mean-squared-error loss for regression attacks.
173///
174/// - loss = mean((predictions − labels)²)
175/// - grad = 2 · (predictions − labels) / n
176pub struct MseAttackLoss;
177
178impl AttackLoss for MseAttackLoss {
179    fn loss(&self, predictions: &[f64], labels: &[f64]) -> f64 {
180        let n = predictions.len() as f64;
181        predictions
182            .iter()
183            .zip(labels.iter())
184            .map(|(&p, &y)| (p - y).powi(2))
185            .sum::<f64>()
186            / n
187    }
188
189    fn grad(&self, predictions: &[f64], labels: &[f64]) -> Vec<f64> {
190        let n = predictions.len() as f64;
191        predictions
192            .iter()
193            .zip(labels.iter())
194            .map(|(&p, &y)| 2.0 * (p - y) / n)
195            .collect()
196    }
197}
198
199// ─────────────────────────────────────────────────────────────────────────────
200// AttackModel trait
201// ─────────────────────────────────────────────────────────────────────────────
202
203/// A model that can be attacked.
204///
205/// Implementors provide a forward pass; `input_gradient` has a default finite-
206/// difference implementation that can be overridden for efficiency.
207pub trait AttackModel: Send + Sync {
208    /// Forward pass: given an input slice, return predictions (logits or probs).
209    fn forward(&self, input: &[f64]) -> Vec<f64>;
210
211    /// Gradient of the scalar `output_grad · forward(input)` w.r.t. input,
212    /// via reverse-mode chain rule if available, otherwise via finite differences.
213    ///
214    /// `output_grad` has the same length as `forward(input)`.
215    fn input_gradient(&self, input: &[f64], output_grad: &[f64]) -> Vec<f64> {
216        // Default: forward-mode finite differences
217        const H: f64 = 1e-5;
218        let mut grad_in = vec![0.0_f64; input.len()];
219        let mut x_plus = input.to_vec();
220        let mut x_minus = input.to_vec();
221        for i in 0..input.len() {
222            x_plus[i] = input[i] + H;
223            x_minus[i] = input[i] - H;
224            let f_plus = self.forward(&x_plus);
225            let f_minus = self.forward(&x_minus);
226            grad_in[i] = f_plus
227                .iter()
228                .zip(f_minus.iter())
229                .zip(output_grad.iter())
230                .map(|((&fp, &fm), &g)| g * (fp - fm) / (2.0 * H))
231                .sum::<f64>();
232            x_plus[i] = input[i];
233            x_minus[i] = input[i];
234        }
235        grad_in
236    }
237}
238
239// ─────────────────────────────────────────────────────────────────────────────
240// LinearAttackModel
241// ─────────────────────────────────────────────────────────────────────────────
242
243/// A simple linear model `f(x) = W·x + b` used primarily for testing attacks.
244pub struct LinearAttackModel {
245    /// Weight matrix: `weights[i]` is the i-th output row (length = n_inputs).
246    pub weights: Vec<Vec<f64>>,
247    /// Bias vector (length = n_outputs).
248    pub bias: Vec<f64>,
249}
250
251impl LinearAttackModel {
252    /// Construct a new linear model, validating that all rows have the same length.
253    pub fn new(weights: Vec<Vec<f64>>, bias: Vec<f64>) -> Result<Self, AdversarialError> {
254        if weights.is_empty() || bias.is_empty() {
255            return Err(AdversarialError::DimensionMismatch {
256                expected: 1,
257                got: 0,
258            });
259        }
260        if weights.len() != bias.len() {
261            return Err(AdversarialError::DimensionMismatch {
262                expected: weights.len(),
263                got: bias.len(),
264            });
265        }
266        let n_inputs = weights[0].len();
267        for (i, row) in weights.iter().enumerate() {
268            if row.len() != n_inputs {
269                return Err(AdversarialError::DimensionMismatch {
270                    expected: n_inputs,
271                    got: row.len(),
272                });
273            }
274            let _ = i; // suppress unused warning
275        }
276        Ok(Self { weights, bias })
277    }
278}
279
280impl AttackModel for LinearAttackModel {
281    fn forward(&self, input: &[f64]) -> Vec<f64> {
282        self.weights
283            .iter()
284            .zip(self.bias.iter())
285            .map(|(row, &b)| {
286                row.iter()
287                    .zip(input.iter())
288                    .map(|(&w, &x)| w * x)
289                    .sum::<f64>()
290                    + b
291            })
292            .collect()
293    }
294
295    /// Exact analytical gradient for a linear model: ∂(g·Wx)/∂x = Wᵀ·g.
296    fn input_gradient(&self, _input: &[f64], output_grad: &[f64]) -> Vec<f64> {
297        let n_inputs = self.weights[0].len();
298        let mut grad = vec![0.0_f64; n_inputs];
299        for (row, &g) in self.weights.iter().zip(output_grad.iter()) {
300            for (j, &w) in row.iter().enumerate() {
301                grad[j] += w * g;
302            }
303        }
304        grad
305    }
306}
307
308// ─────────────────────────────────────────────────────────────────────────────
309// AttackConfig
310// ─────────────────────────────────────────────────────────────────────────────
311
312/// Configuration for an adversarial attack.
313#[derive(Debug, Clone)]
314pub struct AttackConfig {
315    /// Maximum allowed perturbation magnitude (ε > 0).
316    pub epsilon: f64,
317    /// Norm used to constrain the perturbation.
318    pub norm: PerturbNorm,
319    /// Number of iterative steps (used by PGD; FGSM uses 1).
320    pub n_steps: usize,
321    /// Per-step size α.  Defaults to `epsilon / 4.0`.
322    pub step_size: f64,
323    /// If true (PGD), initialise with a random perturbation inside the ε-ball.
324    pub random_start: bool,
325    /// Minimum allowed value for the perturbed input.
326    pub clip_min: f64,
327    /// Maximum allowed value for the perturbed input.
328    pub clip_max: f64,
329}
330
331impl AttackConfig {
332    /// Create a new config with `epsilon` as the perturbation budget.
333    ///
334    /// Defaults: L∞ norm, 40 PGD steps, step_size = ε/4, no random start,
335    /// no input clipping.
336    pub fn new(epsilon: f64) -> Result<Self, AdversarialError> {
337        if epsilon <= 0.0 || !epsilon.is_finite() {
338            return Err(AdversarialError::InvalidEpsilon(epsilon));
339        }
340        Ok(Self {
341            epsilon,
342            norm: PerturbNorm::LInf,
343            n_steps: 40,
344            step_size: epsilon / 4.0,
345            random_start: false,
346            clip_min: f64::NEG_INFINITY,
347            clip_max: f64::INFINITY,
348        })
349    }
350
351    /// Override the perturbation norm.
352    pub fn with_norm(mut self, norm: PerturbNorm) -> Self {
353        self.norm = norm;
354        self
355    }
356
357    /// Override the number of PGD steps.  Must be ≥ 1.
358    pub fn with_steps(mut self, n: usize) -> Result<Self, AdversarialError> {
359        if n == 0 {
360            return Err(AdversarialError::InvalidIterations(n));
361        }
362        self.n_steps = n;
363        Ok(self)
364    }
365
366    /// Override the per-step size.  Must be strictly positive.
367    pub fn with_step_size(mut self, s: f64) -> Result<Self, AdversarialError> {
368        if s <= 0.0 || !s.is_finite() {
369            return Err(AdversarialError::InvalidStepSize(s));
370        }
371        self.step_size = s;
372        Ok(self)
373    }
374
375    /// Enable or disable random initialisation of the perturbation.
376    pub fn with_random_start(mut self, b: bool) -> Self {
377        self.random_start = b;
378        self
379    }
380
381    /// Set the input clipping range [min, max].
382    pub fn with_clip(mut self, min: f64, max: f64) -> Self {
383        self.clip_min = min;
384        self.clip_max = max;
385        self
386    }
387}
388
389// ─────────────────────────────────────────────────────────────────────────────
390// AdversarialTrainStats
391// ─────────────────────────────────────────────────────────────────────────────
392
393/// Summary statistics collected during adversarial training over a batch.
394#[derive(Debug, Default, Clone)]
395pub struct AdversarialTrainStats {
396    /// Number of samples processed.
397    pub n_samples: usize,
398    /// Average L∞ (or configured-norm) magnitude of the adversarial perturbations.
399    pub mean_perturbation_norm: f64,
400    /// Mean clean loss across the batch.
401    pub clean_loss: f64,
402    /// Mean adversarial loss across the batch.
403    pub adversarial_loss: f64,
404    /// Combined loss: α · clean + (1−α) · adversarial.
405    pub combined_loss: f64,
406}
407
408// ─────────────────────────────────────────────────────────────────────────────
409// Projection helpers
410// ─────────────────────────────────────────────────────────────────────────────
411
412/// Project `perturbation` onto the L∞ ball of radius `epsilon`.
413///
414/// Each component is clamped independently to [−ε, ε].
415pub fn project_linf(perturbation: &[f64], epsilon: f64) -> Vec<f64> {
416    perturbation
417        .iter()
418        .map(|&d| d.clamp(-epsilon, epsilon))
419        .collect()
420}
421
422/// Project `perturbation` onto the L2 ball of radius `epsilon`.
423///
424/// If ‖δ‖₂ > ε, the vector is scaled down to have norm exactly ε.
425pub fn project_l2(perturbation: &[f64], epsilon: f64) -> Vec<f64> {
426    let norm: f64 = perturbation.iter().map(|&d| d * d).sum::<f64>().sqrt();
427    if norm <= epsilon || norm == 0.0 {
428        perturbation.to_vec()
429    } else {
430        perturbation.iter().map(|&d| d * epsilon / norm).collect()
431    }
432}
433
434/// Project `perturbation` onto the L1 ball of radius `epsilon`.
435///
436/// Uses the classic Duchi et al. (2008) algorithm via sorting of absolute values.
437pub fn project_l1(perturbation: &[f64], epsilon: f64) -> Vec<f64> {
438    let l1: f64 = perturbation.iter().map(|&d| d.abs()).sum();
439    if l1 <= epsilon {
440        return perturbation.to_vec();
441    }
442    // Compute the soft-threshold via the simplex projection on |δ|/l1.
443    let n = perturbation.len();
444    let mut abs_sorted: Vec<f64> = perturbation.iter().map(|&d| d.abs()).collect();
445    abs_sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
446
447    let mut cumsum = 0.0_f64;
448    let mut rho = 0_usize;
449    for (i, &v) in abs_sorted.iter().enumerate() {
450        cumsum += v;
451        if v > (cumsum - epsilon) / (i as f64 + 1.0) {
452            rho = i;
453        }
454    }
455    let cumsum_rho: f64 = abs_sorted[..=rho].iter().sum();
456    let theta = (cumsum_rho - epsilon) / (rho as f64 + 1.0);
457
458    (0..n)
459        .map(|i| {
460            let sign = if perturbation[i] >= 0.0 { 1.0 } else { -1.0 };
461            sign * (perturbation[i].abs() - theta).max(0.0)
462        })
463        .collect()
464}
465
466// ─────────────────────────────────────────────────────────────────────────────
467// Internal helpers
468// ─────────────────────────────────────────────────────────────────────────────
469
470/// Compute ∇_x L(f(x), y) = J_x^T · ∇_z L(f(x), y).
471fn loss_input_gradient(
472    model: &dyn AttackModel,
473    loss: &dyn AttackLoss,
474    input: &[f64],
475    labels: &[f64],
476) -> Result<Vec<f64>, AdversarialError> {
477    let predictions = model.forward(input);
478    let loss_grad = loss.grad(&predictions, labels); // ∂L/∂z
479    let input_grad = model.input_gradient(input, &loss_grad); // ∂L/∂x
480
481    // Validate that all values are finite.
482    for &g in &input_grad {
483        if !g.is_finite() {
484            return Err(AdversarialError::GradientComputationFailed(
485                "non-finite value in input gradient".to_string(),
486            ));
487        }
488    }
489    Ok(input_grad)
490}
491
492/// Clip `x` component-wise to the configured [clip_min, clip_max] range.
493#[inline]
494fn clip_input(x: &[f64], config: &AttackConfig) -> Vec<f64> {
495    x.iter()
496        .map(|&v| v.clamp(config.clip_min, config.clip_max))
497        .collect()
498}
499
500/// Project a perturbation δ onto the ε-ball determined by the configured norm.
501fn project(perturbation: &[f64], config: &AttackConfig) -> Vec<f64> {
502    match config.norm {
503        PerturbNorm::LInf => project_linf(perturbation, config.epsilon),
504        PerturbNorm::L2 => project_l2(perturbation, config.epsilon),
505        PerturbNorm::L1 => project_l1(perturbation, config.epsilon),
506    }
507}
508
509/// Measure the norm of `perturbation` under the configured `norm`.
510fn measure_norm(perturbation: &[f64], norm: PerturbNorm) -> f64 {
511    match norm {
512        PerturbNorm::LInf => perturbation
513            .iter()
514            .map(|&d| d.abs())
515            .fold(0.0_f64, f64::max),
516        PerturbNorm::L2 => perturbation.iter().map(|&d| d * d).sum::<f64>().sqrt(),
517        PerturbNorm::L1 => perturbation.iter().map(|&d| d.abs()).sum(),
518    }
519}
520
521// ─────────────────────────────────────────────────────────────────────────────
522// Minimal LCG PRNG (no external rand dependency)
523// ─────────────────────────────────────────────────────────────────────────────
524
525/// A simple 64-bit LCG (Knuth's constants) used only for `random_start`.
526struct Lcg64 {
527    state: u64,
528}
529
530impl Lcg64 {
531    fn new(seed: u64) -> Self {
532        // Ensure non-zero state.
533        Self {
534            state: if seed == 0 { 0xdeadbeef_cafebabe } else { seed },
535        }
536    }
537
538    /// Advance and return the next u64.
539    fn next_u64(&mut self) -> u64 {
540        self.state = self
541            .state
542            .wrapping_mul(6_364_136_223_846_793_005)
543            .wrapping_add(1_442_695_040_888_963_407);
544        self.state
545    }
546
547    /// Return a uniform f64 in (−1, 1).
548    fn next_f64_signed(&mut self) -> f64 {
549        // Map u64 to [0, 1) then shift to (−1, 1).
550        let u = (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64;
551        u * 2.0 - 1.0
552    }
553}
554
555// ─────────────────────────────────────────────────────────────────────────────
556// FGSM
557// ─────────────────────────────────────────────────────────────────────────────
558
559/// Fast Gradient Sign Method (Goodfellow et al., 2014).
560///
561/// Computes a single-step adversarial perturbation:
562///
563/// - L∞: δ = ε · sign(∇_x L)
564/// - L2:  δ = ε · ∇_x L / ‖∇_x L‖₂
565/// - L1:  δ = ε · e_k  where k = argmax |∂L/∂xᵢ|
566pub fn fgsm(
567    model: &dyn AttackModel,
568    loss: &dyn AttackLoss,
569    input: &[f64],
570    labels: &[f64],
571    config: &AttackConfig,
572) -> Result<AdversarialExample, AdversarialError> {
573    let grad = loss_input_gradient(model, loss, input, labels)?;
574
575    let raw_delta: Vec<f64> = match config.norm {
576        PerturbNorm::LInf => grad
577            .iter()
578            .map(|&g| {
579                if g == 0.0 {
580                    0.0
581                } else {
582                    config.epsilon * g.signum()
583                }
584            })
585            .collect(),
586        PerturbNorm::L2 => {
587            let norm: f64 = grad.iter().map(|&g| g * g).sum::<f64>().sqrt();
588            if norm < 1e-12 {
589                vec![0.0; grad.len()]
590            } else {
591                grad.iter().map(|&g| config.epsilon * g / norm).collect()
592            }
593        }
594        PerturbNorm::L1 => {
595            // Largest-coordinate attack: unit vector in the direction of max |gradient|.
596            let argmax = grad
597                .iter()
598                .enumerate()
599                .max_by(|(_, a), (_, b)| {
600                    a.abs()
601                        .partial_cmp(&b.abs())
602                        .unwrap_or(std::cmp::Ordering::Equal)
603                })
604                .map(|(i, _)| i)
605                .unwrap_or(0);
606            let mut delta = vec![0.0_f64; grad.len()];
607            delta[argmax] = config.epsilon * grad[argmax].signum();
608            delta
609        }
610    };
611
612    let perturbed_raw: Vec<f64> = input
613        .iter()
614        .zip(raw_delta.iter())
615        .map(|(&x, &d)| x + d)
616        .collect();
617    let perturbed = clip_input(&perturbed_raw, config);
618
619    let perturbation: Vec<f64> = perturbed
620        .iter()
621        .zip(input.iter())
622        .map(|(&p, &x)| p - x)
623        .collect();
624
625    let perturbation_norm = measure_norm(&perturbation, config.norm);
626
627    Ok(AdversarialExample {
628        original: input.to_vec(),
629        perturbed,
630        perturbation,
631        perturbation_norm,
632        n_iterations: 1,
633    })
634}
635
636// ─────────────────────────────────────────────────────────────────────────────
637// PGD
638// ─────────────────────────────────────────────────────────────────────────────
639
640/// Projected Gradient Descent (Madry et al., 2017).
641///
642/// Iterative attack with optional random initialisation:
643///
644/// ```text
645/// x₀ = x + Uniform(−ε, ε)  [if random_start]
646/// xₜ₊₁ = Proj_{Bε(x)}( clip( xₜ + α · step_direction ) )
647/// ```
648///
649/// Step direction:
650/// - L∞: sign(∇_x L)
651/// - L2:  ∇_x L / ‖∇_x L‖₂
652/// - L1:  argmax-coordinate (greedy Frank-Wolfe step)
653///
654/// `seed` is used only when `config.random_start = true`.
655pub fn pgd(
656    model: &dyn AttackModel,
657    loss: &dyn AttackLoss,
658    input: &[f64],
659    labels: &[f64],
660    config: &AttackConfig,
661    seed: u64,
662) -> Result<AdversarialExample, AdversarialError> {
663    let n = input.len();
664    let mut rng = Lcg64::new(seed);
665
666    // Initialise δ.
667    let mut delta: Vec<f64> = if config.random_start {
668        let raw: Vec<f64> = (0..n)
669            .map(|_| rng.next_f64_signed() * config.epsilon)
670            .collect();
671        project(&raw, config)
672    } else {
673        vec![0.0_f64; n]
674    };
675
676    for _ in 0..config.n_steps {
677        // Construct current adversarial input.
678        let x_adv: Vec<f64> = input
679            .iter()
680            .zip(delta.iter())
681            .map(|(&x, &d)| x + d)
682            .collect();
683        let x_adv = clip_input(&x_adv, config);
684
685        let grad = loss_input_gradient(model, loss, &x_adv, labels)?;
686
687        // Compute step direction.
688        let step: Vec<f64> = match config.norm {
689            PerturbNorm::LInf => grad
690                .iter()
691                .map(|&g| {
692                    if g == 0.0 {
693                        0.0
694                    } else {
695                        config.step_size * g.signum()
696                    }
697                })
698                .collect(),
699            PerturbNorm::L2 => {
700                let norm: f64 = grad.iter().map(|&g| g * g).sum::<f64>().sqrt();
701                if norm < 1e-12 {
702                    vec![0.0; n]
703                } else {
704                    grad.iter().map(|&g| config.step_size * g / norm).collect()
705                }
706            }
707            PerturbNorm::L1 => {
708                let argmax = grad
709                    .iter()
710                    .enumerate()
711                    .max_by(|(_, a), (_, b)| {
712                        a.abs()
713                            .partial_cmp(&b.abs())
714                            .unwrap_or(std::cmp::Ordering::Equal)
715                    })
716                    .map(|(i, _)| i)
717                    .unwrap_or(0);
718                let mut s = vec![0.0_f64; n];
719                s[argmax] = config.step_size * grad[argmax].signum();
720                s
721            }
722        };
723
724        // Update δ and project back onto the ε-ball.
725        let new_x_adv: Vec<f64> = input
726            .iter()
727            .zip(delta.iter())
728            .zip(step.iter())
729            .map(|((&x, &d), &s)| x + d + s)
730            .collect();
731        let new_x_adv = clip_input(&new_x_adv, config);
732
733        delta = new_x_adv
734            .iter()
735            .zip(input.iter())
736            .map(|(&xa, &x)| xa - x)
737            .collect();
738        delta = project(&delta, config);
739    }
740
741    let perturbed: Vec<f64> = input
742        .iter()
743        .zip(delta.iter())
744        .map(|(&x, &d)| (x + d).clamp(config.clip_min, config.clip_max))
745        .collect();
746
747    let perturbation: Vec<f64> = perturbed
748        .iter()
749        .zip(input.iter())
750        .map(|(&p, &x)| p - x)
751        .collect();
752
753    let perturbation_norm = measure_norm(&perturbation, config.norm);
754
755    Ok(AdversarialExample {
756        original: input.to_vec(),
757        perturbed,
758        perturbation,
759        perturbation_norm,
760        n_iterations: config.n_steps,
761    })
762}
763
764// ─────────────────────────────────────────────────────────────────────────────
765// Adversarial training loss
766// ─────────────────────────────────────────────────────────────────────────────
767
768/// Compute the combined adversarial training loss over a batch:
769///
770/// ```text
771/// L = α · L_clean(x, y)  +  (1−α) · L_adv(x+δ*, y)
772/// ```
773///
774/// where δ* is the PGD adversarial perturbation for each sample.
775///
776/// Returns the combined scalar loss and per-batch statistics.
777pub fn adversarial_training_loss(
778    model: &dyn AttackModel,
779    loss: &dyn AttackLoss,
780    inputs: &[Vec<f64>],
781    labels: &[Vec<f64>],
782    config: &AttackConfig,
783    alpha: f64,
784    seed: u64,
785) -> Result<(f64, AdversarialTrainStats), AdversarialError> {
786    if inputs.is_empty() {
787        return Ok((0.0, AdversarialTrainStats::default()));
788    }
789    if inputs.len() != labels.len() {
790        return Err(AdversarialError::DimensionMismatch {
791            expected: inputs.len(),
792            got: labels.len(),
793        });
794    }
795
796    let mut total_clean = 0.0_f64;
797    let mut total_adv = 0.0_f64;
798    let mut total_norm = 0.0_f64;
799    let n = inputs.len();
800
801    for (i, (x, y)) in inputs.iter().zip(labels.iter()).enumerate() {
802        // Clean loss.
803        let preds_clean = model.forward(x);
804        total_clean += loss.loss(&preds_clean, y);
805
806        // PGD adversarial example — vary seed per sample to avoid correlation.
807        let sample_seed = seed.wrapping_add((i as u64).wrapping_mul(0x9e3779b97f4a7c15));
808        let adv_ex = pgd(model, loss, x, y, config, sample_seed)?;
809        let preds_adv = model.forward(&adv_ex.perturbed);
810        total_adv += loss.loss(&preds_adv, y);
811        total_norm += adv_ex.perturbation_norm;
812    }
813
814    let mean_clean = total_clean / n as f64;
815    let mean_adv = total_adv / n as f64;
816    let combined = alpha * mean_clean + (1.0 - alpha) * mean_adv;
817
818    let stats = AdversarialTrainStats {
819        n_samples: n,
820        mean_perturbation_norm: total_norm / n as f64,
821        clean_loss: mean_clean,
822        adversarial_loss: mean_adv,
823        combined_loss: combined,
824    };
825
826    Ok((combined, stats))
827}
828
829// ─────────────────────────────────────────────────────────────────────────────
830// Robustness evaluation
831// ─────────────────────────────────────────────────────────────────────────────
832
833/// Evaluate the model's adversarial robustness on a set of samples.
834///
835/// For each sample the PGD attack is run; a sample is considered "robust" if
836/// the argmax prediction does not change after the attack (for classification),
837/// or equivalently if the adversarial loss is not greater than the clean loss
838/// (for regression).
839///
840/// Returns the fraction of samples that remain correctly classified (robust),
841/// in the range \[0, 1\].
842pub fn robustness_eval(
843    model: &dyn AttackModel,
844    inputs: &[Vec<f64>],
845    labels: &[Vec<f64>],
846    config: &AttackConfig,
847    seed: u64,
848) -> Result<f64, AdversarialError> {
849    if inputs.is_empty() {
850        return Ok(1.0);
851    }
852    if inputs.len() != labels.len() {
853        return Err(AdversarialError::DimensionMismatch {
854            expected: inputs.len(),
855            got: labels.len(),
856        });
857    }
858
859    let mut robust_count = 0_usize;
860    let n = inputs.len();
861
862    for (i, (x, y)) in inputs.iter().zip(labels.iter()).enumerate() {
863        let clean_preds = model.forward(x);
864        let clean_argmax = argmax_vec(&clean_preds);
865        let label_argmax = argmax_vec(y);
866
867        // Only count samples that are correctly classified before the attack.
868        if clean_argmax != label_argmax {
869            // Misclassified even on clean input — not robust by definition.
870            continue;
871        }
872
873        let sample_seed = seed.wrapping_add((i as u64).wrapping_mul(0x6c62272e07bb0142));
874        let adv_ex = pgd(model, loss_for_eval(), x, y, config, sample_seed)?;
875        let adv_preds = model.forward(&adv_ex.perturbed);
876        let adv_argmax = argmax_vec(&adv_preds);
877
878        if adv_argmax == clean_argmax {
879            robust_count += 1;
880        }
881    }
882
883    Ok(robust_count as f64 / n as f64)
884}
885
886/// Internal: build a cross-entropy loss instance for robustness evaluation.
887fn loss_for_eval() -> &'static CrossEntropyAttackLoss {
888    static LOSS: CrossEntropyAttackLoss = CrossEntropyAttackLoss;
889    &LOSS
890}
891
892/// Return the index of the maximum element in `v`.
893fn argmax_vec(v: &[f64]) -> usize {
894    v.iter()
895        .enumerate()
896        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
897        .map(|(i, _)| i)
898        .unwrap_or(0)
899}
900
901// ─────────────────────────────────────────────────────────────────────────────
902// Tests
903// ─────────────────────────────────────────────────────────────────────────────
904
905#[cfg(test)]
906mod tests {
907    use super::*;
908    use approx::assert_abs_diff_eq;
909
910    // ── Helpers ────────────────────────────────────────────────────────────────
911
912    /// 2-class linear model with weights [[1,0],[0,1]] and zero bias.
913    fn identity_model_2x2() -> LinearAttackModel {
914        LinearAttackModel::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])
915            .expect("valid model")
916    }
917
918    fn default_config() -> AttackConfig {
919        AttackConfig::new(0.1).expect("valid epsilon")
920    }
921
922    // ── FGSM tests ─────────────────────────────────────────────────────────────
923
924    #[test]
925    fn test_fgsm_linf_norm_le_epsilon() {
926        let model = identity_model_2x2();
927        let loss = MseAttackLoss;
928        let input = vec![0.5, 0.5];
929        let labels = vec![1.0, 0.0];
930        let config = default_config();
931        let ex = fgsm(&model, &loss, &input, &labels, &config).expect("fgsm ok");
932        assert!(ex.perturbation_linf() <= config.epsilon + 1e-10);
933    }
934
935    #[test]
936    fn test_fgsm_changes_input_when_gradient_nonzero() {
937        let model = identity_model_2x2();
938        let loss = MseAttackLoss;
939        let input = vec![0.5, 0.3];
940        let labels = vec![1.0, 0.0]; // gradient is non-zero
941        let config = default_config();
942        let ex = fgsm(&model, &loss, &input, &labels, &config).expect("fgsm ok");
943        let norm: f64 = ex.perturbation.iter().map(|&d| d * d).sum::<f64>().sqrt();
944        assert!(norm > 1e-10, "perturbation should be non-zero");
945    }
946
947    #[test]
948    fn test_fgsm_zero_gradient_produces_zero_perturbation() {
949        let model = identity_model_2x2();
950        let loss = MseAttackLoss;
951        // labels == predictions → MSE grad = 0
952        let input = vec![0.5, 0.5];
953        let labels = vec![0.5, 0.5];
954        let config = default_config();
955        let ex = fgsm(&model, &loss, &input, &labels, &config).expect("fgsm ok");
956        assert_abs_diff_eq!(ex.perturbation_linf(), 0.0, epsilon = 1e-12);
957    }
958
959    #[test]
960    fn test_fgsm_all_dims_within_epsilon() {
961        let model = identity_model_2x2();
962        let loss = MseAttackLoss;
963        let input = vec![0.2, 0.8];
964        let labels = vec![0.0, 1.0];
965        let config = AttackConfig::new(0.05).expect("ok");
966        let ex = fgsm(&model, &loss, &input, &labels, &config).expect("fgsm ok");
967        for &d in &ex.perturbation {
968            assert!(d.abs() <= 0.05 + 1e-10, "component {d} exceeds epsilon");
969        }
970    }
971
972    // ── PGD tests ──────────────────────────────────────────────────────────────
973
974    #[test]
975    fn test_pgd_linf_norm_le_epsilon() {
976        let model = identity_model_2x2();
977        let loss = MseAttackLoss;
978        let input = vec![0.5, 0.5];
979        let labels = vec![1.0, 0.0];
980        let config = default_config();
981        let ex = pgd(&model, &loss, &input, &labels, &config, 42).expect("pgd ok");
982        assert!(ex.perturbation_linf() <= config.epsilon + 1e-10);
983    }
984
985    #[test]
986    fn test_pgd_n_steps_1_matches_fgsm_linf() {
987        let model = identity_model_2x2();
988        let loss = MseAttackLoss;
989        let input = vec![0.3, 0.7];
990        let labels = vec![1.0, 0.0];
991        let eps = 0.1_f64;
992        // Both should produce the same perturbation for a linear model (one step).
993        let config_fgsm = AttackConfig::new(eps)
994            .expect("ok")
995            .with_step_size(eps)
996            .expect("ok")
997            .with_steps(1)
998            .expect("ok");
999        let config_pgd = config_fgsm.clone();
1000        let ex_fgsm = fgsm(&model, &loss, &input, &labels, &config_fgsm).expect("ok");
1001        let ex_pgd = pgd(&model, &loss, &input, &labels, &config_pgd, 0).expect("ok");
1002        for (df, dp) in ex_fgsm.perturbation.iter().zip(ex_pgd.perturbation.iter()) {
1003            assert_abs_diff_eq!(df, dp, epsilon = 1e-10);
1004        }
1005    }
1006
1007    #[test]
1008    fn test_pgd_iterates_more_than_one() {
1009        let model = identity_model_2x2();
1010        let loss = MseAttackLoss;
1011        let input = vec![0.5, 0.5];
1012        let labels = vec![1.0, 0.0];
1013        let config = AttackConfig::new(0.1)
1014            .expect("ok")
1015            .with_steps(10)
1016            .expect("ok");
1017        let ex = pgd(&model, &loss, &input, &labels, &config, 7).expect("ok");
1018        assert_eq!(ex.n_iterations, 10);
1019    }
1020
1021    // ── Projection tests ───────────────────────────────────────────────────────
1022
1023    #[test]
1024    fn test_project_linf_clamps_each_dim() {
1025        let delta = vec![0.2, -0.3, 0.05, -0.01];
1026        let eps = 0.1;
1027        let proj = project_linf(&delta, eps);
1028        for &d in &proj {
1029            assert!(d >= -eps - 1e-12 && d <= eps + 1e-12);
1030        }
1031        assert_abs_diff_eq!(proj[0], 0.1, epsilon = 1e-10);
1032        assert_abs_diff_eq!(proj[1], -0.1, epsilon = 1e-10);
1033        assert_abs_diff_eq!(proj[2], 0.05, epsilon = 1e-10);
1034    }
1035
1036    #[test]
1037    fn test_project_l2_result_within_epsilon() {
1038        let delta = vec![0.3, 0.4]; // norm = 0.5
1039        let eps = 0.1;
1040        let proj = project_l2(&delta, eps);
1041        let norm: f64 = proj.iter().map(|&d| d * d).sum::<f64>().sqrt();
1042        assert!(norm <= eps + 1e-10, "L2 norm {norm} exceeds epsilon {eps}");
1043    }
1044
1045    #[test]
1046    fn test_project_l2_identity_when_within_ball() {
1047        let delta = vec![0.03, 0.04]; // norm = 0.05 < 0.1
1048        let eps = 0.1;
1049        let proj = project_l2(&delta, eps);
1050        assert_abs_diff_eq!(proj[0], 0.03, epsilon = 1e-10);
1051        assert_abs_diff_eq!(proj[1], 0.04, epsilon = 1e-10);
1052    }
1053
1054    // ── CrossEntropyAttackLoss tests ───────────────────────────────────────────
1055
1056    #[test]
1057    fn test_cross_entropy_grad_finite_difference() {
1058        let ce = CrossEntropyAttackLoss;
1059        let preds = vec![1.0, 0.5, -0.5];
1060        let labels = vec![1.0, 0.0, 0.0];
1061        let grad = ce.grad(&preds, &labels);
1062        let h = 1e-5_f64;
1063        for i in 0..preds.len() {
1064            let mut p_plus = preds.clone();
1065            let mut p_minus = preds.clone();
1066            p_plus[i] += h;
1067            p_minus[i] -= h;
1068            let fd = (ce.loss(&p_plus, &labels) - ce.loss(&p_minus, &labels)) / (2.0 * h);
1069            assert_abs_diff_eq!(grad[i], fd, epsilon = 1e-5);
1070        }
1071    }
1072
1073    // ── MseAttackLoss tests ────────────────────────────────────────────────────
1074
1075    #[test]
1076    fn test_mse_loss_zero_for_equal_predictions_and_labels() {
1077        let mse = MseAttackLoss;
1078        let v = vec![0.1, 0.5, -0.3];
1079        assert_abs_diff_eq!(mse.loss(&v, &v), 0.0, epsilon = 1e-12);
1080    }
1081
1082    #[test]
1083    fn test_mse_grad_zero_for_equal_predictions_and_labels() {
1084        let mse = MseAttackLoss;
1085        let v = vec![0.1, 0.5, -0.3];
1086        let grad = mse.grad(&v, &v);
1087        for &g in &grad {
1088            assert_abs_diff_eq!(g, 0.0, epsilon = 1e-12);
1089        }
1090    }
1091
1092    // ── LinearAttackModel tests ────────────────────────────────────────────────
1093
1094    #[test]
1095    fn test_linear_model_forward_correct_dimension() {
1096        let model = identity_model_2x2();
1097        let preds = model.forward(&[0.3, 0.7]);
1098        assert_eq!(preds.len(), 2);
1099    }
1100
1101    #[test]
1102    fn test_linear_model_forward_correct_values() {
1103        let model = identity_model_2x2();
1104        let preds = model.forward(&[0.3, 0.7]);
1105        assert_abs_diff_eq!(preds[0], 0.3, epsilon = 1e-12);
1106        assert_abs_diff_eq!(preds[1], 0.7, epsilon = 1e-12);
1107    }
1108
1109    #[test]
1110    fn test_linear_model_input_gradient_finite_difference() {
1111        // 3-output × 2-input model.
1112        let model = LinearAttackModel::new(
1113            vec![vec![2.0, -1.0], vec![0.5, 3.0], vec![-1.0, 1.0]],
1114            vec![0.0, 0.0, 0.0],
1115        )
1116        .expect("ok");
1117        let input = vec![0.4, 0.6];
1118        let out_grad = vec![1.0, 0.0, 0.0]; // select first output
1119        let analytic = model.input_gradient(&input, &out_grad);
1120        // Verify against numerical FD (default impl).
1121        let h = 1e-5_f64;
1122        for j in 0..input.len() {
1123            let mut ip = input.clone();
1124            let mut im = input.clone();
1125            ip[j] += h;
1126            im[j] -= h;
1127            let fd: f64 = model
1128                .forward(&ip)
1129                .iter()
1130                .zip(model.forward(&im).iter())
1131                .zip(out_grad.iter())
1132                .map(|((&fp, &fm), &g)| g * (fp - fm) / (2.0 * h))
1133                .sum();
1134            assert_abs_diff_eq!(analytic[j], fd, epsilon = 1e-6);
1135        }
1136    }
1137
1138    // ── AdversarialExample tests ───────────────────────────────────────────────
1139
1140    #[test]
1141    fn test_adversarial_example_perturbation_equals_diff() {
1142        let model = identity_model_2x2();
1143        let loss = MseAttackLoss;
1144        let input = vec![0.3, 0.7];
1145        let labels = vec![1.0, 0.0];
1146        let config = default_config();
1147        let ex = fgsm(&model, &loss, &input, &labels, &config).expect("ok");
1148        for (i, (&p, &o)) in ex.perturbed.iter().zip(ex.original.iter()).enumerate() {
1149            assert_abs_diff_eq!(ex.perturbation[i], p - o, epsilon = 1e-12);
1150        }
1151    }
1152
1153    #[test]
1154    fn test_adversarial_example_linf_le_epsilon() {
1155        let model = identity_model_2x2();
1156        let loss = MseAttackLoss;
1157        let input = vec![0.3, 0.7];
1158        let labels = vec![1.0, 0.0];
1159        let config = AttackConfig::new(0.05).expect("ok");
1160        let ex = fgsm(&model, &loss, &input, &labels, &config).expect("ok");
1161        assert!(ex.perturbation_linf() <= 0.05 + 1e-10);
1162    }
1163
1164    // ── AttackConfig validation tests ─────────────────────────────────────────
1165
1166    #[test]
1167    fn test_attack_config_negative_epsilon_is_error() {
1168        let result = AttackConfig::new(-0.1);
1169        assert!(
1170            matches!(result, Err(AdversarialError::InvalidEpsilon(_))),
1171            "expected InvalidEpsilon"
1172        );
1173    }
1174
1175    #[test]
1176    fn test_attack_config_zero_epsilon_is_error() {
1177        let result = AttackConfig::new(0.0);
1178        assert!(matches!(result, Err(AdversarialError::InvalidEpsilon(_))));
1179    }
1180
1181    #[test]
1182    fn test_attack_config_zero_steps_is_error() {
1183        let result = AttackConfig::new(0.1).expect("ok").with_steps(0);
1184        assert!(matches!(
1185            result,
1186            Err(AdversarialError::InvalidIterations(0))
1187        ));
1188    }
1189
1190    // ── adversarial_training_loss tests ───────────────────────────────────────
1191
1192    #[test]
1193    fn test_adversarial_training_loss_alpha_one_equals_clean_loss() {
1194        let model = identity_model_2x2();
1195        let loss = MseAttackLoss;
1196        let inputs = vec![vec![0.5_f64, 0.5_f64]];
1197        let labels = vec![vec![1.0_f64, 0.0_f64]];
1198        // 1 step PGD = FGSM-like; but alpha=1 should zero-out adv contribution.
1199        let config = AttackConfig::new(0.1)
1200            .expect("ok")
1201            .with_steps(1)
1202            .expect("ok");
1203        let (combined, stats) =
1204            adversarial_training_loss(&model, &loss, &inputs, &labels, &config, 1.0, 0)
1205                .expect("ok");
1206        assert_abs_diff_eq!(combined, stats.clean_loss, epsilon = 1e-10);
1207    }
1208
1209    #[test]
1210    fn test_adversarial_training_loss_alpha_zero_equals_adversarial_loss() {
1211        let model = identity_model_2x2();
1212        let loss = MseAttackLoss;
1213        let inputs = vec![vec![0.5_f64, 0.5_f64]];
1214        let labels = vec![vec![1.0_f64, 0.0_f64]];
1215        let config = AttackConfig::new(0.1)
1216            .expect("ok")
1217            .with_steps(1)
1218            .expect("ok");
1219        let (combined, stats) =
1220            adversarial_training_loss(&model, &loss, &inputs, &labels, &config, 0.0, 0)
1221                .expect("ok");
1222        assert_abs_diff_eq!(combined, stats.adversarial_loss, epsilon = 1e-10);
1223    }
1224
1225    // ── robustness_eval test ───────────────────────────────────────────────────
1226
1227    #[test]
1228    fn test_robustness_eval_result_in_0_1() {
1229        let model = identity_model_2x2();
1230        let inputs = vec![
1231            vec![0.9_f64, 0.1_f64], // predicts class 0
1232            vec![0.1_f64, 0.9_f64], // predicts class 1
1233        ];
1234        let labels = vec![vec![1.0_f64, 0.0_f64], vec![0.0_f64, 1.0_f64]];
1235        let config = AttackConfig::new(0.05)
1236            .expect("ok")
1237            .with_steps(5)
1238            .expect("ok");
1239        let frac = robustness_eval(&model, &inputs, &labels, &config, 42).expect("ok");
1240        assert!(
1241            (0.0..=1.0).contains(&frac),
1242            "robustness fraction {frac} out of range"
1243        );
1244    }
1245
1246    #[test]
1247    fn test_robustness_eval_empty_inputs() {
1248        let model = identity_model_2x2();
1249        let config = default_config();
1250        let frac = robustness_eval(&model, &[], &[], &config, 0).expect("ok");
1251        assert_abs_diff_eq!(frac, 1.0, epsilon = 1e-12);
1252    }
1253
1254    // ── AdversarialTrainStats tests ────────────────────────────────────────────
1255
1256    #[test]
1257    fn test_adversarial_train_stats_n_samples() {
1258        let model = identity_model_2x2();
1259        let loss = MseAttackLoss;
1260        let inputs = vec![
1261            vec![0.5_f64, 0.5_f64],
1262            vec![0.2_f64, 0.8_f64],
1263            vec![0.7_f64, 0.3_f64],
1264        ];
1265        let labels = vec![
1266            vec![1.0_f64, 0.0_f64],
1267            vec![0.0_f64, 1.0_f64],
1268            vec![1.0_f64, 0.0_f64],
1269        ];
1270        let config = AttackConfig::new(0.1)
1271            .expect("ok")
1272            .with_steps(2)
1273            .expect("ok");
1274        let (_, stats) =
1275            adversarial_training_loss(&model, &loss, &inputs, &labels, &config, 0.5, 1)
1276                .expect("ok");
1277        assert_eq!(stats.n_samples, 3);
1278        assert!(stats.mean_perturbation_norm >= 0.0);
1279    }
1280
1281    #[test]
1282    fn test_adversarial_train_stats_combined_loss_between_clean_and_adv() {
1283        let model = identity_model_2x2();
1284        let loss = MseAttackLoss;
1285        let inputs = vec![vec![0.5_f64, 0.5_f64]];
1286        let labels = vec![vec![1.0_f64, 0.0_f64]];
1287        let config = AttackConfig::new(0.1)
1288            .expect("ok")
1289            .with_steps(3)
1290            .expect("ok");
1291        let alpha = 0.5;
1292        let (combined, stats) =
1293            adversarial_training_loss(&model, &loss, &inputs, &labels, &config, alpha, 99)
1294                .expect("ok");
1295        let expected = alpha * stats.clean_loss + (1.0 - alpha) * stats.adversarial_loss;
1296        assert_abs_diff_eq!(combined, expected, epsilon = 1e-10);
1297    }
1298
1299    // ── Additional coverage ────────────────────────────────────────────────────
1300
1301    #[test]
1302    fn test_pgd_random_start_stays_within_epsilon() {
1303        let model = identity_model_2x2();
1304        let loss = MseAttackLoss;
1305        let input = vec![0.5_f64, 0.5_f64];
1306        let labels = vec![1.0_f64, 0.0_f64];
1307        let config = AttackConfig::new(0.1)
1308            .expect("ok")
1309            .with_steps(5)
1310            .expect("ok")
1311            .with_random_start(true);
1312        let ex = pgd(&model, &loss, &input, &labels, &config, 12345).expect("ok");
1313        assert!(ex.perturbation_linf() <= 0.1 + 1e-10);
1314    }
1315
1316    #[test]
1317    fn test_fgsm_l2_norm_attack() {
1318        let model = identity_model_2x2();
1319        let loss = MseAttackLoss;
1320        let input = vec![0.3, 0.7];
1321        let labels = vec![0.0, 1.0];
1322        let config = AttackConfig::new(0.1)
1323            .expect("ok")
1324            .with_norm(PerturbNorm::L2);
1325        let ex = fgsm(&model, &loss, &input, &labels, &config).expect("ok");
1326        assert!(ex.perturbation_l2() <= 0.1 + 1e-10);
1327    }
1328
1329    #[test]
1330    fn test_fgsm_l1_norm_attack_single_nonzero_component() {
1331        let model = identity_model_2x2();
1332        let loss = MseAttackLoss;
1333        let input = vec![0.3, 0.7];
1334        let labels = vec![1.0, 0.0];
1335        let config = AttackConfig::new(0.1)
1336            .expect("ok")
1337            .with_norm(PerturbNorm::L1);
1338        let ex = fgsm(&model, &loss, &input, &labels, &config).expect("ok");
1339        // L1 FGSM puts all budget on one coordinate.
1340        let nonzero: Vec<f64> = ex
1341            .perturbation
1342            .iter()
1343            .cloned()
1344            .filter(|&d| d.abs() > 1e-12)
1345            .collect();
1346        assert_eq!(
1347            nonzero.len(),
1348            1,
1349            "L1 FGSM should perturb exactly one dimension"
1350        );
1351    }
1352
1353    #[test]
1354    fn test_linear_model_construction_invalid_bias_len() {
1355        let result = LinearAttackModel::new(
1356            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
1357            vec![0.0], // wrong length
1358        );
1359        assert!(result.is_err());
1360    }
1361}