Skip to main content

scirs2_optimize/nas/
search_space.rs

1//! Neural architecture search space definition.
2//!
3//! Provides a DAG-based representation of neural architectures
4//! and a configurable search space (DARTS-like by default).
5
6use std::collections::VecDeque;
7
8/// Operation types in a NAS cell
9#[derive(Debug, Clone, PartialEq)]
10pub enum OpType {
11    /// Pass-through with no transformation
12    Identity,
13    /// 3x3 convolution
14    Conv3x3,
15    /// 5x5 convolution
16    Conv5x5,
17    /// Dilated 3x3 convolution (dilation=2)
18    DilatedConv3x3,
19    /// Depthwise separable 3x3 convolution
20    DepthwiseSep3x3,
21    /// 3x3 max pooling
22    MaxPool3x3,
23    /// 3x3 average pooling
24    AvgPool3x3,
25    /// Skip connection (residual)
26    Skip,
27    /// Zero connection (no gradient flow)
28    Zero,
29    /// Fully-connected layer with given output size
30    Linear(usize),
31    /// Gated Recurrent Unit cell
32    GRU,
33    /// Long Short-Term Memory cell
34    LSTM,
35}
36
37impl OpType {
38    /// Approximate number of trainable parameters for this op.
39    pub fn num_params(&self, in_channels: usize) -> usize {
40        match self {
41            Self::Identity | Self::Skip | Self::Zero | Self::MaxPool3x3 | Self::AvgPool3x3 => 0,
42            Self::Conv3x3 => 9 * in_channels * in_channels,
43            Self::Conv5x5 => 25 * in_channels * in_channels,
44            Self::DilatedConv3x3 => 9 * in_channels * in_channels,
45            Self::DepthwiseSep3x3 => 9 * in_channels + in_channels * in_channels,
46            Self::Linear(out) => in_channels * out,
47            Self::GRU | Self::LSTM => 4 * in_channels * in_channels,
48        }
49    }
50
51    /// Approximate FLOPs for this op given input channels and spatial size.
52    pub fn flops(&self, in_channels: usize, spatial: usize) -> usize {
53        let spatial_sq = spatial * spatial;
54        match self {
55            Self::Conv3x3 => 9 * 2 * in_channels * in_channels * spatial_sq,
56            Self::Conv5x5 => 25 * 2 * in_channels * in_channels * spatial_sq,
57            Self::DilatedConv3x3 => 9 * 2 * in_channels * in_channels * spatial_sq,
58            Self::DepthwiseSep3x3 => (9 * in_channels + in_channels * in_channels) * spatial_sq,
59            Self::Linear(out) => in_channels * out * 2,
60            Self::GRU | Self::LSTM => 4 * in_channels * in_channels * 2,
61            _ => in_channels * spatial_sq,
62        }
63    }
64}
65
66/// A node in the architecture DAG
67#[derive(Debug, Clone)]
68pub struct ArchNode {
69    /// Unique identifier for the node
70    pub id: usize,
71    /// Human-readable node name
72    pub name: String,
73    /// Number of output channels produced by this node
74    pub output_channels: usize,
75}
76
77/// A directed edge (operation) in the architecture DAG
78#[derive(Debug, Clone)]
79pub struct ArchEdge {
80    /// Source node id
81    pub from: usize,
82    /// Destination node id
83    pub to: usize,
84    /// Operation applied along this edge
85    pub op: OpType,
86}
87
88/// A complete architecture specification as a DAG
89#[derive(Debug, Clone)]
90pub struct Architecture {
91    /// All nodes in the architecture
92    pub nodes: Vec<ArchNode>,
93    /// All directed edges (operations) in the architecture
94    pub edges: Vec<ArchEdge>,
95    /// Number of cells in the architecture
96    pub n_cells: usize,
97    /// Channel width
98    pub channels: usize,
99    /// Number of output classes
100    pub n_classes: usize,
101}
102
103impl Architecture {
104    /// Create a new empty architecture.
105    pub fn new(n_cells: usize, channels: usize, n_classes: usize) -> Self {
106        Self {
107            nodes: Vec::new(),
108            edges: Vec::new(),
109            n_cells,
110            channels,
111            n_classes,
112        }
113    }
114
115    /// Sum of parameter counts across all edges.
116    pub fn total_params(&self) -> usize {
117        self.edges
118            .iter()
119            .map(|e| e.op.num_params(self.channels))
120            .sum()
121    }
122
123    /// Sum of FLOPs across all edges for given spatial dimension.
124    pub fn total_flops(&self, spatial: usize) -> usize {
125        self.edges
126            .iter()
127            .map(|e| e.op.flops(self.channels, spatial))
128            .sum()
129    }
130
131    /// Kahn's algorithm topological sort of DAG nodes.
132    ///
133    /// Returns a sorted list of node ids. If the graph contains a cycle,
134    /// the returned list will be shorter than `nodes.len()`.
135    pub fn topological_sort(&self) -> Vec<usize> {
136        let n = self.nodes.len();
137        let mut in_degree = vec![0usize; n];
138        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
139
140        for e in &self.edges {
141            if e.from < n && e.to < n {
142                adj[e.from].push(e.to);
143                in_degree[e.to] = in_degree[e.to].saturating_add(1);
144            }
145        }
146
147        let mut queue: VecDeque<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
148        let mut order = Vec::with_capacity(n);
149
150        while let Some(v) = queue.pop_front() {
151            order.push(v);
152            for &u in &adj[v] {
153                in_degree[u] -= 1;
154                if in_degree[u] == 0 {
155                    queue.push_back(u);
156                }
157            }
158        }
159        order
160    }
161
162    /// Returns `true` if the architecture DAG contains no cycles.
163    pub fn is_valid(&self) -> bool {
164        self.topological_sort().len() == self.nodes.len()
165    }
166}
167
168/// Defines the search space: the set of valid architecture configurations.
169pub struct SearchSpace {
170    /// Available operation types
171    pub operations: Vec<OpType>,
172    /// Number of intermediate nodes per cell
173    pub n_nodes_per_cell: usize,
174    /// How many previous outputs each node can take as input
175    pub n_input_nodes: usize,
176    /// Possible channel widths
177    pub channels: Vec<usize>,
178    /// (min, max) number of cells (inclusive)
179    pub n_cells_range: (usize, usize),
180}
181
182impl SearchSpace {
183    /// Create a DARTS-like search space with the specified number of intermediate nodes.
184    pub fn darts_like(n_nodes: usize) -> Self {
185        Self {
186            operations: vec![
187                OpType::Skip,
188                OpType::Zero,
189                OpType::MaxPool3x3,
190                OpType::AvgPool3x3,
191                OpType::Conv3x3,
192                OpType::Conv5x5,
193                OpType::DilatedConv3x3,
194                OpType::DepthwiseSep3x3,
195            ],
196            n_nodes_per_cell: n_nodes,
197            n_input_nodes: 2,
198            channels: vec![16, 32, 64, 128],
199            n_cells_range: (2, 20),
200        }
201    }
202
203    /// Upper bound on the number of distinct architectures in this search space.
204    pub fn n_architectures(&self) -> u64 {
205        let n_ops = self.operations.len() as u64;
206        let n = self.n_nodes_per_cell;
207        // edges: n_input_nodes inputs per node across all nodes
208        let n_edges = (self.n_input_nodes * n) as u64;
209        n_ops.saturating_pow(n_edges as u32)
210    }
211
212    /// Sample a random architecture from this search space.
213    pub fn sample_random(
214        &self,
215        rng: &mut (impl scirs2_core::random::Rng + ?Sized),
216    ) -> Architecture {
217        use scirs2_core::random::{Rng, RngExt};
218
219        let cells_lo = self.n_cells_range.0;
220        let cells_hi = self.n_cells_range.1;
221        let n_cells = if cells_lo >= cells_hi {
222            cells_lo
223        } else {
224            rng.random_range(cells_lo..=cells_hi)
225        };
226
227        let ch_idx = rng.random_range(0..self.channels.len());
228        let channels = self.channels[ch_idx];
229        let n_classes = 10;
230
231        let mut arch = Architecture::new(n_cells, channels, n_classes);
232
233        // Add nodes: n_nodes_per_cell per cell
234        for c in 0..n_cells {
235            for j in 0..self.n_nodes_per_cell {
236                arch.nodes.push(ArchNode {
237                    id: c * self.n_nodes_per_cell + j,
238                    name: format!("cell{}_node{}", c, j),
239                    output_channels: channels,
240                });
241            }
242        }
243
244        // Add directed edges with randomly selected operations
245        for c in 0..n_cells {
246            for j in 0..self.n_nodes_per_cell {
247                // Each node gets up to n_input_nodes incoming edges from earlier nodes
248                let n_inputs = j.min(self.n_input_nodes);
249                for k in 0..n_inputs.max(1) {
250                    let from_offset = if n_inputs == 0 {
251                        // first node in cell: connect to beginning (node 0)
252                        0
253                    } else {
254                        c * self.n_nodes_per_cell + (j.saturating_sub(k + 1))
255                    };
256                    let to = c * self.n_nodes_per_cell + j;
257                    if from_offset != to {
258                        let op_idx = rng.random_range(0..self.operations.len());
259                        arch.edges.push(ArchEdge {
260                            from: from_offset,
261                            to,
262                            op: self.operations[op_idx].clone(),
263                        });
264                    }
265                }
266            }
267        }
268
269        arch
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use scirs2_core::random::{rngs::StdRng, SeedableRng};
277
278    #[test]
279    fn test_search_space_sample_produces_arch() {
280        let space = SearchSpace::darts_like(4);
281        let mut rng = StdRng::seed_from_u64(42);
282        let arch = space.sample_random(&mut rng);
283
284        assert!(arch.n_cells > 0);
285        assert!(arch.channels > 0);
286        assert!(!arch.nodes.is_empty());
287    }
288
289    #[test]
290    fn test_architecture_params_nonzero_for_conv() {
291        let mut arch = Architecture::new(2, 32, 10);
292        arch.nodes.push(ArchNode {
293            id: 0,
294            name: "node0".into(),
295            output_channels: 32,
296        });
297        arch.nodes.push(ArchNode {
298            id: 1,
299            name: "node1".into(),
300            output_channels: 32,
301        });
302        arch.edges.push(ArchEdge {
303            from: 0,
304            to: 1,
305            op: OpType::Conv3x3,
306        });
307
308        // Conv3x3: 9 * 32 * 32 = 9216 params
309        assert_eq!(arch.total_params(), 9 * 32 * 32);
310        assert!(arch.total_flops(8) > 0);
311    }
312
313    #[test]
314    fn test_topological_sort_linear_dag() {
315        let mut arch = Architecture::new(1, 32, 10);
316        for i in 0..3_usize {
317            arch.nodes.push(ArchNode {
318                id: i,
319                name: format!("n{}", i),
320                output_channels: 32,
321            });
322        }
323        arch.edges.push(ArchEdge {
324            from: 0,
325            to: 1,
326            op: OpType::Skip,
327        });
328        arch.edges.push(ArchEdge {
329            from: 1,
330            to: 2,
331            op: OpType::Conv3x3,
332        });
333
334        assert!(arch.is_valid());
335        let order = arch.topological_sort();
336        assert_eq!(order.len(), 3);
337        assert_eq!(order[0], 0);
338    }
339
340    #[test]
341    fn test_n_architectures_positive() {
342        let space = SearchSpace::darts_like(4);
343        assert!(space.n_architectures() > 0);
344    }
345
346    #[test]
347    fn test_op_type_zero_params_for_pooling() {
348        assert_eq!(OpType::MaxPool3x3.num_params(64), 0);
349        assert_eq!(OpType::AvgPool3x3.num_params(64), 0);
350        assert_eq!(OpType::Skip.num_params(64), 0);
351    }
352}