tensorlogic_train/nas/
space.rs1use std::collections::HashMap;
7
8use crate::error::{TrainError, TrainResult};
9use crate::hyperparameter::{HyperparamConfig, HyperparamValue};
10
11#[derive(Debug, Clone, PartialEq)]
15pub struct LayerSpec {
16 pub op: String,
18 pub width: usize,
20 pub activation: String,
22}
23
24#[derive(Debug, Clone, PartialEq)]
28pub struct Architecture {
29 pub layers: Vec<LayerSpec>,
31}
32
33impl Architecture {
34 pub fn param_count(&self) -> usize {
38 if self.layers.len() < 2 {
39 return 0;
40 }
41 self.layers
42 .windows(2)
43 .map(|w| w[0].width * w[1].width)
44 .sum()
45 }
46
47 pub fn depth(&self) -> usize {
49 self.layers.len()
50 }
51
52 pub fn to_config(&self) -> HyperparamConfig {
57 let mut m: HashMap<String, HyperparamValue> = HashMap::new();
58 m.insert(
59 "depth".to_string(),
60 HyperparamValue::Int(self.layers.len() as i64),
61 );
62 for (i, layer) in self.layers.iter().enumerate() {
63 m.insert(
64 format!("layer_{i}_op"),
65 HyperparamValue::String(layer.op.clone()),
66 );
67 m.insert(
68 format!("layer_{i}_width"),
69 HyperparamValue::Int(layer.width as i64),
70 );
71 m.insert(
72 format!("layer_{i}_activation"),
73 HyperparamValue::String(layer.activation.clone()),
74 );
75 }
76 m
77 }
78
79 pub fn from_config(cfg: &HyperparamConfig, max_depth: usize) -> TrainResult<Self> {
83 let depth = cfg.get("depth").and_then(|v| v.as_int()).ok_or_else(|| {
84 TrainError::InvalidParameter("config missing 'depth' Int key".to_string())
85 })?;
86
87 if depth < 1 {
88 return Err(TrainError::InvalidParameter(format!(
89 "decoded depth {depth} must be ≥ 1"
90 )));
91 }
92 if depth as usize > max_depth {
93 return Err(TrainError::InvalidParameter(format!(
94 "decoded depth {depth} exceeds max_depth {max_depth}"
95 )));
96 }
97
98 let mut layers = Vec::with_capacity(depth as usize);
99 for i in 0..depth as usize {
100 let op = cfg
101 .get(&format!("layer_{i}_op"))
102 .and_then(|v| v.as_string())
103 .ok_or_else(|| {
104 TrainError::InvalidParameter(format!(
105 "config missing 'layer_{i}_op' String key"
106 ))
107 })?
108 .to_string();
109
110 let width = cfg
111 .get(&format!("layer_{i}_width"))
112 .and_then(|v| v.as_int())
113 .ok_or_else(|| {
114 TrainError::InvalidParameter(format!(
115 "config missing 'layer_{i}_width' Int key"
116 ))
117 })?;
118
119 if width < 1 {
120 return Err(TrainError::InvalidParameter(format!(
121 "layer {i} width {width} must be ≥ 1"
122 )));
123 }
124
125 let activation = cfg
126 .get(&format!("layer_{i}_activation"))
127 .and_then(|v| v.as_string())
128 .ok_or_else(|| {
129 TrainError::InvalidParameter(format!(
130 "config missing 'layer_{i}_activation' String key"
131 ))
132 })?
133 .to_string();
134
135 layers.push(LayerSpec {
136 op,
137 width: width as usize,
138 activation,
139 });
140 }
141
142 Ok(Architecture { layers })
143 }
144}
145
146#[derive(Debug, Clone)]
153pub struct ArchSearchSpace {
154 pub min_depth: usize,
156 pub max_depth: usize,
158 pub width_options: Vec<usize>,
160 pub activation_options: Vec<String>,
162 pub op_options: Vec<String>,
164}
165
166impl ArchSearchSpace {
167 pub fn new(
176 min_depth: usize,
177 max_depth: usize,
178 width_options: Vec<usize>,
179 activation_options: Vec<String>,
180 op_options: Vec<String>,
181 ) -> TrainResult<Self> {
182 if min_depth < 1 {
183 return Err(TrainError::InvalidParameter(
184 "min_depth must be ≥ 1".to_string(),
185 ));
186 }
187 if max_depth < min_depth {
188 return Err(TrainError::InvalidParameter(format!(
189 "max_depth ({max_depth}) must be ≥ min_depth ({min_depth})"
190 )));
191 }
192 if width_options.is_empty() {
193 return Err(TrainError::InvalidParameter(
194 "width_options must be non-empty".to_string(),
195 ));
196 }
197 if activation_options.is_empty() {
198 return Err(TrainError::InvalidParameter(
199 "activation_options must be non-empty".to_string(),
200 ));
201 }
202 if op_options.is_empty() {
203 return Err(TrainError::InvalidParameter(
204 "op_options must be non-empty".to_string(),
205 ));
206 }
207 Ok(Self {
208 min_depth,
209 max_depth,
210 width_options,
211 activation_options,
212 op_options,
213 })
214 }
215}