scirs2_optimize/hardware_nas/
mod.rs1use std::collections::HashMap;
21
22use crate::darts::Operation;
23use crate::error::OptimizeError;
24
25#[derive(Debug, Clone)]
31pub struct LatencyTable {
32 pub op_latencies: HashMap<String, f64>,
34 pub size_scale: f64,
36}
37
38impl LatencyTable {
39 pub fn new() -> Self {
43 let mut op_latencies = HashMap::new();
44 op_latencies.insert("conv3x3".to_string(), 1.5);
45 op_latencies.insert("conv5x5".to_string(), 3.0);
46 op_latencies.insert("max_pool".to_string(), 0.2);
47 op_latencies.insert("avg_pool".to_string(), 0.2);
48 op_latencies.insert("identity".to_string(), 0.05);
49 op_latencies.insert("skip_connect".to_string(), 0.05);
50 op_latencies.insert("zero".to_string(), 0.0);
51 Self {
52 op_latencies,
53 size_scale: 1e-4, }
55 }
56
57 pub fn latency_of(&self, op: &str, input_size: usize) -> f64 {
61 let base = self.op_latencies.get(op).cloned().unwrap_or(1.0);
62 base + self.size_scale * input_size as f64
63 }
64
65 pub fn total_latency(&self, arch: &[(String, usize)]) -> f64 {
67 arch.iter().map(|(op, sz)| self.latency_of(op, *sz)).sum()
68 }
69}
70
71impl Default for LatencyTable {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77#[non_exhaustive]
81#[derive(Debug, Clone)]
82pub enum NasObjective {
83 Accuracy,
85 Latency,
87 FlopsCount,
89 ParamCount,
91 MultiObjective {
93 accuracy_weight: f64,
95 latency_weight: f64,
97 },
98}
99
100impl Default for NasObjective {
101 fn default() -> Self {
102 NasObjective::MultiObjective {
103 accuracy_weight: 1.0,
104 latency_weight: 0.01,
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
113pub struct HardwareNasConfig {
114 pub max_latency_ms: f64,
116 pub max_params: usize,
118 pub n_search_iter: usize,
120 pub objective: NasObjective,
122 pub seed: u64,
124 pub n_ops_per_arch: usize,
126 pub input_size: usize,
128 pub params_per_op: usize,
130 pub population_size: usize,
132 pub tournament_size: usize,
134 pub n_generations: usize,
136}
137
138impl Default for HardwareNasConfig {
139 fn default() -> Self {
140 Self {
141 max_latency_ms: 10.0,
142 max_params: 1_000_000,
143 n_search_iter: 100,
144 objective: NasObjective::default(),
145 seed: 42,
146 n_ops_per_arch: 8,
147 input_size: 224 * 224 * 3,
148 params_per_op: 9 * 16 * 16, population_size: 20,
150 tournament_size: 3,
151 n_generations: 10,
152 }
153 }
154}
155
156#[derive(Debug, Clone)]
160pub struct ArchCandidate {
161 pub operations: Vec<Operation>,
163 pub estimated_accuracy: f64,
165 pub estimated_latency: f64,
167 pub n_params: usize,
169}
170
171impl ArchCandidate {
172 pub fn objective_value(&self, obj: &NasObjective) -> f64 {
174 match obj {
175 NasObjective::Accuracy => self.estimated_accuracy,
176 NasObjective::Latency => -self.estimated_latency,
177 NasObjective::FlopsCount => -(self.n_params as f64), NasObjective::ParamCount => -(self.n_params as f64),
179 NasObjective::MultiObjective {
180 accuracy_weight,
181 latency_weight,
182 } => {
183 accuracy_weight * self.estimated_accuracy - latency_weight * self.estimated_latency
184 }
185 }
186 }
187}
188
189#[derive(Debug)]
193pub struct HardwareNasSearcher {
194 config: HardwareNasConfig,
195 latency_table: LatencyTable,
196 rng_state: u64,
198}
199
200impl HardwareNasSearcher {
201 pub fn new(config: HardwareNasConfig, latency_table: LatencyTable) -> Self {
203 let rng_state = config.seed;
204 Self {
205 config,
206 latency_table,
207 rng_state,
208 }
209 }
210
211 fn lcg_next(&mut self) -> u64 {
215 self.rng_state = self
217 .rng_state
218 .wrapping_mul(6_364_136_223_846_793_005)
219 .wrapping_add(1_442_695_040_888_963_407);
220 self.rng_state
221 }
222
223 fn rand_usize(&mut self, n: usize) -> usize {
225 if n == 0 {
226 return 0;
227 }
228 (self.lcg_next() as usize) % n
229 }
230
231 fn rand_f64(&mut self) -> f64 {
233 (self.lcg_next() >> 11) as f64 / (1u64 << 53) as f64
234 }
235
236 fn sample_random_arch(&mut self) -> Vec<Operation> {
240 let ops = Operation::all();
241 let n = self.config.n_ops_per_arch;
242 (0..n).map(|_| ops[self.rand_usize(ops.len())]).collect()
243 }
244
245 fn estimate_latency(&self, ops: &[Operation]) -> f64 {
247 let pairs: Vec<(String, usize)> = ops
248 .iter()
249 .map(|o| (o.name().to_string(), self.config.input_size))
250 .collect();
251 self.latency_table.total_latency(&pairs)
252 }
253
254 fn estimate_params(&self, ops: &[Operation]) -> usize {
256 ops.iter()
257 .map(|o| match o {
258 Operation::Zero | Operation::Identity | Operation::SkipConnect => 0,
259 Operation::MaxPool | Operation::AvgPool => 0,
260 Operation::Conv3x3 => self.config.params_per_op,
261 Operation::Conv5x5 => self.config.params_per_op * 2,
262 })
263 .sum()
264 }
265
266 fn satisfies_constraints(&self, candidate: &ArchCandidate) -> bool {
268 candidate.estimated_latency <= self.config.max_latency_ms
269 && candidate.n_params <= self.config.max_params
270 }
271
272 fn build_candidate(&mut self, ops: Vec<Operation>, accuracy: f64) -> ArchCandidate {
274 let latency = self.estimate_latency(&ops);
275 let n_params = self.estimate_params(&ops);
276 ArchCandidate {
277 operations: ops,
278 estimated_accuracy: accuracy,
279 estimated_latency: latency,
280 n_params,
281 }
282 }
283
284 pub fn random_search(
294 &mut self,
295 eval_fn: impl Fn(&[Operation]) -> f64,
296 ) -> Result<ArchCandidate, OptimizeError> {
297 let mut best: Option<ArchCandidate> = None;
298 let obj = self.config.objective.clone();
299
300 for _ in 0..self.config.n_search_iter {
301 let ops = self.sample_random_arch();
302 let acc = eval_fn(&ops);
303 let candidate = self.build_candidate(ops, acc);
304 if !self.satisfies_constraints(&candidate) {
305 continue;
306 }
307 match &best {
308 None => best = Some(candidate),
309 Some(b) => {
310 if candidate.objective_value(&obj) > b.objective_value(&obj) {
311 best = Some(candidate);
312 }
313 }
314 }
315 }
316
317 best.ok_or_else(|| {
318 OptimizeError::ConvergenceError(
319 "No architecture found satisfying hardware constraints".to_string(),
320 )
321 })
322 }
323
324 pub fn evolutionary_search(
328 &mut self,
329 eval_fn: impl Fn(&[Operation]) -> f64,
330 ) -> Result<ArchCandidate, OptimizeError> {
331 let pop_size = self.config.population_size;
332 let obj = self.config.objective.clone();
333
334 let mut population: Vec<ArchCandidate> = (0..pop_size)
336 .map(|_| {
337 let ops = self.sample_random_arch();
338 let acc = eval_fn(&ops);
339 self.build_candidate(ops, acc)
340 })
341 .collect();
342
343 let mut best: Option<ArchCandidate> = population
344 .iter()
345 .filter(|c| self.satisfies_constraints(c))
346 .max_by(|a, b| {
347 a.objective_value(&obj)
348 .partial_cmp(&b.objective_value(&obj))
349 .unwrap_or(std::cmp::Ordering::Equal)
350 })
351 .cloned();
352
353 for _gen in 0..self.config.n_generations {
354 let mut next_pop: Vec<ArchCandidate> = Vec::with_capacity(pop_size);
355
356 for _ in 0..pop_size {
357 let parent = self.tournament_select(&population, &obj);
359 let child_ops = self.mutate(&parent.operations);
361 let acc = eval_fn(&child_ops);
362 let child = self.build_candidate(child_ops, acc);
363
364 if self.satisfies_constraints(&child) {
365 match &best {
366 None => best = Some(child.clone()),
367 Some(b) => {
368 if child.objective_value(&obj) > b.objective_value(&obj) {
369 best = Some(child.clone());
370 }
371 }
372 }
373 }
374 next_pop.push(child);
375 }
376 population = next_pop;
377 }
378
379 best.ok_or_else(|| {
380 OptimizeError::ConvergenceError(
381 "Evolutionary search: no constraint-satisfying architecture found".to_string(),
382 )
383 })
384 }
385
386 fn tournament_select(
389 &mut self,
390 population: &[ArchCandidate],
391 obj: &NasObjective,
392 ) -> ArchCandidate {
393 let t = self.config.tournament_size.min(population.len()).max(1);
394 let mut best_idx = self.rand_usize(population.len());
395 for _ in 1..t {
396 let idx = self.rand_usize(population.len());
397 if population[idx].objective_value(obj) > population[best_idx].objective_value(obj) {
398 best_idx = idx;
399 }
400 }
401 population[best_idx].clone()
402 }
403
404 fn mutate(&mut self, ops: &[Operation]) -> Vec<Operation> {
406 if ops.is_empty() {
407 return Vec::new();
408 }
409 let mut new_ops = ops.to_vec();
410 let pos = self.rand_usize(new_ops.len());
411 let all_ops = Operation::all();
412 new_ops[pos] = all_ops[self.rand_usize(all_ops.len())];
413 new_ops
414 }
415
416 pub fn pareto_front(candidates: &[ArchCandidate]) -> Vec<usize> {
421 let n = candidates.len();
422 let mut dominated = vec![false; n];
423
424 for i in 0..n {
425 if dominated[i] {
426 continue;
427 }
428 for j in 0..n {
429 if i == j || dominated[j] {
430 continue;
431 }
432 let j_dom_i = candidates[j].estimated_accuracy >= candidates[i].estimated_accuracy
434 && candidates[j].estimated_latency <= candidates[i].estimated_latency
435 && (candidates[j].estimated_accuracy > candidates[i].estimated_accuracy
436 || candidates[j].estimated_latency < candidates[i].estimated_latency);
437 if j_dom_i {
438 dominated[i] = true;
439 break;
440 }
441 }
442 }
443
444 (0..n).filter(|&i| !dominated[i]).collect()
445 }
446}
447
448#[cfg(test)]
451mod tests {
452 use super::*;
453
454 fn make_searcher() -> HardwareNasSearcher {
455 HardwareNasSearcher::new(HardwareNasConfig::default(), LatencyTable::new())
456 }
457
458 fn acc_oracle(ops: &[Operation]) -> f64 {
460 let light_count = ops
461 .iter()
462 .filter(|o| matches!(o, Operation::Identity | Operation::SkipConnect))
463 .count();
464 0.5 + 0.05 * light_count as f64
465 }
466
467 #[test]
468 fn latency_table_default_contains_ops() {
469 let lt = LatencyTable::new();
470 assert!(lt.latency_of("conv3x3", 1000) > 0.0);
471 assert_eq!(lt.latency_of("zero", 0), 0.0);
472 }
473
474 #[test]
475 fn total_latency_sums_correctly() {
476 let lt = LatencyTable::new();
477 let arch = vec![("conv3x3".to_string(), 0), ("max_pool".to_string(), 0)];
478 let total = lt.total_latency(&arch);
479 let expected = lt.latency_of("conv3x3", 0) + lt.latency_of("max_pool", 0);
480 assert!((total - expected).abs() < 1e-12);
481 }
482
483 #[test]
484 fn random_search_finds_valid_candidate() {
485 let mut config = HardwareNasConfig::default();
486 config.max_latency_ms = 10_000.0;
488 config.n_search_iter = 50;
489 config.n_ops_per_arch = 4;
490 let mut searcher = HardwareNasSearcher::new(config, LatencyTable::new());
491 let result = searcher.random_search(acc_oracle);
492 assert!(result.is_ok(), "Should find a valid candidate");
493 let cand = result.unwrap();
494 assert!(cand.estimated_latency <= 10_000.0);
495 }
496
497 #[test]
498 fn pareto_front_returns_non_dominated_subset() {
499 let candidates = vec![
500 ArchCandidate {
501 operations: vec![],
502 estimated_accuracy: 0.9,
503 estimated_latency: 5.0,
504 n_params: 100,
505 },
506 ArchCandidate {
507 operations: vec![],
508 estimated_accuracy: 0.8,
509 estimated_latency: 3.0,
510 n_params: 80,
511 },
512 ArchCandidate {
513 operations: vec![],
514 estimated_accuracy: 0.7,
515 estimated_latency: 8.0, n_params: 90,
517 },
518 ];
519 let front = HardwareNasSearcher::pareto_front(&candidates);
520 assert!(
521 front.contains(&0),
522 "high accuracy / moderate latency should be on front"
523 );
524 assert!(
525 front.contains(&1),
526 "low latency / moderate accuracy should be on front"
527 );
528 assert!(
529 !front.contains(&2),
530 "dominated candidate should not be on front"
531 );
532 }
533
534 #[test]
535 fn evolutionary_search_runs() {
536 let mut config = HardwareNasConfig::default();
537 config.max_latency_ms = 10_000.0;
538 config.population_size = 10;
539 config.n_generations = 5;
540 config.n_ops_per_arch = 4;
541 let mut searcher = HardwareNasSearcher::new(config, LatencyTable::new());
542 let result = searcher.evolutionary_search(acc_oracle);
543 assert!(
544 result.is_ok(),
545 "Evolutionary search should find a candidate"
546 );
547 }
548
549 #[test]
550 fn pareto_front_single_candidate() {
551 let candidates = vec![ArchCandidate {
552 operations: vec![],
553 estimated_accuracy: 0.85,
554 estimated_latency: 4.0,
555 n_params: 50,
556 }];
557 let front = HardwareNasSearcher::pareto_front(&candidates);
558 assert_eq!(front, vec![0]);
559 }
560}