Skip to main content

wafrift_evolution/search/
novelty.rs

1use crate::evolution::crossover::mutation::mutate_with_log;
2use crate::evolution::{Chromosome, GenePool, population::random_chromosome};
3use crate::lineage::Lineage;
4use crate::search::{EvalCandidate, SearchAlgorithm, fitness_cmp};
5use crate::types::{Budget, EvolutionError, OracleVerdict, SearchStats};
6use rand::rngs::StdRng;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use wafrift_types::pick::pick_ref_from_rng;
10
11/// Novelty search with k-NN behavioral distance archive.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct NoveltySearch {
14    population: Vec<Chromosome>,
15    archive: Vec<Chromosome>,
16    gene_pool: GenePool,
17    generation: u32,
18    eval_counter: u64,
19    k: usize,
20    threshold: f64,
21    #[serde(skip)]
22    in_flight: HashMap<u64, Chromosome>,
23}
24
25impl NoveltySearch {
26    #[must_use]
27    pub fn new(k: usize, threshold: f64) -> Self {
28        Self {
29            population: Vec::new(),
30            archive: Vec::new(),
31            gene_pool: GenePool::default_wafrift(),
32            generation: 0,
33            eval_counter: 0,
34            k,
35            threshold,
36            in_flight: HashMap::new(),
37        }
38    }
39
40    fn phenotypic_distance(a: &Chromosome, b: &Chromosome) -> f64 {
41        let genes_a: Vec<_> = a.genes.iter().map(|(n, v)| format!("{n}={v}")).collect();
42        let genes_b: Vec<_> = b.genes.iter().map(|(n, v)| format!("{n}={v}")).collect();
43        levenshtein_distance(&genes_a.join("|"), &genes_b.join("|")) as f64
44            / (genes_a.len().max(genes_b.len()).max(1) as f64)
45    }
46
47    fn novelty_score(&self, chromosome: &Chromosome) -> f64 {
48        let mut neighbors: Vec<f64> = self
49            .archive
50            .iter()
51            .chain(self.population.iter())
52            .map(|other| Self::phenotypic_distance(chromosome, other))
53            .collect();
54        neighbors.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
55        neighbors.truncate(self.k);
56        if neighbors.is_empty() {
57            return f64::INFINITY;
58        }
59        neighbors.iter().sum::<f64>() / neighbors.len() as f64
60    }
61
62    fn generate_individual(&self, rng: &mut StdRng) -> Chromosome {
63        if self.population.is_empty() {
64            return random_chromosome(&self.gene_pool, rng);
65        }
66        let parent = pick_ref_from_rng(&self.population, rng).unwrap_or(&self.population[0]);
67        let mut child = parent.clone();
68        let log = mutate_with_log(&mut child, &self.gene_pool, 0.3, rng);
69        child.lineage = Lineage::mutation(parent, log, self.generation);
70        child
71    }
72}
73
74impl Default for NoveltySearch {
75    fn default() -> Self {
76        Self::new(15, 0.3)
77    }
78}
79
80impl SearchAlgorithm for NoveltySearch {
81    fn name(&self) -> &'static str {
82        "novelty_search"
83    }
84
85    fn initialize(&mut self, population: Vec<Chromosome>, gene_pool: &GenePool, _rng: &mut StdRng) {
86        self.gene_pool = gene_pool.clone();
87        self.population = population;
88        self.archive.clear();
89        self.in_flight.clear();
90    }
91
92    fn request_evaluations(&mut self, n: usize, rng: &mut StdRng) -> Vec<EvalCandidate> {
93        let mut out = Vec::with_capacity(n);
94        for _ in 0..n {
95            self.eval_counter = self.eval_counter.saturating_add(1);
96            let candidate = self.generate_individual(rng);
97            self.in_flight.insert(self.eval_counter, candidate.clone());
98            out.push(EvalCandidate {
99                id: self.eval_counter,
100                chromosome: candidate,
101            });
102        }
103        out
104    }
105
106    fn submit_evaluations(&mut self, results: Vec<(u64, OracleVerdict)>) {
107        let mut evaluated: Vec<Chromosome> = Vec::with_capacity(results.len());
108        for (id, verdict) in results {
109            if let Some(mut candidate) = self.in_flight.remove(&id) {
110                candidate.record_verdict(&verdict);
111                evaluated.push(candidate);
112            }
113        }
114
115        // Add to archive based on novelty. Cap the archive at 10_000
116        // to prevent unbounded growth on long-running scans (every
117        // novel candidate would otherwise stay alive forever, leaking
118        // memory until OOM). When full, evict the least-novel entry
119        // by score so the highest-novelty history is retained.
120        const ARCHIVE_CAP: usize = 10_000;
121        for candidate in evaluated {
122            let score = self.novelty_score(&candidate);
123            if score > self.threshold {
124                if self.archive.len() >= ARCHIVE_CAP
125                    && let Some((min_idx, _)) = self
126                        .archive
127                        .iter()
128                        .enumerate()
129                        .map(|(i, c)| (i, self.novelty_score(c)))
130                        .min_by(|(_, a), (_, b)| {
131                            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
132                        })
133                {
134                    self.archive.swap_remove(min_idx);
135                }
136                self.archive.push(candidate.clone());
137            }
138            self.population.push(candidate);
139        }
140
141        // Cull population to reasonable size, keeping most novel
142        if self.population.len() > 100 {
143            let temp: Vec<Chromosome> = self.population.drain(..).collect();
144            let mut scored: Vec<(f64, Chromosome)> = temp
145                .into_iter()
146                .map(|c| (self.novelty_score(&c), c))
147                .collect();
148            scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
149            scored.truncate(100);
150            self.population = scored.into_iter().map(|(_, c)| c).collect();
151        }
152
153        self.generation = self.generation.saturating_add(1);
154    }
155
156    fn should_terminate(&self, stats: &SearchStats, budget: &Budget) -> bool {
157        stats.evaluations >= budget.max_requests
158            || stats.generation >= budget.max_generations
159            || stats.stagnation_counter >= budget.stagnation_limit
160    }
161
162    fn best(&self) -> Option<&Chromosome> {
163        // F144 sibling: route through fitness_cmp so a NaN-fitness
164        // chromosome can never beat a finite one. Bare partial_cmp
165        // with `.unwrap_or(Equal)` collapsed every NaN comparison
166        // into Equal, letting a poisoned chromosome become "best"
167        // by simple iteration order.
168        self.population
169            .iter()
170            .chain(self.archive.iter())
171            .max_by(|a, b| fitness_cmp(a.fitness, b.fitness))
172    }
173
174    fn checkpoint(&self) -> Result<Vec<u8>, EvolutionError> {
175        serde_json::to_vec(self).map_err(EvolutionError::SerializationFailed)
176    }
177
178    fn restore(&mut self, bytes: &[u8]) -> Result<(), EvolutionError> {
179        if bytes.len() > crate::types::MAX_CHECKPOINT_BYTES {
180            return Err(EvolutionError::OversizedData {
181                context: "novelty checkpoint restore".into(),
182                size: bytes.len(),
183                max: crate::types::MAX_CHECKPOINT_BYTES,
184            });
185        }
186        *self = serde_json::from_slice(bytes).map_err(EvolutionError::DeserializationFailed)?;
187        self.in_flight.clear();
188        Ok(())
189    }
190
191    /// Population + archive — both are live state the algorithm draws
192    /// candidates from. Diversity over the union is the meaningful
193    /// signal for adaptive mutation pressure.
194    fn population_snapshot(&self) -> Vec<Chromosome> {
195        let mut out = Vec::with_capacity(self.population.len() + self.archive.len());
196        out.extend(self.population.iter().cloned());
197        out.extend(self.archive.iter().cloned());
198        out
199    }
200
201    fn clone_box(&self) -> Box<dyn SearchAlgorithm> {
202        Box::new(self.clone())
203    }
204}
205
206fn levenshtein_distance(a: &str, b: &str) -> usize {
207    let a_chars: Vec<char> = a.chars().collect();
208    let b_chars: Vec<char> = b.chars().collect();
209    let mut prev = vec![0; b_chars.len() + 1];
210    let mut curr = vec![0; b_chars.len() + 1];
211    for (j, slot) in prev.iter_mut().enumerate().take(b_chars.len() + 1) {
212        *slot = j;
213    }
214    for i in 1..=a_chars.len() {
215        curr[0] = i;
216        for j in 1..=b_chars.len() {
217            let cost = if a_chars[i - 1] == b_chars[j - 1] {
218                0
219            } else {
220                1
221            };
222            curr[j] = (curr[j - 1] + 1).min(prev[j] + 1).min(prev[j - 1] + cost);
223        }
224        std::mem::swap(&mut prev, &mut curr);
225    }
226    prev[b_chars.len()]
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use rand::SeedableRng;
233
234    fn dummy_chromosome(encoding: &str, grammar: &str, content_type: &str) -> Chromosome {
235        Chromosome::new(vec![
236            ("encoding".into(), encoding.into()),
237            ("grammar_rule".into(), grammar.into()),
238            ("content_type".into(), content_type.into()),
239        ])
240    }
241
242    #[test]
243    fn initialize_sets_population() {
244        let mut alg = NoveltySearch::new(5, 0.3);
245        let pool = GenePool::default_wafrift();
246        let mut rng = StdRng::seed_from_u64(1);
247        let pop = vec![
248            dummy_chromosome("UrlEncode", "sqli", "json"),
249            dummy_chromosome("CaseAlternation", "cmdi", "form"),
250        ];
251        alg.initialize(pop.clone(), &pool, &mut rng);
252        assert_eq!(alg.population.len(), 2);
253        assert!(alg.archive.is_empty());
254    }
255
256    #[test]
257    fn request_evaluations_returns_unique_ids() {
258        let mut alg = NoveltySearch::new(5, 0.3);
259        let pool = GenePool::default_wafrift();
260        let mut rng = StdRng::seed_from_u64(2);
261        alg.initialize(
262            vec![dummy_chromosome("UrlEncode", "sqli", "json")],
263            &pool,
264            &mut rng,
265        );
266
267        let c1 = alg.request_evaluations(2, &mut rng);
268        let c2 = alg.request_evaluations(2, &mut rng);
269        let ids: Vec<_> = c1.iter().chain(c2.iter()).map(|c| c.id).collect();
270        let unique: std::collections::HashSet<_> = ids.iter().copied().collect();
271        assert_eq!(ids.len(), unique.len());
272    }
273
274    #[test]
275    fn submit_evaluation_populates_archive_and_population() {
276        let mut alg = NoveltySearch::new(5, 0.0); // threshold = 0 → everything is novel
277        let pool = GenePool::default_wafrift();
278        let mut rng = StdRng::seed_from_u64(3);
279        alg.initialize(vec![], &pool, &mut rng);
280
281        let candidates = alg.request_evaluations(2, &mut rng);
282        let id1 = candidates[0].id;
283        let id2 = candidates[1].id;
284
285        alg.submit_evaluations(vec![
286            (
287                id1,
288                OracleVerdict {
289                    passed: true,
290                    status_delta: 1,
291                    body_delta: 1,
292                    latency_ms: 10,
293                    confidence: 0.9,
294                    triggered_rules: 0,
295                    ..Default::default()
296                },
297            ),
298            (
299                id2,
300                OracleVerdict {
301                    passed: false,
302                    status_delta: 0,
303                    body_delta: 0,
304                    latency_ms: 10,
305                    confidence: 0.1,
306                    triggered_rules: 1,
307                    ..Default::default()
308                },
309            ),
310        ]);
311
312        assert!(!alg.population.is_empty());
313        assert!(!alg.archive.is_empty());
314        assert!(alg.best().is_some());
315    }
316
317    #[test]
318    fn archive_respects_threshold() {
319        let mut alg = NoveltySearch::new(5, f64::INFINITY); // threshold = ∞ → nothing is novel
320        let pool = GenePool::default_wafrift();
321        let mut rng = StdRng::seed_from_u64(4);
322        alg.initialize(vec![], &pool, &mut rng);
323
324        let candidates = alg.request_evaluations(3, &mut rng);
325        let results: Vec<_> = candidates
326            .iter()
327            .map(|c| {
328                (
329                    c.id,
330                    OracleVerdict {
331                        passed: true,
332                        status_delta: 1,
333                        body_delta: 1,
334                        latency_ms: 10,
335                        confidence: 0.9,
336                        triggered_rules: 0,
337                        ..Default::default()
338                    },
339                )
340            })
341            .collect();
342        alg.submit_evaluations(results);
343        // With infinite threshold, nothing should enter the archive
344        assert!(alg.archive.is_empty());
345        // But population still grows
346        assert!(!alg.population.is_empty());
347    }
348
349    #[test]
350    fn checkpoint_roundtrip_clears_in_flight() {
351        let mut alg = NoveltySearch::new(5, 0.3);
352        let pool = GenePool::default_wafrift();
353        let mut rng = StdRng::seed_from_u64(5);
354        alg.initialize(
355            vec![dummy_chromosome("UrlEncode", "sqli", "json")],
356            &pool,
357            &mut rng,
358        );
359        let _ = alg.request_evaluations(3, &mut rng);
360        assert!(!alg.in_flight.is_empty());
361
362        let bytes = alg.checkpoint().expect("checkpoint must serialize");
363        let mut restored = NoveltySearch::new(5, 0.3);
364        restored.restore(&bytes).expect("restore must succeed");
365        assert!(restored.in_flight.is_empty());
366    }
367
368    #[test]
369    fn should_terminate_respects_budget() {
370        let alg = NoveltySearch::new(5, 0.3);
371        let budget = Budget::default_wafrift();
372        let stats = SearchStats {
373            generation: budget.max_generations - 1,
374            ..SearchStats::default()
375        };
376        assert!(!alg.should_terminate(&stats, &budget));
377        let stats = SearchStats {
378            generation: budget.max_generations,
379            ..SearchStats::default()
380        };
381        assert!(alg.should_terminate(&stats, &budget));
382    }
383
384    #[test]
385    fn best_returns_none_for_empty_population_and_archive() {
386        let alg = NoveltySearch::new(5, 0.3);
387        assert!(alg.best().is_none());
388    }
389
390    #[test]
391    fn phenotypic_distance_is_symmetric() {
392        let a = dummy_chromosome("UrlEncode", "sqli", "json");
393        let b = dummy_chromosome("CaseAlternation", "cmdi", "form");
394        let d1 = NoveltySearch::phenotypic_distance(&a, &b);
395        let d2 = NoveltySearch::phenotypic_distance(&b, &a);
396        assert!((d1 - d2).abs() < f64::EPSILON);
397    }
398
399    #[test]
400    fn phenotypic_distance_self_is_zero() {
401        let a = dummy_chromosome("UrlEncode", "sqli", "json");
402        let d = NoveltySearch::phenotypic_distance(&a, &a);
403        assert!(d.abs() < f64::EPSILON);
404    }
405
406    #[test]
407    fn levenshtein_distance_smoke() {
408        assert_eq!(super::levenshtein_distance("kitten", "sitting"), 3);
409        assert_eq!(super::levenshtein_distance("", ""), 0);
410        assert_eq!(super::levenshtein_distance("a", ""), 1);
411        assert_eq!(super::levenshtein_distance("", "b"), 1);
412    }
413
414    // ── Saturating-arithmetic regression tests ────────────────────────────────
415
416    /// `eval_counter` must saturate at `u64::MAX` instead of wrapping to 0.
417    #[test]
418    fn eval_counter_saturates_at_u64_max() {
419        let mut alg = NoveltySearch::new(5, 0.0);
420        let pool = GenePool::default_wafrift();
421        let mut rng = StdRng::seed_from_u64(50);
422        alg.initialize(
423            vec![dummy_chromosome("UrlEncode", "sqli", "json")],
424            &pool,
425            &mut rng,
426        );
427        alg.eval_counter = u64::MAX;
428        let _ = alg.request_evaluations(1, &mut rng);
429        assert_eq!(
430            alg.eval_counter,
431            u64::MAX,
432            "eval_counter must saturate at u64::MAX, not wrap to 0"
433        );
434    }
435
436    /// `generation` must saturate at `u32::MAX` instead of wrapping to 0.
437    #[test]
438    fn generation_saturates_at_u32_max() {
439        let mut alg = NoveltySearch::new(5, 0.0);
440        let pool = GenePool::default_wafrift();
441        let mut rng = StdRng::seed_from_u64(51);
442        alg.initialize(vec![], &pool, &mut rng);
443        alg.generation = u32::MAX;
444        alg.submit_evaluations(vec![]);
445        assert_eq!(
446            alg.generation,
447            u32::MAX,
448            "generation must saturate at u32::MAX, not wrap to 0"
449        );
450    }
451
452    /// IDs emitted by `request_evaluations` must never collide across rounds.
453    #[test]
454    fn eval_counter_ids_are_unique_across_generations() {
455        let mut alg = NoveltySearch::new(5, 0.0);
456        let pool = GenePool::default_wafrift();
457        let mut rng = StdRng::seed_from_u64(52);
458        alg.initialize(
459            vec![dummy_chromosome("CaseAlternation", "xss", "form")],
460            &pool,
461            &mut rng,
462        );
463        let mut ids: Vec<u64> = Vec::new();
464        for _ in 0..8 {
465            let batch = alg.request_evaluations(3, &mut rng);
466            for c in &batch {
467                ids.push(c.id);
468            }
469            let verdicts: Vec<_> = batch
470                .into_iter()
471                .map(|c| (c.id, OracleVerdict::from_bool(false)))
472                .collect();
473            alg.submit_evaluations(verdicts);
474        }
475        let unique: std::collections::HashSet<_> = ids.iter().copied().collect();
476        assert_eq!(unique.len(), ids.len(), "eval IDs must never collide");
477    }
478}