Skip to main content

wafrift_evolution/
coverage_feedback.rs

1//! WAF rule-coverage feedback for MAP-Elites quality-diversity search.
2//!
3//! When the bench fires against a ModSec-fronted target, the response body
4//! may contain the specific CRS `rule_id` that fired (parsed by
5//! `wafrift_oracle::signal_body_marker::BlockReason::RuleId`).  This module
6//! turns that signal into a 2-D MAP-Elites *behavior descriptor*:
7//!
8//! ```text
9//!  (PayloadClass, Option<RuleId>)
10//! ```
11//!
12//! The grid cell is `(attack-class × rule-id)`.  When a cell is
13//! undiscovered the mutation strategy can target it deliberately, so
14//! bypasses are found ACROSS the rule corpus rather than concentrated on
15//! the rules the engine accidentally hits first.
16//!
17//! # Usage
18//!
19//! ```
20//! use wafrift_evolution::coverage_feedback::{
21//!     RuleCoverage, PayloadClass, RuleId, map_elites_descriptor,
22//! };
23//!
24//! let mut cov = RuleCoverage::default();
25//! let desc = map_elites_descriptor("' OR 1=1--", Some("942100"));
26//! cov.record("' OR 1=1--", desc.1.as_ref().map(|r| r.0.as_str()));
27//!
28//! let report = cov.coverage_report();
29//! assert!(!report.is_empty());
30//! ```
31
32use serde::{Deserialize, Serialize};
33use std::collections::{BTreeMap, BTreeSet};
34
35// ── Types ─────────────────────────────────────────────────────────────────────
36
37/// The attack-class dimension of the MAP-Elites grid.
38///
39/// Derived from the payload content (or from the bench case's `class` field
40/// if the caller has it).  Comparison is case-insensitive; values are stored
41/// as lower-case canonical strings so cells are stable across equivalent
42/// representations.
43#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
44pub struct PayloadClass(pub String);
45
46impl PayloadClass {
47    /// Construct from an arbitrary string.  The value is lower-cased and
48    /// stripped of leading/trailing whitespace so grid cells are stable.
49    #[must_use]
50    pub fn new(raw: &str) -> Self {
51        Self(raw.trim().to_ascii_lowercase())
52    }
53
54    /// Classify a raw payload string heuristically.
55    ///
56    /// The classifier is intentionally lightweight — it looks for the
57    /// strongest textual signal in the payload rather than doing full
58    /// parse-tree analysis.  The categories match the wafrift bench corpus
59    /// class identifiers so coverage reports are directly comparable.
60    #[must_use]
61    pub fn from_payload(payload: &str) -> Self {
62        let lower = payload.to_ascii_lowercase();
63        if lower.contains("select")
64            || lower.contains("union")
65            || lower.contains("insert")
66            || lower.contains("update")
67            || lower.contains("delete")
68            || lower.contains("drop")
69            || lower.contains("' or ")
70            || lower.contains("or 1=1")
71        {
72            return Self::new("sql");
73        }
74        if lower.contains("<script")
75            || lower.contains("onerror")
76            || lower.contains("onload")
77            || lower.contains("javascript:")
78            || lower.contains("alert(")
79        {
80            return Self::new("xss");
81        }
82        if lower.contains("../")
83            || lower.contains("..\\")
84            || lower.contains("%2e%2e")
85            || lower.contains("etc/passwd")
86        {
87            return Self::new("path");
88        }
89        if lower.contains("$(")
90            || lower.contains("`")
91            || lower.contains("|bash")
92            || lower.contains("cmd.exe")
93            || lower.contains("/bin/sh")
94        {
95            return Self::new("cmdi");
96        }
97        if lower.contains("{{")
98            || lower.contains("{%")
99            || lower.contains("#{")
100            || lower.contains("${'")
101        {
102            return Self::new("ssti");
103        }
104        if lower.contains("ldap://") || lower.contains("(uid=") || lower.contains("(cn=") {
105            return Self::new("ldap");
106        }
107        if lower.contains("http://") || lower.contains("https://") || lower.contains("ssrf") {
108            return Self::new("ssrf");
109        }
110        if lower.contains("<!entity") || lower.contains("<!doctype") || lower.contains("xxe") {
111            return Self::new("xxe");
112        }
113        if lower.contains("${jndi:") || lower.contains("log4j") || lower.contains("log4shell") {
114            return Self::new("log4shell");
115        }
116        Self::new("unknown")
117    }
118
119    /// The canonical string representation.
120    #[must_use]
121    pub fn as_str(&self) -> &str {
122        &self.0
123    }
124}
125
126/// A CRS / WAF rule identifier.
127///
128/// Stored as a canonical lower-case ASCII string so that `942100` and
129/// `RULE_942100` resolve to the same cell.
130#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
131pub struct RuleId(pub String);
132
133impl RuleId {
134    /// Normalise: trim whitespace, strip common prefix noise, lower-case.
135    ///
136    /// Accepts:
137    ///  - bare numbers: `"942100"` → `"942100"`
138    ///  - prefixed:     `"RULE_942100"`, `"rule-942100"` → `"942100"`
139    ///  - mixed case:   `"SQL_942100"` → `"sql_942100"` (prefix retained
140    ///    only when it doesn't match `rule`/`RULE_`)
141    #[must_use]
142    pub fn new(raw: &str) -> Self {
143        let s = raw.trim().to_ascii_lowercase();
144        // Strip "rule_" / "rule-" prefix if present — it adds no information
145        // for grid binning (every entry is a rule) and bloats cell keys.
146        let s = s
147            .strip_prefix("rule_")
148            .or_else(|| s.strip_prefix("rule-"))
149            .unwrap_or(&s);
150        Self(s.to_string())
151    }
152
153    /// The canonical identifier string.
154    #[must_use]
155    pub fn as_str(&self) -> &str {
156        &self.0
157    }
158}
159
160impl Default for RuleId {
161    /// Empty rule-id used as the zero value when a `RuleBucket` is constructed
162    /// without a known rule identifier (e.g. as the field default in `RuleBucket::default()`).
163    fn default() -> Self {
164        Self(String::new())
165    }
166}
167
168// ── Coverage tracker ─────────────────────────────────────────────────────────
169
170/// Accumulates `(payload, rule_id)` observations from live bench runs and
171/// exposes coverage analytics used by the `--coverage-report` flag.
172///
173/// Two complementary indices are maintained:
174///
175/// * `by_rule`  — `rule_id → set of distinct payloads that triggered it`
176/// * `by_class` — `payload_class → set of rule_ids it has reached`
177///
178/// Both are updated atomically on every [`record`][RuleCoverage::record]
179/// call so the coverage report is always consistent.
180#[derive(Debug, Default, Clone, Serialize, Deserialize)]
181pub struct RuleCoverage {
182    /// `rule_id → set of payload fingerprints (first 64 chars) observed`.
183    pub by_rule: BTreeMap<RuleId, BTreeSet<String>>,
184    /// `payload_class → set of rule_ids reached from that class`.
185    pub by_class: BTreeMap<PayloadClass, BTreeSet<RuleId>>,
186}
187
188impl RuleCoverage {
189    /// Create an empty coverage tracker.
190    #[must_use]
191    pub fn new() -> Self {
192        Self::default()
193    }
194
195    /// Record one `(payload, rule_id)` observation.
196    ///
197    /// `rule_id = None` means the request was not blocked (or the block
198    /// reason couldn't be extracted) — the payload class is still indexed
199    /// in `by_class` under a synthetic sentinel so "no rule triggered"
200    /// coverage is visible in the report.
201    pub fn record(&mut self, payload: &str, rule_id: Option<&str>) {
202        let cls = PayloadClass::from_payload(payload);
203        // Fingerprint: first 64 chars of the payload, trimmed to ASCII
204        // printable range to keep the report file human-readable.
205        let fp: String = payload
206            .chars()
207            .filter(|c| !c.is_control())
208            .take(64)
209            .collect();
210
211        if let Some(rid_raw) = rule_id {
212            let rid = RuleId::new(rid_raw);
213            self.by_rule.entry(rid.clone()).or_default().insert(fp);
214            self.by_class.entry(cls).or_default().insert(rid);
215        } else {
216            // Sentinel: no rule blocked this payload.
217            let sentinel = RuleId::new("__unblocked__");
218            self.by_class.entry(cls).or_default().insert(sentinel);
219        }
220    }
221
222    /// Produce a human-readable coverage summary.
223    ///
224    /// Each line is `rule_id  payload_count` separated by a tab, sorted
225    /// by rule id.  Suitable for `--coverage-report` output.
226    #[must_use]
227    pub fn coverage_report(&self) -> String {
228        let mut lines = Vec::with_capacity(self.by_rule.len() + 4);
229        lines.push(format!(
230            "# wafrift rule-coverage report — {} distinct rules triggered",
231            self.by_rule.len()
232        ));
233        lines.push(format!(
234            "# payload classes observed: {}",
235            self.by_class.len()
236        ));
237        lines.push("# rule_id\tpayloads_observed".to_string());
238        for (rule_id, payloads) in &self.by_rule {
239            lines.push(format!("{}\t{}", rule_id.as_str(), payloads.len()));
240        }
241        lines.push("# per-class summary".to_string());
242        for (cls, rules) in &self.by_class {
243            lines.push(format!(
244                "#   {}: {} rule(s) — {}",
245                cls.as_str(),
246                rules.len(),
247                rules
248                    .iter()
249                    .map(|r| r.as_str())
250                    .collect::<Vec<_>>()
251                    .join(", ")
252            ));
253        }
254        lines.join("\n")
255    }
256
257    /// Number of distinct rule IDs observed so far.
258    #[must_use]
259    pub fn rule_count(&self) -> usize {
260        // Exclude the synthetic "__unblocked__" sentinel from the headline.
261        self.by_rule
262            .keys()
263            .filter(|r| r.0 != "__unblocked__")
264            .count()
265    }
266
267    /// Rules that have been triggered at least once in this run.
268    #[must_use]
269    pub fn triggered_rules(&self) -> Vec<&RuleId> {
270        self.by_rule
271            .keys()
272            .filter(|r| r.0 != "__unblocked__")
273            .collect()
274    }
275
276    /// Serialize the coverage map to compact JSON.
277    ///
278    /// # Errors
279    ///
280    /// Returns `serde_json::Error` if serialization fails (only possible
281    /// if the in-memory types contain non-string-keyed maps, which they
282    /// cannot for `BTreeMap<RuleId, _>`).
283    pub fn to_json(&self) -> Result<String, serde_json::Error> {
284        serde_json::to_string_pretty(self)
285    }
286
287    /// Deserialize a coverage map from JSON produced by `to_json`.
288    ///
289    /// # Errors
290    ///
291    /// Returns `serde_json::Error` on malformed JSON.
292    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
293        serde_json::from_str(json)
294    }
295}
296
297// ── Descriptor ───────────────────────────────────────────────────────────────
298
299/// Produce the 2-D MAP-Elites behavior descriptor for one `(payload, rule_id)`
300/// observation.
301///
302/// The descriptor is `(PayloadClass, Option<RuleId>)`:
303///  - `Some(RuleId)` — the payload was blocked by a specific rule; the grid
304///    cell is `(class × rule_id)`.
305///  - `None` — the payload was not blocked (or the rule_id could not be
306///    extracted); the grid cell collapses to class-only, matching the
307///    pre-coverage behavior.
308///
309/// Stability guarantee: the same `(payload, rule_id)` pair always produces
310/// the same descriptor.  The classifier is deterministic.
311#[must_use]
312pub fn map_elites_descriptor(
313    payload: &str,
314    rule_id: Option<&str>,
315) -> (PayloadClass, Option<RuleId>) {
316    let cls = PayloadClass::from_payload(payload);
317    let rid = rule_id.map(RuleId::new);
318    (cls, rid)
319}
320
321// ── Tests ─────────────────────────────────────────────────────────────────────
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    // ── 1. Empty coverage tracker ──────────────────────────────────────────────
328
329    #[test]
330    fn empty_coverage_has_no_rules() {
331        let cov = RuleCoverage::new();
332        assert_eq!(cov.rule_count(), 0);
333        assert!(cov.by_rule.is_empty());
334        assert!(cov.by_class.is_empty());
335    }
336
337    // ── 2. Single rule_id observation ─────────────────────────────────────────
338
339    #[test]
340    fn single_rule_id_recorded_correctly() {
341        let mut cov = RuleCoverage::new();
342        cov.record("' OR 1=1--", Some("942100"));
343        assert_eq!(cov.rule_count(), 1);
344        let rid = RuleId::new("942100");
345        assert!(cov.by_rule.contains_key(&rid));
346        let cls = PayloadClass::new("sql");
347        assert!(cov.by_class.contains_key(&cls));
348    }
349
350    // ── 3. Mixed classes — distinct cells ─────────────────────────────────────
351
352    #[test]
353    fn mixed_classes_produce_distinct_cells() {
354        let mut cov = RuleCoverage::new();
355        cov.record("' OR 1=1--", Some("942100")); // sql
356        cov.record("<script>alert(1)</script>", Some("941100")); // xss
357        cov.record("../../../etc/passwd", Some("930100")); // path
358
359        assert_eq!(cov.rule_count(), 3);
360        // Three distinct classes must be present.
361        assert!(cov.by_class.contains_key(&PayloadClass::new("sql")));
362        assert!(cov.by_class.contains_key(&PayloadClass::new("xss")));
363        assert!(cov.by_class.contains_key(&PayloadClass::new("path")));
364    }
365
366    // ── 4. Descriptor stability — same input → same descriptor ────────────────
367
368    #[test]
369    fn descriptor_is_stable_for_same_input() {
370        let payload = "' UNION SELECT 1,2,3--";
371        let rule = Some("942190");
372        let d1 = map_elites_descriptor(payload, rule);
373        let d2 = map_elites_descriptor(payload, rule);
374        assert_eq!(d1, d2);
375    }
376
377    // ── 5. Descriptor with no rule_id → class only ────────────────────────────
378
379    #[test]
380    fn descriptor_without_rule_id_has_none_dimension() {
381        let (cls, rid) = map_elites_descriptor("' OR 1=1--", None);
382        assert_eq!(cls, PayloadClass::new("sql"));
383        assert!(rid.is_none());
384    }
385
386    // ── 6. JSON round-trip ────────────────────────────────────────────────────
387
388    #[test]
389    fn json_roundtrip_preserves_coverage() {
390        let mut cov = RuleCoverage::new();
391        cov.record("' OR 1=1--", Some("942100"));
392        cov.record("<script>alert(1)</script>", Some("941100"));
393
394        let json = cov.to_json().expect("serialization must not fail");
395        let restored = RuleCoverage::from_json(&json).expect("deserialization must not fail");
396
397        assert_eq!(restored.rule_count(), cov.rule_count());
398        assert_eq!(restored.by_rule.len(), cov.by_rule.len());
399        assert_eq!(restored.by_class.len(), cov.by_class.len());
400    }
401
402    // ── 7. rule_id case-folding ───────────────────────────────────────────────
403
404    #[test]
405    fn rule_id_case_folding_normalises() {
406        // All three forms should resolve to the same canonical RuleId.
407        let r1 = RuleId::new("942100");
408        let r2 = RuleId::new("RULE_942100");
409        let r3 = RuleId::new("rule-942100");
410        assert_eq!(r1, r2);
411        assert_eq!(r1, r3);
412    }
413
414    // ── 8. Rule ID with unusual prefix (no "rule_" prefix) stays intact ───────
415
416    #[test]
417    fn rule_id_without_rule_prefix_preserved() {
418        let r = RuleId::new("sql_942100");
419        // Should NOT strip "sql_" — only "rule_" is stripped.
420        assert_eq!(r.as_str(), "sql_942100");
421    }
422
423    // ── 9. PayloadClass from SQL payload ──────────────────────────────────────
424
425    #[test]
426    fn payload_class_detects_sql() {
427        let cls = PayloadClass::from_payload("' UNION SELECT username, password FROM users--");
428        assert_eq!(cls, PayloadClass::new("sql"));
429    }
430
431    // ── 10. PayloadClass from XSS payload ─────────────────────────────────────
432
433    #[test]
434    fn payload_class_detects_xss() {
435        let cls = PayloadClass::from_payload("<script>alert(document.cookie)</script>");
436        assert_eq!(cls, PayloadClass::new("xss"));
437    }
438
439    // ── 11. Multiple payloads hitting the same rule accumulate ────────────────
440
441    #[test]
442    fn same_rule_accumulates_multiple_payloads() {
443        let mut cov = RuleCoverage::new();
444        cov.record("' OR 1=1--", Some("942100"));
445        cov.record("' OR 'x'='x'--", Some("942100"));
446        cov.record("1 AND 1=1--", Some("942100"));
447        let rid = RuleId::new("942100");
448        // All three payloads are distinct fingerprints.
449        assert_eq!(cov.by_rule[&rid].len(), 3);
450        // Still only one rule.
451        assert_eq!(cov.rule_count(), 1);
452    }
453
454    // ── 12. coverage_report contains expected rule ─────────────────────────────
455
456    #[test]
457    fn coverage_report_contains_triggered_rule() {
458        let mut cov = RuleCoverage::new();
459        cov.record("' OR 1=1--", Some("942100"));
460        let report = cov.coverage_report();
461        assert!(report.contains("942100"), "report must mention rule 942100");
462        assert!(report.contains("1"), "report must show payload count");
463    }
464}