Skip to main content

zeph_experiments/
random.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Uniform random sampling strategy for parameter variation.
5
6use std::collections::HashSet;
7use std::sync::Mutex;
8
9use ordered_float::OrderedFloat;
10use rand::Rng as _;
11use rand::SeedableRng as _;
12use rand::rngs::SmallRng;
13
14use super::generator::VariationGenerator;
15use super::search_space::SearchSpace;
16use super::snapshot::ConfigSnapshot;
17use super::types::{Variation, VariationValue};
18
19/// Maximum number of retry attempts before giving up (space is considered exhausted).
20const MAX_RETRIES: usize = 1000;
21
22/// Uniform random sampling within parameter bounds.
23///
24/// At each call, a parameter is chosen uniformly at random, then a value is
25/// sampled uniformly from its `[min, max]` range and quantized to the nearest
26/// step (if configured). The sample is rejected if it was already visited.
27/// Returns `None` after `MAX_RETRIES` consecutive rejections.
28///
29/// `rng` is wrapped in a [`Mutex`] so that `Random` implements [`Sync`], which is
30/// required by [`VariationGenerator`] to allow [`ExperimentEngine`] to be used
31/// with `tokio::spawn`. The mutex is only ever locked from a single thread
32/// (the experiment loop is sequential), so there is no contention.
33pub struct Random {
34    search_space: SearchSpace,
35    rng: Mutex<SmallRng>,
36}
37
38impl Random {
39    /// Create a new `Random` generator with a deterministic seed.
40    #[must_use]
41    pub fn new(search_space: SearchSpace, seed: u64) -> Self {
42        Self {
43            search_space,
44            rng: Mutex::new(SmallRng::seed_from_u64(seed)),
45        }
46    }
47}
48
49impl VariationGenerator for Random {
50    fn next(
51        &mut self,
52        _baseline: &ConfigSnapshot,
53        visited: &HashSet<Variation>,
54    ) -> Option<Variation> {
55        if self.search_space.parameters.is_empty() {
56            return None;
57        }
58        let mut rng = self.rng.lock().expect("rng mutex poisoned");
59        for _ in 0..MAX_RETRIES {
60            let idx = rng.gen_range(0..self.search_space.parameters.len());
61            let range = &self.search_space.parameters[idx];
62            let raw: f64 = rng.gen_range(range.min..=range.max);
63            let value = range.quantize(raw);
64            let variation = Variation {
65                parameter: range.kind,
66                value: VariationValue::Float(OrderedFloat(value)),
67            };
68            if !visited.contains(&variation) {
69                return Some(variation);
70            }
71        }
72        None
73    }
74
75    fn name(&self) -> &'static str {
76        "random"
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    #![allow(clippy::manual_range_contains)]
83
84    use std::collections::HashSet;
85
86    use super::super::search_space::ParameterRange;
87    use super::super::types::ParameterKind;
88    use super::*;
89
90    #[test]
91    fn random_produces_values_in_range() {
92        let space = SearchSpace {
93            parameters: vec![ParameterRange {
94                kind: ParameterKind::Temperature,
95                min: 0.0,
96                max: 1.0,
97                step: Some(0.1),
98                default: 0.5,
99            }],
100        };
101        let mut generator = Random::new(space, 42);
102        let baseline = ConfigSnapshot::default();
103        let visited = HashSet::new();
104        for _ in 0..20 {
105            if let Some(v) = generator.next(&baseline, &visited) {
106                let val = v.value.as_f64();
107                assert!((0.0..=1.0).contains(&val), "out of range: {val}");
108            }
109        }
110    }
111
112    #[test]
113    fn random_skips_visited() {
114        let space = SearchSpace {
115            parameters: vec![ParameterRange {
116                kind: ParameterKind::Temperature,
117                min: 0.5,
118                max: 0.5,
119                step: Some(0.1),
120                default: 0.5,
121            }],
122        };
123        let mut generator = Random::new(space, 0);
124        let baseline = ConfigSnapshot::default();
125        let mut visited = HashSet::new();
126        visited.insert(Variation {
127            parameter: ParameterKind::Temperature,
128            value: VariationValue::Float(OrderedFloat(0.5)),
129        });
130        // Only one point in space (min==max==0.5), so after visiting it, must return None.
131        let result = generator.next(&baseline, &visited);
132        assert!(
133            result.is_none(),
134            "expected None when only option is already visited"
135        );
136    }
137
138    #[test]
139    fn random_empty_space_returns_none() {
140        let mut generator = Random::new(SearchSpace { parameters: vec![] }, 0);
141        let baseline = ConfigSnapshot::default();
142        let visited = HashSet::new();
143        assert!(generator.next(&baseline, &visited).is_none());
144    }
145
146    #[test]
147    fn random_is_deterministic_with_same_seed() {
148        let space = SearchSpace::default();
149        let baseline = ConfigSnapshot::default();
150        let visited = HashSet::new();
151        let mut gen1 = Random::new(space.clone(), 123);
152        let mut gen2 = Random::new(space, 123);
153        let v1 = gen1.next(&baseline, &visited);
154        let v2 = gen2.next(&baseline, &visited);
155        assert_eq!(v1, v2, "same seed must produce same first variation");
156    }
157
158    #[test]
159    fn random_quantizes_sampled_values() {
160        let space = SearchSpace {
161            parameters: vec![ParameterRange {
162                kind: ParameterKind::TopP,
163                min: 0.1,
164                max: 1.0,
165                step: Some(0.05),
166                default: 0.9,
167            }],
168        };
169        let mut generator = Random::new(space, 7);
170        let baseline = ConfigSnapshot::default();
171        let visited = HashSet::new();
172        for _ in 0..30 {
173            if let Some(v) = generator.next(&baseline, &visited) {
174                let val = v.value.as_f64();
175                // Quantized values must be on the 0.05-step grid anchored at min=0.1:
176                // i.e. (val - 0.1) / 0.05 must be an integer.
177                let steps = (val - 0.1) / 0.05;
178                assert!(
179                    (steps - steps.round()).abs() < 1e-10,
180                    "value {val} is not on the 0.05-step grid anchored at 0.1"
181                );
182            }
183        }
184    }
185
186    #[test]
187    fn random_name() {
188        let generator = Random::new(SearchSpace::default(), 0);
189        assert_eq!(generator.name(), "random");
190    }
191
192    #[test]
193    fn random_is_sync() {
194        fn assert_sync<T: Sync>() {}
195        assert_sync::<Random>();
196    }
197}