Skip to main content

wafrift_evolution/
types.rs

1//! Core types for the evolution engine.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5use std::time::{Duration, Instant};
6
7/// Rich oracle verdict providing gradient signals for fitness.
8#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
9pub struct OracleVerdict {
10    /// Whether the payload passed the WAF.
11    pub passed: bool,
12    /// Delta from baseline response status code.
13    pub status_delta: i16,
14    /// Delta from baseline response body size.
15    pub body_delta: i32,
16    /// Response latency in milliseconds.
17    pub latency_ms: u32,
18    /// Oracle confidence (0.0–1.0).
19    pub confidence: f64,
20    /// Number of WAF rules triggered.
21    pub triggered_rules: u32,
22}
23
24/// Penalty per triggered WAF rule in fitness calculation.
25const RULE_PENALTY_PER_RULE: f64 = 0.05;
26/// Maximum rule-based penalty (caps at 6 rules).
27const MAX_RULE_PENALTY: f64 = 0.3;
28/// Reference latency in ms for normalising the latency penalty.
29const LATENCY_REFERENCE_MS: f64 = 5000.0;
30/// Maximum latency-based penalty.
31const MAX_LATENCY_PENALTY: f64 = 0.1;
32/// Reference body-size delta in bytes for normalising the body penalty.
33const BODY_DELTA_REFERENCE: f64 = 10000.0;
34/// Maximum body-delta-based penalty.
35const MAX_BODY_PENALTY: f64 = 0.1;
36/// Maximum partial-credit pool for a non-passing verdict.
37const MAX_PARTIAL_CREDIT: f64 = 0.3;
38/// Confidence bonus multiplier.
39const CONFIDENCE_BONUS_MULTIPLIER: f64 = 0.05;
40
41impl OracleVerdict {
42    /// Create a binary pass/fail verdict.
43    #[must_use]
44    pub fn from_bool(passed: bool) -> Self {
45        Self {
46            passed,
47            status_delta: 0,
48            body_delta: 0,
49            latency_ms: 0,
50            confidence: 1.0,
51            triggered_rules: if passed { 0 } else { 1 },
52        }
53    }
54
55    /// Compute a scalar fitness from the rich verdict.
56    ///
57    /// Rewards partial progress: fewer triggered rules, lower latency,
58    /// smaller body delta, and high oracle confidence.
59    #[must_use]
60    pub fn to_fitness(&self) -> f64 {
61        let base = if self.passed { 1.0 } else { 0.0 };
62        let partial = if self.passed {
63            0.0
64        } else {
65            // Partial credit for fewer triggered rules, faster response
66            let rule_penalty =
67                (self.triggered_rules as f64 * RULE_PENALTY_PER_RULE).min(MAX_RULE_PENALTY);
68            let latency_penalty =
69                (self.latency_ms as f64 / LATENCY_REFERENCE_MS).min(MAX_LATENCY_PENALTY);
70            let body_penalty = (self.body_delta.abs() as f64 / BODY_DELTA_REFERENCE)
71                .min(MAX_BODY_PENALTY);
72            MAX_PARTIAL_CREDIT - rule_penalty - latency_penalty - body_penalty
73        };
74        let confidence_bonus = self.confidence * CONFIDENCE_BONUS_MULTIPLIER;
75        (base + partial + confidence_bonus).clamp(0.0, 1.0)
76    }
77}
78
79impl Default for OracleVerdict {
80    fn default() -> Self {
81        Self::from_bool(false)
82    }
83}
84
85/// Feedback from evaluating a candidate.
86#[derive(Debug, Clone, PartialEq)]
87pub enum Feedback {
88    /// Payload passed the WAF.
89    Passed,
90    /// Payload was blocked.
91    Blocked,
92    /// Target returned an error (5xx, timeout, etc.).
93    TargetError(String),
94}
95
96impl Feedback {
97    /// Convert feedback to an oracle verdict with default metadata.
98    #[must_use]
99    pub fn to_verdict(&self) -> OracleVerdict {
100        match self {
101            Self::Passed => OracleVerdict::from_bool(true),
102            Self::Blocked => OracleVerdict::from_bool(false),
103            Self::TargetError(_) => OracleVerdict {
104                passed: false,
105                status_delta: 500,
106                body_delta: 0,
107                latency_ms: 0,
108                confidence: 0.0,
109                triggered_rules: 0,
110            },
111        }
112    }
113}
114
115/// Hard budget limits for the search.
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
117pub struct Budget {
118    /// Maximum total oracle evaluations (requests).
119    pub max_requests: usize,
120    /// Maximum generations.
121    pub max_generations: u32,
122    /// Maximum time in seconds.
123    pub max_time_seconds: u64,
124    /// Early-termination stagnation threshold (generations with no improvement).
125    pub stagnation_limit: u32,
126}
127
128impl Budget {
129    /// Default conservative budget.
130    #[must_use]
131    pub fn default_wafrift() -> Self {
132        Self {
133            max_requests: 10_000,
134            max_generations: 200,
135            max_time_seconds: 3_600,
136            stagnation_limit: 10,
137        }
138    }
139}
140
141impl Default for Budget {
142    fn default() -> Self {
143        Self::default_wafrift()
144    }
145}
146
147/// Errors that can occur in the evolution engine.
148#[derive(Debug, thiserror::Error)]
149pub enum EvolutionError {
150    #[error("invalid chromosome index: {0}")]
151    InvalidChromosomeIndex(usize),
152    #[error("budget exhausted: {0}")]
153    BudgetExhausted(String),
154    #[error("target health critical: {0}")]
155    TargetHealthCritical(String),
156    #[error("serialization failed: {0}")]
157    SerializationFailed(#[source] serde_json::Error),
158    #[error("deserialization failed: {0}")]
159    DeserializationFailed(#[source] serde_json::Error),
160    #[error("io error: {0}")]
161    Io(#[from] std::io::Error),
162    #[error("search algorithm error: {0}")]
163    AlgorithmError(String),
164    #[error("data exceeds size limit: {context} ({size} bytes, max {max})")]
165    OversizedData {
166        context: String,
167        size: usize,
168        max: usize,
169    },
170}
171
172/// Reason for terminating evolution.
173#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
174pub enum TerminationReason {
175    BudgetExhausted,
176    MaxGenerationsReached,
177    TimeLimitReached,
178    StagnationLimitReached,
179    TargetHealthCritical,
180    BypassFound,
181}
182
183/// Action emitted by the intelligence loop state machine.
184#[derive(Debug, Clone, PartialEq)]
185pub enum LoopAction {
186    /// Evaluate a differential probe.
187    SendProbe(crate::differential::Probe),
188    /// Evaluate an evolved payload.
189    SendPayload(crate::evolution::Chromosome),
190    /// Save checkpoint to disk.
191    SaveCheckpoint,
192    /// Terminate the loop.
193    Terminate(TerminationReason),
194}
195
196/// Target health monitor with exponential backoff.
197#[derive(Debug, Clone)]
198pub struct TargetHealthMonitor {
199    consecutive_errors: u32,
200    last_error: Option<Instant>,
201    backoff_seconds: u64,
202    max_backoff_seconds: u64,
203    error_threshold: u32,
204}
205
206impl TargetHealthMonitor {
207    #[must_use]
208    pub fn new() -> Self {
209        Self {
210            consecutive_errors: 0,
211            last_error: None,
212            backoff_seconds: 1,
213            max_backoff_seconds: 300,
214            error_threshold: 5,
215        }
216    }
217
218    /// Record a target error.
219    pub fn record_error(&mut self) {
220        self.consecutive_errors += 1;
221        self.last_error = Some(Instant::now());
222        self.backoff_seconds = (self.backoff_seconds * 2).min(self.max_backoff_seconds);
223    }
224
225    /// Record a successful request.
226    pub fn record_success(&mut self) {
227        self.consecutive_errors = 0;
228        self.backoff_seconds = 1;
229    }
230
231    /// Check if the target is considered healthy.
232    #[must_use]
233    pub fn is_healthy(&self) -> bool {
234        self.consecutive_errors < self.error_threshold
235    }
236
237    /// Current backoff duration.
238    #[must_use]
239    pub fn backoff(&self) -> Duration {
240        Duration::from_secs(self.backoff_seconds)
241    }
242
243    /// Whether we are currently in an active backoff period.
244    #[must_use]
245    pub fn in_backoff(&self) -> bool {
246        self.last_error
247            .is_some_and(|t| t.elapsed() < self.backoff())
248    }
249}
250
251impl Default for TargetHealthMonitor {
252    fn default() -> Self {
253        Self::new()
254    }
255}
256
257/// Search statistics passed to algorithms for termination decisions.
258#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
259pub struct SearchStats {
260    pub generation: u32,
261    pub evaluations: usize,
262    pub best_fitness: f64,
263    pub stagnation_counter: u32,
264    #[serde(skip, default = "Instant::now")]
265    pub start_time: Instant,
266    pub start_time_system: std::time::SystemTime,
267}
268
269impl SearchStats {
270    pub fn new() -> Self {
271        Self {
272            generation: 0,
273            evaluations: 0,
274            best_fitness: 0.0,
275            stagnation_counter: 0,
276            start_time: Instant::now(),
277            start_time_system: std::time::SystemTime::now(),
278        }
279    }
280
281    pub fn fixup_start_time(&mut self) {
282        if let Ok(elapsed) = self.start_time_system.elapsed() {
283            self.start_time = Instant::now()
284                .checked_sub(elapsed)
285                .unwrap_or_else(Instant::now);
286        }
287    }
288}
289
290impl Default for SearchStats {
291    fn default() -> Self {
292        Self::new()
293    }
294}
295
296/// Deduplication helpers.
297#[derive(Debug, Clone)]
298pub struct Deduper {
299    seen: HashSet<u64>,
300}
301
302impl Deduper {
303    #[must_use]
304    pub fn new() -> Self {
305        Self {
306            seen: HashSet::new(),
307        }
308    }
309
310    /// Compute a hash for a chromosome based on its genes.
311    #[must_use]
312    pub fn hash_chromosome(chromosome: &crate::evolution::Chromosome) -> u64 {
313        use std::collections::hash_map::DefaultHasher;
314        use std::hash::{Hash, Hasher};
315        let mut hasher = DefaultHasher::new();
316        for (name, value) in &chromosome.genes {
317            name.hash(&mut hasher);
318            value.hash(&mut hasher);
319        }
320        hasher.finish()
321    }
322
323    /// Check if this chromosome has been seen before.
324    #[must_use]
325    pub fn is_duplicate(&self, chromosome: &crate::evolution::Chromosome) -> bool {
326        self.seen.contains(&Self::hash_chromosome(chromosome))
327    }
328
329    /// Mark a chromosome as seen.
330    pub fn insert(&mut self, chromosome: &crate::evolution::Chromosome) {
331        self.seen.insert(Self::hash_chromosome(chromosome));
332    }
333
334    /// Insert multiple chromosomes.
335    pub fn insert_many(&mut self, chromosomes: &[crate::evolution::Chromosome]) {
336        for c in chromosomes {
337            self.insert(c);
338        }
339    }
340}
341
342impl Default for Deduper {
343    fn default() -> Self {
344        Self::new()
345    }
346}
347
348/// Maximum checkpoint file size (bytes). Prevents OOM from
349/// maliciously large checkpoint files.
350pub(crate) const MAX_CHECKPOINT_BYTES: usize = 512 * 1024 * 1024;
351
352/// Checkpoint persistence helpers.
353pub fn save_checkpoint(
354    path: &std::path::Path,
355    data: &impl Serialize,
356) -> Result<(), EvolutionError> {
357    let json = serde_json::to_string_pretty(data).map_err(EvolutionError::SerializationFailed)?;
358    if json.len() > MAX_CHECKPOINT_BYTES {
359        return Err(EvolutionError::OversizedData {
360            context: format!("checkpoint {}", path.display()),
361            size: json.len(),
362            max: MAX_CHECKPOINT_BYTES,
363        });
364    }
365    std::fs::write(path, json)?;
366    Ok(())
367}
368
369/// Load a checkpoint from disk.
370pub fn load_checkpoint<T: for<'de> Deserialize<'de>>(
371    path: &std::path::Path,
372) -> Result<T, EvolutionError> {
373    let meta = std::fs::metadata(path)?;
374    let len = meta.len() as usize;
375    if len > MAX_CHECKPOINT_BYTES {
376        return Err(EvolutionError::OversizedData {
377            context: format!("checkpoint {}", path.display()),
378            size: len,
379            max: MAX_CHECKPOINT_BYTES,
380        });
381    }
382    let json = std::fs::read_to_string(path)?;
383    serde_json::from_str(&json).map_err(EvolutionError::DeserializationFailed)
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use std::time::Duration;
390
391    #[test]
392    fn oracle_verdict_from_bool_true() {
393        let v = OracleVerdict::from_bool(true);
394        assert!(v.passed);
395        assert_eq!(v.triggered_rules, 0);
396        assert_eq!(v.confidence, 1.0);
397    }
398
399    #[test]
400    fn oracle_verdict_from_bool_false() {
401        let v = OracleVerdict::from_bool(false);
402        assert!(!v.passed);
403        assert_eq!(v.triggered_rules, 1);
404    }
405
406    #[test]
407    fn oracle_verdict_fitness_passed_is_one() {
408        let v = OracleVerdict::from_bool(true);
409        // clamped to 1.0 (1.0 base + 0.05 confidence bonus)
410        assert_eq!(v.to_fitness(), 1.0);
411    }
412
413    #[test]
414    fn oracle_verdict_fitness_blocked_penalizes_rules() {
415        let v = OracleVerdict {
416            passed: false,
417            triggered_rules: 5,
418            confidence: 1.0,
419            ..Default::default()
420        };
421        // 0.3 - 0.25 - 0 - 0 + 0.05 = 0.10
422        assert!((v.to_fitness() - 0.10).abs() < 0.01);
423    }
424
425    #[test]
426    fn feedback_to_verdict_passed() {
427        assert!(Feedback::Passed.to_verdict().passed);
428    }
429
430    #[test]
431    fn feedback_to_verdict_target_error() {
432        let v = Feedback::TargetError("timeout".into()).to_verdict();
433        assert!(!v.passed);
434        assert_eq!(v.status_delta, 500);
435        assert_eq!(v.confidence, 0.0);
436    }
437
438    #[test]
439    fn budget_default_wafrift_values() {
440        let b = Budget::default_wafrift();
441        assert_eq!(b.max_requests, 10_000);
442        assert_eq!(b.max_generations, 200);
443        assert_eq!(b.max_time_seconds, 3_600);
444        assert_eq!(b.stagnation_limit, 10);
445    }
446
447    #[test]
448    fn target_health_monitor_starts_healthy() {
449        let h = TargetHealthMonitor::new();
450        assert!(h.is_healthy());
451        assert!(!h.in_backoff());
452        assert_eq!(h.backoff(), Duration::from_secs(1));
453    }
454
455    #[test]
456    fn target_health_monitor_records_errors() {
457        let mut h = TargetHealthMonitor::new();
458        for _ in 0..4 {
459            h.record_error();
460        }
461        assert!(h.is_healthy());
462        assert_eq!(h.backoff(), Duration::from_secs(16));
463        h.record_error();
464        assert!(!h.is_healthy());
465    }
466
467    #[test]
468    fn target_health_monitor_resets_on_success() {
469        let mut h = TargetHealthMonitor::new();
470        h.record_error();
471        h.record_error();
472        h.record_success();
473        assert!(h.is_healthy());
474        assert_eq!(h.backoff(), Duration::from_secs(1));
475    }
476
477    #[test]
478    fn deduper_detects_duplicates() {
479        use crate::evolution::Chromosome;
480        let c1 = Chromosome::new(vec![("a".into(), "1".into())]);
481        let c2 = Chromosome::new(vec![("a".into(), "1".into())]);
482        let c3 = Chromosome::new(vec![("a".into(), "2".into())]);
483
484        let mut d = Deduper::new();
485        assert!(!d.is_duplicate(&c1));
486        d.insert(&c1);
487        assert!(d.is_duplicate(&c2));
488        assert!(!d.is_duplicate(&c3));
489    }
490
491    #[test]
492    fn deduper_insert_many() {
493        use crate::evolution::Chromosome;
494        let c1 = Chromosome::new(vec![("a".into(), "1".into())]);
495        let c2 = Chromosome::new(vec![("b".into(), "2".into())]);
496        let mut d = Deduper::new();
497        d.insert_many(&[c1.clone(), c2.clone()]);
498        assert!(d.is_duplicate(&c1));
499        assert!(d.is_duplicate(&c2));
500    }
501
502    #[test]
503    fn deduper_hash_consistent() {
504        use crate::evolution::Chromosome;
505        let c = Chromosome::new(vec![("x".into(), "y".into())]);
506        let h1 = Deduper::hash_chromosome(&c);
507        let h2 = Deduper::hash_chromosome(&c);
508        assert_eq!(h1, h2);
509    }
510}