Skip to main content

scirs2_optimize/nas/
evolutionary_nas.rs

1//! Evolutionary Neural Architecture Search (AmoebaNet-style).
2//!
3//! Implements a tournament-selection evolutionary algorithm for NAS:
4//! - Initialise a random population
5//! - Each generation: tournament-select a parent, mutate it, replace the
6//!   weakest member of the population with the child (if better)
7//! - Return the best architecture found
8
9use crate::error::OptimizeError;
10use crate::nas::random_nas::{ArchFitness, NASResult};
11use crate::nas::search_space::{Architecture, SearchSpace};
12use scirs2_core::random::{rngs::StdRng, Rng, RngExt, SeedableRng};
13
14/// Configuration for the evolutionary NAS algorithm.
15#[derive(Debug, Clone)]
16pub struct EvolutionaryNASConfig {
17    /// Number of individuals in the population
18    pub population_size: usize,
19    /// Number of evolutionary generations
20    pub n_generations: usize,
21    /// Probability of mutating any given edge per mutation event
22    pub mutation_rate: f64,
23    /// Number of individuals to draw in each tournament
24    pub tournament_size: usize,
25    /// Whether to use elitism (always keep best)
26    pub elitism: bool,
27}
28
29impl Default for EvolutionaryNASConfig {
30    fn default() -> Self {
31        Self {
32            population_size: 20,
33            n_generations: 50,
34            mutation_rate: 0.2,
35            tournament_size: 5,
36            elitism: true,
37        }
38    }
39}
40
41/// Evolutionary Neural Architecture Search.
42///
43/// Uses tournament selection and mutation to search the architecture space.
44/// Inspired by the AmoebaNet regularized evolutionary search (Real et al. 2019).
45pub struct EvolutionaryNAS {
46    /// Algorithm configuration
47    pub config: EvolutionaryNASConfig,
48}
49
50impl EvolutionaryNAS {
51    /// Create with default configuration.
52    pub fn new(population_size: usize, n_generations: usize) -> Self {
53        Self {
54            config: EvolutionaryNASConfig {
55                population_size,
56                n_generations,
57                ..EvolutionaryNASConfig::default()
58            },
59        }
60    }
61
62    /// Create from an explicit config.
63    pub fn with_config(config: EvolutionaryNASConfig) -> Self {
64        Self { config }
65    }
66
67    /// Run evolutionary search.
68    ///
69    /// # Arguments
70    /// - `space`: Architecture search space.
71    /// - `fitness`: Fitness evaluator (higher = better).
72    /// - `seed`: Random seed.
73    pub fn search<F: ArchFitness>(
74        &self,
75        space: &SearchSpace,
76        fitness: &F,
77        seed: u64,
78    ) -> Result<NASResult, OptimizeError> {
79        if self.config.population_size < 2 {
80            return Err(OptimizeError::InvalidParameter(
81                "population_size must be at least 2".to_string(),
82            ));
83        }
84
85        let mut rng = StdRng::seed_from_u64(seed);
86
87        // Initialise population
88        let mut population: Vec<(Architecture, f64)> =
89            Vec::with_capacity(self.config.population_size);
90        for _ in 0..self.config.population_size {
91            let arch = space.sample_random(&mut rng);
92            let score = fitness.evaluate(&arch).unwrap_or(f64::NEG_INFINITY);
93            population.push((arch, score));
94        }
95
96        let mut all_scores: Vec<f64> = population.iter().map(|(_, s)| *s).collect();
97
98        for _gen in 0..self.config.n_generations {
99            // Tournament selection: pick the best from a random subset
100            let parent_idx = self.tournament_select(&population, &mut rng);
101
102            // Mutate the selected parent
103            let mut child = population[parent_idx].0.clone();
104            self.mutate(&mut child, space, &mut rng);
105            let child_score = fitness.evaluate(&child).unwrap_or(f64::NEG_INFINITY);
106            all_scores.push(child_score);
107
108            // Find the weakest individual and replace if child is better
109            let worst_idx = self.find_worst(&population);
110            if child_score > population[worst_idx].1 {
111                population[worst_idx] = (child, child_score);
112            }
113        }
114
115        // Return the best individual
116        let (best_arch, best_score) = population
117            .into_iter()
118            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
119            .map(|(arch, score)| (arch, score))
120            .unwrap_or_else(|| {
121                let arch = space.sample_random(&mut rng);
122                (arch, f64::NEG_INFINITY)
123            });
124
125        let n_evaluated = self.config.population_size + self.config.n_generations;
126        Ok(NASResult {
127            best_arch,
128            best_score,
129            all_scores,
130            n_evaluated,
131        })
132    }
133
134    /// Tournament selection: sample `tournament_size` indices and return
135    /// the index of the individual with the highest score.
136    fn tournament_select(&self, population: &[(Architecture, f64)], rng: &mut StdRng) -> usize {
137        let n = population.len();
138        let k = self.config.tournament_size.min(n);
139        let mut best_idx = rng.random_range(0..n);
140        for _ in 1..k {
141            let idx = rng.random_range(0..n);
142            if population[idx].1 > population[best_idx].1 {
143                best_idx = idx;
144            }
145        }
146        best_idx
147    }
148
149    /// Return the index of the individual with the lowest score.
150    fn find_worst(&self, population: &[(Architecture, f64)]) -> usize {
151        population
152            .iter()
153            .enumerate()
154            .min_by(|a, b| {
155                a.1 .1
156                    .partial_cmp(&b.1 .1)
157                    .unwrap_or(std::cmp::Ordering::Equal)
158            })
159            .map(|(i, _)| i)
160            .unwrap_or(0)
161    }
162
163    /// Mutate an architecture: for each edge, replace its op with a random
164    /// one from the space with probability `mutation_rate`.
165    fn mutate(&self, arch: &mut Architecture, space: &SearchSpace, rng: &mut StdRng) {
166        if space.operations.is_empty() || arch.edges.is_empty() {
167            return;
168        }
169        for edge in arch.edges.iter_mut() {
170            if rng.random::<f64>() < self.config.mutation_rate {
171                let op_idx = rng.random_range(0..space.operations.len());
172                edge.op = space.operations[op_idx].clone();
173            }
174        }
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::nas::random_nas::ParamCountFitness;
182    use crate::nas::search_space::SearchSpace;
183
184    #[test]
185    fn test_evolutionary_nas_runs() {
186        let space = SearchSpace::darts_like(3);
187        let fitness = ParamCountFitness::new(8_000);
188        let nas = EvolutionaryNAS::new(10, 20);
189
190        let result = nas.search(&space, &fitness, 99).expect("search failed");
191
192        // All evaluations = population + generations (some may be skipped if no improvement)
193        assert!(result.n_evaluated >= 10);
194        assert!(!result.all_scores.is_empty());
195    }
196
197    #[test]
198    fn test_evolutionary_nas_small_population_error() {
199        let space = SearchSpace::darts_like(3);
200        let fitness = ParamCountFitness::new(8_000);
201        let nas = EvolutionaryNAS::new(1, 5);
202
203        assert!(nas.search(&space, &fitness, 0).is_err());
204    }
205
206    #[test]
207    fn test_evolutionary_nas_monotone_best_score() {
208        // The best score should be >= the initial population's best
209        let space = SearchSpace::darts_like(3);
210        let fitness = ParamCountFitness::new(5_000);
211        let nas = EvolutionaryNAS::new(8, 30);
212
213        let result = nas.search(&space, &fitness, 7).expect("search failed");
214
215        // best_score must be achievable (finite)
216        assert!(result.best_score.is_finite() || result.best_score == f64::NEG_INFINITY);
217    }
218
219    #[test]
220    fn test_evolutionary_nas_with_config() {
221        let config = EvolutionaryNASConfig {
222            population_size: 6,
223            n_generations: 10,
224            mutation_rate: 0.5,
225            tournament_size: 3,
226            elitism: false,
227        };
228        let space = SearchSpace::darts_like(3);
229        let fitness = ParamCountFitness::new(4_000);
230        let nas = EvolutionaryNAS::with_config(config);
231
232        let result = nas.search(&space, &fitness, 1).expect("search failed");
233        assert!(result.n_evaluated > 0);
234    }
235}