1pub mod gdas;
14pub mod predictor_nas;
15pub mod snas;
16
17use crate::error::{OptimizeError, OptimizeResult};
18
19#[non_exhaustive]
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum Operation {
25 Identity,
27 Zero,
29 Conv3x3,
31 Conv5x5,
33 MaxPool,
35 AvgPool,
37 SkipConnect,
39}
40
41impl Operation {
42 pub fn cost_flops(&self, channels: usize) -> f64 {
47 let c = channels as f64;
48 match self {
49 Operation::Identity => 0.0,
50 Operation::Zero => 0.0,
51 Operation::Conv3x3 => 2.0 * 9.0 * c * c,
52 Operation::Conv5x5 => 2.0 * 25.0 * c * c,
53 Operation::MaxPool => c, Operation::AvgPool => c, Operation::SkipConnect => 0.0,
56 }
57 }
58
59 pub fn name(&self) -> &'static str {
61 match self {
62 Operation::Identity => "identity",
63 Operation::Zero => "zero",
64 Operation::Conv3x3 => "conv3x3",
65 Operation::Conv5x5 => "conv5x5",
66 Operation::MaxPool => "max_pool",
67 Operation::AvgPool => "avg_pool",
68 Operation::SkipConnect => "skip_connect",
69 }
70 }
71
72 pub fn all() -> &'static [Operation] {
74 &[
75 Operation::Identity,
76 Operation::Zero,
77 Operation::Conv3x3,
78 Operation::Conv5x5,
79 Operation::MaxPool,
80 Operation::AvgPool,
81 ]
82 }
83}
84
85#[derive(Debug, Clone)]
89pub struct DartsConfig {
90 pub n_cells: usize,
92 pub n_operations: usize,
94 pub channels: usize,
96 pub n_nodes: usize,
98 pub arch_lr: f64,
100 pub weight_lr: f64,
102 pub temperature: f64,
104}
105
106impl Default for DartsConfig {
107 fn default() -> Self {
108 Self {
109 n_cells: 4,
110 n_operations: 6,
111 channels: 16,
112 n_nodes: 4,
113 arch_lr: 3e-4,
114 weight_lr: 3e-4,
115 temperature: 1.0,
116 }
117 }
118}
119
120pub(crate) struct Lcg {
126 state: u64,
127}
128
129impl Lcg {
130 pub(crate) fn new(seed: u64) -> Self {
131 Self { state: seed }
132 }
133
134 pub(crate) fn next_f64(&mut self) -> f64 {
136 self.state = self
137 .state
138 .wrapping_mul(6_364_136_223_846_793_005)
139 .wrapping_add(1_442_695_040_888_963_407);
140 ((self.state >> 11) as f64) * (1.0 / (1u64 << 53) as f64)
141 }
142}
143
144#[derive(Debug, Clone, PartialEq)]
148pub enum AnnealingStrategy {
149 Linear,
151 Exponential,
153 Cosine,
155}
156
157#[derive(Debug, Clone)]
159pub struct TemperatureSchedule {
160 pub initial: f64,
162 pub final_temp: f64,
164 pub strategy: AnnealingStrategy,
166 pub total_steps: usize,
168}
169
170impl TemperatureSchedule {
171 pub fn new(
173 initial: f64,
174 final_temp: f64,
175 strategy: AnnealingStrategy,
176 total_steps: usize,
177 ) -> Self {
178 Self {
179 initial,
180 final_temp,
181 strategy,
182 total_steps,
183 }
184 }
185
186 pub fn temperature_at(&self, step: usize) -> f64 {
190 let t = step.min(self.total_steps);
191 let frac = if self.total_steps == 0 {
192 1.0
193 } else {
194 t as f64 / self.total_steps as f64
195 };
196 match self.strategy {
197 AnnealingStrategy::Linear => self.initial + (self.final_temp - self.initial) * frac,
198 AnnealingStrategy::Exponential => {
199 if self.initial <= 0.0 || self.final_temp <= 0.0 {
200 self.final_temp
201 } else {
202 self.initial * (self.final_temp / self.initial).powf(frac)
203 }
204 }
205 AnnealingStrategy::Cosine => {
206 self.final_temp
207 + 0.5
208 * (self.initial - self.final_temp)
209 * (1.0 + (std::f64::consts::PI * frac).cos())
210 }
211 }
212 }
213}
214
215#[derive(Debug, Clone)]
221pub struct MixedOperation {
222 pub arch_params: Vec<f64>,
224 pub operation_outputs: Option<Vec<Vec<f64>>>,
226}
227
228impl MixedOperation {
229 pub fn new(n_ops: usize) -> Self {
232 Self {
233 arch_params: vec![0.0_f64; n_ops],
234 operation_outputs: None,
235 }
236 }
237
238 pub fn weights(&self, temperature: f64) -> Vec<f64> {
242 let t = temperature.max(1e-8); let scaled: Vec<f64> = self.arch_params.iter().map(|a| a / t).collect();
244 let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
246 let exps: Vec<f64> = scaled.iter().map(|s| (s - max_val).exp()).collect();
247 let sum: f64 = exps.iter().sum();
248 if sum == 0.0 {
249 vec![1.0 / self.arch_params.len() as f64; self.arch_params.len()]
250 } else {
251 exps.iter().map(|e| e / sum).collect()
252 }
253 }
254
255 pub fn forward(
259 &mut self,
260 x: &[f64],
261 op_fn: impl Fn(usize, &[f64]) -> Vec<f64>,
262 temperature: f64,
263 ) -> Vec<f64> {
264 let w = self.weights(temperature);
265 let n_ops = self.arch_params.len();
266 let op_outputs: Vec<Vec<f64>> = (0..n_ops).map(|k| op_fn(k, x)).collect();
268 let out_len = op_outputs.first().map(|v| v.len()).unwrap_or(x.len());
270 let mut result = vec![0.0_f64; out_len];
271 for (k, out) in op_outputs.iter().enumerate() {
272 for (r, o) in result.iter_mut().zip(out.iter()) {
273 *r += w[k] * o;
274 }
275 }
276 self.operation_outputs = Some(op_outputs);
277 result
278 }
279
280 pub fn argmax_op(&self) -> usize {
282 self.arch_params
283 .iter()
284 .enumerate()
285 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
286 .map(|(i, _)| i)
287 .unwrap_or(0)
288 }
289}
290
291#[derive(Debug, Clone)]
297pub struct DartsCell {
298 pub n_nodes: usize,
300 pub n_input_nodes: usize,
302 pub edges: Vec<Vec<MixedOperation>>,
306}
307
308impl DartsCell {
309 pub fn new(n_input_nodes: usize, n_intermediate_nodes: usize, n_ops: usize) -> Self {
316 let edges: Vec<Vec<MixedOperation>> = (0..n_intermediate_nodes)
319 .map(|i| {
320 let n_predecessors = n_input_nodes + i;
321 (0..n_predecessors)
322 .map(|_| MixedOperation::new(n_ops))
323 .collect()
324 })
325 .collect();
326
327 Self {
328 n_nodes: n_intermediate_nodes,
329 n_input_nodes,
330 edges,
331 }
332 }
333
334 pub fn forward(&mut self, inputs: &[Vec<f64>], temperature: f64) -> Vec<f64> {
343 if inputs.is_empty() {
344 return Vec::new();
345 }
346 let feature_len = inputs[0].len();
347 let mut node_outputs: Vec<Vec<f64>> = inputs.to_vec();
349
350 for i in 0..self.n_nodes {
351 let n_prev = self.n_input_nodes + i;
352 let mut node_out = vec![0.0_f64; feature_len];
353 for j in 0..n_prev {
354 let src = node_outputs[j].clone();
355 let edge_out = self.edges[i][j].forward(&src, default_op_fn, temperature);
356 for (no, eo) in node_out.iter_mut().zip(edge_out.iter()) {
357 *no += eo;
358 }
359 }
360 node_outputs.push(node_out);
361 }
362
363 let mut result = Vec::with_capacity(self.n_nodes * feature_len);
365 for node_out in node_outputs.iter().skip(self.n_input_nodes) {
366 result.extend_from_slice(node_out);
367 }
368 result
369 }
370
371 pub fn arch_parameters(&self) -> Vec<f64> {
373 self.edges
374 .iter()
375 .flat_map(|row| row.iter().flat_map(|mo| mo.arch_params.iter().cloned()))
376 .collect()
377 }
378
379 pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
383 let n_params: usize = self
384 .edges
385 .iter()
386 .flat_map(|row| row.iter())
387 .map(|mo| mo.arch_params.len())
388 .sum();
389 if grads.len() != n_params {
390 return Err(OptimizeError::InvalidInput(format!(
391 "Expected {} gradient values, got {}",
392 n_params,
393 grads.len()
394 )));
395 }
396 let mut idx = 0;
397 for row in self.edges.iter_mut() {
398 for mo in row.iter_mut() {
399 for p in mo.arch_params.iter_mut() {
400 *p -= lr * grads[idx];
401 idx += 1;
402 }
403 }
404 }
405 Ok(())
406 }
407
408 pub fn derive_discrete(&self) -> Vec<Vec<usize>> {
413 self.edges
414 .iter()
415 .map(|row| row.iter().map(|mo| mo.argmax_op()).collect())
416 .collect()
417 }
418}
419
420fn default_op_fn(_k: usize, x: &[f64]) -> Vec<f64> {
422 x.to_vec()
423}
424
425#[derive(Debug, Clone)]
431pub struct DartsSearch {
432 pub cells: Vec<DartsCell>,
434 pub config: DartsConfig,
436 weights: Vec<f64>,
438}
439
440impl DartsSearch {
441 pub fn new(config: DartsConfig) -> Self {
443 let cells: Vec<DartsCell> = (0..config.n_cells)
444 .map(|_| DartsCell::new(2, config.n_nodes, config.n_operations))
445 .collect();
446 let weights = vec![0.01_f64; config.n_cells];
448 Self {
449 cells,
450 config,
451 weights,
452 }
453 }
454
455 pub fn arch_parameters(&self) -> Vec<f64> {
457 self.cells
458 .iter()
459 .flat_map(|c| c.arch_parameters())
460 .collect()
461 }
462
463 pub fn n_arch_params(&self) -> usize {
465 self.cells.iter().map(|c| c.arch_parameters().len()).sum()
466 }
467
468 pub fn update_arch_params(&mut self, grads: &[f64], lr: f64) -> OptimizeResult<()> {
472 let total = self.n_arch_params();
473 if grads.len() != total {
474 return Err(OptimizeError::InvalidInput(format!(
475 "Expected {} arch-param grads, got {}",
476 total,
477 grads.len()
478 )));
479 }
480 let mut offset = 0;
481 for cell in self.cells.iter_mut() {
482 let n = cell.arch_parameters().len();
483 cell.update_arch_params(&grads[offset..offset + n], lr)?;
484 offset += n;
485 }
486 Ok(())
487 }
488
489 pub fn derive_discrete_arch_indices(&self) -> Vec<Vec<Vec<usize>>> {
493 self.cells.iter().map(|c| c.derive_discrete()).collect()
494 }
495
496 pub fn derive_discrete_arch(&self) -> Vec<Vec<Operation>> {
502 let ops = Operation::all();
503 self.derive_discrete_arch_indices()
504 .iter()
505 .map(|cell_disc| {
506 cell_disc
507 .iter()
508 .flat_map(|node_edges| {
509 node_edges.iter().map(|&idx| {
510 if idx < ops.len() {
511 ops[idx]
512 } else {
513 Operation::Identity
514 }
515 })
516 })
517 .collect()
518 })
519 .collect()
520 }
521
522 fn compute_loss(&self, x: &[Vec<f64>], y: &[f64]) -> f64 {
527 if x.is_empty() || y.is_empty() {
528 return 0.0;
529 }
530 let w_sum: f64 = self.weights.iter().sum();
531 let mut loss = 0.0_f64;
532 let n = x.len().min(y.len());
533 for i in 0..n {
534 let x_mean = if x[i].is_empty() {
535 0.0
536 } else {
537 x[i].iter().sum::<f64>() / x[i].len() as f64
538 };
539 let pred = w_sum * x_mean;
540 let diff = pred - y[i];
541 loss += diff * diff;
542 }
543 loss / n as f64
544 }
545
546 fn weight_grads(&self, x: &[Vec<f64>], y: &[f64]) -> Vec<f64> {
548 let n = x.len().min(y.len());
549 if n == 0 {
550 return vec![0.0_f64; self.weights.len()];
551 }
552 let w_sum: f64 = self.weights.iter().sum();
553 let mut grad_sum = 0.0_f64;
554 for i in 0..n {
555 let x_mean = if x[i].is_empty() {
556 0.0
557 } else {
558 x[i].iter().sum::<f64>() / x[i].len() as f64
559 };
560 let pred = w_sum * x_mean;
561 let diff = pred - y[i];
562 grad_sum += 2.0 * diff * x_mean / n as f64;
564 }
565 vec![grad_sum; self.weights.len()]
567 }
568
569 fn arch_grads_fd(&self, x: &[Vec<f64>], y: &[f64]) -> Vec<f64> {
572 let n = self.n_arch_params();
573 if n == 0 {
574 return Vec::new();
575 }
576 let mut grads = vec![0.0_f64; n];
577 let h = 1e-4;
578 let mut offset = 0;
579 for cell_idx in 0..self.cells.len() {
580 let cell_n = self.cells[cell_idx].arch_parameters().len();
581 for local_j in 0..cell_n {
582 let global_j = offset + local_j;
583 let mut search_plus = self.clone();
585 let params_plus = search_plus.cells[cell_idx].arch_parameters();
586 let mut p_plus = params_plus.clone();
587 p_plus[local_j] += h;
588 let _ = search_plus.cells[cell_idx].set_arch_params(&p_plus);
590 let loss_plus = search_plus.compute_loss(x, y);
591
592 let mut search_minus = self.clone();
594 let params_minus = search_minus.cells[cell_idx].arch_parameters();
595 let mut p_minus = params_minus.clone();
596 p_minus[local_j] -= h;
597 let _ = search_minus.cells[cell_idx].set_arch_params(&p_minus);
598 let loss_minus = search_minus.compute_loss(x, y);
599
600 grads[global_j] = (loss_plus - loss_minus) / (2.0 * h);
601 }
602 offset += cell_n;
603 }
604 grads
605 }
606
607 pub fn bilevel_step(
614 &mut self,
615 train_x: &[Vec<f64>],
616 train_y: &[f64],
617 val_x: &[Vec<f64>],
618 val_y: &[f64],
619 ) -> (f64, f64) {
620 let train_loss = self.compute_loss(train_x, train_y);
621 let val_loss = self.compute_loss(val_x, val_y);
622
623 let w_grads = self.weight_grads(train_x, train_y);
625 let lr_w = self.config.weight_lr;
626 for (w, g) in self.weights.iter_mut().zip(w_grads.iter()) {
627 *w -= lr_w * g;
628 }
629
630 let a_grads = self.arch_grads_fd(val_x, val_y);
632 let lr_a = self.config.arch_lr;
633 if !a_grads.is_empty() {
634 let _ = self.update_arch_params(&a_grads, lr_a);
635 }
636
637 (train_loss, val_loss)
638 }
639}
640
641impl DartsCell {
644 pub fn set_arch_params(&mut self, params: &[f64]) -> OptimizeResult<()> {
646 let total: usize = self
647 .edges
648 .iter()
649 .flat_map(|r| r.iter())
650 .map(|m| m.arch_params.len())
651 .sum();
652 if params.len() != total {
653 return Err(OptimizeError::InvalidInput(format!(
654 "set_arch_params: expected {total} values, got {}",
655 params.len()
656 )));
657 }
658 let mut idx = 0;
659 for row in self.edges.iter_mut() {
660 for mo in row.iter_mut() {
661 for p in mo.arch_params.iter_mut() {
662 *p = params[idx];
663 idx += 1;
664 }
665 }
666 }
667 Ok(())
668 }
669}
670
671#[cfg(test)]
674mod tests {
675 use super::*;
676
677 #[test]
678 fn mixed_operation_weights_sum_to_one() {
679 let mo = MixedOperation::new(6);
680 let w = mo.weights(1.0);
681 assert_eq!(w.len(), 6);
682 let sum: f64 = w.iter().sum();
683 assert!((sum - 1.0).abs() < 1e-10, "weights sum = {sum}");
684 }
685
686 #[test]
687 fn mixed_operation_weights_temperature_effect() {
688 let mut mo = MixedOperation::new(4);
690 mo.arch_params = vec![1.0, 0.5, 0.3, 0.2];
691 let w_hot = mo.weights(10.0);
692 let w_cold = mo.weights(0.1);
693 assert!(w_cold[0] > w_hot[0], "cold should be sharper");
695 }
696
697 #[test]
698 fn mixed_operation_forward_correct_shape() {
699 let mut mo = MixedOperation::new(3);
700 let x = vec![1.0_f64; 8];
701 let out = mo.forward(&x, |_k, v| v.to_vec(), 1.0);
702 assert_eq!(out.len(), 8);
703 }
704
705 #[test]
706 fn darts_cell_forward_output_shape() {
707 let mut cell = DartsCell::new(2, 3, 4);
708 let inputs = vec![vec![1.0_f64; 8], vec![0.5_f64; 8]];
709 let out = cell.forward(&inputs, 1.0);
710 assert_eq!(out.len(), 24);
712 }
713
714 #[test]
715 fn derive_discrete_arch_returns_ops() {
716 let config = DartsConfig {
717 n_cells: 2,
718 n_operations: 6,
719 n_nodes: 3,
720 ..Default::default()
721 };
722 let search = DartsSearch::new(config);
723 let arch = search.derive_discrete_arch();
724 assert_eq!(arch.len(), 2, "one vec per cell");
725 for cell_ops in &arch {
728 assert!(!cell_ops.is_empty());
729 }
730 }
731
732 #[test]
733 fn bilevel_step_runs_without_error() {
734 let config = DartsConfig::default();
735 let mut search = DartsSearch::new(config);
736 let train_x = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
737 let train_y = vec![1.5, 3.5];
738 let val_x = vec![vec![0.5, 1.5]];
739 let val_y = vec![1.0];
740 let (tl, vl) = search.bilevel_step(&train_x, &train_y, &val_x, &val_y);
741 assert!(tl.is_finite());
742 assert!(vl.is_finite());
743 }
744
745 #[test]
746 fn arch_parameters_length_consistent() {
747 let config = DartsConfig {
748 n_cells: 3,
749 n_operations: 5,
750 n_nodes: 2,
751 ..Default::default()
752 };
753 let search = DartsSearch::new(config);
754 let params = search.arch_parameters();
755 assert_eq!(params.len(), search.n_arch_params());
756 }
757
758 #[test]
759 fn update_arch_params_wrong_length_errors() {
760 let mut search = DartsSearch::new(DartsConfig::default());
761 let result = search.update_arch_params(&[1.0, 2.0], 0.01);
762 assert!(result.is_err());
763 }
764
765 #[test]
768 fn temperature_schedule_linear_bounds() {
769 let sched = TemperatureSchedule::new(10.0, 1.0, AnnealingStrategy::Linear, 100);
770 let t0 = sched.temperature_at(0);
771 let t_half = sched.temperature_at(50);
772 let t_end = sched.temperature_at(100);
773 assert!((t0 - 10.0).abs() < 1e-10, "t0={t0}");
774 assert!((t_half - 5.5).abs() < 1e-10, "t_half={t_half}");
775 assert!((t_end - 1.0).abs() < 1e-10, "t_end={t_end}");
776 }
777
778 #[test]
779 fn temperature_schedule_exponential_bounds() {
780 let sched = TemperatureSchedule::new(10.0, 1.0, AnnealingStrategy::Exponential, 100);
781 let t0 = sched.temperature_at(0);
782 let t_end = sched.temperature_at(100);
783 assert!((t0 - 10.0).abs() < 1e-8, "t0={t0}");
784 assert!((t_end - 1.0).abs() < 1e-8, "t_end={t_end}");
785 let t_mid = sched.temperature_at(50);
787 assert!(t_mid > 1.0 && t_mid < 10.0, "t_mid={t_mid}");
788 }
789
790 #[test]
791 fn temperature_schedule_cosine_bounds() {
792 let sched = TemperatureSchedule::new(10.0, 1.0, AnnealingStrategy::Cosine, 100);
793 let t0 = sched.temperature_at(0);
794 let t_end = sched.temperature_at(100);
795 assert!((t0 - 10.0).abs() < 1e-8, "t0={t0}");
796 assert!((t_end - 1.0).abs() < 1e-8, "t_end={t_end}");
797 }
798
799 #[test]
800 fn temperature_schedule_clamped_beyond_total() {
801 let sched = TemperatureSchedule::new(5.0, 1.0, AnnealingStrategy::Linear, 10);
802 let t_over = sched.temperature_at(999);
803 let t_end = sched.temperature_at(10);
804 assert!((t_over - t_end).abs() < 1e-10);
805 }
806
807 #[test]
808 fn temperature_schedule_zero_steps() {
809 let sched = TemperatureSchedule::new(5.0, 1.0, AnnealingStrategy::Linear, 0);
811 let t = sched.temperature_at(0);
812 assert!((t - 1.0).abs() < 1e-10, "t={t}");
813 }
814}