scirs2_optimize/nas/
evolutionary_nas.rs1use crate::error::OptimizeError;
10use crate::nas::random_nas::{ArchFitness, NASResult};
11use crate::nas::search_space::{Architecture, SearchSpace};
12use scirs2_core::random::{rngs::StdRng, Rng, RngExt, SeedableRng};
13
14#[derive(Debug, Clone)]
16pub struct EvolutionaryNASConfig {
17 pub population_size: usize,
19 pub n_generations: usize,
21 pub mutation_rate: f64,
23 pub tournament_size: usize,
25 pub elitism: bool,
27}
28
29impl Default for EvolutionaryNASConfig {
30 fn default() -> Self {
31 Self {
32 population_size: 20,
33 n_generations: 50,
34 mutation_rate: 0.2,
35 tournament_size: 5,
36 elitism: true,
37 }
38 }
39}
40
41pub struct EvolutionaryNAS {
46 pub config: EvolutionaryNASConfig,
48}
49
50impl EvolutionaryNAS {
51 pub fn new(population_size: usize, n_generations: usize) -> Self {
53 Self {
54 config: EvolutionaryNASConfig {
55 population_size,
56 n_generations,
57 ..EvolutionaryNASConfig::default()
58 },
59 }
60 }
61
62 pub fn with_config(config: EvolutionaryNASConfig) -> Self {
64 Self { config }
65 }
66
67 pub fn search<F: ArchFitness>(
74 &self,
75 space: &SearchSpace,
76 fitness: &F,
77 seed: u64,
78 ) -> Result<NASResult, OptimizeError> {
79 if self.config.population_size < 2 {
80 return Err(OptimizeError::InvalidParameter(
81 "population_size must be at least 2".to_string(),
82 ));
83 }
84
85 let mut rng = StdRng::seed_from_u64(seed);
86
87 let mut population: Vec<(Architecture, f64)> =
89 Vec::with_capacity(self.config.population_size);
90 for _ in 0..self.config.population_size {
91 let arch = space.sample_random(&mut rng);
92 let score = fitness.evaluate(&arch).unwrap_or(f64::NEG_INFINITY);
93 population.push((arch, score));
94 }
95
96 let mut all_scores: Vec<f64> = population.iter().map(|(_, s)| *s).collect();
97
98 for _gen in 0..self.config.n_generations {
99 let parent_idx = self.tournament_select(&population, &mut rng);
101
102 let mut child = population[parent_idx].0.clone();
104 self.mutate(&mut child, space, &mut rng);
105 let child_score = fitness.evaluate(&child).unwrap_or(f64::NEG_INFINITY);
106 all_scores.push(child_score);
107
108 let worst_idx = self.find_worst(&population);
110 if child_score > population[worst_idx].1 {
111 population[worst_idx] = (child, child_score);
112 }
113 }
114
115 let (best_arch, best_score) = population
117 .into_iter()
118 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
119 .map(|(arch, score)| (arch, score))
120 .unwrap_or_else(|| {
121 let arch = space.sample_random(&mut rng);
122 (arch, f64::NEG_INFINITY)
123 });
124
125 let n_evaluated = self.config.population_size + self.config.n_generations;
126 Ok(NASResult {
127 best_arch,
128 best_score,
129 all_scores,
130 n_evaluated,
131 })
132 }
133
134 fn tournament_select(&self, population: &[(Architecture, f64)], rng: &mut StdRng) -> usize {
137 let n = population.len();
138 let k = self.config.tournament_size.min(n);
139 let mut best_idx = rng.random_range(0..n);
140 for _ in 1..k {
141 let idx = rng.random_range(0..n);
142 if population[idx].1 > population[best_idx].1 {
143 best_idx = idx;
144 }
145 }
146 best_idx
147 }
148
149 fn find_worst(&self, population: &[(Architecture, f64)]) -> usize {
151 population
152 .iter()
153 .enumerate()
154 .min_by(|a, b| {
155 a.1 .1
156 .partial_cmp(&b.1 .1)
157 .unwrap_or(std::cmp::Ordering::Equal)
158 })
159 .map(|(i, _)| i)
160 .unwrap_or(0)
161 }
162
163 fn mutate(&self, arch: &mut Architecture, space: &SearchSpace, rng: &mut StdRng) {
166 if space.operations.is_empty() || arch.edges.is_empty() {
167 return;
168 }
169 for edge in arch.edges.iter_mut() {
170 if rng.random::<f64>() < self.config.mutation_rate {
171 let op_idx = rng.random_range(0..space.operations.len());
172 edge.op = space.operations[op_idx].clone();
173 }
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use crate::nas::random_nas::ParamCountFitness;
182 use crate::nas::search_space::SearchSpace;
183
184 #[test]
185 fn test_evolutionary_nas_runs() {
186 let space = SearchSpace::darts_like(3);
187 let fitness = ParamCountFitness::new(8_000);
188 let nas = EvolutionaryNAS::new(10, 20);
189
190 let result = nas.search(&space, &fitness, 99).expect("search failed");
191
192 assert!(result.n_evaluated >= 10);
194 assert!(!result.all_scores.is_empty());
195 }
196
197 #[test]
198 fn test_evolutionary_nas_small_population_error() {
199 let space = SearchSpace::darts_like(3);
200 let fitness = ParamCountFitness::new(8_000);
201 let nas = EvolutionaryNAS::new(1, 5);
202
203 assert!(nas.search(&space, &fitness, 0).is_err());
204 }
205
206 #[test]
207 fn test_evolutionary_nas_monotone_best_score() {
208 let space = SearchSpace::darts_like(3);
210 let fitness = ParamCountFitness::new(5_000);
211 let nas = EvolutionaryNAS::new(8, 30);
212
213 let result = nas.search(&space, &fitness, 7).expect("search failed");
214
215 assert!(result.best_score.is_finite() || result.best_score == f64::NEG_INFINITY);
217 }
218
219 #[test]
220 fn test_evolutionary_nas_with_config() {
221 let config = EvolutionaryNASConfig {
222 population_size: 6,
223 n_generations: 10,
224 mutation_rate: 0.5,
225 tournament_size: 3,
226 elitism: false,
227 };
228 let space = SearchSpace::darts_like(3);
229 let fitness = ParamCountFitness::new(4_000);
230 let nas = EvolutionaryNAS::with_config(config);
231
232 let result = nas.search(&space, &fitness, 1).expect("search failed");
233 assert!(result.n_evaluated > 0);
234 }
235}