tensorlogic_train/nas/evolution.rs
1//! Regularized (aging) evolution for neural architecture search.
2//!
3//! Implements the algorithm from:
4//!
5//! > Real et al. (2019) "Regularized Evolution for Image Classifier Architecture Search"
6//! > <https://arxiv.org/abs/1802.01548>
7//!
8//! The key idea is a **cyclic population** (VecDeque): the oldest individual is
9//! evicted whenever the population exceeds `population_size`, even if it happens
10//! to be the highest-scoring member. This aging pressure prevents premature
11//! convergence and keeps the search exploratory.
12//!
13//! The **ask/tell** API lets the caller supply evaluation scores without any
14//! stored objective closure, matching the hyperparameter-optimization convention
15//! used elsewhere in this crate.
16
17use std::collections::VecDeque;
18
19use crate::error::{TrainError, TrainResult};
20
21use super::sampler::ArchSampler;
22use super::space::{ArchSearchSpace, Architecture};
23
24// ─── NasResult ──────────────────────────────────────────────────────────────
25
26/// Summary of a completed (or in-progress) NAS run.
27#[derive(Debug, Clone)]
28pub struct NasResult {
29 /// Best architecture found at the time of the call.
30 pub best: Architecture,
31 /// Score of the best architecture (higher = better).
32 pub best_score: f64,
33 /// All evaluated (architecture, score) pairs in evaluation order.
34 pub history: Vec<(Architecture, f64)>,
35}
36
37// ─── RegularizedEvolution ───────────────────────────────────────────────────
38
39/// Regularized (aging) evolution NAS searcher.
40///
41/// Uses a cyclic population with tournament selection and single-step mutation.
42pub struct RegularizedEvolution {
43 /// Cyclic population of (architecture, score) pairs ordered by age (oldest first).
44 population: VecDeque<(Architecture, f64)>,
45 /// Target population size.
46 pub population_size: usize,
47 /// Number of random members sampled per tournament.
48 pub tournament_size: usize,
49 /// Architecture sampler (owns the search space and RNG).
50 sampler: ArchSampler,
51 /// Full evaluation history in tell() order.
52 history: Vec<(Architecture, f64)>,
53 /// True once the population has been filled for the first time.
54 filled: bool,
55}
56
57impl RegularizedEvolution {
58 /// Create a new regularized evolution searcher.
59 ///
60 /// # Arguments
61 ///
62 /// * `space` - Architecture search space.
63 /// * `population_size` - Target size of the cyclic population (≥ 2).
64 /// * `tournament_size` - Number of randomly drawn population members per tournament (≤ population_size).
65 /// * `seed` - RNG seed for reproducibility.
66 ///
67 /// # Errors
68 ///
69 /// Returns [`TrainError::InvalidParameter`] when:
70 /// - `population_size` < 2
71 /// - `tournament_size` == 0 or `tournament_size` > `population_size`
72 pub fn new(
73 space: ArchSearchSpace,
74 population_size: usize,
75 tournament_size: usize,
76 seed: u64,
77 ) -> TrainResult<Self> {
78 if population_size < 2 {
79 return Err(TrainError::InvalidParameter(format!(
80 "population_size ({population_size}) must be ≥ 2"
81 )));
82 }
83 if tournament_size == 0 {
84 return Err(TrainError::InvalidParameter(
85 "tournament_size must be ≥ 1".to_string(),
86 ));
87 }
88 if tournament_size > population_size {
89 return Err(TrainError::InvalidParameter(format!(
90 "tournament_size ({tournament_size}) must be ≤ population_size ({population_size})"
91 )));
92 }
93
94 Ok(Self {
95 population: VecDeque::new(),
96 population_size,
97 tournament_size,
98 sampler: ArchSampler::new(space, seed),
99 history: Vec::new(),
100 filled: false,
101 })
102 }
103
104 /// Ask for the next architecture to evaluate.
105 ///
106 /// * While the population has fewer than `population_size` members (warm-up
107 /// phase), returns a freshly sampled random architecture.
108 /// * Once the population is full, performs tournament selection among
109 /// `tournament_size` randomly chosen members and returns the winner
110 /// mutated by one step.
111 pub fn ask(&mut self) -> TrainResult<Architecture> {
112 if !self.filled {
113 // Warm-up: fill the population with random architectures.
114 self.sampler.random_architecture()
115 } else {
116 // Tournament selection + mutation.
117 let winner = self.tournament_select()?;
118 self.sampler.mutate(&winner)
119 }
120 }
121
122 /// Tell the result of evaluating an architecture.
123 ///
124 /// Appends `(arch, score)` to the full history and pushes it into the
125 /// population (at the back / newest position). When the population exceeds
126 /// `population_size`, the oldest entry (front) is evicted.
127 pub fn tell(&mut self, arch: Architecture, score: f64) {
128 self.history.push((arch.clone(), score));
129 self.population.push_back((arch, score));
130 if self.population.len() >= self.population_size {
131 self.filled = true;
132 }
133 if self.population.len() > self.population_size {
134 self.population.pop_front();
135 }
136 }
137
138 /// Return a reference to the highest-scored (architecture, score) pair
139 /// currently in the population, or `None` if the population is empty.
140 pub fn best(&self) -> Option<&(Architecture, f64)> {
141 self.population
142 .iter()
143 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
144 }
145
146 /// Produce a [`NasResult`] summarising the current search state.
147 ///
148 /// Returns `None` if no evaluations have been recorded yet.
149 pub fn result(&self) -> Option<NasResult> {
150 // best over the *full* history (not just surviving population members)
151 let (best_arch, best_score) = self
152 .history
153 .iter()
154 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))?;
155
156 Some(NasResult {
157 best: best_arch.clone(),
158 best_score: *best_score,
159 history: self.history.clone(),
160 })
161 }
162
163 // ── private helpers ──────────────────────────────────────────────────
164
165 /// Draw `tournament_size` members uniformly at random (without replacement
166 /// when possible) from the population and return a clone of the one with
167 /// the highest score.
168 fn tournament_select(&mut self) -> TrainResult<Architecture> {
169 let pop_len = self.population.len();
170 if pop_len == 0 {
171 return Err(TrainError::InvalidParameter(
172 "tournament_select called on empty population".to_string(),
173 ));
174 }
175
176 // Collect `tournament_size` distinct indices (Fisher-Yates partial shuffle).
177 let sample_size = self.tournament_size.min(pop_len);
178 let mut indices: Vec<usize> = (0..pop_len).collect();
179
180 // Partial Fisher-Yates: bring `sample_size` elements to the front.
181 // For step i, pick j uniformly from [i, pop_len) and swap.
182 for i in 0..sample_size {
183 // gen_range_usize(i, pop_len) returns a value in [i, pop_len)
184 let j = self.sampler.gen_range_usize(i, pop_len);
185 indices.swap(i, j);
186 }
187
188 // Best among the sample.
189 let best_idx = indices[..sample_size]
190 .iter()
191 .max_by(|&&a, &&b| {
192 self.population[a]
193 .1
194 .partial_cmp(&self.population[b].1)
195 .unwrap_or(std::cmp::Ordering::Equal)
196 })
197 .copied()
198 .ok_or_else(|| {
199 TrainError::InvalidParameter("tournament sample was empty".to_string())
200 })?;
201
202 Ok(self.population[best_idx].0.clone())
203 }
204}