wafrift_evolution/search/
tabu.rs1use crate::evolution::crossover::mutation::mutate_with_log;
2use crate::evolution::{Chromosome, GenePool, population::random_chromosome};
3use crate::lineage::Lineage;
4use crate::search::{EvalCandidate, SearchAlgorithm};
5use crate::types::{Budget, EvolutionError, OracleVerdict, SearchStats};
6use rand::rngs::StdRng;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashSet, VecDeque};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TabuSearch {
13 current: Chromosome,
14 best: Chromosome,
15 gene_pool: GenePool,
16 generation: u32,
17 eval_counter: u64,
18 tabu_list: VecDeque<u64>,
19 tabu_tenure: usize,
20 tabu_set: HashSet<u64>,
21}
22
23impl TabuSearch {
24 #[must_use]
25 pub fn new(tabu_tenure: usize) -> Self {
26 Self {
27 current: Chromosome::new(vec![]),
28 best: Chromosome::new(vec![]),
29 gene_pool: GenePool::default_wafrift(),
30 generation: 0,
31 eval_counter: 0,
32 tabu_list: VecDeque::new(),
33 tabu_tenure,
34 tabu_set: HashSet::new(),
35 }
36 }
37
38 fn neighbor(&self, rng: &mut StdRng) -> Chromosome {
39 let mut child = self.current.clone();
40 let log = mutate_with_log(&mut child, &self.gene_pool, 0.25, rng);
41 child.lineage = Lineage::mutation(&self.current, log, self.generation);
42 child
43 }
44
45 fn add_tabu(&mut self, hash: u64) {
46 if self.tabu_set.insert(hash) {
47 self.tabu_list.push_back(hash);
48 }
49 while self.tabu_list.len() > self.tabu_tenure {
50 if let Some(old) = self.tabu_list.pop_front() {
51 self.tabu_set.remove(&old);
52 }
53 }
54 }
55}
56
57impl Default for TabuSearch {
58 fn default() -> Self {
59 Self::new(20)
60 }
61}
62
63impl SearchAlgorithm for TabuSearch {
64 fn name(&self) -> &'static str {
65 "tabu_search"
66 }
67
68 fn initialize(&mut self, population: Vec<Chromosome>, gene_pool: &GenePool, rng: &mut StdRng) {
69 self.gene_pool = gene_pool.clone();
70 if let Some(best) = population.iter().max_by(|a, b| {
71 a.fitness
72 .partial_cmp(&b.fitness)
73 .unwrap_or(std::cmp::Ordering::Equal)
74 }) {
75 self.current = best.clone();
76 self.best = best.clone();
77 self.add_tabu(best.hash());
78 } else {
79 self.current = random_chromosome(gene_pool, rng);
80 self.best = self.current.clone();
81 self.add_tabu(self.current.hash());
82 }
83 }
84
85 fn request_evaluations(&mut self, n: usize, rng: &mut StdRng) -> Vec<EvalCandidate> {
86 let mut out = Vec::with_capacity(n);
87 let mut attempts = 0;
88 while out.len() < n && attempts < n * 10 {
89 attempts += 1;
90 self.eval_counter += 1;
91 let candidate = self.neighbor(rng);
92 let hash = candidate.hash();
93 let is_tabu = self.tabu_set.contains(&hash);
95 let beats_best = candidate.fitness > self.best.fitness;
96 if !is_tabu || beats_best {
97 out.push(EvalCandidate {
98 id: self.eval_counter,
99 chromosome: candidate,
100 });
101 }
102 }
103 out
104 }
105
106 fn submit_evaluations(&mut self, results: Vec<(u64, OracleVerdict)>) {
107 for (_id, verdict) in results {
108 let mut candidate = self.current.clone();
109 candidate.record_verdict(&verdict);
110 self.add_tabu(candidate.hash());
111 if candidate.fitness >= self.current.fitness {
112 self.current = candidate;
113 if self.current.fitness > self.best.fitness {
114 self.best = self.current.clone();
115 }
116 }
117 }
118 self.generation += 1;
119 }
120
121 fn should_terminate(&self, stats: &SearchStats, budget: &Budget) -> bool {
122 stats.evaluations >= budget.max_requests
123 || stats.generation >= budget.max_generations
124 || stats.stagnation_counter >= budget.stagnation_limit
125 }
126
127 fn best(&self) -> Option<&Chromosome> {
128 Some(&self.best)
129 }
130
131 fn checkpoint(&self) -> Result<Vec<u8>, EvolutionError> {
132 serde_json::to_vec(self).map_err(|e| EvolutionError::SerializationFailed(e.to_string()))
133 }
134
135 fn restore(&mut self, bytes: &[u8]) -> Result<(), EvolutionError> {
136 *self = serde_json::from_slice(bytes)
137 .map_err(|e| EvolutionError::DeserializationFailed(e.to_string()))?;
138 Ok(())
139 }
140}