Skip to main content

wafrift_evolution/search/
map_elites.rs

1use crate::evolution::crossover::mutation::mutate_with_log;
2use crate::evolution::{Chromosome, GenePool, population::random_chromosome};
3use crate::lineage::Lineage;
4use crate::search::{EvalCandidate, SearchAlgorithm};
5use crate::types::{Budget, EvolutionError, OracleVerdict, SearchStats};
6use rand::Rng;
7use rand::rngs::StdRng;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Feature descriptor for MAP-Elites grid binning.
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub struct FeatureDescriptor {
14    pub encoding: String,
15    pub grammar: String,
16    pub content_type: String,
17}
18
19impl FeatureDescriptor {
20    #[must_use]
21    pub fn from_chromosome(chromosome: &Chromosome) -> Self {
22        Self {
23            encoding: chromosome.gene("encoding").unwrap_or("None").to_string(),
24            grammar: chromosome
25                .gene("grammar_rule")
26                .unwrap_or("None")
27                .to_string(),
28            content_type: chromosome
29                .gene("content_type")
30                .unwrap_or("None")
31                .to_string(),
32        }
33    }
34}
35
36/// MAP-Elites quality-diversity search.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct MapElites {
39    grid: HashMap<FeatureDescriptor, Chromosome>,
40    gene_pool: GenePool,
41    generation: u32,
42    eval_counter: u64,
43    #[serde(skip)]
44    in_flight: HashMap<u64, Chromosome>,
45}
46
47impl MapElites {
48    #[must_use]
49    pub fn new() -> Self {
50        Self {
51            grid: HashMap::new(),
52            gene_pool: GenePool::default_wafrift(),
53            generation: 0,
54            eval_counter: 0,
55            in_flight: HashMap::new(),
56        }
57    }
58
59    fn sample_parent(&self, rng: &mut StdRng) -> Option<Chromosome> {
60        if self.grid.is_empty() {
61            return None;
62        }
63        // 50% of the time sample from under-filled regions (random bin)
64        // 50% of the time sample uniformly from existing elites
65        if rng.gen_bool(0.5) {
66            let values: Vec<&Chromosome> = self.grid.values().collect();
67            Some(values[rng.gen_range(0..values.len())].clone())
68        } else {
69            // Try to fill a random feature combination
70            let encoding = self
71                .gene_pool
72                .random_value("encoding", rng)
73                .unwrap_or_else(|| "None".into());
74            let grammar = self
75                .gene_pool
76                .random_value("grammar_rule", rng)
77                .unwrap_or_else(|| "None".into());
78            let content_type = self
79                .gene_pool
80                .random_value("content_type", rng)
81                .unwrap_or_else(|| "None".into());
82            let descriptor = FeatureDescriptor {
83                encoding,
84                grammar,
85                content_type,
86            };
87            self.grid.get(&descriptor).cloned().or_else(|| {
88                let values: Vec<&Chromosome> = self.grid.values().collect();
89                Some(values[rng.gen_range(0..values.len())].clone())
90            })
91        }
92    }
93
94    fn generate_individual(&self, rng: &mut StdRng) -> Chromosome {
95        match self.sample_parent(rng) {
96            Some(parent) => {
97                let mut child = parent.clone();
98                let log = mutate_with_log(&mut child, &self.gene_pool, 0.25, rng);
99                child.lineage = Lineage::mutation(&parent, log, self.generation);
100                child
101            }
102            None => random_chromosome(&self.gene_pool, rng),
103        }
104    }
105}
106
107impl Default for MapElites {
108    fn default() -> Self {
109        Self::new()
110    }
111}
112
113impl SearchAlgorithm for MapElites {
114    fn name(&self) -> &'static str {
115        "map_elites"
116    }
117
118    fn initialize(&mut self, population: Vec<Chromosome>, gene_pool: &GenePool, _rng: &mut StdRng) {
119        self.gene_pool = gene_pool.clone();
120        self.grid.clear();
121        self.in_flight.clear();
122        for chromosome in population {
123            let descriptor = FeatureDescriptor::from_chromosome(&chromosome);
124            self.grid.entry(descriptor).or_insert(chromosome);
125        }
126    }
127
128    fn request_evaluations(&mut self, n: usize, rng: &mut StdRng) -> Vec<EvalCandidate> {
129        let mut out = Vec::with_capacity(n);
130        for _ in 0..n {
131            self.eval_counter += 1;
132            let candidate = self.generate_individual(rng);
133            self.in_flight.insert(self.eval_counter, candidate.clone());
134            out.push(EvalCandidate {
135                id: self.eval_counter,
136                chromosome: candidate,
137            });
138        }
139        out
140    }
141
142    fn submit_evaluations(&mut self, results: Vec<(u64, OracleVerdict)>) {
143        for (id, verdict) in results {
144            if let Some(mut candidate) = self.in_flight.remove(&id) {
145                candidate.record_verdict(&verdict);
146                let descriptor = FeatureDescriptor::from_chromosome(&candidate);
147                let should_insert = match self.grid.get(&descriptor) {
148                    Some(existing) => candidate.fitness > existing.fitness,
149                    None => true,
150                };
151                if should_insert {
152                    self.grid.insert(descriptor, candidate);
153                }
154            }
155        }
156        self.generation += 1;
157    }
158
159    fn should_terminate(&self, stats: &SearchStats, budget: &Budget) -> bool {
160        stats.evaluations >= budget.max_requests
161            || stats.generation >= budget.max_generations
162            || stats.stagnation_counter >= budget.stagnation_limit
163    }
164
165    fn best(&self) -> Option<&Chromosome> {
166        self.grid.values().max_by(|a, b| {
167            a.fitness
168                .partial_cmp(&b.fitness)
169                .unwrap_or(std::cmp::Ordering::Equal)
170        })
171    }
172
173    fn checkpoint(&self) -> Result<Vec<u8>, EvolutionError> {
174        serde_json::to_vec(self).map_err(|e| EvolutionError::SerializationFailed(e.to_string()))
175    }
176
177    fn restore(&mut self, bytes: &[u8]) -> Result<(), EvolutionError> {
178        *self = serde_json::from_slice(bytes)
179            .map_err(|e| EvolutionError::DeserializationFailed(e.to_string()))?;
180        self.in_flight.clear();
181        Ok(())
182    }
183}