Skip to main content

wafrift_evolution/
types.rs

1//! Core types for the evolution engine.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5use std::time::{Duration, Instant};
6
7/// Rich oracle verdict providing gradient signals for fitness.
8#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
9pub struct OracleVerdict {
10    /// Whether the payload passed the WAF.
11    pub passed: bool,
12    /// Delta from baseline response status code.
13    pub status_delta: i16,
14    /// Delta from baseline response body size.
15    pub body_delta: i32,
16    /// Response latency in milliseconds.
17    pub latency_ms: u32,
18    /// Oracle confidence (0.0–1.0).
19    pub confidence: f64,
20    /// Number of WAF rules triggered.
21    pub triggered_rules: u32,
22}
23
24impl OracleVerdict {
25    /// Create a binary pass/fail verdict.
26    #[must_use]
27    pub fn from_bool(passed: bool) -> Self {
28        Self {
29            passed,
30            status_delta: 0,
31            body_delta: 0,
32            latency_ms: 0,
33            confidence: 1.0,
34            triggered_rules: if passed { 0 } else { 1 },
35        }
36    }
37
38    /// Compute a scalar fitness from the rich verdict.
39    ///
40    /// Rewards partial progress: fewer triggered rules, lower latency,
41    /// smaller body delta, and high oracle confidence.
42    #[must_use]
43    pub fn to_fitness(&self) -> f64 {
44        let base = if self.passed { 1.0 } else { 0.0 };
45        let partial = if self.passed {
46            0.0
47        } else {
48            // Partial credit for fewer triggered rules, faster response
49            let rule_penalty = (self.triggered_rules as f64 * 0.05).min(0.3);
50            let latency_penalty = (self.latency_ms as f64 / 5000.0).min(0.1);
51            let body_penalty = (self.body_delta.abs() as f64 / 10000.0).min(0.1);
52            0.3 - rule_penalty - latency_penalty - body_penalty
53        };
54        let confidence_bonus = self.confidence * 0.05;
55        (base + partial + confidence_bonus).clamp(0.0, 1.0)
56    }
57}
58
59impl Default for OracleVerdict {
60    fn default() -> Self {
61        Self::from_bool(false)
62    }
63}
64
65/// Feedback from evaluating a candidate.
66#[derive(Debug, Clone, PartialEq)]
67pub enum Feedback {
68    /// Payload passed the WAF.
69    Passed,
70    /// Payload was blocked.
71    Blocked,
72    /// Target returned an error (5xx, timeout, etc.).
73    TargetError(String),
74}
75
76impl Feedback {
77    /// Convert feedback to an oracle verdict with default metadata.
78    #[must_use]
79    pub fn to_verdict(&self) -> OracleVerdict {
80        match self {
81            Self::Passed => OracleVerdict::from_bool(true),
82            Self::Blocked => OracleVerdict::from_bool(false),
83            Self::TargetError(_) => OracleVerdict {
84                passed: false,
85                status_delta: 500,
86                body_delta: 0,
87                latency_ms: 0,
88                confidence: 0.0,
89                triggered_rules: 0,
90            },
91        }
92    }
93}
94
95/// Hard budget limits for the search.
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
97pub struct Budget {
98    /// Maximum total oracle evaluations (requests).
99    pub max_requests: usize,
100    /// Maximum generations.
101    pub max_generations: u32,
102    /// Maximum time in seconds.
103    pub max_time_seconds: u64,
104    /// Early-termination stagnation threshold (generations with no improvement).
105    pub stagnation_limit: u32,
106}
107
108impl Budget {
109    /// Default conservative budget.
110    #[must_use]
111    pub fn default_wafrift() -> Self {
112        Self {
113            max_requests: 10_000,
114            max_generations: 200,
115            max_time_seconds: 3_600,
116            stagnation_limit: 10,
117        }
118    }
119}
120
121impl Default for Budget {
122    fn default() -> Self {
123        Self::default_wafrift()
124    }
125}
126
127/// Errors that can occur in the evolution engine.
128#[derive(Debug, thiserror::Error)]
129pub enum EvolutionError {
130    #[error("invalid chromosome index: {0}")]
131    InvalidChromosomeIndex(usize),
132    #[error("budget exhausted: {0}")]
133    BudgetExhausted(String),
134    #[error("target health critical: {0}")]
135    TargetHealthCritical(String),
136    #[error("serialization failed: {0}")]
137    SerializationFailed(String),
138    #[error("deserialization failed: {0}")]
139    DeserializationFailed(String),
140    #[error("search algorithm error: {0}")]
141    AlgorithmError(String),
142}
143
144/// Reason for terminating evolution.
145#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
146pub enum TerminationReason {
147    BudgetExhausted,
148    MaxGenerationsReached,
149    TimeLimitReached,
150    StagnationLimitReached,
151    TargetHealthCritical,
152    BypassFound,
153}
154
155/// Action emitted by the intelligence loop state machine.
156#[derive(Debug, Clone, PartialEq)]
157pub enum LoopAction {
158    /// Evaluate a differential probe.
159    SendProbe(crate::differential::Probe),
160    /// Evaluate an evolved payload.
161    SendPayload(crate::evolution::Chromosome),
162    /// Save checkpoint to disk.
163    SaveCheckpoint,
164    /// Terminate the loop.
165    Terminate(TerminationReason),
166}
167
168/// Target health monitor with exponential backoff.
169#[derive(Debug, Clone)]
170pub struct TargetHealthMonitor {
171    consecutive_errors: u32,
172    last_error: Option<Instant>,
173    backoff_seconds: u64,
174    max_backoff_seconds: u64,
175    error_threshold: u32,
176}
177
178impl TargetHealthMonitor {
179    #[must_use]
180    pub fn new() -> Self {
181        Self {
182            consecutive_errors: 0,
183            last_error: None,
184            backoff_seconds: 1,
185            max_backoff_seconds: 300,
186            error_threshold: 5,
187        }
188    }
189
190    /// Record a target error.
191    pub fn record_error(&mut self) {
192        self.consecutive_errors += 1;
193        self.last_error = Some(Instant::now());
194        self.backoff_seconds = (self.backoff_seconds * 2).min(self.max_backoff_seconds);
195    }
196
197    /// Record a successful request.
198    pub fn record_success(&mut self) {
199        self.consecutive_errors = 0;
200        self.backoff_seconds = 1;
201    }
202
203    /// Check if the target is considered healthy.
204    #[must_use]
205    pub fn is_healthy(&self) -> bool {
206        self.consecutive_errors < self.error_threshold
207    }
208
209    /// Current backoff duration.
210    #[must_use]
211    pub fn backoff(&self) -> Duration {
212        Duration::from_secs(self.backoff_seconds)
213    }
214
215    /// Whether we are currently in an active backoff period.
216    #[must_use]
217    pub fn in_backoff(&self) -> bool {
218        self.last_error
219            .is_some_and(|t| t.elapsed() < self.backoff())
220    }
221}
222
223impl Default for TargetHealthMonitor {
224    fn default() -> Self {
225        Self::new()
226    }
227}
228
229/// Search statistics passed to algorithms for termination decisions.
230#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
231pub struct SearchStats {
232    pub generation: u32,
233    pub evaluations: usize,
234    pub best_fitness: f64,
235    pub stagnation_counter: u32,
236    #[serde(skip, default = "Instant::now")]
237    pub start_time: Instant,
238    pub start_time_system: std::time::SystemTime,
239}
240
241impl SearchStats {
242    pub fn new() -> Self {
243        Self {
244            generation: 0,
245            evaluations: 0,
246            best_fitness: 0.0,
247            stagnation_counter: 0,
248            start_time: Instant::now(),
249            start_time_system: std::time::SystemTime::now(),
250        }
251    }
252
253    pub fn fixup_start_time(&mut self) {
254        if let Ok(elapsed) = self.start_time_system.elapsed() {
255            self.start_time = Instant::now()
256                .checked_sub(elapsed)
257                .unwrap_or_else(Instant::now);
258        }
259    }
260}
261
262impl Default for SearchStats {
263    fn default() -> Self {
264        Self::new()
265    }
266}
267
268/// Deduplication helpers.
269#[derive(Debug, Clone)]
270pub struct Deduper {
271    seen: HashSet<u64>,
272}
273
274impl Deduper {
275    #[must_use]
276    pub fn new() -> Self {
277        Self {
278            seen: HashSet::new(),
279        }
280    }
281
282    /// Compute a hash for a chromosome based on its genes.
283    #[must_use]
284    pub fn hash_chromosome(chromosome: &crate::evolution::Chromosome) -> u64 {
285        use std::collections::hash_map::DefaultHasher;
286        use std::hash::{Hash, Hasher};
287        let mut hasher = DefaultHasher::new();
288        for (name, value) in &chromosome.genes {
289            name.hash(&mut hasher);
290            value.hash(&mut hasher);
291        }
292        hasher.finish()
293    }
294
295    /// Check if this chromosome has been seen before.
296    #[must_use]
297    pub fn is_duplicate(&self, chromosome: &crate::evolution::Chromosome) -> bool {
298        self.seen.contains(&Self::hash_chromosome(chromosome))
299    }
300
301    /// Mark a chromosome as seen.
302    pub fn insert(&mut self, chromosome: &crate::evolution::Chromosome) {
303        self.seen.insert(Self::hash_chromosome(chromosome));
304    }
305
306    /// Insert multiple chromosomes.
307    pub fn insert_many(&mut self, chromosomes: &[crate::evolution::Chromosome]) {
308        for c in chromosomes {
309            self.insert(c);
310        }
311    }
312}
313
314impl Default for Deduper {
315    fn default() -> Self {
316        Self::new()
317    }
318}
319
320/// Checkpoint persistence helpers.
321pub fn save_checkpoint(
322    path: &std::path::Path,
323    data: &impl Serialize,
324) -> Result<(), EvolutionError> {
325    let json = serde_json::to_string_pretty(data)
326        .map_err(|e| EvolutionError::SerializationFailed(e.to_string()))?;
327    std::fs::write(path, json).map_err(|e| EvolutionError::SerializationFailed(e.to_string()))?;
328    Ok(())
329}
330
331/// Load a checkpoint from disk.
332pub fn load_checkpoint<T: for<'de> Deserialize<'de>>(
333    path: &std::path::Path,
334) -> Result<T, EvolutionError> {
335    let json = std::fs::read_to_string(path)
336        .map_err(|e| EvolutionError::DeserializationFailed(e.to_string()))?;
337    serde_json::from_str(&json).map_err(|e| EvolutionError::DeserializationFailed(e.to_string()))
338}