zeph_experiments/
random.rs1use std::collections::HashSet;
14use std::sync::Mutex;
15
16use ordered_float::OrderedFloat;
17use rand::RngExt as _;
18use rand::SeedableRng as _;
19use rand::rngs::SmallRng;
20
21use super::generator::VariationGenerator;
22use super::search_space::SearchSpace;
23use super::snapshot::ConfigSnapshot;
24use super::types::{Variation, VariationValue};
25
26const MAX_RETRIES: usize = 1000;
28
29pub struct Random {
61 search_space: SearchSpace,
62 rng: Mutex<SmallRng>,
63}
64
65impl Random {
66 #[must_use]
80 pub fn new(search_space: SearchSpace, seed: u64) -> Self {
81 Self {
82 search_space,
83 rng: Mutex::new(SmallRng::seed_from_u64(seed)),
84 }
85 }
86}
87
88impl VariationGenerator for Random {
89 fn next(
90 &mut self,
91 _baseline: &ConfigSnapshot,
92 visited: &HashSet<Variation>,
93 ) -> Option<Variation> {
94 if self.search_space.parameters.is_empty() {
95 return None;
96 }
97 let mut rng = self.rng.lock().expect("rng mutex poisoned");
98 for _ in 0..MAX_RETRIES {
99 let idx = rng.random_range(0..self.search_space.parameters.len());
100 let range = &self.search_space.parameters[idx];
101 let raw: f64 = rng.random_range(range.min..=range.max);
102 let value = range.quantize(raw);
103 let variation = Variation {
104 parameter: range.kind,
105 value: VariationValue::Float(OrderedFloat(value)),
106 };
107 if !visited.contains(&variation) {
108 return Some(variation);
109 }
110 }
111 None
112 }
113
114 fn name(&self) -> &'static str {
115 "random"
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 #![allow(clippy::manual_range_contains)]
122
123 use std::collections::HashSet;
124
125 use super::super::search_space::ParameterRange;
126 use super::super::types::ParameterKind;
127 use super::*;
128
129 #[test]
130 fn random_produces_values_in_range() {
131 let space = SearchSpace {
132 parameters: vec![ParameterRange {
133 kind: ParameterKind::Temperature,
134 min: 0.0,
135 max: 1.0,
136 step: Some(0.1),
137 default: 0.5,
138 }],
139 };
140 let mut generator = Random::new(space, 42);
141 let baseline = ConfigSnapshot::default();
142 let visited = HashSet::new();
143 for _ in 0..20 {
144 if let Some(v) = generator.next(&baseline, &visited) {
145 let val = v.value.as_f64();
146 assert!((0.0..=1.0).contains(&val), "out of range: {val}");
147 }
148 }
149 }
150
151 #[test]
152 fn random_skips_visited() {
153 let space = SearchSpace {
154 parameters: vec![ParameterRange {
155 kind: ParameterKind::Temperature,
156 min: 0.5,
157 max: 0.5,
158 step: Some(0.1),
159 default: 0.5,
160 }],
161 };
162 let mut generator = Random::new(space, 0);
163 let baseline = ConfigSnapshot::default();
164 let mut visited = HashSet::new();
165 visited.insert(Variation {
166 parameter: ParameterKind::Temperature,
167 value: VariationValue::Float(OrderedFloat(0.5)),
168 });
169 let result = generator.next(&baseline, &visited);
171 assert!(
172 result.is_none(),
173 "expected None when only option is already visited"
174 );
175 }
176
177 #[test]
178 fn random_empty_space_returns_none() {
179 let mut generator = Random::new(SearchSpace { parameters: vec![] }, 0);
180 let baseline = ConfigSnapshot::default();
181 let visited = HashSet::new();
182 assert!(generator.next(&baseline, &visited).is_none());
183 }
184
185 #[test]
186 fn random_is_deterministic_with_same_seed() {
187 let space = SearchSpace::default();
188 let baseline = ConfigSnapshot::default();
189 let visited = HashSet::new();
190 let mut gen1 = Random::new(space.clone(), 123);
191 let mut gen2 = Random::new(space, 123);
192 let v1 = gen1.next(&baseline, &visited);
193 let v2 = gen2.next(&baseline, &visited);
194 assert_eq!(v1, v2, "same seed must produce same first variation");
195 }
196
197 #[test]
198 fn random_quantizes_sampled_values() {
199 let space = SearchSpace {
200 parameters: vec![ParameterRange {
201 kind: ParameterKind::TopP,
202 min: 0.1,
203 max: 1.0,
204 step: Some(0.05),
205 default: 0.9,
206 }],
207 };
208 let mut generator = Random::new(space, 7);
209 let baseline = ConfigSnapshot::default();
210 let visited = HashSet::new();
211 for _ in 0..30 {
212 if let Some(v) = generator.next(&baseline, &visited) {
213 let val = v.value.as_f64();
214 let steps = (val - 0.1) / 0.05;
217 assert!(
218 (steps - steps.round()).abs() < 1e-10,
219 "value {val} is not on the 0.05-step grid anchored at 0.1"
220 );
221 }
222 }
223 }
224
225 #[test]
226 fn random_name() {
227 let generator = Random::new(SearchSpace::default(), 0);
228 assert_eq!(generator.name(), "random");
229 }
230
231 #[test]
232 fn random_is_sync() {
233 fn assert_sync<T: Sync>() {}
234 assert_sync::<Random>();
235 }
236}