1use std::collections::VecDeque;
7
8#[derive(Debug, Clone, PartialEq)]
10pub enum OpType {
11 Identity,
13 Conv3x3,
15 Conv5x5,
17 DilatedConv3x3,
19 DepthwiseSep3x3,
21 MaxPool3x3,
23 AvgPool3x3,
25 Skip,
27 Zero,
29 Linear(usize),
31 GRU,
33 LSTM,
35}
36
37impl OpType {
38 pub fn num_params(&self, in_channels: usize) -> usize {
40 match self {
41 Self::Identity | Self::Skip | Self::Zero | Self::MaxPool3x3 | Self::AvgPool3x3 => 0,
42 Self::Conv3x3 => 9 * in_channels * in_channels,
43 Self::Conv5x5 => 25 * in_channels * in_channels,
44 Self::DilatedConv3x3 => 9 * in_channels * in_channels,
45 Self::DepthwiseSep3x3 => 9 * in_channels + in_channels * in_channels,
46 Self::Linear(out) => in_channels * out,
47 Self::GRU | Self::LSTM => 4 * in_channels * in_channels,
48 }
49 }
50
51 pub fn flops(&self, in_channels: usize, spatial: usize) -> usize {
53 let spatial_sq = spatial * spatial;
54 match self {
55 Self::Conv3x3 => 9 * 2 * in_channels * in_channels * spatial_sq,
56 Self::Conv5x5 => 25 * 2 * in_channels * in_channels * spatial_sq,
57 Self::DilatedConv3x3 => 9 * 2 * in_channels * in_channels * spatial_sq,
58 Self::DepthwiseSep3x3 => (9 * in_channels + in_channels * in_channels) * spatial_sq,
59 Self::Linear(out) => in_channels * out * 2,
60 Self::GRU | Self::LSTM => 4 * in_channels * in_channels * 2,
61 _ => in_channels * spatial_sq,
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct ArchNode {
69 pub id: usize,
71 pub name: String,
73 pub output_channels: usize,
75}
76
77#[derive(Debug, Clone)]
79pub struct ArchEdge {
80 pub from: usize,
82 pub to: usize,
84 pub op: OpType,
86}
87
88#[derive(Debug, Clone)]
90pub struct Architecture {
91 pub nodes: Vec<ArchNode>,
93 pub edges: Vec<ArchEdge>,
95 pub n_cells: usize,
97 pub channels: usize,
99 pub n_classes: usize,
101}
102
103impl Architecture {
104 pub fn new(n_cells: usize, channels: usize, n_classes: usize) -> Self {
106 Self {
107 nodes: Vec::new(),
108 edges: Vec::new(),
109 n_cells,
110 channels,
111 n_classes,
112 }
113 }
114
115 pub fn total_params(&self) -> usize {
117 self.edges
118 .iter()
119 .map(|e| e.op.num_params(self.channels))
120 .sum()
121 }
122
123 pub fn total_flops(&self, spatial: usize) -> usize {
125 self.edges
126 .iter()
127 .map(|e| e.op.flops(self.channels, spatial))
128 .sum()
129 }
130
131 pub fn topological_sort(&self) -> Vec<usize> {
136 let n = self.nodes.len();
137 let mut in_degree = vec![0usize; n];
138 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
139
140 for e in &self.edges {
141 if e.from < n && e.to < n {
142 adj[e.from].push(e.to);
143 in_degree[e.to] = in_degree[e.to].saturating_add(1);
144 }
145 }
146
147 let mut queue: VecDeque<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
148 let mut order = Vec::with_capacity(n);
149
150 while let Some(v) = queue.pop_front() {
151 order.push(v);
152 for &u in &adj[v] {
153 in_degree[u] -= 1;
154 if in_degree[u] == 0 {
155 queue.push_back(u);
156 }
157 }
158 }
159 order
160 }
161
162 pub fn is_valid(&self) -> bool {
164 self.topological_sort().len() == self.nodes.len()
165 }
166}
167
168pub struct SearchSpace {
170 pub operations: Vec<OpType>,
172 pub n_nodes_per_cell: usize,
174 pub n_input_nodes: usize,
176 pub channels: Vec<usize>,
178 pub n_cells_range: (usize, usize),
180}
181
182impl SearchSpace {
183 pub fn darts_like(n_nodes: usize) -> Self {
185 Self {
186 operations: vec![
187 OpType::Skip,
188 OpType::Zero,
189 OpType::MaxPool3x3,
190 OpType::AvgPool3x3,
191 OpType::Conv3x3,
192 OpType::Conv5x5,
193 OpType::DilatedConv3x3,
194 OpType::DepthwiseSep3x3,
195 ],
196 n_nodes_per_cell: n_nodes,
197 n_input_nodes: 2,
198 channels: vec![16, 32, 64, 128],
199 n_cells_range: (2, 20),
200 }
201 }
202
203 pub fn n_architectures(&self) -> u64 {
205 let n_ops = self.operations.len() as u64;
206 let n = self.n_nodes_per_cell;
207 let n_edges = (self.n_input_nodes * n) as u64;
209 n_ops.saturating_pow(n_edges as u32)
210 }
211
212 pub fn sample_random(
214 &self,
215 rng: &mut (impl scirs2_core::random::Rng + ?Sized),
216 ) -> Architecture {
217 use scirs2_core::random::{Rng, RngExt};
218
219 let cells_lo = self.n_cells_range.0;
220 let cells_hi = self.n_cells_range.1;
221 let n_cells = if cells_lo >= cells_hi {
222 cells_lo
223 } else {
224 rng.random_range(cells_lo..=cells_hi)
225 };
226
227 let ch_idx = rng.random_range(0..self.channels.len());
228 let channels = self.channels[ch_idx];
229 let n_classes = 10;
230
231 let mut arch = Architecture::new(n_cells, channels, n_classes);
232
233 for c in 0..n_cells {
235 for j in 0..self.n_nodes_per_cell {
236 arch.nodes.push(ArchNode {
237 id: c * self.n_nodes_per_cell + j,
238 name: format!("cell{}_node{}", c, j),
239 output_channels: channels,
240 });
241 }
242 }
243
244 for c in 0..n_cells {
246 for j in 0..self.n_nodes_per_cell {
247 let n_inputs = j.min(self.n_input_nodes);
249 for k in 0..n_inputs.max(1) {
250 let from_offset = if n_inputs == 0 {
251 0
253 } else {
254 c * self.n_nodes_per_cell + (j.saturating_sub(k + 1))
255 };
256 let to = c * self.n_nodes_per_cell + j;
257 if from_offset != to {
258 let op_idx = rng.random_range(0..self.operations.len());
259 arch.edges.push(ArchEdge {
260 from: from_offset,
261 to,
262 op: self.operations[op_idx].clone(),
263 });
264 }
265 }
266 }
267 }
268
269 arch
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use scirs2_core::random::{rngs::StdRng, SeedableRng};
277
278 #[test]
279 fn test_search_space_sample_produces_arch() {
280 let space = SearchSpace::darts_like(4);
281 let mut rng = StdRng::seed_from_u64(42);
282 let arch = space.sample_random(&mut rng);
283
284 assert!(arch.n_cells > 0);
285 assert!(arch.channels > 0);
286 assert!(!arch.nodes.is_empty());
287 }
288
289 #[test]
290 fn test_architecture_params_nonzero_for_conv() {
291 let mut arch = Architecture::new(2, 32, 10);
292 arch.nodes.push(ArchNode {
293 id: 0,
294 name: "node0".into(),
295 output_channels: 32,
296 });
297 arch.nodes.push(ArchNode {
298 id: 1,
299 name: "node1".into(),
300 output_channels: 32,
301 });
302 arch.edges.push(ArchEdge {
303 from: 0,
304 to: 1,
305 op: OpType::Conv3x3,
306 });
307
308 assert_eq!(arch.total_params(), 9 * 32 * 32);
310 assert!(arch.total_flops(8) > 0);
311 }
312
313 #[test]
314 fn test_topological_sort_linear_dag() {
315 let mut arch = Architecture::new(1, 32, 10);
316 for i in 0..3_usize {
317 arch.nodes.push(ArchNode {
318 id: i,
319 name: format!("n{}", i),
320 output_channels: 32,
321 });
322 }
323 arch.edges.push(ArchEdge {
324 from: 0,
325 to: 1,
326 op: OpType::Skip,
327 });
328 arch.edges.push(ArchEdge {
329 from: 1,
330 to: 2,
331 op: OpType::Conv3x3,
332 });
333
334 assert!(arch.is_valid());
335 let order = arch.topological_sort();
336 assert_eq!(order.len(), 3);
337 assert_eq!(order[0], 0);
338 }
339
340 #[test]
341 fn test_n_architectures_positive() {
342 let space = SearchSpace::darts_like(4);
343 assert!(space.n_architectures() > 0);
344 }
345
346 #[test]
347 fn test_op_type_zero_params_for_pooling() {
348 assert_eq!(OpType::MaxPool3x3.num_params(64), 0);
349 assert_eq!(OpType::AvgPool3x3.num_params(64), 0);
350 assert_eq!(OpType::Skip.num_params(64), 0);
351 }
352}