Skip to main content

zeph_experiments/
neighborhood.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Neighborhood perturbation strategy for parameter variation.
5//!
6//! [`Neighborhood`] is a local-search strategy that generates variations by
7//! perturbing the current baseline value of a randomly chosen parameter by a
8//! small random amount proportional to the configured `radius`. It is most
9//! effective after a coarse [`GridStep`] sweep has identified a promising region.
10//!
11//! [`GridStep`]: crate::GridStep
12
13use std::collections::HashSet;
14
15use ordered_float::OrderedFloat;
16use rand::RngExt as _;
17use rand::SeedableRng as _;
18use rand::rngs::SmallRng;
19
20use super::error::EvalError;
21use super::generator::VariationGenerator;
22use super::search_space::SearchSpace;
23use super::snapshot::ConfigSnapshot;
24use super::types::{Variation, VariationValue};
25
26/// Maximum number of retry attempts before giving up (space is considered exhausted).
27const MAX_RETRIES: usize = 1000;
28
29/// Fallback number of steps used when a parameter has no discrete step configured.
30///
31/// This gives a reasonable granularity for continuous parameters without requiring
32/// an explicit step in the search space definition.
33const DEFAULT_STEPS: f64 = 20.0;
34
35/// Perturbation strategy that explores the neighborhood of the current baseline.
36///
37/// At each call, a parameter is chosen uniformly at random. The new value is
38/// computed as `baseline_value ± U(-radius, radius) * step`, then clamped and
39/// quantized to the nearest grid step. Useful after a [`GridStep`] sweep has
40/// narrowed the search to a promising region.
41///
42/// The generator is seeded deterministically via `seed`, making experiments
43/// reproducible. `radius` must be finite and positive (enforced in [`Neighborhood::new`]).
44///
45/// # Examples
46///
47/// ```rust
48/// use std::collections::HashSet;
49/// use zeph_experiments::{ConfigSnapshot, Neighborhood, SearchSpace, VariationGenerator};
50///
51/// let mut generator = Neighborhood::new(SearchSpace::default(), 1.0, 42).unwrap();
52/// let baseline = ConfigSnapshot::default();
53/// let visited = HashSet::new();
54///
55/// // Each call perturbs a random parameter by a small amount.
56/// if let Some(v) = generator.next(&baseline, &visited) {
57///     let val = v.value.as_f64();
58///     assert!(val.is_finite());
59/// }
60/// ```
61///
62/// [`GridStep`]: crate::GridStep
63pub struct Neighborhood {
64    search_space: SearchSpace,
65    radius: f64,
66    rng: SmallRng,
67}
68
69impl Neighborhood {
70    /// Create a new `Neighborhood` generator.
71    ///
72    /// # Errors
73    ///
74    /// Returns [`EvalError::InvalidRadius`] if `radius` is not finite and positive.
75    pub fn new(search_space: SearchSpace, radius: f64, seed: u64) -> Result<Self, EvalError> {
76        if !radius.is_finite() || radius <= 0.0 {
77            return Err(EvalError::InvalidRadius { radius });
78        }
79        Ok(Self {
80            search_space,
81            radius,
82            rng: SmallRng::seed_from_u64(seed),
83        })
84    }
85}
86
87impl VariationGenerator for Neighborhood {
88    fn next(
89        &mut self,
90        baseline: &ConfigSnapshot,
91        visited: &HashSet<Variation>,
92    ) -> Option<Variation> {
93        if self.search_space.parameters.is_empty() {
94            return None;
95        }
96        for _ in 0..MAX_RETRIES {
97            let idx = self.rng.random_range(0..self.search_space.parameters.len());
98            let range = &self.search_space.parameters[idx];
99            let current = baseline.get(range.kind());
100            // DEFAULT_STEPS is used when step is None (continuous parameter).
101            let step = range
102                .step()
103                .unwrap_or_else(|| (range.max() - range.min()) / DEFAULT_STEPS);
104            let delta = self.rng.random_range(-self.radius..=self.radius) * step;
105            // Skip zero perturbations — they produce the baseline value, wasting an attempt.
106            if delta.abs() < f64::EPSILON {
107                continue;
108            }
109            let raw = current + delta;
110            let value = range.quantize(range.clamp(raw));
111            // Skip if the quantized value equals the baseline (no effective change).
112            if (value - current).abs() < f64::EPSILON {
113                continue;
114            }
115            let variation = Variation {
116                parameter: range.kind(),
117                value: VariationValue::Float(OrderedFloat(value)),
118            };
119            if !visited.contains(&variation) {
120                return Some(variation);
121            }
122        }
123        None
124    }
125
126    fn name(&self) -> &'static str {
127        "neighborhood"
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    #![allow(
134        clippy::collapsible_if,
135        clippy::field_reassign_with_default,
136        clippy::manual_midpoint,
137        clippy::manual_range_contains
138    )]
139
140    use std::collections::HashSet;
141
142    use super::super::search_space::ParameterRange;
143    use super::super::types::ParameterKind;
144    use super::*;
145
146    fn make_space(kind: ParameterKind, min: f64, max: f64, step: f64) -> SearchSpace {
147        SearchSpace {
148            parameters: vec![
149                ParameterRange::new(kind, min, max, Some(step), f64::midpoint(min, max)).unwrap(),
150            ],
151        }
152    }
153
154    #[test]
155    fn neighborhood_produces_values_in_range() {
156        let space = make_space(ParameterKind::Temperature, 0.0, 2.0, 0.1);
157        let mut generator = Neighborhood::new(space, 1.0, 42).unwrap();
158        let baseline = ConfigSnapshot::default();
159        let visited = HashSet::new();
160        for _ in 0..20 {
161            if let Some(v) = generator.next(&baseline, &visited) {
162                let val = v.value.as_f64();
163                assert!((0.0..=2.0).contains(&val), "out of range: {val}");
164            }
165        }
166    }
167
168    #[test]
169    fn neighborhood_is_deterministic_with_same_seed() {
170        let space = SearchSpace::default();
171        let baseline = ConfigSnapshot::default();
172        let visited = HashSet::new();
173        let mut gen1 = Neighborhood::new(space.clone(), 1.0, 99).unwrap();
174        let mut gen2 = Neighborhood::new(space, 1.0, 99).unwrap();
175        let v1 = gen1.next(&baseline, &visited);
176        let v2 = gen2.next(&baseline, &visited);
177        assert_eq!(v1, v2, "same seed must produce same first variation");
178    }
179
180    #[test]
181    fn neighborhood_skips_visited() {
182        // Narrow range [0.5, 0.6] with large step: perturbations always clamp to 0.5 or 0.6.
183        // After visiting both, must return None.
184        let space = make_space(ParameterKind::Temperature, 0.5, 0.6, 0.1);
185        let mut generator = Neighborhood::new(space, 1.0, 0).unwrap();
186        let baseline = ConfigSnapshot::default();
187        let mut visited = HashSet::new();
188        visited.insert(Variation {
189            parameter: ParameterKind::Temperature,
190            value: VariationValue::Float(OrderedFloat(0.5)),
191        });
192        visited.insert(Variation {
193            parameter: ParameterKind::Temperature,
194            value: VariationValue::Float(OrderedFloat(0.6)),
195        });
196        assert!(generator.next(&baseline, &visited).is_none());
197    }
198
199    #[test]
200    fn neighborhood_empty_space_returns_none() {
201        let mut generator = Neighborhood::new(SearchSpace { parameters: vec![] }, 1.0, 0).unwrap();
202        let baseline = ConfigSnapshot::default();
203        let visited = HashSet::new();
204        assert!(generator.next(&baseline, &visited).is_none());
205    }
206
207    #[test]
208    fn neighborhood_zero_radius_returns_error() {
209        let result = Neighborhood::new(SearchSpace::default(), 0.0, 0);
210        assert!(result.is_err(), "zero radius must be rejected");
211    }
212
213    #[test]
214    fn neighborhood_negative_radius_returns_error() {
215        let result = Neighborhood::new(SearchSpace::default(), -1.0, 0);
216        assert!(result.is_err(), "negative radius must be rejected");
217    }
218
219    #[test]
220    fn neighborhood_nan_radius_returns_error() {
221        let result = Neighborhood::new(SearchSpace::default(), f64::NAN, 0);
222        assert!(result.is_err(), "NaN radius must be rejected");
223    }
224
225    #[test]
226    fn neighborhood_step_none_uses_default_steps() {
227        // Continuous parameter (step=None) — neighborhood must still produce values.
228        let space = SearchSpace {
229            parameters: vec![
230                super::super::search_space::ParameterRange::new(
231                    ParameterKind::Temperature,
232                    0.0,
233                    2.0,
234                    None,
235                    1.0,
236                )
237                .unwrap(),
238            ],
239        };
240        let mut generator = Neighborhood::new(space, 1.0, 77).unwrap();
241        let baseline = ConfigSnapshot::default();
242        let visited = HashSet::new();
243        // With DEFAULT_STEPS=20, perturbation step = 2.0/20.0 = 0.1; must get at least one result.
244        let mut got_any = false;
245        for _ in 0..50 {
246            if generator.next(&baseline, &visited).is_some() {
247                got_any = true;
248                break;
249            }
250        }
251        assert!(
252            got_any,
253            "should produce at least one variation for continuous parameter"
254        );
255    }
256
257    #[test]
258    fn neighborhood_quantizes_perturbed_values() {
259        let space = make_space(ParameterKind::TopP, 0.1, 1.0, 0.05);
260        let mut generator = Neighborhood::new(space, 2.0, 11).unwrap();
261        let mut baseline = ConfigSnapshot::default();
262        baseline.top_p = 0.5;
263        let visited = HashSet::new();
264        for _ in 0..30 {
265            if let Some(v) = generator.next(&baseline, &visited) {
266                let val = v.value.as_f64();
267                // Quantized values must be multiples of 0.05 anchored at min=0.1:
268                // i.e. (val - 0.1) / 0.05 must be an integer.
269                let steps = (val - 0.1) / 0.05;
270                assert!(
271                    (steps - steps.round()).abs() < 1e-10,
272                    "value {val} is not on the 0.05-step grid anchored at 0.1"
273                );
274            }
275        }
276    }
277
278    #[test]
279    fn neighborhood_name() {
280        let generator = Neighborhood::new(SearchSpace::default(), 1.0, 0).unwrap();
281        assert_eq!(generator.name(), "neighborhood");
282    }
283
284    #[test]
285    fn neighborhood_perturbs_around_baseline() {
286        // Baseline temperature 0.7, radius 1.0, step 0.1 => perturbation in [-0.1, +0.1]
287        // All values should be in [0.6, 0.8] within [0.0, 2.0]
288        let space = make_space(ParameterKind::Temperature, 0.0, 2.0, 0.1);
289        let mut generator = Neighborhood::new(space, 1.0, 55).unwrap();
290        let baseline = ConfigSnapshot::default(); // temperature = 0.7
291        let visited = HashSet::new();
292        let mut temp_values = vec![];
293        for _ in 0..50 {
294            if let Some(v) = generator.next(&baseline, &visited)
295                && v.parameter == ParameterKind::Temperature
296            {
297                temp_values.push(v.value.as_f64());
298            }
299        }
300        assert!(
301            !temp_values.is_empty(),
302            "should produce temperature variations"
303        );
304        // All values must be within ±1 step of 0.7 (i.e., ±0.1, so [0.6, 0.8])
305        for val in &temp_values {
306            assert!(
307                *val >= 0.6 - 1e-10 && *val <= 0.8 + 1e-10,
308                "value {val} not within ±1 step of 0.7"
309            );
310        }
311    }
312}