1use std::collections::HashSet;
14
15use ordered_float::OrderedFloat;
16use rand::RngExt as _;
17use rand::SeedableRng as _;
18use rand::rngs::SmallRng;
19
20use super::error::EvalError;
21use super::generator::VariationGenerator;
22use super::search_space::SearchSpace;
23use super::snapshot::ConfigSnapshot;
24use super::types::{Variation, VariationValue};
25
26const MAX_RETRIES: usize = 1000;
28
29const DEFAULT_STEPS: f64 = 20.0;
34
35pub struct Neighborhood {
64 search_space: SearchSpace,
65 radius: f64,
66 rng: SmallRng,
67}
68
69impl Neighborhood {
70 pub fn new(search_space: SearchSpace, radius: f64, seed: u64) -> Result<Self, EvalError> {
76 if !radius.is_finite() || radius <= 0.0 {
77 return Err(EvalError::InvalidRadius { radius });
78 }
79 Ok(Self {
80 search_space,
81 radius,
82 rng: SmallRng::seed_from_u64(seed),
83 })
84 }
85}
86
87impl VariationGenerator for Neighborhood {
88 fn next(
89 &mut self,
90 baseline: &ConfigSnapshot,
91 visited: &HashSet<Variation>,
92 ) -> Option<Variation> {
93 if self.search_space.parameters.is_empty() {
94 return None;
95 }
96 for _ in 0..MAX_RETRIES {
97 let idx = self.rng.random_range(0..self.search_space.parameters.len());
98 let range = &self.search_space.parameters[idx];
99 let current = baseline.get(range.kind());
100 let step = range
102 .step()
103 .unwrap_or_else(|| (range.max() - range.min()) / DEFAULT_STEPS);
104 let delta = self.rng.random_range(-self.radius..=self.radius) * step;
105 if delta.abs() < f64::EPSILON {
107 continue;
108 }
109 let raw = current + delta;
110 let value = range.quantize(range.clamp(raw));
111 if (value - current).abs() < f64::EPSILON {
113 continue;
114 }
115 let variation = Variation {
116 parameter: range.kind(),
117 value: VariationValue::Float(OrderedFloat(value)),
118 };
119 if !visited.contains(&variation) {
120 return Some(variation);
121 }
122 }
123 None
124 }
125
126 fn name(&self) -> &'static str {
127 "neighborhood"
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 #![allow(
134 clippy::collapsible_if,
135 clippy::field_reassign_with_default,
136 clippy::manual_midpoint,
137 clippy::manual_range_contains
138 )]
139
140 use std::collections::HashSet;
141
142 use super::super::search_space::ParameterRange;
143 use super::super::types::ParameterKind;
144 use super::*;
145
146 fn make_space(kind: ParameterKind, min: f64, max: f64, step: f64) -> SearchSpace {
147 SearchSpace {
148 parameters: vec![
149 ParameterRange::new(kind, min, max, Some(step), f64::midpoint(min, max)).unwrap(),
150 ],
151 }
152 }
153
154 #[test]
155 fn neighborhood_produces_values_in_range() {
156 let space = make_space(ParameterKind::Temperature, 0.0, 2.0, 0.1);
157 let mut generator = Neighborhood::new(space, 1.0, 42).unwrap();
158 let baseline = ConfigSnapshot::default();
159 let visited = HashSet::new();
160 for _ in 0..20 {
161 if let Some(v) = generator.next(&baseline, &visited) {
162 let val = v.value.as_f64();
163 assert!((0.0..=2.0).contains(&val), "out of range: {val}");
164 }
165 }
166 }
167
168 #[test]
169 fn neighborhood_is_deterministic_with_same_seed() {
170 let space = SearchSpace::default();
171 let baseline = ConfigSnapshot::default();
172 let visited = HashSet::new();
173 let mut gen1 = Neighborhood::new(space.clone(), 1.0, 99).unwrap();
174 let mut gen2 = Neighborhood::new(space, 1.0, 99).unwrap();
175 let v1 = gen1.next(&baseline, &visited);
176 let v2 = gen2.next(&baseline, &visited);
177 assert_eq!(v1, v2, "same seed must produce same first variation");
178 }
179
180 #[test]
181 fn neighborhood_skips_visited() {
182 let space = make_space(ParameterKind::Temperature, 0.5, 0.6, 0.1);
185 let mut generator = Neighborhood::new(space, 1.0, 0).unwrap();
186 let baseline = ConfigSnapshot::default();
187 let mut visited = HashSet::new();
188 visited.insert(Variation {
189 parameter: ParameterKind::Temperature,
190 value: VariationValue::Float(OrderedFloat(0.5)),
191 });
192 visited.insert(Variation {
193 parameter: ParameterKind::Temperature,
194 value: VariationValue::Float(OrderedFloat(0.6)),
195 });
196 assert!(generator.next(&baseline, &visited).is_none());
197 }
198
199 #[test]
200 fn neighborhood_empty_space_returns_none() {
201 let mut generator = Neighborhood::new(SearchSpace { parameters: vec![] }, 1.0, 0).unwrap();
202 let baseline = ConfigSnapshot::default();
203 let visited = HashSet::new();
204 assert!(generator.next(&baseline, &visited).is_none());
205 }
206
207 #[test]
208 fn neighborhood_zero_radius_returns_error() {
209 let result = Neighborhood::new(SearchSpace::default(), 0.0, 0);
210 assert!(result.is_err(), "zero radius must be rejected");
211 }
212
213 #[test]
214 fn neighborhood_negative_radius_returns_error() {
215 let result = Neighborhood::new(SearchSpace::default(), -1.0, 0);
216 assert!(result.is_err(), "negative radius must be rejected");
217 }
218
219 #[test]
220 fn neighborhood_nan_radius_returns_error() {
221 let result = Neighborhood::new(SearchSpace::default(), f64::NAN, 0);
222 assert!(result.is_err(), "NaN radius must be rejected");
223 }
224
225 #[test]
226 fn neighborhood_step_none_uses_default_steps() {
227 let space = SearchSpace {
229 parameters: vec![
230 super::super::search_space::ParameterRange::new(
231 ParameterKind::Temperature,
232 0.0,
233 2.0,
234 None,
235 1.0,
236 )
237 .unwrap(),
238 ],
239 };
240 let mut generator = Neighborhood::new(space, 1.0, 77).unwrap();
241 let baseline = ConfigSnapshot::default();
242 let visited = HashSet::new();
243 let mut got_any = false;
245 for _ in 0..50 {
246 if generator.next(&baseline, &visited).is_some() {
247 got_any = true;
248 break;
249 }
250 }
251 assert!(
252 got_any,
253 "should produce at least one variation for continuous parameter"
254 );
255 }
256
257 #[test]
258 fn neighborhood_quantizes_perturbed_values() {
259 let space = make_space(ParameterKind::TopP, 0.1, 1.0, 0.05);
260 let mut generator = Neighborhood::new(space, 2.0, 11).unwrap();
261 let mut baseline = ConfigSnapshot::default();
262 baseline.top_p = 0.5;
263 let visited = HashSet::new();
264 for _ in 0..30 {
265 if let Some(v) = generator.next(&baseline, &visited) {
266 let val = v.value.as_f64();
267 let steps = (val - 0.1) / 0.05;
270 assert!(
271 (steps - steps.round()).abs() < 1e-10,
272 "value {val} is not on the 0.05-step grid anchored at 0.1"
273 );
274 }
275 }
276 }
277
278 #[test]
279 fn neighborhood_name() {
280 let generator = Neighborhood::new(SearchSpace::default(), 1.0, 0).unwrap();
281 assert_eq!(generator.name(), "neighborhood");
282 }
283
284 #[test]
285 fn neighborhood_perturbs_around_baseline() {
286 let space = make_space(ParameterKind::Temperature, 0.0, 2.0, 0.1);
289 let mut generator = Neighborhood::new(space, 1.0, 55).unwrap();
290 let baseline = ConfigSnapshot::default(); let visited = HashSet::new();
292 let mut temp_values = vec![];
293 for _ in 0..50 {
294 if let Some(v) = generator.next(&baseline, &visited)
295 && v.parameter == ParameterKind::Temperature
296 {
297 temp_values.push(v.value.as_f64());
298 }
299 }
300 assert!(
301 !temp_values.is_empty(),
302 "should produce temperature variations"
303 );
304 for val in &temp_values {
306 assert!(
307 *val >= 0.6 - 1e-10 && *val <= 0.8 + 1e-10,
308 "value {val} not within ±1 step of 0.7"
309 );
310 }
311 }
312}