scirs2_optimize/nas/
differentiable.rs1use crate::error::OptimizeError;
15use crate::nas::search_space::{ArchEdge, ArchNode, Architecture, OpType, SearchSpace};
16
17#[derive(Debug, Clone)]
22pub struct DARTSSearch {
23 pub n_nodes: usize,
25 pub n_ops: usize,
27 pub alpha: Vec<Vec<f64>>,
29 pub learning_rate: f64,
31 pub n_input_nodes: usize,
33}
34
35impl DARTSSearch {
36 pub fn new(n_nodes: usize, operations: &[OpType], n_input_nodes: usize) -> Self {
43 let n_ops = operations.len();
44 let n_edges: usize = (0..n_nodes).map(|i| n_input_nodes + i).sum();
48 let init_weight = if n_ops > 0 { 1.0 / n_ops as f64 } else { 0.0 };
49 let alpha = vec![vec![init_weight; n_ops]; n_edges.max(1)];
50
51 Self {
52 n_nodes,
53 n_ops,
54 alpha,
55 learning_rate: 3e-4,
56 n_input_nodes,
57 }
58 }
59
60 pub fn n_edges(&self) -> usize {
62 self.alpha.len()
63 }
64
65 pub fn get_op_weights(&self, edge_idx: usize) -> Vec<f64> {
69 if edge_idx >= self.alpha.len() {
70 return vec![0.0; self.n_ops];
71 }
72 let raw = &self.alpha[edge_idx];
73 let max = raw.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
74 let exp: Vec<f64> = raw.iter().map(|x| (x - max).exp()).collect();
75 let sum: f64 = exp.iter().sum();
76 if sum == 0.0 {
77 return vec![1.0 / self.n_ops as f64; self.n_ops];
78 }
79 exp.iter().map(|e| e / sum).collect()
80 }
81
82 pub fn derive_architecture(
86 &self,
87 space: &SearchSpace,
88 n_cells: usize,
89 channels: usize,
90 n_classes: usize,
91 ) -> Architecture {
92 let mut arch = Architecture::new(n_cells, channels, n_classes);
93
94 for i in 0..self.n_input_nodes {
96 arch.nodes.push(ArchNode {
97 id: i,
98 name: format!("input{}", i),
99 output_channels: channels,
100 });
101 }
102
103 let mut edge_idx = 0usize;
105 for i in 0..self.n_nodes {
106 let node_id = self.n_input_nodes + i;
107 arch.nodes.push(ArchNode {
108 id: node_id,
109 name: format!("node{}", i),
110 output_channels: channels,
111 });
112
113 let n_prev = self.n_input_nodes + i;
115 for from_id in 0..n_prev {
116 let weights = self.get_op_weights(edge_idx);
117 let best_op_idx = weights
118 .iter()
119 .enumerate()
120 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
121 .map(|(idx, _)| idx)
122 .unwrap_or(0);
123
124 let op = space
125 .operations
126 .get(best_op_idx)
127 .cloned()
128 .unwrap_or(OpType::Skip);
129
130 arch.edges.push(ArchEdge {
131 from: from_id,
132 to: node_id,
133 op,
134 });
135 edge_idx += 1;
136 }
137 }
138
139 arch
140 }
141
142 pub fn update_alpha(
147 &mut self,
148 edge_idx: usize,
149 op_idx: usize,
150 grad: f64,
151 ) -> Result<(), OptimizeError> {
152 if edge_idx >= self.alpha.len() {
153 return Err(OptimizeError::InvalidParameter(format!(
154 "edge_idx {} out of range (n_edges = {})",
155 edge_idx,
156 self.alpha.len()
157 )));
158 }
159 if op_idx >= self.n_ops {
160 return Err(OptimizeError::InvalidParameter(format!(
161 "op_idx {} out of range (n_ops = {})",
162 op_idx, self.n_ops
163 )));
164 }
165 self.alpha[edge_idx][op_idx] += self.learning_rate * grad;
166 Ok(())
167 }
168
169 pub fn update_alpha_batch(&mut self, grads: &[Vec<f64>]) -> Result<(), OptimizeError> {
173 if grads.len() != self.alpha.len() {
174 return Err(OptimizeError::InvalidParameter(format!(
175 "grads has {} rows but alpha has {}",
176 grads.len(),
177 self.alpha.len()
178 )));
179 }
180 for (e, row) in grads.iter().enumerate() {
181 if row.len() != self.n_ops {
182 return Err(OptimizeError::InvalidParameter(format!(
183 "grads[{}] has {} columns but n_ops = {}",
184 e,
185 row.len(),
186 self.n_ops
187 )));
188 }
189 for (k, &g) in row.iter().enumerate() {
190 self.alpha[e][k] += self.learning_rate * g;
191 }
192 }
193 Ok(())
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::nas::search_space::SearchSpace;
201
202 fn make_darts() -> DARTSSearch {
203 let space = SearchSpace::darts_like(4);
204 DARTSSearch::new(4, &space.operations, 2)
205 }
206
207 #[test]
208 fn test_get_op_weights_sum_to_one() {
209 let darts = make_darts();
210 for e in 0..darts.n_edges() {
211 let w = darts.get_op_weights(e);
212 let sum: f64 = w.iter().sum();
213 assert!(
214 (sum - 1.0).abs() < 1e-10,
215 "weights do not sum to 1: {}",
216 sum
217 );
218 }
219 }
220
221 #[test]
222 fn test_derive_architecture_correct_structure() {
223 let space = SearchSpace::darts_like(4);
224 let darts = DARTSSearch::new(4, &space.operations, 2);
225 let arch = darts.derive_architecture(&space, 2, 64, 10);
226
227 assert_eq!(arch.nodes.len(), 2 + 4);
229 for e in &arch.edges {
231 assert!(e.from < arch.nodes.len());
232 assert!(e.to < arch.nodes.len());
233 }
234 }
235
236 #[test]
237 fn test_update_alpha_changes_weights() {
238 let mut darts = make_darts();
239 let before = darts.alpha[0][0];
240 darts.update_alpha(0, 0, 1.0).expect("update failed");
241 assert!(
242 (darts.alpha[0][0] - before).abs() > 1e-12,
243 "alpha did not change"
244 );
245 }
246
247 #[test]
248 fn test_update_alpha_out_of_range_errors() {
249 let mut darts = make_darts();
250 assert!(darts.update_alpha(9999, 0, 1.0).is_err());
251 assert!(darts.update_alpha(0, 9999, 1.0).is_err());
252 }
253
254 #[test]
255 fn test_update_alpha_batch_correct_shape() {
256 let mut darts = make_darts();
257 let n_e = darts.n_edges();
258 let n_o = darts.n_ops;
259 let grads = vec![vec![0.1; n_o]; n_e];
260 darts
261 .update_alpha_batch(&grads)
262 .expect("batch update failed");
263 }
264
265 #[test]
266 fn test_update_alpha_batch_wrong_shape_errors() {
267 let mut darts = make_darts();
268 let grads = vec![vec![0.1; darts.n_ops]; darts.n_edges() + 1];
269 assert!(darts.update_alpha_batch(&grads).is_err());
270 }
271
272 #[test]
273 fn test_argmax_selects_highest_weight() {
274 let space = SearchSpace::darts_like(2);
275 let mut darts = DARTSSearch::new(2, &space.operations, 2);
276 let n_ops = darts.n_ops;
278 for k in 0..n_ops {
279 darts.alpha[0][k] = 0.0;
280 }
281 darts.alpha[0][3] = 10.0;
282
283 let arch = darts.derive_architecture(&space, 1, 32, 10);
284 if let Some(e) = arch.edges.first() {
286 assert_eq!(e.op, space.operations[3]);
287 }
288 }
289}