Skip to main content

zeph_experiments/
neighborhood.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Neighborhood perturbation strategy for parameter variation.
5
6use std::collections::HashSet;
7
8use ordered_float::OrderedFloat;
9use rand::Rng as _;
10use rand::SeedableRng as _;
11use rand::rngs::SmallRng;
12
13use super::error::EvalError;
14use super::generator::VariationGenerator;
15use super::search_space::SearchSpace;
16use super::snapshot::ConfigSnapshot;
17use super::types::{Variation, VariationValue};
18
19/// Maximum number of retry attempts before giving up (space is considered exhausted).
20const MAX_RETRIES: usize = 1000;
21
22/// Fallback number of steps used when a parameter has no discrete step configured.
23///
24/// This gives a reasonable granularity for continuous parameters without requiring
25/// an explicit step in the search space definition.
26const DEFAULT_STEPS: f64 = 20.0;
27
28/// Perturbation strategy around the current baseline.
29///
30/// At each call, a parameter is chosen uniformly at random. The new value is
31/// computed as `baseline_value ± U(-radius, radius) * step`, then clamped and
32/// quantized to the nearest grid step. Useful after a grid sweep has narrowed
33/// the search to a promising region.
34///
35/// `radius` must be positive (enforced in [`Neighborhood::new`]).
36pub struct Neighborhood {
37    search_space: SearchSpace,
38    radius: f64,
39    rng: SmallRng,
40}
41
42impl Neighborhood {
43    /// Create a new `Neighborhood` generator.
44    ///
45    /// # Errors
46    ///
47    /// Returns [`EvalError::InvalidRadius`] if `radius` is not finite and positive.
48    pub fn new(search_space: SearchSpace, radius: f64, seed: u64) -> Result<Self, EvalError> {
49        if !radius.is_finite() || radius <= 0.0 {
50            return Err(EvalError::InvalidRadius { radius });
51        }
52        Ok(Self {
53            search_space,
54            radius,
55            rng: SmallRng::seed_from_u64(seed),
56        })
57    }
58}
59
60impl VariationGenerator for Neighborhood {
61    fn next(
62        &mut self,
63        baseline: &ConfigSnapshot,
64        visited: &HashSet<Variation>,
65    ) -> Option<Variation> {
66        if self.search_space.parameters.is_empty() {
67            return None;
68        }
69        for _ in 0..MAX_RETRIES {
70            let idx = self.rng.gen_range(0..self.search_space.parameters.len());
71            let range = &self.search_space.parameters[idx];
72            let current = baseline.get(range.kind);
73            // DEFAULT_STEPS is used when step is None (continuous parameter).
74            let step = range
75                .step
76                .unwrap_or_else(|| (range.max - range.min) / DEFAULT_STEPS);
77            let delta = self.rng.gen_range(-self.radius..=self.radius) * step;
78            // Skip zero perturbations — they produce the baseline value, wasting an attempt.
79            if delta.abs() < f64::EPSILON {
80                continue;
81            }
82            let raw = current + delta;
83            let value = range.quantize(range.clamp(raw));
84            // Skip if the quantized value equals the baseline (no effective change).
85            if (value - current).abs() < f64::EPSILON {
86                continue;
87            }
88            let variation = Variation {
89                parameter: range.kind,
90                value: VariationValue::Float(OrderedFloat(value)),
91            };
92            if !visited.contains(&variation) {
93                return Some(variation);
94            }
95        }
96        None
97    }
98
99    fn name(&self) -> &'static str {
100        "neighborhood"
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    #![allow(
107        clippy::collapsible_if,
108        clippy::field_reassign_with_default,
109        clippy::manual_midpoint,
110        clippy::manual_range_contains
111    )]
112
113    use std::collections::HashSet;
114
115    use super::super::search_space::ParameterRange;
116    use super::super::types::ParameterKind;
117    use super::*;
118
119    fn make_space(kind: ParameterKind, min: f64, max: f64, step: f64) -> SearchSpace {
120        SearchSpace {
121            parameters: vec![ParameterRange {
122                kind,
123                min,
124                max,
125                step: Some(step),
126                default: f64::midpoint(min, max),
127            }],
128        }
129    }
130
131    #[test]
132    fn neighborhood_produces_values_in_range() {
133        let space = make_space(ParameterKind::Temperature, 0.0, 2.0, 0.1);
134        let mut generator = Neighborhood::new(space, 1.0, 42).unwrap();
135        let baseline = ConfigSnapshot::default();
136        let visited = HashSet::new();
137        for _ in 0..20 {
138            if let Some(v) = generator.next(&baseline, &visited) {
139                let val = v.value.as_f64();
140                assert!((0.0..=2.0).contains(&val), "out of range: {val}");
141            }
142        }
143    }
144
145    #[test]
146    fn neighborhood_is_deterministic_with_same_seed() {
147        let space = SearchSpace::default();
148        let baseline = ConfigSnapshot::default();
149        let visited = HashSet::new();
150        let mut gen1 = Neighborhood::new(space.clone(), 1.0, 99).unwrap();
151        let mut gen2 = Neighborhood::new(space, 1.0, 99).unwrap();
152        let v1 = gen1.next(&baseline, &visited);
153        let v2 = gen2.next(&baseline, &visited);
154        assert_eq!(v1, v2, "same seed must produce same first variation");
155    }
156
157    #[test]
158    fn neighborhood_skips_visited() {
159        // Single-point space: min == max == 0.5, step 0.1
160        let space = make_space(ParameterKind::Temperature, 0.5, 0.5, 0.1);
161        let mut generator = Neighborhood::new(space, 1.0, 0).unwrap();
162        let baseline = ConfigSnapshot::default();
163        let mut visited = HashSet::new();
164        visited.insert(Variation {
165            parameter: ParameterKind::Temperature,
166            value: VariationValue::Float(OrderedFloat(0.5)),
167        });
168        assert!(generator.next(&baseline, &visited).is_none());
169    }
170
171    #[test]
172    fn neighborhood_empty_space_returns_none() {
173        let mut generator = Neighborhood::new(SearchSpace { parameters: vec![] }, 1.0, 0).unwrap();
174        let baseline = ConfigSnapshot::default();
175        let visited = HashSet::new();
176        assert!(generator.next(&baseline, &visited).is_none());
177    }
178
179    #[test]
180    fn neighborhood_zero_radius_returns_error() {
181        let result = Neighborhood::new(SearchSpace::default(), 0.0, 0);
182        assert!(result.is_err(), "zero radius must be rejected");
183    }
184
185    #[test]
186    fn neighborhood_negative_radius_returns_error() {
187        let result = Neighborhood::new(SearchSpace::default(), -1.0, 0);
188        assert!(result.is_err(), "negative radius must be rejected");
189    }
190
191    #[test]
192    fn neighborhood_nan_radius_returns_error() {
193        let result = Neighborhood::new(SearchSpace::default(), f64::NAN, 0);
194        assert!(result.is_err(), "NaN radius must be rejected");
195    }
196
197    #[test]
198    fn neighborhood_step_none_uses_default_steps() {
199        // Continuous parameter (step=None) — neighborhood must still produce values.
200        let space = SearchSpace {
201            parameters: vec![super::super::search_space::ParameterRange {
202                kind: ParameterKind::Temperature,
203                min: 0.0,
204                max: 2.0,
205                step: None,
206                default: 1.0,
207            }],
208        };
209        let mut generator = Neighborhood::new(space, 1.0, 77).unwrap();
210        let baseline = ConfigSnapshot::default();
211        let visited = HashSet::new();
212        // With DEFAULT_STEPS=20, perturbation step = 2.0/20.0 = 0.1; must get at least one result.
213        let mut got_any = false;
214        for _ in 0..50 {
215            if generator.next(&baseline, &visited).is_some() {
216                got_any = true;
217                break;
218            }
219        }
220        assert!(
221            got_any,
222            "should produce at least one variation for continuous parameter"
223        );
224    }
225
226    #[test]
227    fn neighborhood_quantizes_perturbed_values() {
228        let space = make_space(ParameterKind::TopP, 0.1, 1.0, 0.05);
229        let mut generator = Neighborhood::new(space, 2.0, 11).unwrap();
230        let mut baseline = ConfigSnapshot::default();
231        baseline.top_p = 0.5;
232        let visited = HashSet::new();
233        for _ in 0..30 {
234            if let Some(v) = generator.next(&baseline, &visited) {
235                let val = v.value.as_f64();
236                // Quantized values must be multiples of 0.05 anchored at min=0.1:
237                // i.e. (val - 0.1) / 0.05 must be an integer.
238                let steps = (val - 0.1) / 0.05;
239                assert!(
240                    (steps - steps.round()).abs() < 1e-10,
241                    "value {val} is not on the 0.05-step grid anchored at 0.1"
242                );
243            }
244        }
245    }
246
247    #[test]
248    fn neighborhood_name() {
249        let generator = Neighborhood::new(SearchSpace::default(), 1.0, 0).unwrap();
250        assert_eq!(generator.name(), "neighborhood");
251    }
252
253    #[test]
254    fn neighborhood_perturbs_around_baseline() {
255        // Baseline temperature 0.7, radius 1.0, step 0.1 => perturbation in [-0.1, +0.1]
256        // All values should be in [0.6, 0.8] within [0.0, 2.0]
257        let space = make_space(ParameterKind::Temperature, 0.0, 2.0, 0.1);
258        let mut generator = Neighborhood::new(space, 1.0, 55).unwrap();
259        let baseline = ConfigSnapshot::default(); // temperature = 0.7
260        let visited = HashSet::new();
261        let mut temp_values = vec![];
262        for _ in 0..50 {
263            if let Some(v) = generator.next(&baseline, &visited)
264                && v.parameter == ParameterKind::Temperature
265            {
266                temp_values.push(v.value.as_f64());
267            }
268        }
269        assert!(
270            !temp_values.is_empty(),
271            "should produce temperature variations"
272        );
273        // All values must be within ±1 step of 0.7 (i.e., ±0.1, so [0.6, 0.8])
274        for val in &temp_values {
275            assert!(
276                *val >= 0.6 - 1e-10 && *val <= 0.8 + 1e-10,
277                "value {val} not within ±1 step of 0.7"
278            );
279        }
280    }
281}