tensorlogic_train/nas/
sampler.rs1use scirs2_core::random::{SeedableRng, StdRng};
7
8use crate::error::{TrainError, TrainResult};
9
10use super::space::{ArchSearchSpace, Architecture, LayerSpec};
11
12pub struct ArchSampler {
16 space: ArchSearchSpace,
17 rng: StdRng,
18}
19
20impl ArchSampler {
21 pub fn new(space: ArchSearchSpace, seed: u64) -> Self {
23 Self {
24 space,
25 rng: StdRng::seed_from_u64(seed),
26 }
27 }
28
29 pub fn random_architecture(&mut self) -> TrainResult<Architecture> {
35 let depth_range = self.space.max_depth - self.space.min_depth + 1;
36 let depth = self.space.min_depth + self.rng.gen_range(0..depth_range);
37
38 let mut layers = Vec::with_capacity(depth);
39 for _ in 0..depth {
40 layers.push(self.sample_layer()?);
41 }
42
43 Ok(Architecture { layers })
44 }
45
46 pub fn mutate(&mut self, arch: &Architecture) -> TrainResult<Architecture> {
57 let mut new_arch = arch.clone();
58 let mutation_type = self.rng.gen_range(0..4_usize);
59
60 match mutation_type {
61 0 => {
62 let layer_idx = self.pick_layer_index(&new_arch)?;
64 let new_op = self.pick_option(&self.space.op_options.clone())?;
65 new_arch.layers[layer_idx].op = new_op;
66 }
67 1 => {
68 let layer_idx = self.pick_layer_index(&new_arch)?;
70 let new_width = self.pick_width()?;
71 new_arch.layers[layer_idx].width = new_width;
72 }
73 2 => {
74 let layer_idx = self.pick_layer_index(&new_arch)?;
76 let new_act = self.pick_option(&self.space.activation_options.clone())?;
77 new_arch.layers[layer_idx].activation = new_act;
78 }
79 3 => {
80 let can_add = new_arch.depth() < self.space.max_depth;
82 let can_remove = new_arch.depth() > self.space.min_depth;
83
84 if can_add && can_remove {
85 if self.rng.gen_range(0..2_usize) == 0 {
87 self.add_random_layer(&mut new_arch)?;
88 } else {
89 self.remove_random_layer(&mut new_arch)?;
90 }
91 } else if can_add {
92 self.add_random_layer(&mut new_arch)?;
93 } else if can_remove {
94 self.remove_random_layer(&mut new_arch)?;
95 } else {
96 let layer_idx = self.pick_layer_index(&new_arch)?;
98 let new_op = self.pick_option(&self.space.op_options.clone())?;
99 new_arch.layers[layer_idx].op = new_op;
100 }
101 }
102 _ => unreachable!("gen_range(0..4) is always in 0..3"),
103 }
104
105 Ok(new_arch)
106 }
107
108 fn sample_layer(&mut self) -> TrainResult<LayerSpec> {
112 let op = self.pick_option(&self.space.op_options.clone())?;
113 let width = self.pick_width()?;
114 let activation = self.pick_option(&self.space.activation_options.clone())?;
115 Ok(LayerSpec {
116 op,
117 width,
118 activation,
119 })
120 }
121
122 fn pick_option(&mut self, options: &[String]) -> TrainResult<String> {
124 if options.is_empty() {
125 return Err(TrainError::InvalidParameter(
126 "option list must be non-empty".to_string(),
127 ));
128 }
129 let idx = self.rng.gen_range(0..options.len());
130 Ok(options[idx].clone())
131 }
132
133 fn pick_width(&mut self) -> TrainResult<usize> {
135 if self.space.width_options.is_empty() {
136 return Err(TrainError::InvalidParameter(
137 "width_options must be non-empty".to_string(),
138 ));
139 }
140 let idx = self.rng.gen_range(0..self.space.width_options.len());
141 Ok(self.space.width_options[idx])
142 }
143
144 fn pick_layer_index(&mut self, arch: &Architecture) -> TrainResult<usize> {
146 if arch.layers.is_empty() {
147 return Err(TrainError::InvalidParameter(
148 "architecture has no layers to mutate".to_string(),
149 ));
150 }
151 Ok(self.rng.gen_range(0..arch.layers.len()))
152 }
153
154 fn add_random_layer(&mut self, arch: &mut Architecture) -> TrainResult<()> {
156 let new_layer = self.sample_layer()?;
157 let pos = self.rng.gen_range(0..=arch.layers.len());
159 arch.layers.insert(pos, new_layer);
160 Ok(())
161 }
162
163 fn remove_random_layer(&mut self, arch: &mut Architecture) -> TrainResult<()> {
165 let idx = self.pick_layer_index(arch)?;
166 arch.layers.remove(idx);
167 Ok(())
168 }
169
170 pub fn gen_range_usize(&mut self, lower: usize, upper: usize) -> usize {
175 if upper <= lower {
176 return lower;
177 }
178 lower + self.rng.gen_range(0..(upper - lower))
179 }
180}