Skip to main content

scirs2_optimize/nas/
random_nas.rs

1//! Random Neural Architecture Search baseline.
2//!
3//! Implements the simplest NAS strategy: uniformly sample architectures
4//! from the search space and evaluate each with a fitness function.
5//! Serves as a strong baseline for more sophisticated methods.
6
7use crate::error::OptimizeError;
8use crate::nas::search_space::{Architecture, SearchSpace};
9use scirs2_core::random::{rngs::StdRng, SeedableRng};
10
11/// Result of a NAS search run.
12#[derive(Debug, Clone)]
13pub struct NASResult {
14    /// Architecture with the highest fitness score found
15    pub best_arch: Architecture,
16    /// Fitness score of the best architecture
17    pub best_score: f64,
18    /// All fitness scores in the order they were evaluated
19    pub all_scores: Vec<f64>,
20    /// Total number of architectures evaluated
21    pub n_evaluated: usize,
22}
23
24/// Trait for evaluating the fitness of an architecture.
25///
26/// Implementations should be deterministic for reproducibility.
27pub trait ArchFitness: Send + Sync {
28    /// Evaluate an architecture and return a scalar fitness score.
29    ///
30    /// Higher scores are better (the search maximizes this value).
31    fn evaluate(&self, arch: &Architecture) -> Result<f64, OptimizeError>;
32}
33
34/// Proxy fitness: score based on closeness to a target parameter count.
35///
36/// Returns 0.0 when the architecture matches the target exactly,
37/// with negative scores proportional to relative deviation.
38pub struct ParamCountFitness {
39    /// Desired number of parameters
40    pub target_params: usize,
41}
42
43impl ParamCountFitness {
44    /// Create a new `ParamCountFitness` with the given target.
45    pub fn new(target_params: usize) -> Self {
46        Self { target_params }
47    }
48}
49
50impl ArchFitness for ParamCountFitness {
51    fn evaluate(&self, arch: &Architecture) -> Result<f64, OptimizeError> {
52        let params = arch.total_params() as f64;
53        let target = self.target_params as f64;
54        if target == 0.0 {
55            return Ok(if params == 0.0 { 0.0 } else { -1.0 });
56        }
57        Ok(-(params - target).abs() / target)
58    }
59}
60
61/// Proxy fitness based on FLOPs efficiency at a given spatial resolution.
62pub struct FlopsFitness {
63    /// Maximum FLOPs budget (architectures exceeding this are penalized)
64    pub flops_budget: usize,
65    /// Spatial dimension used for FLOPs estimation
66    pub spatial: usize,
67}
68
69impl FlopsFitness {
70    /// Create a new `FlopsFitness`.
71    pub fn new(flops_budget: usize, spatial: usize) -> Self {
72        Self {
73            flops_budget,
74            spatial,
75        }
76    }
77}
78
79impl ArchFitness for FlopsFitness {
80    fn evaluate(&self, arch: &Architecture) -> Result<f64, OptimizeError> {
81        let flops = arch.total_flops(self.spatial) as f64;
82        let budget = self.flops_budget as f64;
83        if budget == 0.0 {
84            return Ok(0.0);
85        }
86        // Reward for staying under budget; penalize excess
87        if flops <= budget {
88            Ok(flops / budget)
89        } else {
90            Ok(-(flops - budget) / budget)
91        }
92    }
93}
94
95/// Random Neural Architecture Search.
96///
97/// Samples `n_trials` architectures uniformly at random from the search
98/// space and returns the one with the highest fitness score.
99pub struct RandomNAS {
100    /// Number of random architectures to evaluate
101    pub n_trials: usize,
102}
103
104impl RandomNAS {
105    /// Create a new `RandomNAS` with the specified trial budget.
106    pub fn new(n_trials: usize) -> Self {
107        Self { n_trials }
108    }
109
110    /// Run random search over the given `space` using `fitness`.
111    ///
112    /// # Arguments
113    /// - `space`: The architecture search space to sample from.
114    /// - `fitness`: Fitness evaluator (higher = better).
115    /// - `seed`: Random seed for reproducibility.
116    pub fn search<F: ArchFitness>(
117        &self,
118        space: &SearchSpace,
119        fitness: &F,
120        seed: u64,
121    ) -> Result<NASResult, OptimizeError> {
122        use scirs2_core::random::{Rng, RngExt};
123
124        if self.n_trials == 0 {
125            return Err(OptimizeError::InvalidParameter(
126                "n_trials must be at least 1".to_string(),
127            ));
128        }
129
130        let mut rng = StdRng::seed_from_u64(seed);
131
132        let mut best_score = f64::NEG_INFINITY;
133        // Sample the initial best architecture before the loop
134        let mut best_arch = space.sample_random(&mut rng);
135        let mut all_scores = Vec::with_capacity(self.n_trials);
136
137        for _ in 0..self.n_trials {
138            let arch = space.sample_random(&mut rng);
139            let score = fitness.evaluate(&arch)?;
140            all_scores.push(score);
141            if score > best_score {
142                best_score = score;
143                best_arch = arch;
144            }
145        }
146
147        Ok(NASResult {
148            best_arch,
149            best_score,
150            all_scores,
151            n_evaluated: self.n_trials,
152        })
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::nas::search_space::SearchSpace;
160
161    #[test]
162    fn test_random_nas_returns_result() {
163        let space = SearchSpace::darts_like(3);
164        let fitness = ParamCountFitness::new(10_000);
165        let nas = RandomNAS::new(20);
166
167        let result = nas.search(&space, &fitness, 0).expect("search failed");
168
169        assert_eq!(result.n_evaluated, 20);
170        assert_eq!(result.all_scores.len(), 20);
171        // best_score is finite
172        assert!(result.best_score.is_finite());
173    }
174
175    #[test]
176    fn test_random_nas_zero_trials_errors() {
177        let space = SearchSpace::darts_like(3);
178        let fitness = ParamCountFitness::new(10_000);
179        let nas = RandomNAS::new(0);
180
181        assert!(nas.search(&space, &fitness, 0).is_err());
182    }
183
184    #[test]
185    fn test_param_count_fitness_exact_match() {
186        let mut arch = Architecture::new(1, 32, 10);
187        // Architecture with zero params
188        let fitness = ParamCountFitness::new(0);
189        let score = fitness.evaluate(&arch).expect("eval failed");
190        assert_eq!(score, 0.0);
191
192        // Arch with non-zero params vs 0 target
193        use crate::nas::search_space::{ArchEdge, ArchNode, OpType};
194        arch.nodes.push(ArchNode {
195            id: 0,
196            name: "n0".into(),
197            output_channels: 32,
198        });
199        arch.nodes.push(ArchNode {
200            id: 1,
201            name: "n1".into(),
202            output_channels: 32,
203        });
204        arch.edges.push(ArchEdge {
205            from: 0,
206            to: 1,
207            op: OpType::Conv3x3,
208        });
209        let fitness2 = ParamCountFitness::new(0);
210        let score2 = fitness2.evaluate(&arch).expect("eval failed");
211        assert_eq!(score2, -1.0);
212    }
213
214    #[test]
215    fn test_flops_fitness_under_budget() {
216        use crate::nas::search_space::{ArchEdge, ArchNode, OpType};
217        let mut arch = Architecture::new(1, 8, 10);
218        arch.nodes.push(ArchNode {
219            id: 0,
220            name: "n0".into(),
221            output_channels: 8,
222        });
223        arch.nodes.push(ArchNode {
224            id: 1,
225            name: "n1".into(),
226            output_channels: 8,
227        });
228        arch.edges.push(ArchEdge {
229            from: 0,
230            to: 1,
231            op: OpType::Skip,
232        });
233
234        let fitness = FlopsFitness::new(1_000_000, 8);
235        let score = fitness.evaluate(&arch).expect("eval failed");
236        // Skip has near-zero flops, so score should be between 0 and 1
237        assert!(score >= 0.0 && score <= 1.0);
238    }
239}