Skip to main content

scirs2_optimize/darts/
mod.rs

1//! DARTS: Differentiable Architecture Search (Liu et al., ICLR 2019)
2//!
3//! Relaxes the discrete architecture choice over a set of candidate operations to a
4//! continuous mixing via softmax weights. During search both the network weights and
5//! architecture weights are optimised in a bi-level fashion.  After search the discrete
6//! architecture is recovered by taking argmax per edge.
7//!
8//! ## References
9//!
10//! - Liu, H., Simonyan, K. and Yang, Y. (2019). "DARTS: Differentiable Architecture
11//!   Search". ICLR 2019.
12
13use crate::error::{OptimizeError, OptimizeResult};
14
15// ───────────────────────────────────────────────────────────────── Operations ──
16
17/// Candidate primitive operations for DARTS cells.
18#[non_exhaustive]
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum Operation {
21    /// Identity (skip connection).
22    Identity,
23    /// Zero (no information flow).
24    Zero,
25    /// 3×3 separable convolution.
26    Conv3x3,
27    /// 5×5 separable convolution.
28    Conv5x5,
29    /// 3×3 max pooling.
30    MaxPool,
31    /// 3×3 average pooling.
32    AvgPool,
33    /// Skip connection (same as Identity but conceptually distinct).
34    SkipConnect,
35}
36
37impl Operation {
38    /// Rough FLOP estimate for a single forward pass through this operation.
39    ///
40    /// Formula used: conv FLOPs ≈ 2 * kernel² * channels² (assuming spatial size = 1
41    /// for simplicity; callers may scale by H*W).
42    pub fn cost_flops(&self, channels: usize) -> f64 {
43        let c = channels as f64;
44        match self {
45            Operation::Identity => 0.0,
46            Operation::Zero => 0.0,
47            Operation::Conv3x3 => 2.0 * 9.0 * c * c,
48            Operation::Conv5x5 => 2.0 * 25.0 * c * c,
49            Operation::MaxPool => c, // negligible — just comparisons
50            Operation::AvgPool => c, // one add per element
51            Operation::SkipConnect => 0.0,
52        }
53    }
54
55    /// Human-readable name used in diagnostic output.
56    pub fn name(&self) -> &'static str {
57        match self {
58            Operation::Identity => "identity",
59            Operation::Zero => "zero",
60            Operation::Conv3x3 => "conv3x3",
61            Operation::Conv5x5 => "conv5x5",
62            Operation::MaxPool => "max_pool",
63            Operation::AvgPool => "avg_pool",
64            Operation::SkipConnect => "skip_connect",
65        }
66    }
67
68    /// All primitive operations in a fixed canonical order.
69    pub fn all() -> &'static [Operation] {
70        &[
71            Operation::Identity,
72            Operation::Zero,
73            Operation::Conv3x3,
74            Operation::Conv5x5,
75            Operation::MaxPool,
76            Operation::AvgPool,
77        ]
78    }
79}
80
81// ────────────────────────────────────────────────────────────── DartsConfig ──
82
83/// Configuration for a DARTS architecture search experiment.
84#[derive(Debug, Clone)]
85pub struct DartsConfig {
86    /// Number of cells stacked in the super-network.
87    pub n_cells: usize,
88    /// Number of candidate operations per edge.
89    pub n_operations: usize,
90    /// Number of feature channels (used for FLOP estimation).
91    pub channels: usize,
92    /// Number of intermediate nodes per cell.
93    pub n_nodes: usize,
94    /// Learning rate for architecture parameter updates.
95    pub arch_lr: f64,
96    /// Learning rate for network weight updates.
97    pub weight_lr: f64,
98    /// Softmax temperature (lower → sharper distribution).
99    pub temperature: f64,
100}
101
102impl Default for DartsConfig {
103    fn default() -> Self {
104        Self {
105            n_cells: 4,
106            n_operations: 6,
107            channels: 16,
108            n_nodes: 4,
109            arch_lr: 3e-4,
110            weight_lr: 3e-4,
111            temperature: 1.0,
112        }
113    }
114}
115
116// ─────────────────────────────────────────────────────────── MixedOperation ──
117
118/// One mixed operation on a directed edge in the DARTS cell DAG.
119///
120/// Maintains per-operation un-normalised log-weights (architecture parameters).
121#[derive(Debug, Clone)]
122pub struct MixedOperation {
123    /// Un-normalised architecture parameters α_k, one per operation.
124    pub arch_params: Vec<f64>,
125    /// Cached per-operation outputs from the last `forward` call.
126    pub operation_outputs: Option<Vec<Vec<f64>>>,
127}
128
129impl MixedOperation {
130    /// Create a new `MixedOperation` with `n_ops` operations, initialised to
131    /// uniform architecture weights (all log-weights = 0).
132    pub fn new(n_ops: usize) -> Self {
133        Self {
134            arch_params: vec![0.0_f64; n_ops],
135            operation_outputs: None,
136        }
137    }
138
139    /// Compute softmax-normalised operation weights at the given temperature.
140    ///
141    /// `weights[k] = exp(α_k / T) / Σ_j exp(α_j / T)`
142    pub fn weights(&self, temperature: f64) -> Vec<f64> {
143        let t = temperature.max(1e-8); // guard against divide-by-zero
144        let scaled: Vec<f64> = self.arch_params.iter().map(|a| a / t).collect();
145        // numerically-stable softmax
146        let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
147        let exps: Vec<f64> = scaled.iter().map(|s| (s - max_val).exp()).collect();
148        let sum: f64 = exps.iter().sum();
149        if sum == 0.0 {
150            vec![1.0 / self.arch_params.len() as f64; self.arch_params.len()]
151        } else {
152            exps.iter().map(|e| e / sum).collect()
153        }
154    }
155
156    /// Forward pass: weighted sum Σ_k w_k · op_k(x).
157    ///
158    /// `op_fn(k, x)` returns the output of operation k applied to input x.
159    pub fn forward(
160        &mut self,
161        x: &[f64],
162        op_fn: impl Fn(usize, &[f64]) -> Vec<f64>,
163        temperature: f64,
164    ) -> Vec<f64> {
165        let w = self.weights(temperature);
166        let n_ops = self.arch_params.len();
167        // collect individual op outputs
168        let op_outputs: Vec<Vec<f64>> = (0..n_ops).map(|k| op_fn(k, x)).collect();
169        // weighted sum
170        let out_len = op_outputs.first().map(|v| v.len()).unwrap_or(x.len());
171        let mut result = vec![0.0_f64; out_len];
172        for (k, out) in op_outputs.iter().enumerate() {
173            for (r, o) in result.iter_mut().zip(out.iter()) {
174                *r += w[k] * o;
175            }
176        }
177        self.operation_outputs = Some(op_outputs);
178        result
179    }
180
181    /// Index of the operation with the highest architecture weight.
182    pub fn argmax_op(&self) -> usize {
183        self.arch_params
184            .iter()
185            .enumerate()
186            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
187            .map(|(i, _)| i)
188            .unwrap_or(0)
189    }
190}
191
192// ──────────────────────────────────────────────────────────────── DartsCell ──
193
194/// A DARTS cell: a DAG with `n_input_nodes` fixed inputs and `n_nodes`
195/// intermediate nodes.  Each node aggregates outputs from all prior nodes via
196/// `MixedOperation` edges.
197#[derive(Debug, Clone)]
198pub struct DartsCell {
199    /// Number of intermediate (learnable) nodes.
200    pub n_nodes: usize,
201    /// Number of fixed input nodes (typically 2 in the standard DARTS setup).
202    pub n_input_nodes: usize,
203    /// `edges[i][j]` is the `MixedOperation` from node j to intermediate node i.
204    /// Node indices: 0..n_input_nodes are inputs; n_input_nodes..n_input_nodes+n_nodes are
205    /// intermediate nodes.
206    pub edges: Vec<Vec<MixedOperation>>,
207}
208
209impl DartsCell {
210    /// Create a new DARTS cell.
211    ///
212    /// # Arguments
213    /// - `n_input_nodes`: Number of input nodes (preceding-cell outputs).
214    /// - `n_intermediate_nodes`: Number of intermediate nodes to build.
215    /// - `n_ops`: Number of candidate operations per edge.
216    pub fn new(n_input_nodes: usize, n_intermediate_nodes: usize, n_ops: usize) -> Self {
217        // edges[i] = mixed operations coming into intermediate node i
218        // node i receives edges from all n_input_nodes + i prior nodes
219        let edges: Vec<Vec<MixedOperation>> = (0..n_intermediate_nodes)
220            .map(|i| {
221                let n_predecessors = n_input_nodes + i;
222                (0..n_predecessors)
223                    .map(|_| MixedOperation::new(n_ops))
224                    .collect()
225            })
226            .collect();
227
228        Self {
229            n_nodes: n_intermediate_nodes,
230            n_input_nodes,
231            edges,
232        }
233    }
234
235    /// Forward pass through the cell.
236    ///
237    /// Each intermediate node output = Σ_{j < i} mixed_op_{ij}(node_j_output).
238    /// Final cell output = concatenation of all intermediate node outputs.
239    ///
240    /// # Arguments
241    /// - `inputs`: Outputs of the n_input_nodes preceding this cell.
242    /// - `temperature`: Softmax temperature forwarded to each `MixedOperation`.
243    pub fn forward(&mut self, inputs: &[Vec<f64>], temperature: f64) -> Vec<f64> {
244        if inputs.is_empty() {
245            return Vec::new();
246        }
247        let feature_len = inputs[0].len();
248        // All node outputs (inputs first, then intermediate).
249        let mut node_outputs: Vec<Vec<f64>> = inputs.to_vec();
250
251        for i in 0..self.n_nodes {
252            let n_prev = self.n_input_nodes + i;
253            let mut node_out = vec![0.0_f64; feature_len];
254            for j in 0..n_prev {
255                let src = node_outputs[j].clone();
256                let edge_out = self.edges[i][j].forward(&src, default_op_fn, temperature);
257                for (no, eo) in node_out.iter_mut().zip(edge_out.iter()) {
258                    *no += eo;
259                }
260            }
261            node_outputs.push(node_out);
262        }
263
264        // Concatenate intermediate node outputs (skip the input nodes).
265        let mut result = Vec::with_capacity(self.n_nodes * feature_len);
266        for node_out in node_outputs.iter().skip(self.n_input_nodes) {
267            result.extend_from_slice(node_out);
268        }
269        result
270    }
271
272    /// Collect all architecture parameters from every edge in this cell, flattened.
273    pub fn arch_parameters(&self) -> Vec<f64> {
274        self.edges
275            .iter()
276            .flat_map(|row| row.iter().flat_map(|mo| mo.arch_params.iter().cloned()))
277            .collect()
278    }
279
280    /// Apply gradient updates to architecture parameters using a gradient slice.
281    ///
282    /// `grads` must have the same length as `arch_parameters()`.
283    pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
284        let n_params: usize = self
285            .edges
286            .iter()
287            .flat_map(|row| row.iter())
288            .map(|mo| mo.arch_params.len())
289            .sum();
290        if grads.len() != n_params {
291            return Err(OptimizeError::InvalidInput(format!(
292                "Expected {} gradient values, got {}",
293                n_params,
294                grads.len()
295            )));
296        }
297        let mut idx = 0;
298        for row in self.edges.iter_mut() {
299            for mo in row.iter_mut() {
300                for p in mo.arch_params.iter_mut() {
301                    *p -= lr * grads[idx];
302                    idx += 1;
303                }
304            }
305        }
306        Ok(())
307    }
308
309    /// Derive the discrete architecture for this cell: argmax per edge.
310    ///
311    /// Returns a `Vec<Vec<usize>>` with the same shape as `edges`, where each
312    /// entry is the index of the best operation.
313    pub fn derive_discrete(&self) -> Vec<Vec<usize>> {
314        self.edges
315            .iter()
316            .map(|row| row.iter().map(|mo| mo.argmax_op()).collect())
317            .collect()
318    }
319}
320
321/// Default operation function used inside cells: identity (returns x unchanged).
322fn default_op_fn(_k: usize, x: &[f64]) -> Vec<f64> {
323    x.to_vec()
324}
325
326// ────────────────────────────────────────────────────────────── DartsSearch ──
327
328/// Top-level DARTS search controller.
329///
330/// Manages a stack of `DartsCell`s and implements the bi-level optimisation loop.
331#[derive(Debug, Clone)]
332pub struct DartsSearch {
333    /// Stack of cells forming the super-network.
334    pub cells: Vec<DartsCell>,
335    /// Configuration.
336    pub config: DartsConfig,
337    /// Flat network weights (shared across all cells for this toy model).
338    weights: Vec<f64>,
339}
340
341impl DartsSearch {
342    /// Construct a `DartsSearch` from the given config.
343    pub fn new(config: DartsConfig) -> Self {
344        let cells: Vec<DartsCell> = (0..config.n_cells)
345            .map(|_| DartsCell::new(2, config.n_nodes, config.n_operations))
346            .collect();
347        // Simple weight vector (one scalar weight per cell for the toy model).
348        let weights = vec![0.01_f64; config.n_cells];
349        Self {
350            cells,
351            config,
352            weights,
353        }
354    }
355
356    /// Return all architecture parameters across all cells, flattened.
357    pub fn arch_parameters(&self) -> Vec<f64> {
358        self.cells
359            .iter()
360            .flat_map(|c| c.arch_parameters())
361            .collect()
362    }
363
364    /// Total number of architecture parameters.
365    pub fn n_arch_params(&self) -> usize {
366        self.cells.iter().map(|c| c.arch_parameters().len()).sum()
367    }
368
369    /// Apply a gradient step to architecture parameters.
370    ///
371    /// `grads` must be the same length as `arch_parameters()`.
372    pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
373        let total = self.n_arch_params();
374        if grads.len() != total {
375            return Err(OptimizeError::InvalidInput(format!(
376                "Expected {} arch-param grads, got {}",
377                total,
378                grads.len()
379            )));
380        }
381        let mut offset = 0;
382        for cell in self.cells.iter_mut() {
383            let n = cell.arch_parameters().len();
384            cell.update_arch_params(&grads[offset..offset + n], lr)?;
385            offset += n;
386        }
387        Ok(())
388    }
389
390    /// Derive the discrete architecture: for each cell, argmax op per edge.
391    ///
392    /// Returns `Vec<Vec<Vec<usize>>>` — `[cell][intermediate_node][predecessor]`.
393    pub fn derive_discrete_arch_indices(&self) -> Vec<Vec<Vec<usize>>> {
394        self.cells.iter().map(|c| c.derive_discrete()).collect()
395    }
396
397    /// Derive the discrete architecture as a vec of `Operation` vectors.
398    ///
399    /// Uses `Operation::all()` to map operation index to `Operation`.  If
400    /// `n_operations` exceeds the canonical list length, `Operation::Identity`
401    /// is used as a fallback.
402    pub fn derive_discrete_arch(&self) -> Vec<Vec<Operation>> {
403        let ops = Operation::all();
404        self.derive_discrete_arch_indices()
405            .iter()
406            .map(|cell_disc| {
407                cell_disc
408                    .iter()
409                    .flat_map(|node_edges| {
410                        node_edges.iter().map(|&idx| {
411                            if idx < ops.len() {
412                                ops[idx]
413                            } else {
414                                Operation::Identity
415                            }
416                        })
417                    })
418                    .collect()
419            })
420            .collect()
421    }
422
423    /// Compute a simple regression loss (MSE) over a dataset.
424    ///
425    /// The toy model prediction is: y_hat_i = w · mean(x_i), where w is the
426    /// sum of `self.weights`.  This is purely for exercising the bi-level loop.
427    fn compute_loss(&self, x: &[Vec<f64>], y: &[f64]) -> f64 {
428        if x.is_empty() || y.is_empty() {
429            return 0.0;
430        }
431        let w_sum: f64 = self.weights.iter().sum();
432        let mut loss = 0.0_f64;
433        let n = x.len().min(y.len());
434        for i in 0..n {
435            let x_mean = if x[i].is_empty() {
436                0.0
437            } else {
438                x[i].iter().sum::<f64>() / x[i].len() as f64
439            };
440            let pred = w_sum * x_mean;
441            let diff = pred - y[i];
442            loss += diff * diff;
443        }
444        loss / n as f64
445    }
446
447    /// Compute MSE gradient with respect to `self.weights` (one grad per cell weight).
448    fn weight_grads(&self, x: &[Vec<f64>], y: &[f64]) -> Vec<f64> {
449        let n = x.len().min(y.len());
450        if n == 0 {
451            return vec![0.0_f64; self.weights.len()];
452        }
453        let w_sum: f64 = self.weights.iter().sum();
454        let mut grad_sum = 0.0_f64;
455        for i in 0..n {
456            let x_mean = if x[i].is_empty() {
457                0.0
458            } else {
459                x[i].iter().sum::<f64>() / x[i].len() as f64
460            };
461            let pred = w_sum * x_mean;
462            let diff = pred - y[i];
463            // d(loss)/d(w_sum) = 2 * diff * x_mean / n
464            grad_sum += 2.0 * diff * x_mean / n as f64;
465        }
466        // Each cell weight contributes equally to w_sum.
467        vec![grad_sum; self.weights.len()]
468    }
469
470    /// Compute approximate gradient of loss with respect to architecture params
471    /// via finite differences (central differences, step = 1e-4).
472    fn arch_grads_fd(&self, x: &[Vec<f64>], y: &[f64]) -> Vec<f64> {
473        let n = self.n_arch_params();
474        if n == 0 {
475            return Vec::new();
476        }
477        let mut grads = vec![0.0_f64; n];
478        let h = 1e-4;
479        let mut offset = 0;
480        for cell_idx in 0..self.cells.len() {
481            let cell_n = self.cells[cell_idx].arch_parameters().len();
482            for local_j in 0..cell_n {
483                let global_j = offset + local_j;
484                // +h
485                let mut search_plus = self.clone();
486                let params_plus = search_plus.cells[cell_idx].arch_parameters();
487                let mut p_plus = params_plus.clone();
488                p_plus[local_j] += h;
489                // Rebuild cell arch params from the modified flat vector
490                let _ = search_plus.cells[cell_idx].set_arch_params(&p_plus);
491                let loss_plus = search_plus.compute_loss(x, y);
492
493                // -h
494                let mut search_minus = self.clone();
495                let params_minus = search_minus.cells[cell_idx].arch_parameters();
496                let mut p_minus = params_minus.clone();
497                p_minus[local_j] -= h;
498                let _ = search_minus.cells[cell_idx].set_arch_params(&p_minus);
499                let loss_minus = search_minus.compute_loss(x, y);
500
501                grads[global_j] = (loss_plus - loss_minus) / (2.0 * h);
502            }
503            offset += cell_n;
504        }
505        grads
506    }
507
508    /// One bilevel optimisation step (approximate first-order DARTS).
509    ///
510    /// Inner step: update network weights on `train_x`/`train_y`.
511    /// Outer step: update architecture params on `val_x`/`val_y`.
512    ///
513    /// Returns `(train_loss, val_loss)` before the update.
514    pub fn bilevel_step(
515        &mut self,
516        train_x: &[Vec<f64>],
517        train_y: &[f64],
518        val_x: &[Vec<f64>],
519        val_y: &[f64],
520    ) -> (f64, f64) {
521        let train_loss = self.compute_loss(train_x, train_y);
522        let val_loss = self.compute_loss(val_x, val_y);
523
524        // Inner: gradient step on network weights using train data.
525        let w_grads = self.weight_grads(train_x, train_y);
526        let lr_w = self.config.weight_lr;
527        for (w, g) in self.weights.iter_mut().zip(w_grads.iter()) {
528            *w -= lr_w * g;
529        }
530
531        // Outer: gradient step on architecture params using val data.
532        let a_grads = self.arch_grads_fd(val_x, val_y);
533        let lr_a = self.config.arch_lr;
534        if !a_grads.is_empty() {
535            let _ = self.update_arch_params(&a_grads, lr_a);
536        }
537
538        (train_loss, val_loss)
539    }
540}
541
542// ──────────────────────────────────────── helper: set_arch_params on DartsCell ──
543
544impl DartsCell {
545    /// Replace all architecture parameters with values from the flat slice.
546    pub fn set_arch_params(&mut self, params: &[f64]) -> OptimizeResult<()> {
547        let total: usize = self
548            .edges
549            .iter()
550            .flat_map(|r| r.iter())
551            .map(|m| m.arch_params.len())
552            .sum();
553        if params.len() != total {
554            return Err(OptimizeError::InvalidInput(format!(
555                "set_arch_params: expected {total} values, got {}",
556                params.len()
557            )));
558        }
559        let mut idx = 0;
560        for row in self.edges.iter_mut() {
561            for mo in row.iter_mut() {
562                for p in mo.arch_params.iter_mut() {
563                    *p = params[idx];
564                    idx += 1;
565                }
566            }
567        }
568        Ok(())
569    }
570}
571
572// ═══════════════════════════════════════════════════════════════════ tests ═══
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577
578    #[test]
579    fn mixed_operation_weights_sum_to_one() {
580        let mo = MixedOperation::new(6);
581        let w = mo.weights(1.0);
582        assert_eq!(w.len(), 6);
583        let sum: f64 = w.iter().sum();
584        assert!((sum - 1.0).abs() < 1e-10, "weights sum = {sum}");
585    }
586
587    #[test]
588    fn mixed_operation_weights_temperature_effect() {
589        // Low temperature should sharpen the distribution.
590        let mut mo = MixedOperation::new(4);
591        mo.arch_params = vec![1.0, 0.5, 0.3, 0.2];
592        let w_hot = mo.weights(10.0);
593        let w_cold = mo.weights(0.1);
594        // Highest-weight op (index 0) should have larger weight at low temp.
595        assert!(w_cold[0] > w_hot[0], "cold should be sharper");
596    }
597
598    #[test]
599    fn mixed_operation_forward_correct_shape() {
600        let mut mo = MixedOperation::new(3);
601        let x = vec![1.0_f64; 8];
602        let out = mo.forward(&x, |_k, v| v.to_vec(), 1.0);
603        assert_eq!(out.len(), 8);
604    }
605
606    #[test]
607    fn darts_cell_forward_output_shape() {
608        let mut cell = DartsCell::new(2, 3, 4);
609        let inputs = vec![vec![1.0_f64; 8], vec![0.5_f64; 8]];
610        let out = cell.forward(&inputs, 1.0);
611        // Output should be n_nodes * feature_len = 3 * 8 = 24.
612        assert_eq!(out.len(), 24);
613    }
614
615    #[test]
616    fn derive_discrete_arch_returns_ops() {
617        let config = DartsConfig {
618            n_cells: 2,
619            n_operations: 6,
620            n_nodes: 3,
621            ..Default::default()
622        };
623        let search = DartsSearch::new(config);
624        let arch = search.derive_discrete_arch();
625        assert_eq!(arch.len(), 2, "one vec per cell");
626        // Each cell has n_nodes intermediate nodes, each receiving 2+(0..n-1) edges.
627        // Total edges per cell = 2+3+4 = 9 for n_nodes=3, n_input_nodes=2.
628        for cell_ops in &arch {
629            assert!(!cell_ops.is_empty());
630        }
631    }
632
633    #[test]
634    fn bilevel_step_runs_without_error() {
635        let config = DartsConfig::default();
636        let mut search = DartsSearch::new(config);
637        let train_x = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
638        let train_y = vec![1.5, 3.5];
639        let val_x = vec![vec![0.5, 1.5]];
640        let val_y = vec![1.0];
641        let (tl, vl) = search.bilevel_step(&train_x, &train_y, &val_x, &val_y);
642        assert!(tl.is_finite());
643        assert!(vl.is_finite());
644    }
645
646    #[test]
647    fn arch_parameters_length_consistent() {
648        let config = DartsConfig {
649            n_cells: 3,
650            n_operations: 5,
651            n_nodes: 2,
652            ..Default::default()
653        };
654        let search = DartsSearch::new(config);
655        let params = search.arch_parameters();
656        assert_eq!(params.len(), search.n_arch_params());
657    }
658
659    #[test]
660    fn update_arch_params_wrong_length_errors() {
661        let mut search = DartsSearch::new(DartsConfig::default());
662        let result = search.update_arch_params(&[1.0, 2.0], 0.01);
663        assert!(result.is_err());
664    }
665}