Skip to main content

tensorlogic_train/nas/
random_search.rs

1//! Random architecture search for NAS.
2//!
3//! A simple baseline: architectures are sampled uniformly at random from the
4//! [`ArchSearchSpace`] with no model of performance.  Useful as a lower-bound
5//! comparison against more sophisticated strategies such as
6//! [`RegularizedEvolution`](super::evolution::RegularizedEvolution).
7
8use crate::error::TrainResult;
9
10use super::evolution::NasResult;
11use super::sampler::ArchSampler;
12use super::space::{ArchSearchSpace, Architecture};
13
14// ─── RandomArchSearch ───────────────────────────────────────────────────────
15
16/// Random-search NAS: samples architectures uniformly at random, keeps track
17/// of the best seen so far.
18pub struct RandomArchSearch {
19    sampler: ArchSampler,
20    best: Option<(Architecture, f64)>,
21    history: Vec<(Architecture, f64)>,
22}
23
24impl RandomArchSearch {
25    /// Create a new random architecture searcher.
26    ///
27    /// # Arguments
28    ///
29    /// * `space` - Architecture search space to sample from.
30    /// * `seed` - RNG seed for reproducibility.
31    pub fn new(space: ArchSearchSpace, seed: u64) -> Self {
32        Self {
33            sampler: ArchSampler::new(space, seed),
34            best: None,
35            history: Vec::new(),
36        }
37    }
38
39    /// Ask for the next architecture to evaluate.
40    ///
41    /// Always returns a fresh uniformly random architecture independent of
42    /// previous results.
43    pub fn ask(&mut self) -> TrainResult<Architecture> {
44        self.sampler.random_architecture()
45    }
46
47    /// Tell the result of evaluating an architecture.
48    ///
49    /// Updates the running best when `score` exceeds the current best.
50    pub fn tell(&mut self, arch: Architecture, score: f64) {
51        let is_better = self
52            .best
53            .as_ref()
54            .is_none_or(|(_, best_score)| score > *best_score);
55
56        if is_better {
57            self.best = Some((arch.clone(), score));
58        }
59        self.history.push((arch, score));
60    }
61
62    /// Return a reference to the best (architecture, score) pair seen so far,
63    /// or `None` if no evaluations have been recorded.
64    pub fn best(&self) -> Option<&(Architecture, f64)> {
65        self.best.as_ref()
66    }
67
68    /// Produce a [`NasResult`] summarising the current search state.
69    ///
70    /// Returns `None` if no evaluations have been recorded yet.
71    pub fn result(&self) -> Option<NasResult> {
72        let (best_arch, best_score) = self.best.as_ref()?;
73        Some(NasResult {
74            best: best_arch.clone(),
75            best_score: *best_score,
76            history: self.history.clone(),
77        })
78    }
79}