Skip to main content

scirs2_optimize/darts/
predictor_nas.rs

1//! Predictor-based Neural Architecture Search.
2//!
3//! Uses a Gaussian kernel ridge regression surrogate (manually implemented,
4//! no external GP dependency) to model the mapping from architecture encodings
5//! to validation scores.  An active learning loop iterates:
6//!
7//! 1. Maintain a surrogate trained on evaluated architectures.
8//! 2. Use UCB or Expected Improvement acquisition to propose candidates.
9//! 3. Evaluate top candidates with the true evaluation function.
10//! 4. Retrain the surrogate and repeat.
11//!
12//! ## Architecture Encoding
13//!
14//! An architecture is encoded as a flat `Vec<f64>` of normalised operation
15//! indices in `[0, 1]`.  Each element `op_idx / (n_operations - 1)`.
16//!
17//! ## Surrogate
18//!
19//! Gaussian kernel ridge regression with RBF kernel:
20//! `k(a, b) = exp(-||a - b||² / 2)`.
21//! Prediction: `ŷ(x) = k(x, X)ᵀ (K + αI)⁻¹ y`.
22
23use super::Lcg;
24use crate::error::{OptimizeError, OptimizeResult};
25
26// ────────────────────────────────────────────────── Architecture encoding ──
27
28/// Encode a discrete architecture as a normalised flat `Vec<f64>`.
29///
30/// Each operation index `op_idx` is divided by `(n_operations - 1).max(1)` so
31/// values lie in `[0, 1]`.
32///
33/// # Arguments
34/// - `arch_indices`: `[cell][node][predecessor]` operation indices.
35/// - `n_operations`: Total number of candidate operations.
36pub fn encode_architecture(arch_indices: &[Vec<Vec<usize>>], n_operations: usize) -> Vec<f64> {
37    let norm = (n_operations.max(1) - 1) as f64;
38    let denom = norm.max(1.0);
39    arch_indices
40        .iter()
41        .flat_map(|cell| {
42            cell.iter()
43                .flat_map(|node_edges| node_edges.iter().map(|&op_idx| op_idx as f64 / denom))
44        })
45        .collect()
46}
47
48// ──────────────────────────────────────────────── PredictorNasConfig ──
49
50/// Configuration for the predictor-based NAS searcher.
51#[derive(Debug, Clone)]
52pub struct PredictorNasConfig {
53    /// Number of cells per architecture.
54    pub n_cells: usize,
55    /// Number of candidate operations per edge.
56    pub n_operations: usize,
57    /// Number of feature channels (informational; used for index bounds).
58    pub channels: usize,
59    /// Number of intermediate nodes per cell.
60    pub n_nodes: usize,
61    /// Number of random architectures evaluated in Phase 1 (warm-up).
62    pub n_initial_samples: usize,
63    /// Number of active-learning iterations in Phase 2.
64    pub n_iterations: usize,
65    /// Number of candidate architectures proposed per iteration.
66    pub n_candidates_per_iter: usize,
67    /// Number of top candidates actually evaluated per iteration.
68    pub n_top_to_evaluate: usize,
69    /// Exploration-exploitation trade-off for UCB: `μ + κ·σ`.
70    pub ucb_kappa: f64,
71    /// Random seed for the internal LCG.
72    pub seed: u64,
73}
74
75impl Default for PredictorNasConfig {
76    fn default() -> Self {
77        Self {
78            n_cells: 3,
79            n_operations: 6,
80            channels: 32,
81            n_nodes: 4,
82            n_initial_samples: 5,
83            n_iterations: 3,
84            n_candidates_per_iter: 20,
85            n_top_to_evaluate: 2,
86            ucb_kappa: 2.0,
87            seed: 42,
88        }
89    }
90}
91
92// ─────────────────────────────────────────────────── AcquisitionStrategy ──
93
94/// Strategy for selecting candidate architectures from the surrogate.
95#[derive(Debug, Clone, PartialEq)]
96pub enum AcquisitionStrategy {
97    /// Upper Confidence Bound: `μ(x) + κ · σ(x)`.
98    Ucb,
99    /// Expected Improvement over the current best: `E[max(0, f(x) - f_best)]`.
100    ExpectedImprovement,
101}
102
103// ──────────────────────────────────────────────────── PredictorNasResult ──
104
105/// Result of a predictor-based NAS search.
106#[derive(Debug, Clone)]
107pub struct PredictorNasResult {
108    /// Discrete architecture indices for the best found architecture.
109    pub best_arch_indices: Vec<Vec<Vec<usize>>>,
110    /// Score of the best architecture (higher is better).
111    pub best_score: f64,
112    /// Total number of architectures that were actually evaluated.
113    pub n_evaluated: usize,
114}
115
116// ──────────────────────────────────────────── RidgeSurrogate (private) ──
117
118/// Gaussian kernel ridge regression surrogate.
119///
120/// RBF kernel: `k(a, b) = exp(-||a - b||² / 2)`.
121/// Predictions: `ŷ(x) = k(x, X)ᵀ α_coeff` where
122/// `α_coeff = (K + α·I)⁻¹ y`.
123///
124/// For the small datasets that arise during NAS warm-up (n ≤ ~20) we solve
125/// the linear system via Gaussian elimination.
126struct RidgeSurrogate {
127    /// Training inputs, one row per sample.
128    x_train: Vec<Vec<f64>>,
129    /// Training targets.
130    y_train: Vec<f64>,
131    /// Regularisation coefficient.
132    alpha: f64,
133    /// Solved coefficients `(K + α·I)⁻¹ y`.  Empty when not yet fitted.
134    coeffs: Vec<f64>,
135}
136
137impl RidgeSurrogate {
138    fn new(alpha: f64) -> Self {
139        Self {
140            x_train: Vec::new(),
141            y_train: Vec::new(),
142            alpha,
143            coeffs: Vec::new(),
144        }
145    }
146
147    /// Compute RBF kernel value between two vectors.
148    fn rbf(&self, a: &[f64], b: &[f64]) -> f64 {
149        let sq_dist: f64 = a
150            .iter()
151            .zip(b.iter())
152            .map(|(ai, bi)| (ai - bi) * (ai - bi))
153            .sum();
154        (-sq_dist / 2.0).exp()
155    }
156
157    /// Fit the surrogate to `(x, y)` training data.
158    fn fit(&mut self, x: &[Vec<f64>], y: &[f64]) {
159        self.x_train = x.to_vec();
160        self.y_train = y.to_vec();
161        let n = x.len();
162        if n == 0 {
163            self.coeffs = Vec::new();
164            return;
165        }
166
167        // Build kernel matrix K (n×n) and add regularisation on diagonal.
168        let mut k_matrix: Vec<Vec<f64>> = (0..n)
169            .map(|i| {
170                (0..n)
171                    .map(|j| {
172                        let kij = self.rbf(&x[i], &x[j]);
173                        if i == j {
174                            kij + self.alpha
175                        } else {
176                            kij
177                        }
178                    })
179                    .collect()
180            })
181            .collect();
182
183        // Solve (K + αI) α_coeff = y via Gaussian elimination with partial pivoting.
184        let mut rhs: Vec<f64> = y.to_vec();
185        gauss_elimination(&mut k_matrix, &mut rhs);
186        self.coeffs = rhs;
187    }
188
189    /// Predict `(mean, std)` for a query point `x`.
190    ///
191    /// The predictive std is estimated as a diagonal approximation.
192    fn predict_mean_std(&self, x: &[f64]) -> (f64, f64) {
193        let n = self.x_train.len();
194        if n == 0 || self.coeffs.len() != n {
195            // Uninformed prior: mean = 0, large std.
196            return (0.0, 1.0);
197        }
198
199        // k_vec[i] = k(x, x_train[i])
200        let k_vec: Vec<f64> = self.x_train.iter().map(|xi| self.rbf(x, xi)).collect();
201
202        // Mean: k_vec · coeffs
203        let mean: f64 = k_vec
204            .iter()
205            .zip(self.coeffs.iter())
206            .map(|(ki, ci)| ki * ci)
207            .sum();
208
209        // Predictive variance approximation:
210        // var ≈ k(x,x) - Σ_i k(x,xi)² / (k(xi,xi) + α)
211        let k_self = self.rbf(x, x); // = 1.0 for RBF
212
213        let var_approx: f64 = k_self
214            - k_vec
215                .iter()
216                .zip(self.x_train.iter())
217                .map(|(&kxi, xi)| {
218                    let kii = self.rbf(xi, xi) + self.alpha;
219                    kxi * kxi / kii.max(1e-12)
220                })
221                .sum::<f64>();
222
223        let std = var_approx.max(0.0).sqrt();
224        (mean, std)
225    }
226}
227
228/// Gaussian elimination with partial pivoting.
229///
230/// Modifies `a` and `b` in-place to solve `a · x = b`.
231/// Result is stored back in `b`.
232fn gauss_elimination(a: &mut Vec<Vec<f64>>, b: &mut Vec<f64>) -> bool {
233    let n = b.len();
234    if n == 0 {
235        return true;
236    }
237
238    for col in 0..n {
239        // Find pivot.
240        let pivot_row = (col..n)
241            .max_by(|&r1, &r2| {
242                a[r1][col]
243                    .abs()
244                    .partial_cmp(&a[r2][col].abs())
245                    .unwrap_or(std::cmp::Ordering::Equal)
246            })
247            .unwrap_or(col);
248
249        if a[pivot_row][col].abs() < 1e-14 {
250            // Near-singular: skip (regularisation prevents this in practice).
251            continue;
252        }
253
254        // Swap rows.
255        a.swap(col, pivot_row);
256        b.swap(col, pivot_row);
257
258        let pivot = a[col][col];
259        // Eliminate below.
260        for row in (col + 1)..n {
261            let factor = a[row][col] / pivot;
262            b[row] -= factor * b[col];
263            for k in col..n {
264                a[row][k] -= factor * a[col][k];
265            }
266        }
267    }
268
269    // Back-substitution.
270    for col in (0..n).rev() {
271        if a[col][col].abs() < 1e-14 {
272            b[col] = 0.0;
273            continue;
274        }
275        for row in 0..col {
276            let factor = a[row][col] / a[col][col];
277            b[row] -= factor * b[col];
278        }
279        b[col] /= a[col][col];
280    }
281    true
282}
283
284// ──────────────────────────────────────────────── PredictorNasSearcher ──
285
286/// Predictor-based NAS searcher.
287///
288/// Maintains a `RidgeSurrogate` that is retrained after each batch of true
289/// evaluations.
290pub struct PredictorNasSearcher {
291    config: PredictorNasConfig,
292    surrogate: RidgeSurrogate,
293    rng: Lcg,
294    evaluated_x: Vec<Vec<f64>>,
295    evaluated_y: Vec<f64>,
296}
297
298impl PredictorNasSearcher {
299    /// Construct a new searcher from the given config.
300    pub fn new(config: PredictorNasConfig) -> Self {
301        let rng = Lcg::new(config.seed);
302        Self {
303            surrogate: RidgeSurrogate::new(1e-3),
304            config,
305            rng,
306            evaluated_x: Vec::new(),
307            evaluated_y: Vec::new(),
308        }
309    }
310
311    /// Sample a random architecture — `[cell][node][predecessor]` indices.
312    fn sample_random_arch(&mut self) -> Vec<Vec<Vec<usize>>> {
313        let n_ops = self.config.n_operations;
314        (0..self.config.n_cells)
315            .map(|_| {
316                (0..self.config.n_nodes)
317                    .map(|i| {
318                        let n_predecessors = 2 + i; // 2 fixed input nodes
319                        (0..n_predecessors)
320                            .map(|_| {
321                                let raw = self.rng.next_f64();
322                                ((raw * n_ops as f64) as usize).min(n_ops - 1)
323                            })
324                            .collect()
325                    })
326                    .collect()
327            })
328            .collect()
329    }
330
331    /// Compute the UCB acquisition value for a query encoding.
332    fn ucb(&self, x: &[f64]) -> f64 {
333        let (mean, std) = self.surrogate.predict_mean_std(x);
334        mean + self.config.ucb_kappa * std
335    }
336
337    /// Compute the Expected Improvement acquisition value.
338    fn expected_improvement(&self, x: &[f64]) -> f64 {
339        let f_best = self
340            .evaluated_y
341            .iter()
342            .cloned()
343            .fold(f64::NEG_INFINITY, f64::max);
344        if f_best.is_infinite() {
345            return 0.0;
346        }
347        let (mean, std) = self.surrogate.predict_mean_std(x);
348        if std < 1e-12 {
349            return (mean - f_best).max(0.0);
350        }
351        let z = (mean - f_best) / std;
352        // EI = (mean - f_best) · Φ(z) + std · φ(z)
353        let phi_z = normal_cdf(z);
354        let pdf_z = normal_pdf(z);
355        (mean - f_best) * phi_z + std * pdf_z
356    }
357
358    /// Acquisition function value (dispatches to UCB or EI).
359    fn acquisition(&self, x: &[f64], strategy: &AcquisitionStrategy) -> f64 {
360        match strategy {
361            AcquisitionStrategy::Ucb => self.ucb(x),
362            AcquisitionStrategy::ExpectedImprovement => self.expected_improvement(x),
363        }
364    }
365
366    /// Evaluate an architecture with `eval_fn` and record the result.
367    fn evaluate_and_record(
368        &mut self,
369        arch: &[Vec<Vec<usize>>],
370        eval_fn: &impl Fn(&[Vec<Vec<usize>>]) -> f64,
371    ) -> f64 {
372        let score = eval_fn(arch);
373        let enc = encode_architecture(arch, self.config.n_operations);
374        self.evaluated_x.push(enc);
375        self.evaluated_y.push(score);
376        score
377    }
378
379    /// Refit the surrogate to all evaluated data.
380    fn refit_surrogate(&mut self) {
381        self.surrogate.fit(&self.evaluated_x, &self.evaluated_y);
382    }
383
384    /// Run the full predictor-based NAS search.
385    ///
386    /// Phase 1: evaluate `n_initial_samples` random architectures.
387    /// Phase 2: run `n_iterations` active-learning rounds.
388    ///
389    /// # Arguments
390    /// - `eval_fn`: True evaluation function.  Higher return value = better arch.
391    pub fn search(
392        &mut self,
393        eval_fn: impl Fn(&[Vec<Vec<usize>>]) -> f64,
394    ) -> OptimizeResult<PredictorNasResult> {
395        if self.config.n_initial_samples == 0 {
396            return Err(OptimizeError::InvalidInput(
397                "n_initial_samples must be > 0".to_string(),
398            ));
399        }
400
401        // ── Phase 1: warm-up with random samples ──────────────────────────────
402        for _ in 0..self.config.n_initial_samples {
403            let arch = self.sample_random_arch();
404            self.evaluate_and_record(&arch, &eval_fn);
405        }
406        self.refit_surrogate();
407
408        // ── Phase 2: active learning ──────────────────────────────────────────
409        let strategy = AcquisitionStrategy::Ucb;
410        for _ in 0..self.config.n_iterations {
411            // Generate candidate architectures.
412            let mut candidates: Vec<(f64, Vec<Vec<Vec<usize>>>)> =
413                (0..self.config.n_candidates_per_iter)
414                    .map(|_| {
415                        let arch = self.sample_random_arch();
416                        let enc = encode_architecture(&arch, self.config.n_operations);
417                        let acq = self.acquisition(&enc, &strategy);
418                        (acq, arch)
419                    })
420                    .collect();
421
422            // Sort by acquisition (descending).
423            candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
424
425            // Evaluate top-k.
426            let n_eval = self.config.n_top_to_evaluate.min(candidates.len());
427            for (_, arch) in candidates.into_iter().take(n_eval) {
428                self.evaluate_and_record(&arch, &eval_fn);
429            }
430
431            self.refit_surrogate();
432        }
433
434        // ── Find best ─────────────────────────────────────────────────────────
435        let best_idx = self
436            .evaluated_y
437            .iter()
438            .enumerate()
439            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
440            .map(|(i, _)| i)
441            .ok_or_else(|| {
442                OptimizeError::ComputationError("No architectures were evaluated".to_string())
443            })?;
444
445        let best_score = self.evaluated_y[best_idx];
446        let best_enc = &self.evaluated_x[best_idx];
447        // Decode back to arch indices: round to nearest integer index.
448        let norm = (self.config.n_operations.max(1) - 1) as f64;
449        let denom = norm.max(1.0);
450        let decoded_flat: Vec<usize> = best_enc
451            .iter()
452            .map(|&v| ((v * denom).round() as usize).min(self.config.n_operations - 1))
453            .collect();
454        let best_arch_indices =
455            reconstruct_arch_indices(&decoded_flat, self.config.n_cells, self.config.n_nodes);
456
457        Ok(PredictorNasResult {
458            best_arch_indices,
459            best_score,
460            n_evaluated: self.evaluated_y.len(),
461        })
462    }
463}
464
465// ─────────────────────────────────────────────────── helpers ──
466
467/// Reconstruct `[cell][node][predecessor]` arch indices from a flat vector.
468fn reconstruct_arch_indices(
469    flat: &[usize],
470    n_cells: usize,
471    n_nodes: usize,
472) -> Vec<Vec<Vec<usize>>> {
473    let mut offset = 0;
474    let mut result = Vec::with_capacity(n_cells);
475    for _ in 0..n_cells {
476        let mut cell = Vec::with_capacity(n_nodes);
477        for i in 0..n_nodes {
478            let n_pred = 2 + i;
479            let node_edges: Vec<usize> = if offset + n_pred <= flat.len() {
480                flat[offset..offset + n_pred].to_vec()
481            } else {
482                vec![0; n_pred]
483            };
484            offset += n_pred;
485            cell.push(node_edges);
486        }
487        result.push(cell);
488    }
489    result
490}
491
492/// Standard normal CDF approximation (Abramowitz & Stegun 26.2.17).
493fn normal_cdf(x: f64) -> f64 {
494    let t = 1.0 / (1.0 + 0.2316419 * x.abs());
495    let poly = t
496        * (0.319_381_53
497            + t * (-0.356_563_782
498                + t * (1.781_477_937 + t * (-1.821_255_978 + t * 1.330_274_429))));
499    let pdf = normal_pdf(x);
500    let cdf_pos = 1.0 - pdf * poly;
501    if x >= 0.0 {
502        cdf_pos
503    } else {
504        1.0 - cdf_pos
505    }
506}
507
508/// Standard normal PDF.
509fn normal_pdf(x: f64) -> f64 {
510    (-x * x / 2.0).exp() / (2.0 * std::f64::consts::PI).sqrt()
511}
512
513// ═══════════════════════════════════════════════════════════════════ tests ═══
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    // ── encode_architecture ────────────────────────────────────────────────────
520
521    #[test]
522    fn test_encode_architecture_deterministic() {
523        let arch: Vec<Vec<Vec<usize>>> = vec![vec![vec![0, 1], vec![2, 3, 0], vec![1, 0, 2, 1]]];
524        let enc1 = encode_architecture(&arch, 6);
525        let enc2 = encode_architecture(&arch, 6);
526        assert_eq!(enc1, enc2);
527    }
528
529    #[test]
530    fn test_encode_architecture_length() {
531        // n_cells=2, n_nodes=4: edges = 2+3+4+5 = 14 per cell → 28 total.
532        let arch: Vec<Vec<Vec<usize>>> = (0..2_usize)
533            .map(|_| {
534                (0..4_usize)
535                    .map(|i| vec![0_usize; 2 + i])
536                    .collect::<Vec<_>>()
537            })
538            .collect();
539        let enc = encode_architecture(&arch, 6);
540        assert_eq!(enc.len(), 28, "enc.len()={}", enc.len());
541    }
542
543    #[test]
544    fn test_encode_architecture_range() {
545        let arch: Vec<Vec<Vec<usize>>> = vec![vec![vec![0, 5], vec![3, 1, 5]]];
546        let enc = encode_architecture(&arch, 6);
547        for &v in &enc {
548            assert!(v >= 0.0 && v <= 1.0, "v={v} out of [0,1]");
549        }
550    }
551
552    #[test]
553    fn test_encode_architecture_single_op() {
554        // With n_operations=1 all encodings should be 0.0.
555        let arch: Vec<Vec<Vec<usize>>> = vec![vec![vec![0, 0]]];
556        let enc = encode_architecture(&arch, 1);
557        for &v in &enc {
558            assert!((v - 0.0).abs() < 1e-10, "v={v}");
559        }
560    }
561
562    // ── RidgeSurrogate ─────────────────────────────────────────────────────────
563
564    #[test]
565    fn test_ridge_surrogate_predict_after_fit() {
566        let mut surr = RidgeSurrogate::new(1e-3);
567        let x = vec![vec![0.0], vec![0.5], vec![1.0]];
568        let y = vec![0.0, 0.5, 1.0];
569        surr.fit(&x, &y);
570        let (mean, _std) = surr.predict_mean_std(&[0.25]);
571        assert!(mean.is_finite(), "mean={mean}");
572    }
573
574    #[test]
575    fn test_ridge_surrogate_empty_returns_prior() {
576        let surr = RidgeSurrogate::new(1e-3);
577        let (mean, std) = surr.predict_mean_std(&[0.5]);
578        assert!((mean - 0.0).abs() < 1e-10, "mean={mean}");
579        assert!((std - 1.0).abs() < 1e-10, "std={std}");
580    }
581
582    #[test]
583    fn test_ridge_surrogate_std_nonneg() {
584        let mut surr = RidgeSurrogate::new(1e-3);
585        let x: Vec<Vec<f64>> = (0..5).map(|i| vec![i as f64 / 4.0]).collect();
586        let y: Vec<f64> = (0..5).map(|i| i as f64).collect();
587        surr.fit(&x, &y);
588        for i in 0..10 {
589            let xq = vec![i as f64 / 10.0];
590            let (_mean, std) = surr.predict_mean_std(&xq);
591            assert!(std >= 0.0, "std={std} at x={}", xq[0]);
592        }
593    }
594
595    // ── PredictorNasSearcher ────────────────────────────────────────────────────
596
597    #[test]
598    fn test_predictor_search_returns_result() {
599        // eval_fn: negative sum of all op indices (lower-index ops are "better").
600        let eval_fn = |arch: &[Vec<Vec<usize>>]| -> f64 {
601            let total: usize = arch
602                .iter()
603                .flat_map(|c| c.iter().flat_map(|n| n.iter()))
604                .sum();
605            -(total as f64)
606        };
607
608        let config = PredictorNasConfig {
609            n_cells: 2,
610            n_nodes: 3,
611            n_operations: 6,
612            n_initial_samples: 4,
613            n_iterations: 2,
614            n_candidates_per_iter: 10,
615            n_top_to_evaluate: 2,
616            ..Default::default()
617        };
618
619        let mut searcher = PredictorNasSearcher::new(config);
620        let result = searcher.search(eval_fn).expect("search should succeed");
621
622        assert!(
623            result.best_score.is_finite(),
624            "best_score={}",
625            result.best_score
626        );
627        assert!(
628            result.n_evaluated >= 4,
629            "n_evaluated={}",
630            result.n_evaluated
631        );
632    }
633
634    #[test]
635    fn test_active_learning_improves_best_score() {
636        let eval_fn = |arch: &[Vec<Vec<usize>>]| -> f64 {
637            let total: usize = arch
638                .iter()
639                .flat_map(|c| c.iter().flat_map(|n| n.iter()))
640                .sum();
641            -(total as f64)
642        };
643
644        let config_small = PredictorNasConfig {
645            n_cells: 1,
646            n_nodes: 2,
647            n_operations: 6,
648            n_initial_samples: 3,
649            n_iterations: 0,
650            n_candidates_per_iter: 5,
651            n_top_to_evaluate: 1,
652            seed: 7,
653            ..Default::default()
654        };
655        let mut searcher_small = PredictorNasSearcher::new(config_small);
656        let result_small = searcher_small.search(&eval_fn).expect("small search");
657
658        let config_large = PredictorNasConfig {
659            n_cells: 1,
660            n_nodes: 2,
661            n_operations: 6,
662            n_initial_samples: 3,
663            n_iterations: 4,
664            n_candidates_per_iter: 10,
665            n_top_to_evaluate: 2,
666            seed: 7,
667            ..Default::default()
668        };
669        let mut searcher_large = PredictorNasSearcher::new(config_large);
670        let result_large = searcher_large.search(&eval_fn).expect("large search");
671
672        // Larger budget should evaluate more architectures.
673        assert!(
674            result_large.n_evaluated >= result_small.n_evaluated,
675            "large n_eval={} < small n_eval={}",
676            result_large.n_evaluated,
677            result_small.n_evaluated
678        );
679        assert!(result_small.best_score.is_finite());
680        assert!(result_large.best_score.is_finite());
681    }
682
683    #[test]
684    fn test_predictor_n_evaluated_count() {
685        let config = PredictorNasConfig {
686            n_initial_samples: 5,
687            n_iterations: 3,
688            n_top_to_evaluate: 2,
689            n_candidates_per_iter: 10,
690            n_cells: 2,
691            n_nodes: 3,
692            n_operations: 6,
693            ..Default::default()
694        };
695        let expected_min = 5 + 3 * 2; // initial + iterations * top_k
696
697        let mut searcher = PredictorNasSearcher::new(config);
698        let result = searcher.search(|_| 1.0).expect("search should not fail");
699
700        assert!(
701            result.n_evaluated >= expected_min,
702            "n_evaluated={} < expected_min={expected_min}",
703            result.n_evaluated
704        );
705    }
706
707    #[test]
708    fn test_predictor_zero_iterations_still_works() {
709        let config = PredictorNasConfig {
710            n_initial_samples: 3,
711            n_iterations: 0,
712            ..Default::default()
713        };
714        let mut searcher = PredictorNasSearcher::new(config);
715        let result = searcher
716            .search(|_| 42.0)
717            .expect("zero-iteration search should succeed");
718        assert_eq!(result.best_score, 42.0);
719    }
720
721    #[test]
722    fn test_normal_cdf_basic() {
723        // Φ(0) ≈ 0.5
724        assert!((normal_cdf(0.0) - 0.5).abs() < 0.01);
725        // Φ(∞) ≈ 1
726        assert!(normal_cdf(10.0) > 0.999);
727        // Φ(-∞) ≈ 0
728        assert!(normal_cdf(-10.0) < 0.001);
729    }
730
731    #[test]
732    fn test_gauss_elimination_simple() {
733        // Solve 2x = 4 → x = 2.
734        let mut a = vec![vec![2.0_f64]];
735        let mut b = vec![4.0_f64];
736        gauss_elimination(&mut a, &mut b);
737        assert!((b[0] - 2.0).abs() < 1e-10, "b[0]={}", b[0]);
738    }
739
740    #[test]
741    fn test_gauss_elimination_2x2() {
742        // [[1,2],[3,4]] x = [5, 11] → x = [1, 2]
743        let mut a = vec![vec![1.0_f64, 2.0], vec![3.0, 4.0]];
744        let mut b = vec![5.0_f64, 11.0];
745        gauss_elimination(&mut a, &mut b);
746        assert!((b[0] - 1.0).abs() < 1e-9, "b[0]={}", b[0]);
747        assert!((b[1] - 2.0).abs() < 1e-9, "b[1]={}", b[1]);
748    }
749}