1use super::Lcg;
24use crate::error::{OptimizeError, OptimizeResult};
25
26pub fn encode_architecture(arch_indices: &[Vec<Vec<usize>>], n_operations: usize) -> Vec<f64> {
37 let norm = (n_operations.max(1) - 1) as f64;
38 let denom = norm.max(1.0);
39 arch_indices
40 .iter()
41 .flat_map(|cell| {
42 cell.iter()
43 .flat_map(|node_edges| node_edges.iter().map(|&op_idx| op_idx as f64 / denom))
44 })
45 .collect()
46}
47
48#[derive(Debug, Clone)]
52pub struct PredictorNasConfig {
53 pub n_cells: usize,
55 pub n_operations: usize,
57 pub channels: usize,
59 pub n_nodes: usize,
61 pub n_initial_samples: usize,
63 pub n_iterations: usize,
65 pub n_candidates_per_iter: usize,
67 pub n_top_to_evaluate: usize,
69 pub ucb_kappa: f64,
71 pub seed: u64,
73}
74
75impl Default for PredictorNasConfig {
76 fn default() -> Self {
77 Self {
78 n_cells: 3,
79 n_operations: 6,
80 channels: 32,
81 n_nodes: 4,
82 n_initial_samples: 5,
83 n_iterations: 3,
84 n_candidates_per_iter: 20,
85 n_top_to_evaluate: 2,
86 ucb_kappa: 2.0,
87 seed: 42,
88 }
89 }
90}
91
92#[derive(Debug, Clone, PartialEq)]
96pub enum AcquisitionStrategy {
97 Ucb,
99 ExpectedImprovement,
101}
102
103#[derive(Debug, Clone)]
107pub struct PredictorNasResult {
108 pub best_arch_indices: Vec<Vec<Vec<usize>>>,
110 pub best_score: f64,
112 pub n_evaluated: usize,
114}
115
116struct RidgeSurrogate {
127 x_train: Vec<Vec<f64>>,
129 y_train: Vec<f64>,
131 alpha: f64,
133 coeffs: Vec<f64>,
135}
136
137impl RidgeSurrogate {
138 fn new(alpha: f64) -> Self {
139 Self {
140 x_train: Vec::new(),
141 y_train: Vec::new(),
142 alpha,
143 coeffs: Vec::new(),
144 }
145 }
146
147 fn rbf(&self, a: &[f64], b: &[f64]) -> f64 {
149 let sq_dist: f64 = a
150 .iter()
151 .zip(b.iter())
152 .map(|(ai, bi)| (ai - bi) * (ai - bi))
153 .sum();
154 (-sq_dist / 2.0).exp()
155 }
156
157 fn fit(&mut self, x: &[Vec<f64>], y: &[f64]) {
159 self.x_train = x.to_vec();
160 self.y_train = y.to_vec();
161 let n = x.len();
162 if n == 0 {
163 self.coeffs = Vec::new();
164 return;
165 }
166
167 let mut k_matrix: Vec<Vec<f64>> = (0..n)
169 .map(|i| {
170 (0..n)
171 .map(|j| {
172 let kij = self.rbf(&x[i], &x[j]);
173 if i == j {
174 kij + self.alpha
175 } else {
176 kij
177 }
178 })
179 .collect()
180 })
181 .collect();
182
183 let mut rhs: Vec<f64> = y.to_vec();
185 gauss_elimination(&mut k_matrix, &mut rhs);
186 self.coeffs = rhs;
187 }
188
189 fn predict_mean_std(&self, x: &[f64]) -> (f64, f64) {
193 let n = self.x_train.len();
194 if n == 0 || self.coeffs.len() != n {
195 return (0.0, 1.0);
197 }
198
199 let k_vec: Vec<f64> = self.x_train.iter().map(|xi| self.rbf(x, xi)).collect();
201
202 let mean: f64 = k_vec
204 .iter()
205 .zip(self.coeffs.iter())
206 .map(|(ki, ci)| ki * ci)
207 .sum();
208
209 let k_self = self.rbf(x, x); let var_approx: f64 = k_self
214 - k_vec
215 .iter()
216 .zip(self.x_train.iter())
217 .map(|(&kxi, xi)| {
218 let kii = self.rbf(xi, xi) + self.alpha;
219 kxi * kxi / kii.max(1e-12)
220 })
221 .sum::<f64>();
222
223 let std = var_approx.max(0.0).sqrt();
224 (mean, std)
225 }
226}
227
228fn gauss_elimination(a: &mut Vec<Vec<f64>>, b: &mut Vec<f64>) -> bool {
233 let n = b.len();
234 if n == 0 {
235 return true;
236 }
237
238 for col in 0..n {
239 let pivot_row = (col..n)
241 .max_by(|&r1, &r2| {
242 a[r1][col]
243 .abs()
244 .partial_cmp(&a[r2][col].abs())
245 .unwrap_or(std::cmp::Ordering::Equal)
246 })
247 .unwrap_or(col);
248
249 if a[pivot_row][col].abs() < 1e-14 {
250 continue;
252 }
253
254 a.swap(col, pivot_row);
256 b.swap(col, pivot_row);
257
258 let pivot = a[col][col];
259 for row in (col + 1)..n {
261 let factor = a[row][col] / pivot;
262 b[row] -= factor * b[col];
263 for k in col..n {
264 a[row][k] -= factor * a[col][k];
265 }
266 }
267 }
268
269 for col in (0..n).rev() {
271 if a[col][col].abs() < 1e-14 {
272 b[col] = 0.0;
273 continue;
274 }
275 for row in 0..col {
276 let factor = a[row][col] / a[col][col];
277 b[row] -= factor * b[col];
278 }
279 b[col] /= a[col][col];
280 }
281 true
282}
283
284pub struct PredictorNasSearcher {
291 config: PredictorNasConfig,
292 surrogate: RidgeSurrogate,
293 rng: Lcg,
294 evaluated_x: Vec<Vec<f64>>,
295 evaluated_y: Vec<f64>,
296}
297
298impl PredictorNasSearcher {
299 pub fn new(config: PredictorNasConfig) -> Self {
301 let rng = Lcg::new(config.seed);
302 Self {
303 surrogate: RidgeSurrogate::new(1e-3),
304 config,
305 rng,
306 evaluated_x: Vec::new(),
307 evaluated_y: Vec::new(),
308 }
309 }
310
311 fn sample_random_arch(&mut self) -> Vec<Vec<Vec<usize>>> {
313 let n_ops = self.config.n_operations;
314 (0..self.config.n_cells)
315 .map(|_| {
316 (0..self.config.n_nodes)
317 .map(|i| {
318 let n_predecessors = 2 + i; (0..n_predecessors)
320 .map(|_| {
321 let raw = self.rng.next_f64();
322 ((raw * n_ops as f64) as usize).min(n_ops - 1)
323 })
324 .collect()
325 })
326 .collect()
327 })
328 .collect()
329 }
330
331 fn ucb(&self, x: &[f64]) -> f64 {
333 let (mean, std) = self.surrogate.predict_mean_std(x);
334 mean + self.config.ucb_kappa * std
335 }
336
337 fn expected_improvement(&self, x: &[f64]) -> f64 {
339 let f_best = self
340 .evaluated_y
341 .iter()
342 .cloned()
343 .fold(f64::NEG_INFINITY, f64::max);
344 if f_best.is_infinite() {
345 return 0.0;
346 }
347 let (mean, std) = self.surrogate.predict_mean_std(x);
348 if std < 1e-12 {
349 return (mean - f_best).max(0.0);
350 }
351 let z = (mean - f_best) / std;
352 let phi_z = normal_cdf(z);
354 let pdf_z = normal_pdf(z);
355 (mean - f_best) * phi_z + std * pdf_z
356 }
357
358 fn acquisition(&self, x: &[f64], strategy: &AcquisitionStrategy) -> f64 {
360 match strategy {
361 AcquisitionStrategy::Ucb => self.ucb(x),
362 AcquisitionStrategy::ExpectedImprovement => self.expected_improvement(x),
363 }
364 }
365
366 fn evaluate_and_record(
368 &mut self,
369 arch: &[Vec<Vec<usize>>],
370 eval_fn: &impl Fn(&[Vec<Vec<usize>>]) -> f64,
371 ) -> f64 {
372 let score = eval_fn(arch);
373 let enc = encode_architecture(arch, self.config.n_operations);
374 self.evaluated_x.push(enc);
375 self.evaluated_y.push(score);
376 score
377 }
378
379 fn refit_surrogate(&mut self) {
381 self.surrogate.fit(&self.evaluated_x, &self.evaluated_y);
382 }
383
384 pub fn search(
392 &mut self,
393 eval_fn: impl Fn(&[Vec<Vec<usize>>]) -> f64,
394 ) -> OptimizeResult<PredictorNasResult> {
395 if self.config.n_initial_samples == 0 {
396 return Err(OptimizeError::InvalidInput(
397 "n_initial_samples must be > 0".to_string(),
398 ));
399 }
400
401 for _ in 0..self.config.n_initial_samples {
403 let arch = self.sample_random_arch();
404 self.evaluate_and_record(&arch, &eval_fn);
405 }
406 self.refit_surrogate();
407
408 let strategy = AcquisitionStrategy::Ucb;
410 for _ in 0..self.config.n_iterations {
411 let mut candidates: Vec<(f64, Vec<Vec<Vec<usize>>>)> =
413 (0..self.config.n_candidates_per_iter)
414 .map(|_| {
415 let arch = self.sample_random_arch();
416 let enc = encode_architecture(&arch, self.config.n_operations);
417 let acq = self.acquisition(&enc, &strategy);
418 (acq, arch)
419 })
420 .collect();
421
422 candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
424
425 let n_eval = self.config.n_top_to_evaluate.min(candidates.len());
427 for (_, arch) in candidates.into_iter().take(n_eval) {
428 self.evaluate_and_record(&arch, &eval_fn);
429 }
430
431 self.refit_surrogate();
432 }
433
434 let best_idx = self
436 .evaluated_y
437 .iter()
438 .enumerate()
439 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
440 .map(|(i, _)| i)
441 .ok_or_else(|| {
442 OptimizeError::ComputationError("No architectures were evaluated".to_string())
443 })?;
444
445 let best_score = self.evaluated_y[best_idx];
446 let best_enc = &self.evaluated_x[best_idx];
447 let norm = (self.config.n_operations.max(1) - 1) as f64;
449 let denom = norm.max(1.0);
450 let decoded_flat: Vec<usize> = best_enc
451 .iter()
452 .map(|&v| ((v * denom).round() as usize).min(self.config.n_operations - 1))
453 .collect();
454 let best_arch_indices =
455 reconstruct_arch_indices(&decoded_flat, self.config.n_cells, self.config.n_nodes);
456
457 Ok(PredictorNasResult {
458 best_arch_indices,
459 best_score,
460 n_evaluated: self.evaluated_y.len(),
461 })
462 }
463}
464
465fn reconstruct_arch_indices(
469 flat: &[usize],
470 n_cells: usize,
471 n_nodes: usize,
472) -> Vec<Vec<Vec<usize>>> {
473 let mut offset = 0;
474 let mut result = Vec::with_capacity(n_cells);
475 for _ in 0..n_cells {
476 let mut cell = Vec::with_capacity(n_nodes);
477 for i in 0..n_nodes {
478 let n_pred = 2 + i;
479 let node_edges: Vec<usize> = if offset + n_pred <= flat.len() {
480 flat[offset..offset + n_pred].to_vec()
481 } else {
482 vec![0; n_pred]
483 };
484 offset += n_pred;
485 cell.push(node_edges);
486 }
487 result.push(cell);
488 }
489 result
490}
491
492fn normal_cdf(x: f64) -> f64 {
494 let t = 1.0 / (1.0 + 0.2316419 * x.abs());
495 let poly = t
496 * (0.319_381_53
497 + t * (-0.356_563_782
498 + t * (1.781_477_937 + t * (-1.821_255_978 + t * 1.330_274_429))));
499 let pdf = normal_pdf(x);
500 let cdf_pos = 1.0 - pdf * poly;
501 if x >= 0.0 {
502 cdf_pos
503 } else {
504 1.0 - cdf_pos
505 }
506}
507
508fn normal_pdf(x: f64) -> f64 {
510 (-x * x / 2.0).exp() / (2.0 * std::f64::consts::PI).sqrt()
511}
512
513#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[test]
522 fn test_encode_architecture_deterministic() {
523 let arch: Vec<Vec<Vec<usize>>> = vec![vec![vec![0, 1], vec![2, 3, 0], vec![1, 0, 2, 1]]];
524 let enc1 = encode_architecture(&arch, 6);
525 let enc2 = encode_architecture(&arch, 6);
526 assert_eq!(enc1, enc2);
527 }
528
529 #[test]
530 fn test_encode_architecture_length() {
531 let arch: Vec<Vec<Vec<usize>>> = (0..2_usize)
533 .map(|_| {
534 (0..4_usize)
535 .map(|i| vec![0_usize; 2 + i])
536 .collect::<Vec<_>>()
537 })
538 .collect();
539 let enc = encode_architecture(&arch, 6);
540 assert_eq!(enc.len(), 28, "enc.len()={}", enc.len());
541 }
542
543 #[test]
544 fn test_encode_architecture_range() {
545 let arch: Vec<Vec<Vec<usize>>> = vec![vec![vec![0, 5], vec![3, 1, 5]]];
546 let enc = encode_architecture(&arch, 6);
547 for &v in &enc {
548 assert!(v >= 0.0 && v <= 1.0, "v={v} out of [0,1]");
549 }
550 }
551
552 #[test]
553 fn test_encode_architecture_single_op() {
554 let arch: Vec<Vec<Vec<usize>>> = vec![vec![vec![0, 0]]];
556 let enc = encode_architecture(&arch, 1);
557 for &v in &enc {
558 assert!((v - 0.0).abs() < 1e-10, "v={v}");
559 }
560 }
561
562 #[test]
565 fn test_ridge_surrogate_predict_after_fit() {
566 let mut surr = RidgeSurrogate::new(1e-3);
567 let x = vec![vec![0.0], vec![0.5], vec![1.0]];
568 let y = vec![0.0, 0.5, 1.0];
569 surr.fit(&x, &y);
570 let (mean, _std) = surr.predict_mean_std(&[0.25]);
571 assert!(mean.is_finite(), "mean={mean}");
572 }
573
574 #[test]
575 fn test_ridge_surrogate_empty_returns_prior() {
576 let surr = RidgeSurrogate::new(1e-3);
577 let (mean, std) = surr.predict_mean_std(&[0.5]);
578 assert!((mean - 0.0).abs() < 1e-10, "mean={mean}");
579 assert!((std - 1.0).abs() < 1e-10, "std={std}");
580 }
581
582 #[test]
583 fn test_ridge_surrogate_std_nonneg() {
584 let mut surr = RidgeSurrogate::new(1e-3);
585 let x: Vec<Vec<f64>> = (0..5).map(|i| vec![i as f64 / 4.0]).collect();
586 let y: Vec<f64> = (0..5).map(|i| i as f64).collect();
587 surr.fit(&x, &y);
588 for i in 0..10 {
589 let xq = vec![i as f64 / 10.0];
590 let (_mean, std) = surr.predict_mean_std(&xq);
591 assert!(std >= 0.0, "std={std} at x={}", xq[0]);
592 }
593 }
594
595 #[test]
598 fn test_predictor_search_returns_result() {
599 let eval_fn = |arch: &[Vec<Vec<usize>>]| -> f64 {
601 let total: usize = arch
602 .iter()
603 .flat_map(|c| c.iter().flat_map(|n| n.iter()))
604 .sum();
605 -(total as f64)
606 };
607
608 let config = PredictorNasConfig {
609 n_cells: 2,
610 n_nodes: 3,
611 n_operations: 6,
612 n_initial_samples: 4,
613 n_iterations: 2,
614 n_candidates_per_iter: 10,
615 n_top_to_evaluate: 2,
616 ..Default::default()
617 };
618
619 let mut searcher = PredictorNasSearcher::new(config);
620 let result = searcher.search(eval_fn).expect("search should succeed");
621
622 assert!(
623 result.best_score.is_finite(),
624 "best_score={}",
625 result.best_score
626 );
627 assert!(
628 result.n_evaluated >= 4,
629 "n_evaluated={}",
630 result.n_evaluated
631 );
632 }
633
634 #[test]
635 fn test_active_learning_improves_best_score() {
636 let eval_fn = |arch: &[Vec<Vec<usize>>]| -> f64 {
637 let total: usize = arch
638 .iter()
639 .flat_map(|c| c.iter().flat_map(|n| n.iter()))
640 .sum();
641 -(total as f64)
642 };
643
644 let config_small = PredictorNasConfig {
645 n_cells: 1,
646 n_nodes: 2,
647 n_operations: 6,
648 n_initial_samples: 3,
649 n_iterations: 0,
650 n_candidates_per_iter: 5,
651 n_top_to_evaluate: 1,
652 seed: 7,
653 ..Default::default()
654 };
655 let mut searcher_small = PredictorNasSearcher::new(config_small);
656 let result_small = searcher_small.search(&eval_fn).expect("small search");
657
658 let config_large = PredictorNasConfig {
659 n_cells: 1,
660 n_nodes: 2,
661 n_operations: 6,
662 n_initial_samples: 3,
663 n_iterations: 4,
664 n_candidates_per_iter: 10,
665 n_top_to_evaluate: 2,
666 seed: 7,
667 ..Default::default()
668 };
669 let mut searcher_large = PredictorNasSearcher::new(config_large);
670 let result_large = searcher_large.search(&eval_fn).expect("large search");
671
672 assert!(
674 result_large.n_evaluated >= result_small.n_evaluated,
675 "large n_eval={} < small n_eval={}",
676 result_large.n_evaluated,
677 result_small.n_evaluated
678 );
679 assert!(result_small.best_score.is_finite());
680 assert!(result_large.best_score.is_finite());
681 }
682
683 #[test]
684 fn test_predictor_n_evaluated_count() {
685 let config = PredictorNasConfig {
686 n_initial_samples: 5,
687 n_iterations: 3,
688 n_top_to_evaluate: 2,
689 n_candidates_per_iter: 10,
690 n_cells: 2,
691 n_nodes: 3,
692 n_operations: 6,
693 ..Default::default()
694 };
695 let expected_min = 5 + 3 * 2; let mut searcher = PredictorNasSearcher::new(config);
698 let result = searcher.search(|_| 1.0).expect("search should not fail");
699
700 assert!(
701 result.n_evaluated >= expected_min,
702 "n_evaluated={} < expected_min={expected_min}",
703 result.n_evaluated
704 );
705 }
706
707 #[test]
708 fn test_predictor_zero_iterations_still_works() {
709 let config = PredictorNasConfig {
710 n_initial_samples: 3,
711 n_iterations: 0,
712 ..Default::default()
713 };
714 let mut searcher = PredictorNasSearcher::new(config);
715 let result = searcher
716 .search(|_| 42.0)
717 .expect("zero-iteration search should succeed");
718 assert_eq!(result.best_score, 42.0);
719 }
720
721 #[test]
722 fn test_normal_cdf_basic() {
723 assert!((normal_cdf(0.0) - 0.5).abs() < 0.01);
725 assert!(normal_cdf(10.0) > 0.999);
727 assert!(normal_cdf(-10.0) < 0.001);
729 }
730
731 #[test]
732 fn test_gauss_elimination_simple() {
733 let mut a = vec![vec![2.0_f64]];
735 let mut b = vec![4.0_f64];
736 gauss_elimination(&mut a, &mut b);
737 assert!((b[0] - 2.0).abs() < 1e-10, "b[0]={}", b[0]);
738 }
739
740 #[test]
741 fn test_gauss_elimination_2x2() {
742 let mut a = vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]];
744 let mut b = vec![5.0_f64, 11.0];
745 gauss_elimination(&mut a, &mut b);
746 assert!((b[0] - 1.0).abs() < 1e-9, "b[0]={}", b[0]);
747 assert!((b[1] - 2.0).abs() < 1e-9, "b[1]={}", b[1]);
748 }
749}