scirs2_optimize/nas/
random_nas.rs1use crate::error::OptimizeError;
8use crate::nas::search_space::{Architecture, SearchSpace};
9use scirs2_core::random::{rngs::StdRng, SeedableRng};
10
11#[derive(Debug, Clone)]
13pub struct NASResult {
14 pub best_arch: Architecture,
16 pub best_score: f64,
18 pub all_scores: Vec<f64>,
20 pub n_evaluated: usize,
22}
23
24pub trait ArchFitness: Send + Sync {
28 fn evaluate(&self, arch: &Architecture) -> Result<f64, OptimizeError>;
32}
33
34pub struct ParamCountFitness {
39 pub target_params: usize,
41}
42
43impl ParamCountFitness {
44 pub fn new(target_params: usize) -> Self {
46 Self { target_params }
47 }
48}
49
50impl ArchFitness for ParamCountFitness {
51 fn evaluate(&self, arch: &Architecture) -> Result<f64, OptimizeError> {
52 let params = arch.total_params() as f64;
53 let target = self.target_params as f64;
54 if target == 0.0 {
55 return Ok(if params == 0.0 { 0.0 } else { -1.0 });
56 }
57 Ok(-(params - target).abs() / target)
58 }
59}
60
61pub struct FlopsFitness {
63 pub flops_budget: usize,
65 pub spatial: usize,
67}
68
69impl FlopsFitness {
70 pub fn new(flops_budget: usize, spatial: usize) -> Self {
72 Self {
73 flops_budget,
74 spatial,
75 }
76 }
77}
78
79impl ArchFitness for FlopsFitness {
80 fn evaluate(&self, arch: &Architecture) -> Result<f64, OptimizeError> {
81 let flops = arch.total_flops(self.spatial) as f64;
82 let budget = self.flops_budget as f64;
83 if budget == 0.0 {
84 return Ok(0.0);
85 }
86 if flops <= budget {
88 Ok(flops / budget)
89 } else {
90 Ok(-(flops - budget) / budget)
91 }
92 }
93}
94
95pub struct RandomNAS {
100 pub n_trials: usize,
102}
103
104impl RandomNAS {
105 pub fn new(n_trials: usize) -> Self {
107 Self { n_trials }
108 }
109
110 pub fn search<F: ArchFitness>(
117 &self,
118 space: &SearchSpace,
119 fitness: &F,
120 seed: u64,
121 ) -> Result<NASResult, OptimizeError> {
122 use scirs2_core::random::{Rng, RngExt};
123
124 if self.n_trials == 0 {
125 return Err(OptimizeError::InvalidParameter(
126 "n_trials must be at least 1".to_string(),
127 ));
128 }
129
130 let mut rng = StdRng::seed_from_u64(seed);
131
132 let mut best_score = f64::NEG_INFINITY;
133 let mut best_arch = space.sample_random(&mut rng);
135 let mut all_scores = Vec::with_capacity(self.n_trials);
136
137 for _ in 0..self.n_trials {
138 let arch = space.sample_random(&mut rng);
139 let score = fitness.evaluate(&arch)?;
140 all_scores.push(score);
141 if score > best_score {
142 best_score = score;
143 best_arch = arch;
144 }
145 }
146
147 Ok(NASResult {
148 best_arch,
149 best_score,
150 all_scores,
151 n_evaluated: self.n_trials,
152 })
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::nas::search_space::SearchSpace;
160
161 #[test]
162 fn test_random_nas_returns_result() {
163 let space = SearchSpace::darts_like(3);
164 let fitness = ParamCountFitness::new(10_000);
165 let nas = RandomNAS::new(20);
166
167 let result = nas.search(&space, &fitness, 0).expect("search failed");
168
169 assert_eq!(result.n_evaluated, 20);
170 assert_eq!(result.all_scores.len(), 20);
171 assert!(result.best_score.is_finite());
173 }
174
175 #[test]
176 fn test_random_nas_zero_trials_errors() {
177 let space = SearchSpace::darts_like(3);
178 let fitness = ParamCountFitness::new(10_000);
179 let nas = RandomNAS::new(0);
180
181 assert!(nas.search(&space, &fitness, 0).is_err());
182 }
183
184 #[test]
185 fn test_param_count_fitness_exact_match() {
186 let mut arch = Architecture::new(1, 32, 10);
187 let fitness = ParamCountFitness::new(0);
189 let score = fitness.evaluate(&arch).expect("eval failed");
190 assert_eq!(score, 0.0);
191
192 use crate::nas::search_space::{ArchEdge, ArchNode, OpType};
194 arch.nodes.push(ArchNode {
195 id: 0,
196 name: "n0".into(),
197 output_channels: 32,
198 });
199 arch.nodes.push(ArchNode {
200 id: 1,
201 name: "n1".into(),
202 output_channels: 32,
203 });
204 arch.edges.push(ArchEdge {
205 from: 0,
206 to: 1,
207 op: OpType::Conv3x3,
208 });
209 let fitness2 = ParamCountFitness::new(0);
210 let score2 = fitness2.evaluate(&arch).expect("eval failed");
211 assert_eq!(score2, -1.0);
212 }
213
214 #[test]
215 fn test_flops_fitness_under_budget() {
216 use crate::nas::search_space::{ArchEdge, ArchNode, OpType};
217 let mut arch = Architecture::new(1, 8, 10);
218 arch.nodes.push(ArchNode {
219 id: 0,
220 name: "n0".into(),
221 output_channels: 8,
222 });
223 arch.nodes.push(ArchNode {
224 id: 1,
225 name: "n1".into(),
226 output_channels: 8,
227 });
228 arch.edges.push(ArchEdge {
229 from: 0,
230 to: 1,
231 op: OpType::Skip,
232 });
233
234 let fitness = FlopsFitness::new(1_000_000, 8);
235 let score = fitness.evaluate(&arch).expect("eval failed");
236 assert!(score >= 0.0 && score <= 1.0);
238 }
239}