wafrift_evolution/coverage_feedback.rs
1//! WAF rule-coverage feedback for MAP-Elites quality-diversity search.
2//!
3//! When the bench fires against a ModSec-fronted target, the response body
4//! may contain the specific CRS `rule_id` that fired (parsed by
5//! `wafrift_oracle::signal_body_marker::BlockReason::RuleId`). This module
6//! turns that signal into a 2-D MAP-Elites *behavior descriptor*:
7//!
8//! ```text
9//! (PayloadClass, Option<RuleId>)
10//! ```
11//!
12//! The grid cell is `(attack-class × rule-id)`. When a cell is
13//! undiscovered the mutation strategy can target it deliberately, so
14//! bypasses are found ACROSS the rule corpus rather than concentrated on
15//! the rules the engine accidentally hits first.
16//!
17//! # Usage
18//!
19//! ```
20//! use wafrift_evolution::coverage_feedback::{
21//! RuleCoverage, PayloadClass, RuleId, map_elites_descriptor,
22//! };
23//!
24//! let mut cov = RuleCoverage::default();
25//! let desc = map_elites_descriptor("' OR 1=1--", Some("942100"));
26//! cov.record("' OR 1=1--", desc.1.as_ref().map(|r| r.0.as_str()));
27//!
28//! let report = cov.coverage_report();
29//! assert!(!report.is_empty());
30//! ```
31
32use serde::{Deserialize, Serialize};
33use std::collections::{BTreeMap, BTreeSet};
34
35// ── Types ─────────────────────────────────────────────────────────────────────
36
37/// The attack-class dimension of the MAP-Elites grid.
38///
39/// Derived from the payload content (or from the bench case's `class` field
40/// if the caller has it). Comparison is case-insensitive; values are stored
41/// as lower-case canonical strings so cells are stable across equivalent
42/// representations.
43#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
44pub struct PayloadClass(pub String);
45
46impl PayloadClass {
47 /// Construct from an arbitrary string. The value is lower-cased and
48 /// stripped of leading/trailing whitespace so grid cells are stable.
49 #[must_use]
50 pub fn new(raw: &str) -> Self {
51 Self(raw.trim().to_ascii_lowercase())
52 }
53
54 /// Classify a raw payload string heuristically.
55 ///
56 /// The classifier is intentionally lightweight — it looks for the
57 /// strongest textual signal in the payload rather than doing full
58 /// parse-tree analysis. The categories match the wafrift bench corpus
59 /// class identifiers so coverage reports are directly comparable.
60 #[must_use]
61 pub fn from_payload(payload: &str) -> Self {
62 let lower = payload.to_ascii_lowercase();
63 if lower.contains("select")
64 || lower.contains("union")
65 || lower.contains("insert")
66 || lower.contains("update")
67 || lower.contains("delete")
68 || lower.contains("drop")
69 || lower.contains("' or ")
70 || lower.contains("or 1=1")
71 {
72 return Self::new("sql");
73 }
74 if lower.contains("<script")
75 || lower.contains("onerror")
76 || lower.contains("onload")
77 || lower.contains("javascript:")
78 || lower.contains("alert(")
79 {
80 return Self::new("xss");
81 }
82 if lower.contains("../")
83 || lower.contains("..\\")
84 || lower.contains("%2e%2e")
85 || lower.contains("etc/passwd")
86 {
87 return Self::new("path");
88 }
89 if lower.contains("$(")
90 || lower.contains("`")
91 || lower.contains("|bash")
92 || lower.contains("cmd.exe")
93 || lower.contains("/bin/sh")
94 {
95 return Self::new("cmdi");
96 }
97 if lower.contains("{{")
98 || lower.contains("{%")
99 || lower.contains("#{")
100 || lower.contains("${'")
101 {
102 return Self::new("ssti");
103 }
104 if lower.contains("ldap://") || lower.contains("(uid=") || lower.contains("(cn=") {
105 return Self::new("ldap");
106 }
107 if lower.contains("http://") || lower.contains("https://") || lower.contains("ssrf") {
108 return Self::new("ssrf");
109 }
110 if lower.contains("<!entity") || lower.contains("<!doctype") || lower.contains("xxe") {
111 return Self::new("xxe");
112 }
113 if lower.contains("${jndi:") || lower.contains("log4j") || lower.contains("log4shell") {
114 return Self::new("log4shell");
115 }
116 Self::new("unknown")
117 }
118
119 /// The canonical string representation.
120 #[must_use]
121 pub fn as_str(&self) -> &str {
122 &self.0
123 }
124}
125
126/// A CRS / WAF rule identifier.
127///
128/// Stored as a canonical lower-case ASCII string so that `942100` and
129/// `RULE_942100` resolve to the same cell.
130#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
131pub struct RuleId(pub String);
132
133impl RuleId {
134 /// Normalise: trim whitespace, strip common prefix noise, lower-case.
135 ///
136 /// Accepts:
137 /// - bare numbers: `"942100"` → `"942100"`
138 /// - prefixed: `"RULE_942100"`, `"rule-942100"` → `"942100"`
139 /// - mixed case: `"SQL_942100"` → `"sql_942100"` (prefix retained
140 /// only when it doesn't match `rule`/`RULE_`)
141 #[must_use]
142 pub fn new(raw: &str) -> Self {
143 let s = raw.trim().to_ascii_lowercase();
144 // Strip "rule_" / "rule-" prefix if present — it adds no information
145 // for grid binning (every entry is a rule) and bloats cell keys.
146 let s = s
147 .strip_prefix("rule_")
148 .or_else(|| s.strip_prefix("rule-"))
149 .unwrap_or(&s);
150 Self(s.to_string())
151 }
152
153 /// The canonical identifier string.
154 #[must_use]
155 pub fn as_str(&self) -> &str {
156 &self.0
157 }
158}
159
160impl Default for RuleId {
161 /// Empty rule-id used as the zero value when a `RuleBucket` is constructed
162 /// without a known rule identifier (e.g. as the field default in `RuleBucket::default()`).
163 fn default() -> Self {
164 Self(String::new())
165 }
166}
167
168// ── Coverage tracker ─────────────────────────────────────────────────────────
169
170/// Accumulates `(payload, rule_id)` observations from live bench runs and
171/// exposes coverage analytics used by the `--coverage-report` flag.
172///
173/// Two complementary indices are maintained:
174///
175/// * `by_rule` — `rule_id → set of distinct payloads that triggered it`
176/// * `by_class` — `payload_class → set of rule_ids it has reached`
177///
178/// Both are updated atomically on every [`record`][RuleCoverage::record]
179/// call so the coverage report is always consistent.
180#[derive(Debug, Default, Clone, Serialize, Deserialize)]
181pub struct RuleCoverage {
182 /// `rule_id → set of payload fingerprints (first 64 chars) observed`.
183 pub by_rule: BTreeMap<RuleId, BTreeSet<String>>,
184 /// `payload_class → set of rule_ids reached from that class`.
185 pub by_class: BTreeMap<PayloadClass, BTreeSet<RuleId>>,
186}
187
188impl RuleCoverage {
189 /// Create an empty coverage tracker.
190 #[must_use]
191 pub fn new() -> Self {
192 Self::default()
193 }
194
195 /// Record one `(payload, rule_id)` observation.
196 ///
197 /// `rule_id = None` means the request was not blocked (or the block
198 /// reason couldn't be extracted) — the payload class is still indexed
199 /// in `by_class` under a synthetic sentinel so "no rule triggered"
200 /// coverage is visible in the report.
201 pub fn record(&mut self, payload: &str, rule_id: Option<&str>) {
202 let cls = PayloadClass::from_payload(payload);
203 // Fingerprint: first 64 chars of the payload, trimmed to ASCII
204 // printable range to keep the report file human-readable.
205 let fp: String = payload
206 .chars()
207 .filter(|c| !c.is_control())
208 .take(64)
209 .collect();
210
211 if let Some(rid_raw) = rule_id {
212 let rid = RuleId::new(rid_raw);
213 self.by_rule.entry(rid.clone()).or_default().insert(fp);
214 self.by_class.entry(cls).or_default().insert(rid);
215 } else {
216 // Sentinel: no rule blocked this payload.
217 let sentinel = RuleId::new("__unblocked__");
218 self.by_class.entry(cls).or_default().insert(sentinel);
219 }
220 }
221
222 /// Produce a human-readable coverage summary.
223 ///
224 /// Each line is `rule_id payload_count` separated by a tab, sorted
225 /// by rule id. Suitable for `--coverage-report` output.
226 #[must_use]
227 pub fn coverage_report(&self) -> String {
228 let mut lines = Vec::with_capacity(self.by_rule.len() + 4);
229 lines.push(format!(
230 "# wafrift rule-coverage report — {} distinct rules triggered",
231 self.by_rule.len()
232 ));
233 lines.push(format!(
234 "# payload classes observed: {}",
235 self.by_class.len()
236 ));
237 lines.push("# rule_id\tpayloads_observed".to_string());
238 for (rule_id, payloads) in &self.by_rule {
239 lines.push(format!("{}\t{}", rule_id.as_str(), payloads.len()));
240 }
241 lines.push("# per-class summary".to_string());
242 for (cls, rules) in &self.by_class {
243 lines.push(format!(
244 "# {}: {} rule(s) — {}",
245 cls.as_str(),
246 rules.len(),
247 rules
248 .iter()
249 .map(|r| r.as_str())
250 .collect::<Vec<_>>()
251 .join(", ")
252 ));
253 }
254 lines.join("\n")
255 }
256
257 /// Number of distinct rule IDs observed so far.
258 #[must_use]
259 pub fn rule_count(&self) -> usize {
260 // Exclude the synthetic "__unblocked__" sentinel from the headline.
261 self.by_rule
262 .keys()
263 .filter(|r| r.0 != "__unblocked__")
264 .count()
265 }
266
267 /// Rules that have been triggered at least once in this run.
268 #[must_use]
269 pub fn triggered_rules(&self) -> Vec<&RuleId> {
270 self.by_rule
271 .keys()
272 .filter(|r| r.0 != "__unblocked__")
273 .collect()
274 }
275
276 /// Serialize the coverage map to compact JSON.
277 ///
278 /// # Errors
279 ///
280 /// Returns `serde_json::Error` if serialization fails (only possible
281 /// if the in-memory types contain non-string-keyed maps, which they
282 /// cannot for `BTreeMap<RuleId, _>`).
283 pub fn to_json(&self) -> Result<String, serde_json::Error> {
284 serde_json::to_string_pretty(self)
285 }
286
287 /// Deserialize a coverage map from JSON produced by `to_json`.
288 ///
289 /// # Errors
290 ///
291 /// Returns `serde_json::Error` on malformed JSON.
292 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
293 serde_json::from_str(json)
294 }
295}
296
297// ── Descriptor ───────────────────────────────────────────────────────────────
298
299/// Produce the 2-D MAP-Elites behavior descriptor for one `(payload, rule_id)`
300/// observation.
301///
302/// The descriptor is `(PayloadClass, Option<RuleId>)`:
303/// - `Some(RuleId)` — the payload was blocked by a specific rule; the grid
304/// cell is `(class × rule_id)`.
305/// - `None` — the payload was not blocked (or the rule_id could not be
306/// extracted); the grid cell collapses to class-only, matching the
307/// pre-coverage behavior.
308///
309/// Stability guarantee: the same `(payload, rule_id)` pair always produces
310/// the same descriptor. The classifier is deterministic.
311#[must_use]
312pub fn map_elites_descriptor(
313 payload: &str,
314 rule_id: Option<&str>,
315) -> (PayloadClass, Option<RuleId>) {
316 let cls = PayloadClass::from_payload(payload);
317 let rid = rule_id.map(RuleId::new);
318 (cls, rid)
319}
320
321// ── Tests ─────────────────────────────────────────────────────────────────────
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 // ── 1. Empty coverage tracker ──────────────────────────────────────────────
328
329 #[test]
330 fn empty_coverage_has_no_rules() {
331 let cov = RuleCoverage::new();
332 assert_eq!(cov.rule_count(), 0);
333 assert!(cov.by_rule.is_empty());
334 assert!(cov.by_class.is_empty());
335 }
336
337 // ── 2. Single rule_id observation ─────────────────────────────────────────
338
339 #[test]
340 fn single_rule_id_recorded_correctly() {
341 let mut cov = RuleCoverage::new();
342 cov.record("' OR 1=1--", Some("942100"));
343 assert_eq!(cov.rule_count(), 1);
344 let rid = RuleId::new("942100");
345 assert!(cov.by_rule.contains_key(&rid));
346 let cls = PayloadClass::new("sql");
347 assert!(cov.by_class.contains_key(&cls));
348 }
349
350 // ── 3. Mixed classes — distinct cells ─────────────────────────────────────
351
352 #[test]
353 fn mixed_classes_produce_distinct_cells() {
354 let mut cov = RuleCoverage::new();
355 cov.record("' OR 1=1--", Some("942100")); // sql
356 cov.record("<script>alert(1)</script>", Some("941100")); // xss
357 cov.record("../../../etc/passwd", Some("930100")); // path
358
359 assert_eq!(cov.rule_count(), 3);
360 // Three distinct classes must be present.
361 assert!(cov.by_class.contains_key(&PayloadClass::new("sql")));
362 assert!(cov.by_class.contains_key(&PayloadClass::new("xss")));
363 assert!(cov.by_class.contains_key(&PayloadClass::new("path")));
364 }
365
366 // ── 4. Descriptor stability — same input → same descriptor ────────────────
367
368 #[test]
369 fn descriptor_is_stable_for_same_input() {
370 let payload = "' UNION SELECT 1,2,3--";
371 let rule = Some("942190");
372 let d1 = map_elites_descriptor(payload, rule);
373 let d2 = map_elites_descriptor(payload, rule);
374 assert_eq!(d1, d2);
375 }
376
377 // ── 5. Descriptor with no rule_id → class only ────────────────────────────
378
379 #[test]
380 fn descriptor_without_rule_id_has_none_dimension() {
381 let (cls, rid) = map_elites_descriptor("' OR 1=1--", None);
382 assert_eq!(cls, PayloadClass::new("sql"));
383 assert!(rid.is_none());
384 }
385
386 // ── 6. JSON round-trip ────────────────────────────────────────────────────
387
388 #[test]
389 fn json_roundtrip_preserves_coverage() {
390 let mut cov = RuleCoverage::new();
391 cov.record("' OR 1=1--", Some("942100"));
392 cov.record("<script>alert(1)</script>", Some("941100"));
393
394 let json = cov.to_json().expect("serialization must not fail");
395 let restored = RuleCoverage::from_json(&json).expect("deserialization must not fail");
396
397 assert_eq!(restored.rule_count(), cov.rule_count());
398 assert_eq!(restored.by_rule.len(), cov.by_rule.len());
399 assert_eq!(restored.by_class.len(), cov.by_class.len());
400 }
401
402 // ── 7. rule_id case-folding ───────────────────────────────────────────────
403
404 #[test]
405 fn rule_id_case_folding_normalises() {
406 // All three forms should resolve to the same canonical RuleId.
407 let r1 = RuleId::new("942100");
408 let r2 = RuleId::new("RULE_942100");
409 let r3 = RuleId::new("rule-942100");
410 assert_eq!(r1, r2);
411 assert_eq!(r1, r3);
412 }
413
414 // ── 8. Rule ID with unusual prefix (no "rule_" prefix) stays intact ───────
415
416 #[test]
417 fn rule_id_without_rule_prefix_preserved() {
418 let r = RuleId::new("sql_942100");
419 // Should NOT strip "sql_" — only "rule_" is stripped.
420 assert_eq!(r.as_str(), "sql_942100");
421 }
422
423 // ── 9. PayloadClass from SQL payload ──────────────────────────────────────
424
425 #[test]
426 fn payload_class_detects_sql() {
427 let cls = PayloadClass::from_payload("' UNION SELECT username, password FROM users--");
428 assert_eq!(cls, PayloadClass::new("sql"));
429 }
430
431 // ── 10. PayloadClass from XSS payload ─────────────────────────────────────
432
433 #[test]
434 fn payload_class_detects_xss() {
435 let cls = PayloadClass::from_payload("<script>alert(document.cookie)</script>");
436 assert_eq!(cls, PayloadClass::new("xss"));
437 }
438
439 // ── 11. Multiple payloads hitting the same rule accumulate ────────────────
440
441 #[test]
442 fn same_rule_accumulates_multiple_payloads() {
443 let mut cov = RuleCoverage::new();
444 cov.record("' OR 1=1--", Some("942100"));
445 cov.record("' OR 'x'='x'--", Some("942100"));
446 cov.record("1 AND 1=1--", Some("942100"));
447 let rid = RuleId::new("942100");
448 // All three payloads are distinct fingerprints.
449 assert_eq!(cov.by_rule[&rid].len(), 3);
450 // Still only one rule.
451 assert_eq!(cov.rule_count(), 1);
452 }
453
454 // ── 12. coverage_report contains expected rule ─────────────────────────────
455
456 #[test]
457 fn coverage_report_contains_triggered_rule() {
458 let mut cov = RuleCoverage::new();
459 cov.record("' OR 1=1--", Some("942100"));
460 let report = cov.coverage_report();
461 assert!(report.contains("942100"), "report must mention rule 942100");
462 assert!(report.contains("1"), "report must show payload count");
463 }
464}