Skip to main content

scirs2_optimize/hardware_nas/
mod.rs

1//! Hardware-Aware Neural Architecture Search
2//!
3//! Adds a latency (and parameter-count) model on top of architecture search so that
4//! only architectures that fit within a hardware budget are considered.
5//!
6//! ## What is included
7//!
8//! - `LatencyTable`: maps operation names × input sizes to latency in milliseconds.
9//! - `NasObjective`: single or multi-objective formulation.
10//! - `HardwareNasSearcher`: random search and evolutionary search over candidate
11//!   architectures, with constraint filtering and Pareto-front extraction.
12//!
13//! ## References
14//!
15//! - Cai, H. et al. (2019). "ProxylessNAS: Direct Neural Architecture Search on
16//!   Target Task and Hardware". ICLR 2019.
17//! - Tan, M. & Le, Q.V. (2019). "EfficientNet: Rethinking Model Scaling for
18//!   Convolutional Neural Networks". ICML 2019.
19
20use std::collections::HashMap;
21
22use crate::darts::Operation;
23use crate::error::OptimizeError;
24
25// ──────────────────────────────────────────────────────────── LatencyTable ──
26
27/// Lookup table mapping `(operation_name, input_size)` → latency in milliseconds.
28///
29/// The default constructor populates a set of representative estimates.
30#[derive(Debug, Clone)]
31pub struct LatencyTable {
32    /// Raw latency data: key = operation name, value = base latency (ms).
33    pub op_latencies: HashMap<String, f64>,
34    /// Scale factor applied per input-size unit.
35    pub size_scale: f64,
36}
37
38impl LatencyTable {
39    /// Create a new table with default hardware latency estimates.
40    ///
41    /// Values are representative of a mid-range mobile CPU at 224×224 feature map.
42    pub fn new() -> Self {
43        let mut op_latencies = HashMap::new();
44        op_latencies.insert("conv3x3".to_string(), 1.5);
45        op_latencies.insert("conv5x5".to_string(), 3.0);
46        op_latencies.insert("max_pool".to_string(), 0.2);
47        op_latencies.insert("avg_pool".to_string(), 0.2);
48        op_latencies.insert("identity".to_string(), 0.05);
49        op_latencies.insert("skip_connect".to_string(), 0.05);
50        op_latencies.insert("zero".to_string(), 0.0);
51        Self {
52            op_latencies,
53            size_scale: 1e-4, // latency per unit of input_size beyond a base
54        }
55    }
56
57    /// Latency for a single operation given `input_size` (e.g., H*W*C).
58    ///
59    /// Uses a simple linear model: latency = base + size_scale * input_size.
60    pub fn latency_of(&self, op: &str, input_size: usize) -> f64 {
61        let base = self.op_latencies.get(op).cloned().unwrap_or(1.0);
62        base + self.size_scale * input_size as f64
63    }
64
65    /// Total latency for a sequence of `(operation_name, input_size)` pairs.
66    pub fn total_latency(&self, arch: &[(String, usize)]) -> f64 {
67        arch.iter().map(|(op, sz)| self.latency_of(op, *sz)).sum()
68    }
69}
70
71impl Default for LatencyTable {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77// ──────────────────────────────────────────────────────────── NasObjective ──
78
79/// Optimisation objective for hardware-aware NAS.
80#[non_exhaustive]
81#[derive(Debug, Clone)]
82pub enum NasObjective {
83    /// Maximise accuracy only (no latency constraint in the objective).
84    Accuracy,
85    /// Minimise latency only.
86    Latency,
87    /// Minimise FLOPs.
88    FlopsCount,
89    /// Minimise parameter count.
90    ParamCount,
91    /// Linear scalarisation: `accuracy_weight * accuracy - latency_weight * latency`.
92    MultiObjective {
93        /// Weight on accuracy.
94        accuracy_weight: f64,
95        /// Weight on latency.
96        latency_weight: f64,
97    },
98}
99
100impl Default for NasObjective {
101    fn default() -> Self {
102        NasObjective::MultiObjective {
103            accuracy_weight: 1.0,
104            latency_weight: 0.01,
105        }
106    }
107}
108
109// ──────────────────────────────────────────────────────── HardwareNasConfig ──
110
111/// Configuration for hardware-aware NAS.
112#[derive(Debug, Clone)]
113pub struct HardwareNasConfig {
114    /// Maximum allowed latency in milliseconds.
115    pub max_latency_ms: f64,
116    /// Maximum allowed number of parameters.
117    pub max_params: usize,
118    /// Number of random search iterations (or initial population size for evolution).
119    pub n_search_iter: usize,
120    /// Optimisation objective.
121    pub objective: NasObjective,
122    /// RNG seed for reproducibility.
123    pub seed: u64,
124    /// Number of operations to include in each candidate architecture.
125    pub n_ops_per_arch: usize,
126    /// Input size (H*W*C) used for latency estimation.
127    pub input_size: usize,
128    /// Number of parameters assumed for each operation (simplified model).
129    pub params_per_op: usize,
130    /// Population size for evolutionary search.
131    pub population_size: usize,
132    /// Tournament size for evolutionary selection.
133    pub tournament_size: usize,
134    /// Number of generations for evolutionary search.
135    pub n_generations: usize,
136}
137
138impl Default for HardwareNasConfig {
139    fn default() -> Self {
140        Self {
141            max_latency_ms: 10.0,
142            max_params: 1_000_000,
143            n_search_iter: 100,
144            objective: NasObjective::default(),
145            seed: 42,
146            n_ops_per_arch: 8,
147            input_size: 224 * 224 * 3,
148            params_per_op: 9 * 16 * 16, // 3×3 conv, C=16
149            population_size: 20,
150            tournament_size: 3,
151            n_generations: 10,
152        }
153    }
154}
155
156// ───────────────────────────────────────────────────────── ArchCandidate ──
157
158/// A concrete architecture candidate with performance estimates.
159#[derive(Debug, Clone)]
160pub struct ArchCandidate {
161    /// Sequence of operations.
162    pub operations: Vec<Operation>,
163    /// Estimated top-1 accuracy (fraction in [0, 1]).
164    pub estimated_accuracy: f64,
165    /// Estimated latency in milliseconds.
166    pub estimated_latency: f64,
167    /// Estimated parameter count.
168    pub n_params: usize,
169}
170
171impl ArchCandidate {
172    /// Scalar objective value (higher is better).
173    pub fn objective_value(&self, obj: &NasObjective) -> f64 {
174        match obj {
175            NasObjective::Accuracy => self.estimated_accuracy,
176            NasObjective::Latency => -self.estimated_latency,
177            NasObjective::FlopsCount => -(self.n_params as f64), // use params as proxy for FLOPs
178            NasObjective::ParamCount => -(self.n_params as f64),
179            NasObjective::MultiObjective {
180                accuracy_weight,
181                latency_weight,
182            } => {
183                accuracy_weight * self.estimated_accuracy - latency_weight * self.estimated_latency
184            }
185        }
186    }
187}
188
189// ─────────────────────────────────────────────────── HardwareNasSearcher ──
190
191/// Hardware-aware NAS searcher.
192#[derive(Debug)]
193pub struct HardwareNasSearcher {
194    config: HardwareNasConfig,
195    latency_table: LatencyTable,
196    /// Internal LCG state.
197    rng_state: u64,
198}
199
200impl HardwareNasSearcher {
201    /// Create a new searcher.
202    pub fn new(config: HardwareNasConfig, latency_table: LatencyTable) -> Self {
203        let rng_state = config.seed;
204        Self {
205            config,
206            latency_table,
207            rng_state,
208        }
209    }
210
211    // ── LCG random number generator ──────────────────────────────────────
212
213    /// Advance LCG and return next u64.
214    fn lcg_next(&mut self) -> u64 {
215        // Knuth's multiplicative LCG
216        self.rng_state = self
217            .rng_state
218            .wrapping_mul(6_364_136_223_846_793_005)
219            .wrapping_add(1_442_695_040_888_963_407);
220        self.rng_state
221    }
222
223    /// Sample a uniformly random `usize` in `0..n`.
224    fn rand_usize(&mut self, n: usize) -> usize {
225        if n == 0 {
226            return 0;
227        }
228        (self.lcg_next() as usize) % n
229    }
230
231    /// Sample a uniformly random f64 in `[0, 1)`.
232    fn rand_f64(&mut self) -> f64 {
233        (self.lcg_next() >> 11) as f64 / (1u64 << 53) as f64
234    }
235
236    // ── Architecture sampling helpers ─────────────────────────────────────
237
238    /// Sample a random architecture (sequence of `n_ops_per_arch` operations).
239    fn sample_random_arch(&mut self) -> Vec<Operation> {
240        let ops = Operation::all();
241        let n = self.config.n_ops_per_arch;
242        (0..n).map(|_| ops[self.rand_usize(ops.len())]).collect()
243    }
244
245    /// Estimate latency for an architecture using the latency table.
246    fn estimate_latency(&self, ops: &[Operation]) -> f64 {
247        let pairs: Vec<(String, usize)> = ops
248            .iter()
249            .map(|o| (o.name().to_string(), self.config.input_size))
250            .collect();
251        self.latency_table.total_latency(&pairs)
252    }
253
254    /// Estimate parameter count for an architecture.
255    fn estimate_params(&self, ops: &[Operation]) -> usize {
256        ops.iter()
257            .map(|o| match o {
258                Operation::Zero | Operation::Identity | Operation::SkipConnect => 0,
259                Operation::MaxPool | Operation::AvgPool => 0,
260                Operation::Conv3x3 => self.config.params_per_op,
261                Operation::Conv5x5 => self.config.params_per_op * 2,
262            })
263            .sum()
264    }
265
266    /// Check whether a candidate satisfies the hardware constraints.
267    fn satisfies_constraints(&self, candidate: &ArchCandidate) -> bool {
268        candidate.estimated_latency <= self.config.max_latency_ms
269            && candidate.n_params <= self.config.max_params
270    }
271
272    /// Build an `ArchCandidate` for given ops using the provided accuracy estimate.
273    fn build_candidate(&mut self, ops: Vec<Operation>, accuracy: f64) -> ArchCandidate {
274        let latency = self.estimate_latency(&ops);
275        let n_params = self.estimate_params(&ops);
276        ArchCandidate {
277            operations: ops,
278            estimated_accuracy: accuracy,
279            estimated_latency: latency,
280            n_params,
281        }
282    }
283
284    // ── Public search methods ─────────────────────────────────────────────
285
286    /// Random search: sample `n_search_iter` architectures, evaluate with `eval_fn`,
287    /// filter by constraints, return the best.
288    ///
289    /// `eval_fn` receives a slice of `Operation` and returns an accuracy estimate
290    /// in `[0, 1]`.
291    ///
292    /// Returns an error if no candidate satisfies the hardware constraints.
293    pub fn random_search(
294        &mut self,
295        eval_fn: impl Fn(&[Operation]) -> f64,
296    ) -> Result<ArchCandidate, OptimizeError> {
297        let mut best: Option<ArchCandidate> = None;
298        let obj = self.config.objective.clone();
299
300        for _ in 0..self.config.n_search_iter {
301            let ops = self.sample_random_arch();
302            let acc = eval_fn(&ops);
303            let candidate = self.build_candidate(ops, acc);
304            if !self.satisfies_constraints(&candidate) {
305                continue;
306            }
307            match &best {
308                None => best = Some(candidate),
309                Some(b) => {
310                    if candidate.objective_value(&obj) > b.objective_value(&obj) {
311                        best = Some(candidate);
312                    }
313                }
314            }
315        }
316
317        best.ok_or_else(|| {
318            OptimizeError::ConvergenceError(
319                "No architecture found satisfying hardware constraints".to_string(),
320            )
321        })
322    }
323
324    /// Evolutionary search: start with a random population, apply tournament
325    /// selection and random mutation for `n_generations`, return the best
326    /// constraint-satisfying candidate found.
327    pub fn evolutionary_search(
328        &mut self,
329        eval_fn: impl Fn(&[Operation]) -> f64,
330    ) -> Result<ArchCandidate, OptimizeError> {
331        let pop_size = self.config.population_size;
332        let obj = self.config.objective.clone();
333
334        // Initialise population.
335        let mut population: Vec<ArchCandidate> = (0..pop_size)
336            .map(|_| {
337                let ops = self.sample_random_arch();
338                let acc = eval_fn(&ops);
339                self.build_candidate(ops, acc)
340            })
341            .collect();
342
343        let mut best: Option<ArchCandidate> = population
344            .iter()
345            .filter(|c| self.satisfies_constraints(c))
346            .max_by(|a, b| {
347                a.objective_value(&obj)
348                    .partial_cmp(&b.objective_value(&obj))
349                    .unwrap_or(std::cmp::Ordering::Equal)
350            })
351            .cloned();
352
353        for _gen in 0..self.config.n_generations {
354            let mut next_pop: Vec<ArchCandidate> = Vec::with_capacity(pop_size);
355
356            for _ in 0..pop_size {
357                // Tournament selection.
358                let parent = self.tournament_select(&population, &obj);
359                // Mutate: randomly swap one operation.
360                let child_ops = self.mutate(&parent.operations);
361                let acc = eval_fn(&child_ops);
362                let child = self.build_candidate(child_ops, acc);
363
364                if self.satisfies_constraints(&child) {
365                    match &best {
366                        None => best = Some(child.clone()),
367                        Some(b) => {
368                            if child.objective_value(&obj) > b.objective_value(&obj) {
369                                best = Some(child.clone());
370                            }
371                        }
372                    }
373                }
374                next_pop.push(child);
375            }
376            population = next_pop;
377        }
378
379        best.ok_or_else(|| {
380            OptimizeError::ConvergenceError(
381                "Evolutionary search: no constraint-satisfying architecture found".to_string(),
382            )
383        })
384    }
385
386    /// Tournament selection: sample `tournament_size` candidates uniformly at
387    /// random, return a clone of the one with the best objective.
388    fn tournament_select(
389        &mut self,
390        population: &[ArchCandidate],
391        obj: &NasObjective,
392    ) -> ArchCandidate {
393        let t = self.config.tournament_size.min(population.len()).max(1);
394        let mut best_idx = self.rand_usize(population.len());
395        for _ in 1..t {
396            let idx = self.rand_usize(population.len());
397            if population[idx].objective_value(obj) > population[best_idx].objective_value(obj) {
398                best_idx = idx;
399            }
400        }
401        population[best_idx].clone()
402    }
403
404    /// Mutation: randomly replace one operation with another candidate operation.
405    fn mutate(&mut self, ops: &[Operation]) -> Vec<Operation> {
406        if ops.is_empty() {
407            return Vec::new();
408        }
409        let mut new_ops = ops.to_vec();
410        let pos = self.rand_usize(new_ops.len());
411        let all_ops = Operation::all();
412        new_ops[pos] = all_ops[self.rand_usize(all_ops.len())];
413        new_ops
414    }
415
416    /// Compute the Pareto front of a set of candidates w.r.t.
417    /// `(estimated_accuracy, -estimated_latency)` (both maximised).
418    ///
419    /// Returns the indices of non-dominated candidates.
420    pub fn pareto_front(candidates: &[ArchCandidate]) -> Vec<usize> {
421        let n = candidates.len();
422        let mut dominated = vec![false; n];
423
424        for i in 0..n {
425            if dominated[i] {
426                continue;
427            }
428            for j in 0..n {
429                if i == j || dominated[j] {
430                    continue;
431                }
432                // Does j dominate i?
433                let j_dom_i = candidates[j].estimated_accuracy >= candidates[i].estimated_accuracy
434                    && candidates[j].estimated_latency <= candidates[i].estimated_latency
435                    && (candidates[j].estimated_accuracy > candidates[i].estimated_accuracy
436                        || candidates[j].estimated_latency < candidates[i].estimated_latency);
437                if j_dom_i {
438                    dominated[i] = true;
439                    break;
440                }
441            }
442        }
443
444        (0..n).filter(|&i| !dominated[i]).collect()
445    }
446}
447
448// ═══════════════════════════════════════════════════════════════════ tests ═══
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    fn make_searcher() -> HardwareNasSearcher {
455        HardwareNasSearcher::new(HardwareNasConfig::default(), LatencyTable::new())
456    }
457
458    /// A simple accuracy oracle: favour architectures with more identity/skip ops.
459    fn acc_oracle(ops: &[Operation]) -> f64 {
460        let light_count = ops
461            .iter()
462            .filter(|o| matches!(o, Operation::Identity | Operation::SkipConnect))
463            .count();
464        0.5 + 0.05 * light_count as f64
465    }
466
467    #[test]
468    fn latency_table_default_contains_ops() {
469        let lt = LatencyTable::new();
470        assert!(lt.latency_of("conv3x3", 1000) > 0.0);
471        assert_eq!(lt.latency_of("zero", 0), 0.0);
472    }
473
474    #[test]
475    fn total_latency_sums_correctly() {
476        let lt = LatencyTable::new();
477        let arch = vec![("conv3x3".to_string(), 0), ("max_pool".to_string(), 0)];
478        let total = lt.total_latency(&arch);
479        let expected = lt.latency_of("conv3x3", 0) + lt.latency_of("max_pool", 0);
480        assert!((total - expected).abs() < 1e-12);
481    }
482
483    #[test]
484    fn random_search_finds_valid_candidate() {
485        let mut config = HardwareNasConfig::default();
486        // Use a very loose latency budget to ensure we always find something.
487        config.max_latency_ms = 10_000.0;
488        config.n_search_iter = 50;
489        config.n_ops_per_arch = 4;
490        let mut searcher = HardwareNasSearcher::new(config, LatencyTable::new());
491        let result = searcher.random_search(acc_oracle);
492        assert!(result.is_ok(), "Should find a valid candidate");
493        let cand = result.unwrap();
494        assert!(cand.estimated_latency <= 10_000.0);
495    }
496
497    #[test]
498    fn pareto_front_returns_non_dominated_subset() {
499        let candidates = vec![
500            ArchCandidate {
501                operations: vec![],
502                estimated_accuracy: 0.9,
503                estimated_latency: 5.0,
504                n_params: 100,
505            },
506            ArchCandidate {
507                operations: vec![],
508                estimated_accuracy: 0.8,
509                estimated_latency: 3.0,
510                n_params: 80,
511            },
512            ArchCandidate {
513                operations: vec![],
514                estimated_accuracy: 0.7,
515                estimated_latency: 8.0, // dominated by both above
516                n_params: 90,
517            },
518        ];
519        let front = HardwareNasSearcher::pareto_front(&candidates);
520        assert!(
521            front.contains(&0),
522            "high accuracy / moderate latency should be on front"
523        );
524        assert!(
525            front.contains(&1),
526            "low latency / moderate accuracy should be on front"
527        );
528        assert!(
529            !front.contains(&2),
530            "dominated candidate should not be on front"
531        );
532    }
533
534    #[test]
535    fn evolutionary_search_runs() {
536        let mut config = HardwareNasConfig::default();
537        config.max_latency_ms = 10_000.0;
538        config.population_size = 10;
539        config.n_generations = 5;
540        config.n_ops_per_arch = 4;
541        let mut searcher = HardwareNasSearcher::new(config, LatencyTable::new());
542        let result = searcher.evolutionary_search(acc_oracle);
543        assert!(
544            result.is_ok(),
545            "Evolutionary search should find a candidate"
546        );
547    }
548
549    #[test]
550    fn pareto_front_single_candidate() {
551        let candidates = vec![ArchCandidate {
552            operations: vec![],
553            estimated_accuracy: 0.85,
554            estimated_latency: 4.0,
555            n_params: 50,
556        }];
557        let front = HardwareNasSearcher::pareto_front(&candidates);
558        assert_eq!(front, vec![0]);
559    }
560}