Skip to main content

wafrift_evolution/search/
map_elites.rs

1use crate::evolution::crossover::mutation::mutate_with_log;
2use crate::evolution::{Chromosome, GenePool, population::random_chromosome};
3use crate::lineage::Lineage;
4use crate::search::{EvalCandidate, SearchAlgorithm};
5use crate::types::{Budget, EvolutionError, OracleVerdict, SearchStats};
6use rand::Rng;
7use rand::rngs::StdRng;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Feature descriptor for MAP-Elites grid binning.
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub struct FeatureDescriptor {
14    pub encoding: String,
15    pub grammar: String,
16    pub content_type: String,
17}
18
19impl FeatureDescriptor {
20    #[must_use]
21    pub fn from_chromosome(chromosome: &Chromosome) -> Self {
22        Self {
23            encoding: chromosome.gene("encoding").unwrap_or("None").to_string(),
24            grammar: chromosome
25                .gene("grammar_rule")
26                .unwrap_or("None")
27                .to_string(),
28            content_type: chromosome
29                .gene("content_type")
30                .unwrap_or("None")
31                .to_string(),
32        }
33    }
34}
35
36/// MAP-Elites quality-diversity search.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct MapElites {
39    grid: Vec<(FeatureDescriptor, Chromosome)>,
40    gene_pool: GenePool,
41    generation: u32,
42    eval_counter: u64,
43    #[serde(skip)]
44    in_flight: HashMap<u64, Chromosome>,
45}
46
47impl MapElites {
48    #[must_use]
49    pub fn new() -> Self {
50        Self {
51            grid: Vec::new(),
52            gene_pool: GenePool::default_wafrift(),
53            generation: 0,
54            eval_counter: 0,
55            in_flight: HashMap::new(),
56        }
57    }
58
59    fn sample_parent(&self, rng: &mut StdRng) -> Option<Chromosome> {
60        if self.grid.is_empty() {
61            return None;
62        }
63        // 50% of the time sample from under-filled regions (random bin)
64        // 50% of the time sample uniformly from existing elites
65        if rng.gen_bool(0.5) {
66            let idx = rng.gen_range(0..self.grid.len());
67            Some(self.grid[idx].1.clone())
68        } else {
69            // Try to fill a random feature combination
70            let encoding = self
71                .gene_pool
72                .random_value("encoding", rng)
73                .unwrap_or_else(|| "None".into());
74            let grammar = self
75                .gene_pool
76                .random_value("grammar_rule", rng)
77                .unwrap_or_else(|| "None".into());
78            let content_type = self
79                .gene_pool
80                .random_value("content_type", rng)
81                .unwrap_or_else(|| "None".into());
82            let descriptor = FeatureDescriptor {
83                encoding,
84                grammar,
85                content_type,
86            };
87            self.grid
88                .iter()
89                .find(|(d, _)| *d == descriptor)
90                .map(|(_, c)| c.clone())
91                .or_else(|| {
92                    let idx = rng.gen_range(0..self.grid.len());
93                    Some(self.grid[idx].1.clone())
94                })
95        }
96    }
97
98    fn generate_individual(&self, rng: &mut StdRng) -> Chromosome {
99        match self.sample_parent(rng) {
100            Some(parent) => {
101                let mut child = parent.clone();
102                let log = mutate_with_log(&mut child, &self.gene_pool, 0.25, rng);
103                child.lineage = Lineage::mutation(&parent, log, self.generation);
104                child
105            }
106            None => random_chromosome(&self.gene_pool, rng),
107        }
108    }
109}
110
111impl Default for MapElites {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117impl SearchAlgorithm for MapElites {
118    fn name(&self) -> &'static str {
119        "map_elites"
120    }
121
122    fn initialize(&mut self, population: Vec<Chromosome>, gene_pool: &GenePool, _rng: &mut StdRng) {
123        self.gene_pool = gene_pool.clone();
124        self.grid.clear();
125        self.in_flight.clear();
126        for chromosome in population {
127            let descriptor = FeatureDescriptor::from_chromosome(&chromosome);
128            if !self.grid.iter().any(|(d, _)| *d == descriptor) {
129                self.grid.push((descriptor, chromosome));
130            }
131        }
132    }
133
134    fn request_evaluations(&mut self, n: usize, rng: &mut StdRng) -> Vec<EvalCandidate> {
135        let mut out = Vec::with_capacity(n);
136        for _ in 0..n {
137            self.eval_counter += 1;
138            let candidate = self.generate_individual(rng);
139            self.in_flight.insert(self.eval_counter, candidate.clone());
140            out.push(EvalCandidate {
141                id: self.eval_counter,
142                chromosome: candidate,
143            });
144        }
145        out
146    }
147
148    fn submit_evaluations(&mut self, results: Vec<(u64, OracleVerdict)>) {
149        for (id, verdict) in results {
150            if let Some(mut candidate) = self.in_flight.remove(&id) {
151                candidate.record_verdict(&verdict);
152                let descriptor = FeatureDescriptor::from_chromosome(&candidate);
153                let should_insert = match self.grid.iter().find(|(d, _)| *d == descriptor) {
154                    Some((_, existing)) => candidate.fitness > existing.fitness,
155                    None => true,
156                };
157                if should_insert {
158                    if let Some((idx, _)) = self
159                        .grid
160                        .iter()
161                        .enumerate()
162                        .find(|(_, (d, _))| *d == descriptor)
163                    {
164                        self.grid[idx] = (descriptor, candidate);
165                    } else {
166                        self.grid.push((descriptor, candidate));
167                    }
168                }
169            }
170        }
171        self.generation += 1;
172    }
173
174    fn should_terminate(&self, stats: &SearchStats, budget: &Budget) -> bool {
175        stats.evaluations >= budget.max_requests
176            || stats.generation >= budget.max_generations
177            || stats.stagnation_counter >= budget.stagnation_limit
178    }
179
180    fn best(&self) -> Option<&Chromosome> {
181        self.grid.iter().map(|(_, c)| c).max_by(|a, b| {
182            a.fitness
183                .partial_cmp(&b.fitness)
184                .unwrap_or(std::cmp::Ordering::Equal)
185        })
186    }
187
188    fn checkpoint(&self) -> Result<Vec<u8>, EvolutionError> {
189        serde_json::to_vec(self).map_err(EvolutionError::SerializationFailed)
190    }
191
192    fn restore(&mut self, bytes: &[u8]) -> Result<(), EvolutionError> {
193        if bytes.len() > crate::types::MAX_CHECKPOINT_BYTES {
194            return Err(EvolutionError::OversizedData {
195                context: "map_elites checkpoint restore".into(),
196                size: bytes.len(),
197                max: crate::types::MAX_CHECKPOINT_BYTES,
198            });
199        }
200        *self = serde_json::from_slice(bytes).map_err(EvolutionError::DeserializationFailed)?;
201        self.in_flight.clear();
202        Ok(())
203    }
204
205    /// Every grid cell holds a (descriptor, elite chromosome) pair —
206    /// the elite set IS the live population for diversity purposes.
207    fn population_snapshot(&self) -> Vec<Chromosome> {
208        self.grid.iter().map(|(_, c)| c.clone()).collect()
209    }
210
211    fn clone_box(&self) -> Box<dyn SearchAlgorithm> {
212        Box::new(self.clone())
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use rand::SeedableRng;
220
221    fn dummy_chromosome(encoding: &str, grammar: &str, content_type: &str) -> Chromosome {
222        Chromosome::new(vec![
223            ("encoding".into(), encoding.into()),
224            ("grammar_rule".into(), grammar.into()),
225            ("content_type".into(), content_type.into()),
226        ])
227    }
228
229    #[test]
230    fn initialize_populates_grid() {
231        let mut alg = MapElites::new();
232        let pool = GenePool::default_wafrift();
233        let mut rng = StdRng::seed_from_u64(1);
234        let pop = vec![
235            dummy_chromosome("UrlEncode", "sqli", "json"),
236            dummy_chromosome("CaseAlternation", "cmdi", "form"),
237        ];
238        alg.initialize(pop, &pool, &mut rng);
239        assert_eq!(alg.grid.len(), 2);
240    }
241
242    #[test]
243    fn request_evaluations_returns_unique_ids() {
244        let mut alg = MapElites::new();
245        let pool = GenePool::default_wafrift();
246        let mut rng = StdRng::seed_from_u64(2);
247        alg.initialize(
248            vec![dummy_chromosome("UrlEncode", "sqli", "json")],
249            &pool,
250            &mut rng,
251        );
252
253        let c1 = alg.request_evaluations(2, &mut rng);
254        let c2 = alg.request_evaluations(2, &mut rng);
255        let ids: Vec<_> = c1.iter().chain(c2.iter()).map(|c| c.id).collect();
256        let unique: std::collections::HashSet<_> = ids.iter().copied().collect();
257        assert_eq!(ids.len(), unique.len());
258    }
259
260    #[test]
261    fn submit_evaluation_inserts_into_grid() {
262        let mut alg = MapElites::new();
263        let pool = GenePool::default_wafrift();
264        let mut rng = StdRng::seed_from_u64(3);
265        alg.initialize(vec![], &pool, &mut rng);
266
267        let candidates = alg.request_evaluations(1, &mut rng);
268        let id = candidates[0].id;
269
270        alg.submit_evaluations(vec![(
271            id,
272            OracleVerdict {
273                passed: true,
274                status_delta: 1,
275                body_delta: 1,
276                latency_ms: 10,
277                confidence: 0.9,
278                triggered_rules: 0,
279            },
280        )]);
281
282        assert!(!alg.grid.is_empty());
283        assert!(alg.best().is_some());
284        assert!(alg.best().unwrap().fitness > 0.0);
285    }
286
287    #[test]
288    fn higher_fitness_replaces_existing_grid_cell() {
289        let mut alg = MapElites::new();
290        let pool = GenePool::default_wafrift();
291        let mut rng = StdRng::seed_from_u64(4);
292        let mut low = dummy_chromosome("UrlEncode", "sqli", "json");
293        low.fitness = 0.1;
294        alg.initialize(vec![low], &pool, &mut rng);
295
296        // Force a candidate with the same descriptor but higher fitness
297        let mut high = dummy_chromosome("UrlEncode", "sqli", "json");
298        high.fitness = 0.9;
299        alg.in_flight.insert(42, high);
300        alg.submit_evaluations(vec![(
301            42,
302            OracleVerdict {
303                passed: true,
304                status_delta: 1,
305                body_delta: 1,
306                latency_ms: 10,
307                confidence: 0.9,
308                triggered_rules: 0,
309            },
310        )]);
311
312        assert!(alg.best().unwrap().fitness > 0.5);
313    }
314
315    #[test]
316    fn checkpoint_roundtrip_clears_in_flight() {
317        let mut alg = MapElites::new();
318        let pool = GenePool::default_wafrift();
319        let mut rng = StdRng::seed_from_u64(5);
320        alg.initialize(
321            vec![dummy_chromosome("UrlEncode", "sqli", "json")],
322            &pool,
323            &mut rng,
324        );
325        let _ = alg.request_evaluations(3, &mut rng);
326        assert!(!alg.in_flight.is_empty());
327
328        let bytes = alg.checkpoint().expect("checkpoint must serialize");
329        let mut restored = MapElites::new();
330        restored.restore(&bytes).expect("restore must succeed");
331        assert!(restored.in_flight.is_empty());
332        assert_eq!(restored.grid.len(), alg.grid.len());
333    }
334
335    #[test]
336    fn should_terminate_respects_budget() {
337        let alg = MapElites::new();
338        let budget = Budget::default_wafrift();
339        let stats = SearchStats {
340            evaluations: budget.max_requests - 1,
341            ..SearchStats::default()
342        };
343        assert!(!alg.should_terminate(&stats, &budget));
344        let stats = SearchStats {
345            evaluations: budget.max_requests,
346            ..SearchStats::default()
347        };
348        assert!(alg.should_terminate(&stats, &budget));
349    }
350
351    #[test]
352    fn best_returns_none_for_empty_grid() {
353        let alg = MapElites::new();
354        assert!(alg.best().is_none());
355    }
356}