Skip to main content

tensorlogic_train/nas/
space.rs

1//! Neural architecture search space definitions.
2//!
3//! Defines [`ArchSearchSpace`], [`Architecture`], and [`LayerSpec`] — the primitives
4//! that describe which architectures are valid candidates during NAS.
5
6use std::collections::HashMap;
7
8use crate::error::{TrainError, TrainResult};
9use crate::hyperparameter::{HyperparamConfig, HyperparamValue};
10
11// ─── LayerSpec ──────────────────────────────────────────────────────────────
12
13/// Specification for a single layer in a neural architecture.
14#[derive(Debug, Clone, PartialEq)]
15pub struct LayerSpec {
16    /// Operation type (e.g. "linear", "conv", "attention").
17    pub op: String,
18    /// Width (number of output units / channels).
19    pub width: usize,
20    /// Non-linearity applied after the operation (e.g. "relu", "gelu").
21    pub activation: String,
22}
23
24// ─── Architecture ───────────────────────────────────────────────────────────
25
26/// A concrete neural architecture represented as an ordered sequence of layers.
27#[derive(Debug, Clone, PartialEq)]
28pub struct Architecture {
29    /// Ordered layer specifications (input → output).
30    pub layers: Vec<LayerSpec>,
31}
32
33impl Architecture {
34    /// Proxy parameter count: sum of width_i × width_{i+1} over consecutive pairs.
35    ///
36    /// Returns 0 if fewer than 2 layers.
37    pub fn param_count(&self) -> usize {
38        if self.layers.len() < 2 {
39            return 0;
40        }
41        self.layers
42            .windows(2)
43            .map(|w| w[0].width * w[1].width)
44            .sum()
45    }
46
47    /// Depth (number of layers).
48    pub fn depth(&self) -> usize {
49        self.layers.len()
50    }
51
52    /// Encode this architecture as a [`HyperparamConfig`].
53    ///
54    /// Keys: `depth` (Int), `layer_{i}_op` (String), `layer_{i}_width` (Int),
55    /// `layer_{i}_activation` (String).
56    pub fn to_config(&self) -> HyperparamConfig {
57        let mut m: HashMap<String, HyperparamValue> = HashMap::new();
58        m.insert(
59            "depth".to_string(),
60            HyperparamValue::Int(self.layers.len() as i64),
61        );
62        for (i, layer) in self.layers.iter().enumerate() {
63            m.insert(
64                format!("layer_{i}_op"),
65                HyperparamValue::String(layer.op.clone()),
66            );
67            m.insert(
68                format!("layer_{i}_width"),
69                HyperparamValue::Int(layer.width as i64),
70            );
71            m.insert(
72                format!("layer_{i}_activation"),
73                HyperparamValue::String(layer.activation.clone()),
74            );
75        }
76        m
77    }
78
79    /// Reconstruct an [`Architecture`] from a [`HyperparamConfig`] created by [`Architecture::to_config`].
80    ///
81    /// `max_depth` is used only for bounds-checking the encoded depth.
82    pub fn from_config(cfg: &HyperparamConfig, max_depth: usize) -> TrainResult<Self> {
83        let depth = cfg.get("depth").and_then(|v| v.as_int()).ok_or_else(|| {
84            TrainError::InvalidParameter("config missing 'depth' Int key".to_string())
85        })?;
86
87        if depth < 1 {
88            return Err(TrainError::InvalidParameter(format!(
89                "decoded depth {depth} must be ≥ 1"
90            )));
91        }
92        if depth as usize > max_depth {
93            return Err(TrainError::InvalidParameter(format!(
94                "decoded depth {depth} exceeds max_depth {max_depth}"
95            )));
96        }
97
98        let mut layers = Vec::with_capacity(depth as usize);
99        for i in 0..depth as usize {
100            let op = cfg
101                .get(&format!("layer_{i}_op"))
102                .and_then(|v| v.as_string())
103                .ok_or_else(|| {
104                    TrainError::InvalidParameter(format!(
105                        "config missing 'layer_{i}_op' String key"
106                    ))
107                })?
108                .to_string();
109
110            let width = cfg
111                .get(&format!("layer_{i}_width"))
112                .and_then(|v| v.as_int())
113                .ok_or_else(|| {
114                    TrainError::InvalidParameter(format!(
115                        "config missing 'layer_{i}_width' Int key"
116                    ))
117                })?;
118
119            if width < 1 {
120                return Err(TrainError::InvalidParameter(format!(
121                    "layer {i} width {width} must be ≥ 1"
122                )));
123            }
124
125            let activation = cfg
126                .get(&format!("layer_{i}_activation"))
127                .and_then(|v| v.as_string())
128                .ok_or_else(|| {
129                    TrainError::InvalidParameter(format!(
130                        "config missing 'layer_{i}_activation' String key"
131                    ))
132                })?
133                .to_string();
134
135            layers.push(LayerSpec {
136                op,
137                width: width as usize,
138                activation,
139            });
140        }
141
142        Ok(Architecture { layers })
143    }
144}
145
146// ─── ArchSearchSpace ────────────────────────────────────────────────────────
147
148/// Defines the search space over neural architectures.
149///
150/// Constrains depth range, layer width choices, activation functions, and
151/// operation types that can appear in any sampled architecture.
152#[derive(Debug, Clone)]
153pub struct ArchSearchSpace {
154    /// Minimum number of layers (inclusive, ≥ 1).
155    pub min_depth: usize,
156    /// Maximum number of layers (inclusive, ≥ min_depth).
157    pub max_depth: usize,
158    /// Allowed width (hidden-unit count) options per layer.
159    pub width_options: Vec<usize>,
160    /// Allowed activation function names per layer.
161    pub activation_options: Vec<String>,
162    /// Allowed operation type names per layer.
163    pub op_options: Vec<String>,
164}
165
166impl ArchSearchSpace {
167    /// Construct a validated [`ArchSearchSpace`].
168    ///
169    /// # Errors
170    ///
171    /// Returns [`TrainError::InvalidParameter`] if:
172    /// - `min_depth` < 1
173    /// - `max_depth` < `min_depth`
174    /// - any of the option vecs is empty
175    pub fn new(
176        min_depth: usize,
177        max_depth: usize,
178        width_options: Vec<usize>,
179        activation_options: Vec<String>,
180        op_options: Vec<String>,
181    ) -> TrainResult<Self> {
182        if min_depth < 1 {
183            return Err(TrainError::InvalidParameter(
184                "min_depth must be ≥ 1".to_string(),
185            ));
186        }
187        if max_depth < min_depth {
188            return Err(TrainError::InvalidParameter(format!(
189                "max_depth ({max_depth}) must be ≥ min_depth ({min_depth})"
190            )));
191        }
192        if width_options.is_empty() {
193            return Err(TrainError::InvalidParameter(
194                "width_options must be non-empty".to_string(),
195            ));
196        }
197        if activation_options.is_empty() {
198            return Err(TrainError::InvalidParameter(
199                "activation_options must be non-empty".to_string(),
200            ));
201        }
202        if op_options.is_empty() {
203            return Err(TrainError::InvalidParameter(
204                "op_options must be non-empty".to_string(),
205            ));
206        }
207        Ok(Self {
208            min_depth,
209            max_depth,
210            width_options,
211            activation_options,
212            op_options,
213        })
214    }
215}