Skip to main content

tensorlogic_train/nas/
evolution.rs

1//! Regularized (aging) evolution for neural architecture search.
2//!
3//! Implements the algorithm from:
4//!
5//! > Real et al. (2019) "Regularized Evolution for Image Classifier Architecture Search"
6//! > <https://arxiv.org/abs/1802.01548>
7//!
8//! The key idea is a **cyclic population** (VecDeque): the oldest individual is
9//! evicted whenever the population exceeds `population_size`, even if it happens
10//! to be the highest-scoring member.  This aging pressure prevents premature
11//! convergence and keeps the search exploratory.
12//!
13//! The **ask/tell** API lets the caller supply evaluation scores without any
14//! stored objective closure, matching the hyperparameter-optimization convention
15//! used elsewhere in this crate.
16
17use std::collections::VecDeque;
18
19use crate::error::{TrainError, TrainResult};
20
21use super::sampler::ArchSampler;
22use super::space::{ArchSearchSpace, Architecture};
23
24// ─── NasResult ──────────────────────────────────────────────────────────────
25
26/// Summary of a completed (or in-progress) NAS run.
27#[derive(Debug, Clone)]
28pub struct NasResult {
29    /// Best architecture found at the time of the call.
30    pub best: Architecture,
31    /// Score of the best architecture (higher = better).
32    pub best_score: f64,
33    /// All evaluated (architecture, score) pairs in evaluation order.
34    pub history: Vec<(Architecture, f64)>,
35}
36
37// ─── RegularizedEvolution ───────────────────────────────────────────────────
38
39/// Regularized (aging) evolution NAS searcher.
40///
41/// Uses a cyclic population with tournament selection and single-step mutation.
42pub struct RegularizedEvolution {
43    /// Cyclic population of (architecture, score) pairs ordered by age (oldest first).
44    population: VecDeque<(Architecture, f64)>,
45    /// Target population size.
46    pub population_size: usize,
47    /// Number of random members sampled per tournament.
48    pub tournament_size: usize,
49    /// Architecture sampler (owns the search space and RNG).
50    sampler: ArchSampler,
51    /// Full evaluation history in tell() order.
52    history: Vec<(Architecture, f64)>,
53    /// True once the population has been filled for the first time.
54    filled: bool,
55}
56
57impl RegularizedEvolution {
58    /// Create a new regularized evolution searcher.
59    ///
60    /// # Arguments
61    ///
62    /// * `space` - Architecture search space.
63    /// * `population_size` - Target size of the cyclic population (≥ 2).
64    /// * `tournament_size` - Number of randomly drawn population members per tournament (≤ population_size).
65    /// * `seed` - RNG seed for reproducibility.
66    ///
67    /// # Errors
68    ///
69    /// Returns [`TrainError::InvalidParameter`] when:
70    /// - `population_size` < 2
71    /// - `tournament_size` == 0 or `tournament_size` > `population_size`
72    pub fn new(
73        space: ArchSearchSpace,
74        population_size: usize,
75        tournament_size: usize,
76        seed: u64,
77    ) -> TrainResult<Self> {
78        if population_size < 2 {
79            return Err(TrainError::InvalidParameter(format!(
80                "population_size ({population_size}) must be ≥ 2"
81            )));
82        }
83        if tournament_size == 0 {
84            return Err(TrainError::InvalidParameter(
85                "tournament_size must be ≥ 1".to_string(),
86            ));
87        }
88        if tournament_size > population_size {
89            return Err(TrainError::InvalidParameter(format!(
90                "tournament_size ({tournament_size}) must be ≤ population_size ({population_size})"
91            )));
92        }
93
94        Ok(Self {
95            population: VecDeque::new(),
96            population_size,
97            tournament_size,
98            sampler: ArchSampler::new(space, seed),
99            history: Vec::new(),
100            filled: false,
101        })
102    }
103
104    /// Ask for the next architecture to evaluate.
105    ///
106    /// * While the population has fewer than `population_size` members (warm-up
107    ///   phase), returns a freshly sampled random architecture.
108    /// * Once the population is full, performs tournament selection among
109    ///   `tournament_size` randomly chosen members and returns the winner
110    ///   mutated by one step.
111    pub fn ask(&mut self) -> TrainResult<Architecture> {
112        if !self.filled {
113            // Warm-up: fill the population with random architectures.
114            self.sampler.random_architecture()
115        } else {
116            // Tournament selection + mutation.
117            let winner = self.tournament_select()?;
118            self.sampler.mutate(&winner)
119        }
120    }
121
122    /// Tell the result of evaluating an architecture.
123    ///
124    /// Appends `(arch, score)` to the full history and pushes it into the
125    /// population (at the back / newest position).  When the population exceeds
126    /// `population_size`, the oldest entry (front) is evicted.
127    pub fn tell(&mut self, arch: Architecture, score: f64) {
128        self.history.push((arch.clone(), score));
129        self.population.push_back((arch, score));
130        if self.population.len() >= self.population_size {
131            self.filled = true;
132        }
133        if self.population.len() > self.population_size {
134            self.population.pop_front();
135        }
136    }
137
138    /// Return a reference to the highest-scored (architecture, score) pair
139    /// currently in the population, or `None` if the population is empty.
140    pub fn best(&self) -> Option<&(Architecture, f64)> {
141        self.population
142            .iter()
143            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
144    }
145
146    /// Produce a [`NasResult`] summarising the current search state.
147    ///
148    /// Returns `None` if no evaluations have been recorded yet.
149    pub fn result(&self) -> Option<NasResult> {
150        // best over the *full* history (not just surviving population members)
151        let (best_arch, best_score) = self
152            .history
153            .iter()
154            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))?;
155
156        Some(NasResult {
157            best: best_arch.clone(),
158            best_score: *best_score,
159            history: self.history.clone(),
160        })
161    }
162
163    // ── private helpers ──────────────────────────────────────────────────
164
165    /// Draw `tournament_size` members uniformly at random (without replacement
166    /// when possible) from the population and return a clone of the one with
167    /// the highest score.
168    fn tournament_select(&mut self) -> TrainResult<Architecture> {
169        let pop_len = self.population.len();
170        if pop_len == 0 {
171            return Err(TrainError::InvalidParameter(
172                "tournament_select called on empty population".to_string(),
173            ));
174        }
175
176        // Collect `tournament_size` distinct indices (Fisher-Yates partial shuffle).
177        let sample_size = self.tournament_size.min(pop_len);
178        let mut indices: Vec<usize> = (0..pop_len).collect();
179
180        // Partial Fisher-Yates: bring `sample_size` elements to the front.
181        // For step i, pick j uniformly from [i, pop_len) and swap.
182        for i in 0..sample_size {
183            // gen_range_usize(i, pop_len) returns a value in [i, pop_len)
184            let j = self.sampler.gen_range_usize(i, pop_len);
185            indices.swap(i, j);
186        }
187
188        // Best among the sample.
189        let best_idx = indices[..sample_size]
190            .iter()
191            .max_by(|&&a, &&b| {
192                self.population[a]
193                    .1
194                    .partial_cmp(&self.population[b].1)
195                    .unwrap_or(std::cmp::Ordering::Equal)
196            })
197            .copied()
198            .ok_or_else(|| {
199                TrainError::InvalidParameter("tournament sample was empty".to_string())
200            })?;
201
202        Ok(self.population[best_idx].0.clone())
203    }
204}