Skip to main content

scirs2_optimize/darts/
mod.rs

1//! DARTS: Differentiable Architecture Search (Liu et al., ICLR 2019)
2//!
3//! Relaxes the discrete architecture choice over a set of candidate operations to a
4//! continuous mixing via softmax weights. During search both the network weights and
5//! architecture weights are optimised in a bi-level fashion.  After search the discrete
6//! architecture is recovered by taking argmax per edge.
7//!
8//! ## References
9//!
10//! - Liu, H., Simonyan, K. and Yang, Y. (2019). "DARTS: Differentiable Architecture
11//!   Search". ICLR 2019.
12
13pub mod gdas;
14pub mod predictor_nas;
15pub mod snas;
16
17use crate::error::{OptimizeError, OptimizeResult};
18
19// ───────────────────────────────────────────────────────────────── Operations ──
20
21/// Candidate primitive operations for DARTS cells.
22#[non_exhaustive]
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum Operation {
25    /// Identity (skip connection).
26    Identity,
27    /// Zero (no information flow).
28    Zero,
29    /// 3×3 separable convolution.
30    Conv3x3,
31    /// 5×5 separable convolution.
32    Conv5x5,
33    /// 3×3 max pooling.
34    MaxPool,
35    /// 3×3 average pooling.
36    AvgPool,
37    /// Skip connection (same as Identity but conceptually distinct).
38    SkipConnect,
39}
40
41impl Operation {
42    /// Rough FLOP estimate for a single forward pass through this operation.
43    ///
44    /// Formula used: conv FLOPs ≈ 2 * kernel² * channels² (assuming spatial size = 1
45    /// for simplicity; callers may scale by H*W).
46    pub fn cost_flops(&self, channels: usize) -> f64 {
47        let c = channels as f64;
48        match self {
49            Operation::Identity => 0.0,
50            Operation::Zero => 0.0,
51            Operation::Conv3x3 => 2.0 * 9.0 * c * c,
52            Operation::Conv5x5 => 2.0 * 25.0 * c * c,
53            Operation::MaxPool => c, // negligible — just comparisons
54            Operation::AvgPool => c, // one add per element
55            Operation::SkipConnect => 0.0,
56        }
57    }
58
59    /// Human-readable name used in diagnostic output.
60    pub fn name(&self) -> &'static str {
61        match self {
62            Operation::Identity => "identity",
63            Operation::Zero => "zero",
64            Operation::Conv3x3 => "conv3x3",
65            Operation::Conv5x5 => "conv5x5",
66            Operation::MaxPool => "max_pool",
67            Operation::AvgPool => "avg_pool",
68            Operation::SkipConnect => "skip_connect",
69        }
70    }
71
72    /// All primitive operations in a fixed canonical order.
73    pub fn all() -> &'static [Operation] {
74        &[
75            Operation::Identity,
76            Operation::Zero,
77            Operation::Conv3x3,
78            Operation::Conv5x5,
79            Operation::MaxPool,
80            Operation::AvgPool,
81        ]
82    }
83}
84
85// ────────────────────────────────────────────────────────────── DartsConfig ──
86
87/// Configuration for a DARTS architecture search experiment.
88#[derive(Debug, Clone)]
89pub struct DartsConfig {
90    /// Number of cells stacked in the super-network.
91    pub n_cells: usize,
92    /// Number of candidate operations per edge.
93    pub n_operations: usize,
94    /// Number of feature channels (used for FLOP estimation).
95    pub channels: usize,
96    /// Number of intermediate nodes per cell.
97    pub n_nodes: usize,
98    /// Learning rate for architecture parameter updates.
99    pub arch_lr: f64,
100    /// Learning rate for network weight updates.
101    pub weight_lr: f64,
102    /// Softmax temperature (lower → sharper distribution).
103    pub temperature: f64,
104}
105
106impl Default for DartsConfig {
107    fn default() -> Self {
108        Self {
109            n_cells: 4,
110            n_operations: 6,
111            channels: 16,
112            n_nodes: 4,
113            arch_lr: 3e-4,
114            weight_lr: 3e-4,
115            temperature: 1.0,
116        }
117    }
118}
119
120// ──────────────────────────────────────────────────────────────── Lcg (RNG) ──
121
122/// Minimal linear congruential generator (LCG) to avoid external rand dependency
123/// inside the DARTS sub-modules.  Not cryptographically secure; suitable only for
124/// NAS stochastic sampling.
125pub(crate) struct Lcg {
126    state: u64,
127}
128
129impl Lcg {
130    pub(crate) fn new(seed: u64) -> Self {
131        Self { state: seed }
132    }
133
134    /// Returns a pseudo-random `f64` in `[0, 1)`.
135    pub(crate) fn next_f64(&mut self) -> f64 {
136        self.state = self
137            .state
138            .wrapping_mul(6_364_136_223_846_793_005)
139            .wrapping_add(1_442_695_040_888_963_407);
140        ((self.state >> 11) as f64) * (1.0 / (1u64 << 53) as f64)
141    }
142}
143
144// ─────────────────────────────────────────────── AnnealingStrategy / Schedule ──
145
146/// Strategy for annealing the softmax / Gumbel-Softmax temperature.
147#[derive(Debug, Clone, PartialEq)]
148pub enum AnnealingStrategy {
149    /// Linear decay from initial to final.
150    Linear,
151    /// Exponential decay: `T(t) = T_init * (T_final / T_init)^(t / total)`.
152    Exponential,
153    /// Cosine annealing.
154    Cosine,
155}
156
157/// A temperature schedule for Gumbel-Softmax or concrete relaxation.
158#[derive(Debug, Clone)]
159pub struct TemperatureSchedule {
160    /// Starting temperature.
161    pub initial: f64,
162    /// Ending temperature.
163    pub final_temp: f64,
164    /// Decay strategy.
165    pub strategy: AnnealingStrategy,
166    /// Total number of steps over which the schedule spans.
167    pub total_steps: usize,
168}
169
170impl TemperatureSchedule {
171    /// Construct a new `TemperatureSchedule`.
172    pub fn new(
173        initial: f64,
174        final_temp: f64,
175        strategy: AnnealingStrategy,
176        total_steps: usize,
177    ) -> Self {
178        Self {
179            initial,
180            final_temp,
181            strategy,
182            total_steps,
183        }
184    }
185
186    /// Temperature at the given `step` index.
187    ///
188    /// `step` is clamped to `[0, total_steps]`.
189    pub fn temperature_at(&self, step: usize) -> f64 {
190        let t = step.min(self.total_steps);
191        let frac = if self.total_steps == 0 {
192            1.0
193        } else {
194            t as f64 / self.total_steps as f64
195        };
196        match self.strategy {
197            AnnealingStrategy::Linear => self.initial + (self.final_temp - self.initial) * frac,
198            AnnealingStrategy::Exponential => {
199                if self.initial <= 0.0 || self.final_temp <= 0.0 {
200                    self.final_temp
201                } else {
202                    self.initial * (self.final_temp / self.initial).powf(frac)
203                }
204            }
205            AnnealingStrategy::Cosine => {
206                self.final_temp
207                    + 0.5
208                        * (self.initial - self.final_temp)
209                        * (1.0 + (std::f64::consts::PI * frac).cos())
210            }
211        }
212    }
213}
214
215// ─────────────────────────────────────────────────────────── MixedOperation ──
216
217/// One mixed operation on a directed edge in the DARTS cell DAG.
218///
219/// Maintains per-operation un-normalised log-weights (architecture parameters).
220#[derive(Debug, Clone)]
221pub struct MixedOperation {
222    /// Un-normalised architecture parameters α_k, one per operation.
223    pub arch_params: Vec<f64>,
224    /// Cached per-operation outputs from the last `forward` call.
225    pub operation_outputs: Option<Vec<Vec<f64>>>,
226}
227
228impl MixedOperation {
229    /// Create a new `MixedOperation` with `n_ops` operations, initialised to
230    /// uniform architecture weights (all log-weights = 0).
231    pub fn new(n_ops: usize) -> Self {
232        Self {
233            arch_params: vec![0.0_f64; n_ops],
234            operation_outputs: None,
235        }
236    }
237
238    /// Compute softmax-normalised operation weights at the given temperature.
239    ///
240    /// `weights[k] = exp(α_k / T) / Σ_j exp(α_j / T)`
241    pub fn weights(&self, temperature: f64) -> Vec<f64> {
242        let t = temperature.max(1e-8); // guard against divide-by-zero
243        let scaled: Vec<f64> = self.arch_params.iter().map(|a| a / t).collect();
244        // numerically-stable softmax
245        let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
246        let exps: Vec<f64> = scaled.iter().map(|s| (s - max_val).exp()).collect();
247        let sum: f64 = exps.iter().sum();
248        if sum == 0.0 {
249            vec![1.0 / self.arch_params.len() as f64; self.arch_params.len()]
250        } else {
251            exps.iter().map(|e| e / sum).collect()
252        }
253    }
254
255    /// Forward pass: weighted sum Σ_k w_k · op_k(x).
256    ///
257    /// `op_fn(k, x)` returns the output of operation k applied to input x.
258    pub fn forward(
259        &mut self,
260        x: &[f64],
261        op_fn: impl Fn(usize, &[f64]) -> Vec<f64>,
262        temperature: f64,
263    ) -> Vec<f64> {
264        let w = self.weights(temperature);
265        let n_ops = self.arch_params.len();
266        // collect individual op outputs
267        let op_outputs: Vec<Vec<f64>> = (0..n_ops).map(|k| op_fn(k, x)).collect();
268        // weighted sum
269        let out_len = op_outputs.first().map(|v| v.len()).unwrap_or(x.len());
270        let mut result = vec![0.0_f64; out_len];
271        for (k, out) in op_outputs.iter().enumerate() {
272            for (r, o) in result.iter_mut().zip(out.iter()) {
273                *r += w[k] * o;
274            }
275        }
276        self.operation_outputs = Some(op_outputs);
277        result
278    }
279
280    /// Index of the operation with the highest architecture weight.
281    pub fn argmax_op(&self) -> usize {
282        self.arch_params
283            .iter()
284            .enumerate()
285            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
286            .map(|(i, _)| i)
287            .unwrap_or(0)
288    }
289}
290
291// ──────────────────────────────────────────────────────────────── DartsCell ──
292
293/// A DARTS cell: a DAG with `n_input_nodes` fixed inputs and `n_nodes`
294/// intermediate nodes.  Each node aggregates outputs from all prior nodes via
295/// `MixedOperation` edges.
296#[derive(Debug, Clone)]
297pub struct DartsCell {
298    /// Number of intermediate (learnable) nodes.
299    pub n_nodes: usize,
300    /// Number of fixed input nodes (typically 2 in the standard DARTS setup).
301    pub n_input_nodes: usize,
302    /// `edges[i][j]` is the `MixedOperation` from node j to intermediate node i.
303    /// Node indices: 0..n_input_nodes are inputs; n_input_nodes..n_input_nodes+n_nodes are
304    /// intermediate nodes.
305    pub edges: Vec<Vec<MixedOperation>>,
306}
307
308impl DartsCell {
309    /// Create a new DARTS cell.
310    ///
311    /// # Arguments
312    /// - `n_input_nodes`: Number of input nodes (preceding-cell outputs).
313    /// - `n_intermediate_nodes`: Number of intermediate nodes to build.
314    /// - `n_ops`: Number of candidate operations per edge.
315    pub fn new(n_input_nodes: usize, n_intermediate_nodes: usize, n_ops: usize) -> Self {
316        // edges[i] = mixed operations coming into intermediate node i
317        // node i receives edges from all n_input_nodes + i prior nodes
318        let edges: Vec<Vec<MixedOperation>> = (0..n_intermediate_nodes)
319            .map(|i| {
320                let n_predecessors = n_input_nodes + i;
321                (0..n_predecessors)
322                    .map(|_| MixedOperation::new(n_ops))
323                    .collect()
324            })
325            .collect();
326
327        Self {
328            n_nodes: n_intermediate_nodes,
329            n_input_nodes,
330            edges,
331        }
332    }
333
334    /// Forward pass through the cell.
335    ///
336    /// Each intermediate node output = Σ_{j < i} mixed_op_{ij}(node_j_output).
337    /// Final cell output = concatenation of all intermediate node outputs.
338    ///
339    /// # Arguments
340    /// - `inputs`: Outputs of the n_input_nodes preceding this cell.
341    /// - `temperature`: Softmax temperature forwarded to each `MixedOperation`.
342    pub fn forward(&mut self, inputs: &[Vec<f64>], temperature: f64) -> Vec<f64> {
343        if inputs.is_empty() {
344            return Vec::new();
345        }
346        let feature_len = inputs[0].len();
347        // All node outputs (inputs first, then intermediate).
348        let mut node_outputs: Vec<Vec<f64>> = inputs.to_vec();
349
350        for i in 0..self.n_nodes {
351            let n_prev = self.n_input_nodes + i;
352            let mut node_out = vec![0.0_f64; feature_len];
353            for j in 0..n_prev {
354                let src = node_outputs[j].clone();
355                let edge_out = self.edges[i][j].forward(&src, default_op_fn, temperature);
356                for (no, eo) in node_out.iter_mut().zip(edge_out.iter()) {
357                    *no += eo;
358                }
359            }
360            node_outputs.push(node_out);
361        }
362
363        // Concatenate intermediate node outputs (skip the input nodes).
364        let mut result = Vec::with_capacity(self.n_nodes * feature_len);
365        for node_out in node_outputs.iter().skip(self.n_input_nodes) {
366            result.extend_from_slice(node_out);
367        }
368        result
369    }
370
371    /// Collect all architecture parameters from every edge in this cell, flattened.
372    pub fn arch_parameters(&self) -> Vec<f64> {
373        self.edges
374            .iter()
375            .flat_map(|row| row.iter().flat_map(|mo| mo.arch_params.iter().cloned()))
376            .collect()
377    }
378
379    /// Apply gradient updates to architecture parameters using a gradient slice.
380    ///
381    /// `grads` must have the same length as `arch_parameters()`.
382    pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
383        let n_params: usize = self
384            .edges
385            .iter()
386            .flat_map(|row| row.iter())
387            .map(|mo| mo.arch_params.len())
388            .sum();
389        if grads.len() != n_params {
390            return Err(OptimizeError::InvalidInput(format!(
391                "Expected {} gradient values, got {}",
392                n_params,
393                grads.len()
394            )));
395        }
396        let mut idx = 0;
397        for row in self.edges.iter_mut() {
398            for mo in row.iter_mut() {
399                for p in mo.arch_params.iter_mut() {
400                    *p -= lr * grads[idx];
401                    idx += 1;
402                }
403            }
404        }
405        Ok(())
406    }
407
408    /// Derive the discrete architecture for this cell: argmax per edge.
409    ///
410    /// Returns a `Vec<Vec<usize>>` with the same shape as `edges`, where each
411    /// entry is the index of the best operation.
412    pub fn derive_discrete(&self) -> Vec<Vec<usize>> {
413        self.edges
414            .iter()
415            .map(|row| row.iter().map(|mo| mo.argmax_op()).collect())
416            .collect()
417    }
418}
419
420/// Default operation function used inside cells: identity (returns x unchanged).
421fn default_op_fn(_k: usize, x: &[f64]) -> Vec<f64> {
422    x.to_vec()
423}
424
425// ────────────────────────────────────────────────────────────── DartsSearch ──
426
427/// Top-level DARTS search controller.
428///
429/// Manages a stack of `DartsCell`s and implements the bi-level optimisation loop.
430#[derive(Debug, Clone)]
431pub struct DartsSearch {
432    /// Stack of cells forming the super-network.
433    pub cells: Vec<DartsCell>,
434    /// Configuration.
435    pub config: DartsConfig,
436    /// Flat network weights (shared across all cells for this toy model).
437    weights: Vec<f64>,
438}
439
440impl DartsSearch {
441    /// Construct a `DartsSearch` from the given config.
442    pub fn new(config: DartsConfig) -> Self {
443        let cells: Vec<DartsCell> = (0..config.n_cells)
444            .map(|_| DartsCell::new(2, config.n_nodes, config.n_operations))
445            .collect();
446        // Simple weight vector (one scalar weight per cell for the toy model).
447        let weights = vec![0.01_f64; config.n_cells];
448        Self {
449            cells,
450            config,
451            weights,
452        }
453    }
454
455    /// Return all architecture parameters across all cells, flattened.
456    pub fn arch_parameters(&self) -> Vec<f64> {
457        self.cells
458            .iter()
459            .flat_map(|c| c.arch_parameters())
460            .collect()
461    }
462
463    /// Total number of architecture parameters.
464    pub fn n_arch_params(&self) -> usize {
465        self.cells.iter().map(|c| c.arch_parameters().len()).sum()
466    }
467
468    /// Apply a gradient step to architecture parameters.
469    ///
470    /// `grads` must be the same length as `arch_parameters()`.
471    pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
472        let total = self.n_arch_params();
473        if grads.len() != total {
474            return Err(OptimizeError::InvalidInput(format!(
475                "Expected {} arch-param grads, got {}",
476                total,
477                grads.len()
478            )));
479        }
480        let mut offset = 0;
481        for cell in self.cells.iter_mut() {
482            let n = cell.arch_parameters().len();
483            cell.update_arch_params(&grads[offset..offset + n], lr)?;
484            offset += n;
485        }
486        Ok(())
487    }
488
489    /// Derive the discrete architecture: for each cell, argmax op per edge.
490    ///
491    /// Returns `Vec<Vec<Vec<usize>>>` — `[cell][intermediate_node][predecessor]`.
492    pub fn derive_discrete_arch_indices(&self) -> Vec<Vec<Vec<usize>>> {
493        self.cells.iter().map(|c| c.derive_discrete()).collect()
494    }
495
496    /// Derive the discrete architecture as a vec of `Operation` vectors.
497    ///
498    /// Uses `Operation::all()` to map operation index to `Operation`.  If
499    /// `n_operations` exceeds the canonical list length, `Operation::Identity`
500    /// is used as a fallback.
501    pub fn derive_discrete_arch(&self) -> Vec<Vec<Operation>> {
502        let ops = Operation::all();
503        self.derive_discrete_arch_indices()
504            .iter()
505            .map(|cell_disc| {
506                cell_disc
507                    .iter()
508                    .flat_map(|node_edges| {
509                        node_edges.iter().map(|&idx| {
510                            if idx < ops.len() {
511                                ops[idx]
512                            } else {
513                                Operation::Identity
514                            }
515                        })
516                    })
517                    .collect()
518            })
519            .collect()
520    }
521
522    /// Compute a simple regression loss (MSE) over a dataset.
523    ///
524    /// The toy model prediction is: y_hat_i = w · mean(x_i), where w is the
525    /// sum of `self.weights`.  This is purely for exercising the bi-level loop.
526    fn compute_loss(&self, x: &[Vec<f64>], y: &[f64]) -> f64 {
527        if x.is_empty() || y.is_empty() {
528            return 0.0;
529        }
530        let w_sum: f64 = self.weights.iter().sum();
531        let mut loss = 0.0_f64;
532        let n = x.len().min(y.len());
533        for i in 0..n {
534            let x_mean = if x[i].is_empty() {
535                0.0
536            } else {
537                x[i].iter().sum::<f64>() / x[i].len() as f64
538            };
539            let pred = w_sum * x_mean;
540            let diff = pred - y[i];
541            loss += diff * diff;
542        }
543        loss / n as f64
544    }
545
546    /// Compute MSE gradient with respect to `self.weights` (one grad per cell weight).
547    fn weight_grads(&self, x: &[Vec<f64>], y: &[f64]) -> Vec<f64> {
548        let n = x.len().min(y.len());
549        if n == 0 {
550            return vec![0.0_f64; self.weights.len()];
551        }
552        let w_sum: f64 = self.weights.iter().sum();
553        let mut grad_sum = 0.0_f64;
554        for i in 0..n {
555            let x_mean = if x[i].is_empty() {
556                0.0
557            } else {
558                x[i].iter().sum::<f64>() / x[i].len() as f64
559            };
560            let pred = w_sum * x_mean;
561            let diff = pred - y[i];
562            // d(loss)/d(w_sum) = 2 * diff * x_mean / n
563            grad_sum += 2.0 * diff * x_mean / n as f64;
564        }
565        // Each cell weight contributes equally to w_sum.
566        vec![grad_sum; self.weights.len()]
567    }
568
569    /// Compute approximate gradient of loss with respect to architecture params
570    /// via finite differences (central differences, step = 1e-4).
571    fn arch_grads_fd(&self, x: &[Vec<f64>], y: &[f64]) -> Vec<f64> {
572        let n = self.n_arch_params();
573        if n == 0 {
574            return Vec::new();
575        }
576        let mut grads = vec![0.0_f64; n];
577        let h = 1e-4;
578        let mut offset = 0;
579        for cell_idx in 0..self.cells.len() {
580            let cell_n = self.cells[cell_idx].arch_parameters().len();
581            for local_j in 0..cell_n {
582                let global_j = offset + local_j;
583                // +h
584                let mut search_plus = self.clone();
585                let params_plus = search_plus.cells[cell_idx].arch_parameters();
586                let mut p_plus = params_plus.clone();
587                p_plus[local_j] += h;
588                // Rebuild cell arch params from the modified flat vector
589                let _ = search_plus.cells[cell_idx].set_arch_params(&p_plus);
590                let loss_plus = search_plus.compute_loss(x, y);
591
592                // -h
593                let mut search_minus = self.clone();
594                let params_minus = search_minus.cells[cell_idx].arch_parameters();
595                let mut p_minus = params_minus.clone();
596                p_minus[local_j] -= h;
597                let _ = search_minus.cells[cell_idx].set_arch_params(&p_minus);
598                let loss_minus = search_minus.compute_loss(x, y);
599
600                grads[global_j] = (loss_plus - loss_minus) / (2.0 * h);
601            }
602            offset += cell_n;
603        }
604        grads
605    }
606
607    /// One bilevel optimisation step (approximate first-order DARTS).
608    ///
609    /// Inner step: update network weights on `train_x`/`train_y`.
610    /// Outer step: update architecture params on `val_x`/`val_y`.
611    ///
612    /// Returns `(train_loss, val_loss)` before the update.
613    pub fn bilevel_step(
614        &mut self,
615        train_x: &[Vec<f64>],
616        train_y: &[f64],
617        val_x: &[Vec<f64>],
618        val_y: &[f64],
619    ) -> (f64, f64) {
620        let train_loss = self.compute_loss(train_x, train_y);
621        let val_loss = self.compute_loss(val_x, val_y);
622
623        // Inner: gradient step on network weights using train data.
624        let w_grads = self.weight_grads(train_x, train_y);
625        let lr_w = self.config.weight_lr;
626        for (w, g) in self.weights.iter_mut().zip(w_grads.iter()) {
627            *w -= lr_w * g;
628        }
629
630        // Outer: gradient step on architecture params using val data.
631        let a_grads = self.arch_grads_fd(val_x, val_y);
632        let lr_a = self.config.arch_lr;
633        if !a_grads.is_empty() {
634            let _ = self.update_arch_params(&a_grads, lr_a);
635        }
636
637        (train_loss, val_loss)
638    }
639}
640
641// ──────────────────────────────────────── helper: set_arch_params on DartsCell ──
642
643impl DartsCell {
644    /// Replace all architecture parameters with values from the flat slice.
645    pub fn set_arch_params(&mut self, params: &[f64]) -> OptimizeResult<()> {
646        let total: usize = self
647            .edges
648            .iter()
649            .flat_map(|r| r.iter())
650            .map(|m| m.arch_params.len())
651            .sum();
652        if params.len() != total {
653            return Err(OptimizeError::InvalidInput(format!(
654                "set_arch_params: expected {total} values, got {}",
655                params.len()
656            )));
657        }
658        let mut idx = 0;
659        for row in self.edges.iter_mut() {
660            for mo in row.iter_mut() {
661                for p in mo.arch_params.iter_mut() {
662                    *p = params[idx];
663                    idx += 1;
664                }
665            }
666        }
667        Ok(())
668    }
669}
670
671// ═══════════════════════════════════════════════════════════════════ tests ═══
672
673#[cfg(test)]
674mod tests {
675    use super::*;
676
677    #[test]
678    fn mixed_operation_weights_sum_to_one() {
679        let mo = MixedOperation::new(6);
680        let w = mo.weights(1.0);
681        assert_eq!(w.len(), 6);
682        let sum: f64 = w.iter().sum();
683        assert!((sum - 1.0).abs() < 1e-10, "weights sum = {sum}");
684    }
685
686    #[test]
687    fn mixed_operation_weights_temperature_effect() {
688        // Low temperature should sharpen the distribution.
689        let mut mo = MixedOperation::new(4);
690        mo.arch_params = vec![1.0, 0.5, 0.3, 0.2];
691        let w_hot = mo.weights(10.0);
692        let w_cold = mo.weights(0.1);
693        // Highest-weight op (index 0) should have larger weight at low temp.
694        assert!(w_cold[0] > w_hot[0], "cold should be sharper");
695    }
696
697    #[test]
698    fn mixed_operation_forward_correct_shape() {
699        let mut mo = MixedOperation::new(3);
700        let x = vec![1.0_f64; 8];
701        let out = mo.forward(&x, |_k, v| v.to_vec(), 1.0);
702        assert_eq!(out.len(), 8);
703    }
704
705    #[test]
706    fn darts_cell_forward_output_shape() {
707        let mut cell = DartsCell::new(2, 3, 4);
708        let inputs = vec![vec![1.0_f64; 8], vec![0.5_f64; 8]];
709        let out = cell.forward(&inputs, 1.0);
710        // Output should be n_nodes * feature_len = 3 * 8 = 24.
711        assert_eq!(out.len(), 24);
712    }
713
714    #[test]
715    fn derive_discrete_arch_returns_ops() {
716        let config = DartsConfig {
717            n_cells: 2,
718            n_operations: 6,
719            n_nodes: 3,
720            ..Default::default()
721        };
722        let search = DartsSearch::new(config);
723        let arch = search.derive_discrete_arch();
724        assert_eq!(arch.len(), 2, "one vec per cell");
725        // Each cell has n_nodes intermediate nodes, each receiving 2+(0..n-1) edges.
726        // Total edges per cell = 2+3+4 = 9 for n_nodes=3, n_input_nodes=2.
727        for cell_ops in &arch {
728            assert!(!cell_ops.is_empty());
729        }
730    }
731
732    #[test]
733    fn bilevel_step_runs_without_error() {
734        let config = DartsConfig::default();
735        let mut search = DartsSearch::new(config);
736        let train_x = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
737        let train_y = vec![1.5, 3.5];
738        let val_x = vec![vec![0.5, 1.5]];
739        let val_y = vec![1.0];
740        let (tl, vl) = search.bilevel_step(&train_x, &train_y, &val_x, &val_y);
741        assert!(tl.is_finite());
742        assert!(vl.is_finite());
743    }
744
745    #[test]
746    fn arch_parameters_length_consistent() {
747        let config = DartsConfig {
748            n_cells: 3,
749            n_operations: 5,
750            n_nodes: 2,
751            ..Default::default()
752        };
753        let search = DartsSearch::new(config);
754        let params = search.arch_parameters();
755        assert_eq!(params.len(), search.n_arch_params());
756    }
757
758    #[test]
759    fn update_arch_params_wrong_length_errors() {
760        let mut search = DartsSearch::new(DartsConfig::default());
761        let result = search.update_arch_params(&[1.0, 2.0], 0.01);
762        assert!(result.is_err());
763    }
764
765    // ── TemperatureSchedule tests ──────────────────────────────────────────────
766
767    #[test]
768    fn temperature_schedule_linear_bounds() {
769        let sched = TemperatureSchedule::new(10.0, 1.0, AnnealingStrategy::Linear, 100);
770        let t0 = sched.temperature_at(0);
771        let t_half = sched.temperature_at(50);
772        let t_end = sched.temperature_at(100);
773        assert!((t0 - 10.0).abs() < 1e-10, "t0={t0}");
774        assert!((t_half - 5.5).abs() < 1e-10, "t_half={t_half}");
775        assert!((t_end - 1.0).abs() < 1e-10, "t_end={t_end}");
776    }
777
778    #[test]
779    fn temperature_schedule_exponential_bounds() {
780        let sched = TemperatureSchedule::new(10.0, 1.0, AnnealingStrategy::Exponential, 100);
781        let t0 = sched.temperature_at(0);
782        let t_end = sched.temperature_at(100);
783        assert!((t0 - 10.0).abs() < 1e-8, "t0={t0}");
784        assert!((t_end - 1.0).abs() < 1e-8, "t_end={t_end}");
785        // Intermediate should be between bounds.
786        let t_mid = sched.temperature_at(50);
787        assert!(t_mid > 1.0 && t_mid < 10.0, "t_mid={t_mid}");
788    }
789
790    #[test]
791    fn temperature_schedule_cosine_bounds() {
792        let sched = TemperatureSchedule::new(10.0, 1.0, AnnealingStrategy::Cosine, 100);
793        let t0 = sched.temperature_at(0);
794        let t_end = sched.temperature_at(100);
795        assert!((t0 - 10.0).abs() < 1e-8, "t0={t0}");
796        assert!((t_end - 1.0).abs() < 1e-8, "t_end={t_end}");
797    }
798
799    #[test]
800    fn temperature_schedule_clamped_beyond_total() {
801        let sched = TemperatureSchedule::new(5.0, 1.0, AnnealingStrategy::Linear, 10);
802        let t_over = sched.temperature_at(999);
803        let t_end = sched.temperature_at(10);
804        assert!((t_over - t_end).abs() < 1e-10);
805    }
806
807    #[test]
808    fn temperature_schedule_zero_steps() {
809        // When total_steps == 0, frac = 1.0 immediately for any step.
810        let sched = TemperatureSchedule::new(5.0, 1.0, AnnealingStrategy::Linear, 0);
811        let t = sched.temperature_at(0);
812        assert!((t - 1.0).abs() < 1e-10, "t={t}");
813    }
814}