Skip to main content

zeph_experiments/
grid.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Systematic grid sweep strategy for parameter variation.
5//!
6//! [`GridStep`] iterates each parameter through its discrete steps in order,
7//! skipping variations that have already been visited. This gives exhaustive
8//! coverage of the search space and is well-suited as a first-pass exploration
9//! before switching to a [`Neighborhood`] or [`Random`] strategy.
10//!
11//! [`Neighborhood`]: crate::Neighborhood
12//! [`Random`]: crate::Random
13
14use std::collections::HashSet;
15
16use ordered_float::OrderedFloat;
17
18use super::generator::VariationGenerator;
19use super::search_space::SearchSpace;
20use super::snapshot::ConfigSnapshot;
21use super::types::{Variation, VariationValue};
22
23/// Systematic grid sweep: iterate each parameter through its discrete steps, skip visited.
24///
25/// Parameters are swept one at a time. For each parameter, all grid points from
26/// `min` to `max` (with the configured `step`) are enumerated in order. Already-visited
27/// variations are skipped. When all steps for a parameter are exhausted, the next
28/// parameter is tried. Returns `None` when the full grid has been visited.
29///
30/// When a parameter has no discrete `step`, [`GridStep`] falls back to
31/// `(max - min) / 20` as the step size.
32///
33/// # Examples
34///
35/// ```rust
36/// use std::collections::HashSet;
37/// use zeph_experiments::{
38///     ConfigSnapshot, GridStep, ParameterKind, ParameterRange, SearchSpace, VariationGenerator,
39/// };
40///
41/// let space = SearchSpace {
42///     parameters: vec![ParameterRange {
43///         kind: ParameterKind::Temperature,
44///         min: 0.0,
45///         max: 1.0,
46///         step: Some(0.5),
47///         default: 0.5,
48///     }],
49/// };
50/// let mut generator = GridStep::new(space);
51/// let baseline = ConfigSnapshot::default();
52/// let mut visited = HashSet::new();
53///
54/// // Produces 0.0, 0.5, 1.0 in order.
55/// let mut count = 0;
56/// while let Some(v) = generator.next(&baseline, &visited) {
57///     visited.insert(v);
58///     count += 1;
59/// }
60/// assert_eq!(count, 3);
61/// ```
62pub struct GridStep {
63    search_space: SearchSpace,
64    current_param: usize,
65    current_step: usize,
66}
67
68impl GridStep {
69    /// Create a new [`GridStep`] generator starting at the first grid point.
70    ///
71    /// # Examples
72    ///
73    /// ```rust
74    /// use zeph_experiments::{GridStep, SearchSpace, VariationGenerator};
75    ///
76    /// let generator = GridStep::new(SearchSpace::default());
77    /// assert_eq!(generator.name(), "grid");
78    /// ```
79    #[must_use]
80    pub fn new(search_space: SearchSpace) -> Self {
81        Self {
82            search_space,
83            current_param: 0,
84            current_step: 0,
85        }
86    }
87}
88
89impl VariationGenerator for GridStep {
90    fn next(
91        &mut self,
92        _baseline: &ConfigSnapshot,
93        visited: &HashSet<Variation>,
94    ) -> Option<Variation> {
95        while self.current_param < self.search_space.parameters.len() {
96            let range = &self.search_space.parameters[self.current_param];
97            let step = range.step.unwrap_or_else(|| (range.max - range.min) / 20.0);
98            if step <= 0.0 {
99                self.current_param += 1;
100                self.current_step = 0;
101                continue;
102            }
103
104            #[allow(clippy::cast_precision_loss)]
105            let raw = range.min + step * self.current_step as f64;
106
107            if raw > range.max + f64::EPSILON {
108                self.current_param += 1;
109                self.current_step = 0;
110                continue;
111            }
112
113            self.current_step += 1;
114
115            // Quantize to avoid floating-point accumulation before deduplication.
116            let value = range.quantize(raw);
117
118            let variation = Variation {
119                parameter: range.kind,
120                value: VariationValue::Float(OrderedFloat(value)),
121            };
122
123            if !visited.contains(&variation) {
124                return Some(variation);
125            }
126        }
127        None
128    }
129
130    fn name(&self) -> &'static str {
131        "grid"
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use std::collections::HashSet;
138
139    use super::super::search_space::ParameterRange;
140    use super::super::types::ParameterKind;
141    use super::*;
142
143    fn single_param_space(min: f64, max: f64, step: f64) -> SearchSpace {
144        SearchSpace {
145            parameters: vec![ParameterRange {
146                kind: ParameterKind::Temperature,
147                min,
148                max,
149                step: Some(step),
150                default: min,
151            }],
152        }
153    }
154
155    #[test]
156    fn grid_step_produces_values_in_range() {
157        let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.5));
158        let baseline = ConfigSnapshot::default();
159        let mut visited = HashSet::new();
160        let mut values = vec![];
161        while let Some(v) = generator.next(&baseline, &visited) {
162            visited.insert(v.clone());
163            values.push(v.value.as_f64());
164        }
165        assert_eq!(values.len(), 3, "0.0, 0.5, 1.0");
166        for v in &values {
167            assert!(*v >= 0.0 && *v <= 1.0);
168        }
169    }
170
171    #[test]
172    fn grid_step_skips_visited() {
173        let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.5));
174        let baseline = ConfigSnapshot::default();
175        let mut visited = HashSet::new();
176        visited.insert(Variation {
177            parameter: ParameterKind::Temperature,
178            value: VariationValue::Float(OrderedFloat(0.0)),
179        });
180        let first = generator.next(&baseline, &visited).unwrap();
181        assert!(
182            (first.value.as_f64() - 0.5).abs() < 1e-10,
183            "expected 0.5, got {}",
184            first.value.as_f64()
185        );
186    }
187
188    #[test]
189    fn grid_step_returns_none_when_exhausted() {
190        let mut generator = GridStep::new(single_param_space(0.0, 0.0, 1.0));
191        let baseline = ConfigSnapshot::default();
192        let mut visited = HashSet::new();
193        // Only one point: 0.0
194        generator.next(&baseline, &visited).unwrap();
195        visited.insert(Variation {
196            parameter: ParameterKind::Temperature,
197            value: VariationValue::Float(OrderedFloat(0.0)),
198        });
199        assert!(generator.next(&baseline, &visited).is_none());
200    }
201
202    #[test]
203    fn grid_step_multiple_params() {
204        let space = SearchSpace {
205            parameters: vec![
206                ParameterRange {
207                    kind: ParameterKind::Temperature,
208                    min: 0.0,
209                    max: 0.5,
210                    step: Some(0.5),
211                    default: 0.0,
212                },
213                ParameterRange {
214                    kind: ParameterKind::TopP,
215                    min: 0.5,
216                    max: 1.0,
217                    step: Some(0.5),
218                    default: 0.5,
219                },
220            ],
221        };
222        let mut generator = GridStep::new(space);
223        let baseline = ConfigSnapshot::default();
224        let mut visited = HashSet::new();
225        let mut results = vec![];
226        while let Some(v) = generator.next(&baseline, &visited) {
227            visited.insert(v.clone());
228            results.push(v);
229        }
230        // Temperature: 0.0, 0.5 — TopP: 0.5, 1.0
231        assert_eq!(results.len(), 4);
232        let temp_count = results
233            .iter()
234            .filter(|v| v.parameter == ParameterKind::Temperature)
235            .count();
236        let top_p_count = results
237            .iter()
238            .filter(|v| v.parameter == ParameterKind::TopP)
239            .count();
240        assert_eq!(temp_count, 2);
241        assert_eq!(top_p_count, 2);
242    }
243
244    #[test]
245    fn grid_step_quantizes_to_avoid_fp_drift() {
246        // 0.1 * 7 via accumulation = 0.7000000000000001
247        // quantize must snap to 0.7
248        let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.1));
249        let baseline = ConfigSnapshot::default();
250        let mut visited = HashSet::new();
251        let mut values = vec![];
252        while let Some(v) = generator.next(&baseline, &visited) {
253            visited.insert(v.clone());
254            values.push(v.value.as_f64());
255        }
256        // All values should be clean multiples of 0.1
257        for v in &values {
258            let rounded = (v * 10.0).round() / 10.0;
259            assert!(
260                (v - rounded).abs() < 1e-10,
261                "value {v} is not a clean multiple of 0.1"
262            );
263        }
264    }
265
266    #[test]
267    fn grid_step_empty_space_returns_none() {
268        let mut generator = GridStep::new(SearchSpace { parameters: vec![] });
269        let baseline = ConfigSnapshot::default();
270        let visited = HashSet::new();
271        assert!(generator.next(&baseline, &visited).is_none());
272    }
273
274    #[test]
275    fn grid_step_none_step_uses_fallback() {
276        // Parameter with step=None — GridStep falls back to (max-min)/20.0 as step size.
277        let space = SearchSpace {
278            parameters: vec![ParameterRange {
279                kind: ParameterKind::Temperature,
280                min: 0.0,
281                max: 1.0,
282                step: None,
283                default: 0.5,
284            }],
285        };
286        let mut generator = GridStep::new(space);
287        let baseline = ConfigSnapshot::default();
288        let mut visited = HashSet::new();
289        let mut count = 0;
290        while let Some(v) = generator.next(&baseline, &visited) {
291            visited.insert(v.clone());
292            count += 1;
293        }
294        // With step = 1.0/20.0, there should be 21 steps (0.0, 0.05, ..., 1.0)
295        assert_eq!(
296            count, 21,
297            "expected 21 steps for step=None with DEFAULT_STEPS=20"
298        );
299    }
300
301    #[test]
302    fn grid_step_name() {
303        let generator = GridStep::new(SearchSpace::default());
304        assert_eq!(generator.name(), "grid");
305    }
306}