zeph_experiments/
random.rs1use std::collections::HashSet;
14use std::sync::Mutex;
15
16use ordered_float::OrderedFloat;
17use rand::RngExt as _;
18use rand::SeedableRng as _;
19use rand::rngs::SmallRng;
20
21use super::generator::VariationGenerator;
22use super::search_space::SearchSpace;
23use super::snapshot::ConfigSnapshot;
24use super::types::{Variation, VariationValue};
25
26const MAX_RETRIES: usize = 1000;
28
29pub struct Random {
61 search_space: SearchSpace,
62 rng: Mutex<SmallRng>,
63}
64
65impl Random {
66 #[must_use]
80 pub fn new(search_space: SearchSpace, seed: u64) -> Self {
81 Self {
82 search_space,
83 rng: Mutex::new(SmallRng::seed_from_u64(seed)),
84 }
85 }
86}
87
88impl VariationGenerator for Random {
89 fn next(
90 &mut self,
91 _baseline: &ConfigSnapshot,
92 visited: &HashSet<Variation>,
93 ) -> Option<Variation> {
94 if self.search_space.parameters.is_empty() {
95 return None;
96 }
97 let mut rng = self.rng.lock().expect("rng mutex poisoned");
98 for _ in 0..MAX_RETRIES {
99 let idx = rng.random_range(0..self.search_space.parameters.len());
100 let range = &self.search_space.parameters[idx];
101 let raw: f64 = rng.random_range(range.min()..=range.max());
102 let value = range.quantize(raw);
103 let variation = Variation {
104 parameter: range.kind(),
105 value: VariationValue::Float(OrderedFloat(value)),
106 };
107 if !visited.contains(&variation) {
108 return Some(variation);
109 }
110 }
111 None
112 }
113
114 fn name(&self) -> &'static str {
115 "random"
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 #![allow(clippy::manual_range_contains)]
122
123 use std::collections::HashSet;
124
125 use super::super::search_space::ParameterRange;
126 use super::super::types::ParameterKind;
127 use super::*;
128
129 #[test]
130 fn random_produces_values_in_range() {
131 let space = SearchSpace {
132 parameters: vec![
133 ParameterRange::new(ParameterKind::Temperature, 0.0, 1.0, Some(0.1), 0.5).unwrap(),
134 ],
135 };
136 let mut generator = Random::new(space, 42);
137 let baseline = ConfigSnapshot::default();
138 let visited = HashSet::new();
139 for _ in 0..20 {
140 if let Some(v) = generator.next(&baseline, &visited) {
141 let val = v.value.as_f64();
142 assert!((0.0..=1.0).contains(&val), "out of range: {val}");
143 }
144 }
145 }
146
147 #[test]
148 fn random_skips_visited() {
149 let space = SearchSpace {
151 parameters: vec![
152 ParameterRange::new(ParameterKind::Temperature, 0.5, 0.6, Some(1.0), 0.55).unwrap(),
153 ],
154 };
155 let mut generator = Random::new(space, 0);
156 let baseline = ConfigSnapshot::default();
157 let mut visited = HashSet::new();
158 visited.insert(Variation {
159 parameter: ParameterKind::Temperature,
160 value: VariationValue::Float(OrderedFloat(0.5)),
161 });
162 let result = generator.next(&baseline, &visited);
164 assert!(
165 result.is_none(),
166 "expected None when only option is already visited"
167 );
168 }
169
170 #[test]
171 fn random_empty_space_returns_none() {
172 let mut generator = Random::new(SearchSpace { parameters: vec![] }, 0);
173 let baseline = ConfigSnapshot::default();
174 let visited = HashSet::new();
175 assert!(generator.next(&baseline, &visited).is_none());
176 }
177
178 #[test]
179 fn random_is_deterministic_with_same_seed() {
180 let space = SearchSpace::default();
181 let baseline = ConfigSnapshot::default();
182 let visited = HashSet::new();
183 let mut gen1 = Random::new(space.clone(), 123);
184 let mut gen2 = Random::new(space, 123);
185 let v1 = gen1.next(&baseline, &visited);
186 let v2 = gen2.next(&baseline, &visited);
187 assert_eq!(v1, v2, "same seed must produce same first variation");
188 }
189
190 #[test]
191 fn random_quantizes_sampled_values() {
192 let space = SearchSpace {
193 parameters: vec![
194 ParameterRange::new(ParameterKind::TopP, 0.1, 1.0, Some(0.05), 0.9).unwrap(),
195 ],
196 };
197 let mut generator = Random::new(space, 7);
198 let baseline = ConfigSnapshot::default();
199 let visited = HashSet::new();
200 for _ in 0..30 {
201 if let Some(v) = generator.next(&baseline, &visited) {
202 let val = v.value.as_f64();
203 let steps = (val - 0.1) / 0.05;
206 assert!(
207 (steps - steps.round()).abs() < 1e-10,
208 "value {val} is not on the 0.05-step grid anchored at 0.1"
209 );
210 }
211 }
212 }
213
214 #[test]
215 fn random_name() {
216 let generator = Random::new(SearchSpace::default(), 0);
217 assert_eq!(generator.name(), "random");
218 }
219
220 #[test]
221 fn random_is_sync() {
222 fn assert_sync<T: Sync>() {}
223 assert_sync::<Random>();
224 }
225}