Skip to main content

tensorlogic_train/nas/
sampler.rs

1//! Random architecture sampler and mutation operators for NAS.
2//!
3//! [`ArchSampler`] draws uniformly random architectures from an [`ArchSearchSpace`]
4//! and provides a neighbourhood mutation used by evolutionary search.
5
6use scirs2_core::random::{SeedableRng, StdRng};
7
8use crate::error::{TrainError, TrainResult};
9
10use super::space::{ArchSearchSpace, Architecture, LayerSpec};
11
12// ─── ArchSampler ────────────────────────────────────────────────────────────
13
14/// Samples and mutates architectures within an [`ArchSearchSpace`].
15pub struct ArchSampler {
16    space: ArchSearchSpace,
17    rng: StdRng,
18}
19
20impl ArchSampler {
21    /// Create a new sampler for the given search space, seeded for reproducibility.
22    pub fn new(space: ArchSearchSpace, seed: u64) -> Self {
23        Self {
24            space,
25            rng: StdRng::seed_from_u64(seed),
26        }
27    }
28
29    /// Sample a uniformly random architecture within the search space.
30    ///
31    /// Depth is drawn uniformly from `[min_depth, max_depth]`.  Each layer's
32    /// op, width, and activation are drawn independently and uniformly from
33    /// their respective option lists.
34    pub fn random_architecture(&mut self) -> TrainResult<Architecture> {
35        let depth_range = self.space.max_depth - self.space.min_depth + 1;
36        let depth = self.space.min_depth + self.rng.gen_range(0..depth_range);
37
38        let mut layers = Vec::with_capacity(depth);
39        for _ in 0..depth {
40            layers.push(self.sample_layer()?);
41        }
42
43        Ok(Architecture { layers })
44    }
45
46    /// Mutate one aspect of `arch` at random, returning a new valid architecture.
47    ///
48    /// Four mutations are selected uniformly at random:
49    ///
50    /// | Index | Mutation |
51    /// |-------|----------|
52    /// | 0 | Change a random layer's **op** |
53    /// | 1 | Change a random layer's **width** |
54    /// | 2 | Change a random layer's **activation** |
55    /// | 3 | **Add** a layer (if depth < max_depth) or **remove** one (if depth > min_depth); if neither is possible, fall back to changing op |
56    pub fn mutate(&mut self, arch: &Architecture) -> TrainResult<Architecture> {
57        let mut new_arch = arch.clone();
58        let mutation_type = self.rng.gen_range(0..4_usize);
59
60        match mutation_type {
61            0 => {
62                // Change a random layer's op
63                let layer_idx = self.pick_layer_index(&new_arch)?;
64                let new_op = self.pick_option(&self.space.op_options.clone())?;
65                new_arch.layers[layer_idx].op = new_op;
66            }
67            1 => {
68                // Change a random layer's width
69                let layer_idx = self.pick_layer_index(&new_arch)?;
70                let new_width = self.pick_width()?;
71                new_arch.layers[layer_idx].width = new_width;
72            }
73            2 => {
74                // Change a random layer's activation
75                let layer_idx = self.pick_layer_index(&new_arch)?;
76                let new_act = self.pick_option(&self.space.activation_options.clone())?;
77                new_arch.layers[layer_idx].activation = new_act;
78            }
79            3 => {
80                // Add or remove a layer
81                let can_add = new_arch.depth() < self.space.max_depth;
82                let can_remove = new_arch.depth() > self.space.min_depth;
83
84                if can_add && can_remove {
85                    // Choose randomly
86                    if self.rng.gen_range(0..2_usize) == 0 {
87                        self.add_random_layer(&mut new_arch)?;
88                    } else {
89                        self.remove_random_layer(&mut new_arch)?;
90                    }
91                } else if can_add {
92                    self.add_random_layer(&mut new_arch)?;
93                } else if can_remove {
94                    self.remove_random_layer(&mut new_arch)?;
95                } else {
96                    // Neither add nor remove possible — fall back to op change
97                    let layer_idx = self.pick_layer_index(&new_arch)?;
98                    let new_op = self.pick_option(&self.space.op_options.clone())?;
99                    new_arch.layers[layer_idx].op = new_op;
100                }
101            }
102            _ => unreachable!("gen_range(0..4) is always in 0..3"),
103        }
104
105        Ok(new_arch)
106    }
107
108    // ── private helpers ──────────────────────────────────────────────────
109
110    /// Sample a single fresh layer from the search space.
111    fn sample_layer(&mut self) -> TrainResult<LayerSpec> {
112        let op = self.pick_option(&self.space.op_options.clone())?;
113        let width = self.pick_width()?;
114        let activation = self.pick_option(&self.space.activation_options.clone())?;
115        Ok(LayerSpec {
116            op,
117            width,
118            activation,
119        })
120    }
121
122    /// Return a random element from `options` (assumed non-empty by space invariants).
123    fn pick_option(&mut self, options: &[String]) -> TrainResult<String> {
124        if options.is_empty() {
125            return Err(TrainError::InvalidParameter(
126                "option list must be non-empty".to_string(),
127            ));
128        }
129        let idx = self.rng.gen_range(0..options.len());
130        Ok(options[idx].clone())
131    }
132
133    /// Return a random width from `width_options`.
134    fn pick_width(&mut self) -> TrainResult<usize> {
135        if self.space.width_options.is_empty() {
136            return Err(TrainError::InvalidParameter(
137                "width_options must be non-empty".to_string(),
138            ));
139        }
140        let idx = self.rng.gen_range(0..self.space.width_options.len());
141        Ok(self.space.width_options[idx])
142    }
143
144    /// Return a valid random index into `arch.layers`.
145    fn pick_layer_index(&mut self, arch: &Architecture) -> TrainResult<usize> {
146        if arch.layers.is_empty() {
147            return Err(TrainError::InvalidParameter(
148                "architecture has no layers to mutate".to_string(),
149            ));
150        }
151        Ok(self.rng.gen_range(0..arch.layers.len()))
152    }
153
154    /// Insert a new random layer at a random position.
155    fn add_random_layer(&mut self, arch: &mut Architecture) -> TrainResult<()> {
156        let new_layer = self.sample_layer()?;
157        // Insert at a random position in [0, depth]
158        let pos = self.rng.gen_range(0..=arch.layers.len());
159        arch.layers.insert(pos, new_layer);
160        Ok(())
161    }
162
163    /// Remove the layer at a random position.
164    fn remove_random_layer(&mut self, arch: &mut Architecture) -> TrainResult<()> {
165        let idx = self.pick_layer_index(arch)?;
166        arch.layers.remove(idx);
167        Ok(())
168    }
169
170    /// Generate a uniformly random `usize` in `[lower, upper)`.
171    ///
172    /// Used by [`super::RegularizedEvolution`] to share the sampler's RNG for
173    /// tournament index shuffling.  Returns `lower` when `upper <= lower`.
174    pub fn gen_range_usize(&mut self, lower: usize, upper: usize) -> usize {
175        if upper <= lower {
176            return lower;
177        }
178        lower + self.rng.gen_range(0..(upper - lower))
179    }
180}