Skip to main content

wafrift_evolution/search/
novelty.rs

1use crate::evolution::crossover::mutation::mutate_with_log;
2use crate::evolution::{Chromosome, GenePool, population::random_chromosome};
3use crate::lineage::Lineage;
4use crate::search::{EvalCandidate, SearchAlgorithm};
5use crate::types::{Budget, EvolutionError, OracleVerdict, SearchStats};
6use rand::Rng;
7use rand::rngs::StdRng;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Novelty search with k-NN behavioral distance archive.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct NoveltySearch {
14    population: Vec<Chromosome>,
15    archive: Vec<Chromosome>,
16    gene_pool: GenePool,
17    generation: u32,
18    eval_counter: u64,
19    k: usize,
20    threshold: f64,
21    #[serde(skip)]
22    in_flight: HashMap<u64, Chromosome>,
23}
24
25impl NoveltySearch {
26    #[must_use]
27    pub fn new(k: usize, threshold: f64) -> Self {
28        Self {
29            population: Vec::new(),
30            archive: Vec::new(),
31            gene_pool: GenePool::default_wafrift(),
32            generation: 0,
33            eval_counter: 0,
34            k,
35            threshold,
36            in_flight: HashMap::new(),
37        }
38    }
39
40    fn phenotypic_distance(a: &Chromosome, b: &Chromosome) -> f64 {
41        let genes_a: Vec<_> = a.genes.iter().map(|(n, v)| format!("{n}={v}")).collect();
42        let genes_b: Vec<_> = b.genes.iter().map(|(n, v)| format!("{n}={v}")).collect();
43        levenshtein_distance(&genes_a.join("|"), &genes_b.join("|")) as f64
44            / (genes_a.len().max(genes_b.len()).max(1) as f64)
45    }
46
47    fn novelty_score(&self, chromosome: &Chromosome) -> f64 {
48        let mut neighbors: Vec<f64> = self
49            .archive
50            .iter()
51            .chain(self.population.iter())
52            .map(|other| Self::phenotypic_distance(chromosome, other))
53            .collect();
54        neighbors.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
55        neighbors.truncate(self.k);
56        if neighbors.is_empty() {
57            return f64::INFINITY;
58        }
59        neighbors.iter().sum::<f64>() / neighbors.len() as f64
60    }
61
62    fn generate_individual(&self, rng: &mut StdRng) -> Chromosome {
63        if self.population.is_empty() {
64            return random_chromosome(&self.gene_pool, rng);
65        }
66        let parent = &self.population[rng.gen_range(0..self.population.len())];
67        let mut child = parent.clone();
68        let log = mutate_with_log(&mut child, &self.gene_pool, 0.3, rng);
69        child.lineage = Lineage::mutation(parent, log, self.generation);
70        child
71    }
72}
73
74impl Default for NoveltySearch {
75    fn default() -> Self {
76        Self::new(15, 0.3)
77    }
78}
79
80impl SearchAlgorithm for NoveltySearch {
81    fn name(&self) -> &'static str {
82        "novelty_search"
83    }
84
85    fn initialize(&mut self, population: Vec<Chromosome>, gene_pool: &GenePool, _rng: &mut StdRng) {
86        self.gene_pool = gene_pool.clone();
87        self.population = population;
88        self.archive.clear();
89        self.in_flight.clear();
90    }
91
92    fn request_evaluations(&mut self, n: usize, rng: &mut StdRng) -> Vec<EvalCandidate> {
93        let mut out = Vec::with_capacity(n);
94        for _ in 0..n {
95            self.eval_counter += 1;
96            let candidate = self.generate_individual(rng);
97            self.in_flight.insert(self.eval_counter, candidate.clone());
98            out.push(EvalCandidate {
99                id: self.eval_counter,
100                chromosome: candidate,
101            });
102        }
103        out
104    }
105
106    fn submit_evaluations(&mut self, results: Vec<(u64, OracleVerdict)>) {
107        let mut evaluated: Vec<Chromosome> = Vec::with_capacity(results.len());
108        for (id, verdict) in results {
109            if let Some(mut candidate) = self.in_flight.remove(&id) {
110                candidate.record_verdict(&verdict);
111                evaluated.push(candidate);
112            }
113        }
114
115        // Add to archive based on novelty
116        for candidate in evaluated {
117            let score = self.novelty_score(&candidate);
118            if score > self.threshold {
119                self.archive.push(candidate.clone());
120            }
121            self.population.push(candidate);
122        }
123
124        // Cull population to reasonable size, keeping most novel
125        if self.population.len() > 100 {
126            let temp: Vec<Chromosome> = self.population.drain(..).collect();
127            let mut scored: Vec<(f64, Chromosome)> = temp
128                .into_iter()
129                .map(|c| (self.novelty_score(&c), c))
130                .collect();
131            scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
132            scored.truncate(100);
133            self.population = scored.into_iter().map(|(_, c)| c).collect();
134        }
135
136        self.generation += 1;
137    }
138
139    fn should_terminate(&self, stats: &SearchStats, budget: &Budget) -> bool {
140        stats.evaluations >= budget.max_requests
141            || stats.generation >= budget.max_generations
142            || stats.stagnation_counter >= budget.stagnation_limit
143    }
144
145    fn best(&self) -> Option<&Chromosome> {
146        self.population
147            .iter()
148            .chain(self.archive.iter())
149            .max_by(|a, b| {
150                a.fitness
151                    .partial_cmp(&b.fitness)
152                    .unwrap_or(std::cmp::Ordering::Equal)
153            })
154    }
155
156    fn checkpoint(&self) -> Result<Vec<u8>, EvolutionError> {
157        serde_json::to_vec(self).map_err(|e| EvolutionError::SerializationFailed(e.to_string()))
158    }
159
160    fn restore(&mut self, bytes: &[u8]) -> Result<(), EvolutionError> {
161        *self = serde_json::from_slice(bytes)
162            .map_err(|e| EvolutionError::DeserializationFailed(e.to_string()))?;
163        self.in_flight.clear();
164        Ok(())
165    }
166}
167
168fn levenshtein_distance(a: &str, b: &str) -> usize {
169    let a_chars: Vec<char> = a.chars().collect();
170    let b_chars: Vec<char> = b.chars().collect();
171    let mut prev = vec![0; b_chars.len() + 1];
172    let mut curr = vec![0; b_chars.len() + 1];
173    for (j, slot) in prev.iter_mut().enumerate().take(b_chars.len() + 1) {
174        *slot = j;
175    }
176    for i in 1..=a_chars.len() {
177        curr[0] = i;
178        for j in 1..=b_chars.len() {
179            let cost = if a_chars[i - 1] == b_chars[j - 1] {
180                0
181            } else {
182                1
183            };
184            curr[j] = (curr[j - 1] + 1).min(prev[j] + 1).min(prev[j - 1] + cost);
185        }
186        std::mem::swap(&mut prev, &mut curr);
187    }
188    prev[b_chars.len()]
189}