Skip to main content

zeph_experiments/
grid.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Grid sweep strategy for parameter variation.
5
6use std::collections::HashSet;
7
8use ordered_float::OrderedFloat;
9
10use super::generator::VariationGenerator;
11use super::search_space::SearchSpace;
12use super::snapshot::ConfigSnapshot;
13use super::types::{Variation, VariationValue};
14
15/// Systematic grid sweep: iterate each parameter through its discrete steps, skip visited.
16///
17/// Parameters are swept one at a time. For each parameter, all grid points from
18/// `min` to `max` (with the configured `step`) are enumerated in order. Already-visited
19/// variations are skipped. When all steps for a parameter are exhausted, the next parameter
20/// is tried. Returns `None` when the full grid has been visited.
21pub struct GridStep {
22    search_space: SearchSpace,
23    current_param: usize,
24    current_step: usize,
25}
26
27impl GridStep {
28    /// Create a new `GridStep` generator with the given search space.
29    #[must_use]
30    pub fn new(search_space: SearchSpace) -> Self {
31        Self {
32            search_space,
33            current_param: 0,
34            current_step: 0,
35        }
36    }
37}
38
39impl VariationGenerator for GridStep {
40    fn next(
41        &mut self,
42        _baseline: &ConfigSnapshot,
43        visited: &HashSet<Variation>,
44    ) -> Option<Variation> {
45        while self.current_param < self.search_space.parameters.len() {
46            let range = &self.search_space.parameters[self.current_param];
47            let step = range.step.unwrap_or_else(|| (range.max - range.min) / 20.0);
48            if step <= 0.0 {
49                self.current_param += 1;
50                self.current_step = 0;
51                continue;
52            }
53
54            #[allow(clippy::cast_precision_loss)]
55            let raw = range.min + step * self.current_step as f64;
56
57            if raw > range.max + f64::EPSILON {
58                self.current_param += 1;
59                self.current_step = 0;
60                continue;
61            }
62
63            self.current_step += 1;
64
65            // Quantize to avoid floating-point accumulation before deduplication.
66            let value = range.quantize(raw);
67
68            let variation = Variation {
69                parameter: range.kind,
70                value: VariationValue::Float(OrderedFloat(value)),
71            };
72
73            if !visited.contains(&variation) {
74                return Some(variation);
75            }
76        }
77        None
78    }
79
80    fn name(&self) -> &'static str {
81        "grid"
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use std::collections::HashSet;
88
89    use super::super::search_space::ParameterRange;
90    use super::super::types::ParameterKind;
91    use super::*;
92
93    fn single_param_space(min: f64, max: f64, step: f64) -> SearchSpace {
94        SearchSpace {
95            parameters: vec![ParameterRange {
96                kind: ParameterKind::Temperature,
97                min,
98                max,
99                step: Some(step),
100                default: min,
101            }],
102        }
103    }
104
105    #[test]
106    fn grid_step_produces_values_in_range() {
107        let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.5));
108        let baseline = ConfigSnapshot::default();
109        let mut visited = HashSet::new();
110        let mut values = vec![];
111        while let Some(v) = generator.next(&baseline, &visited) {
112            visited.insert(v.clone());
113            values.push(v.value.as_f64());
114        }
115        assert_eq!(values.len(), 3, "0.0, 0.5, 1.0");
116        for v in &values {
117            assert!(*v >= 0.0 && *v <= 1.0);
118        }
119    }
120
121    #[test]
122    fn grid_step_skips_visited() {
123        let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.5));
124        let baseline = ConfigSnapshot::default();
125        let mut visited = HashSet::new();
126        visited.insert(Variation {
127            parameter: ParameterKind::Temperature,
128            value: VariationValue::Float(OrderedFloat(0.0)),
129        });
130        let first = generator.next(&baseline, &visited).unwrap();
131        assert!(
132            (first.value.as_f64() - 0.5).abs() < 1e-10,
133            "expected 0.5, got {}",
134            first.value.as_f64()
135        );
136    }
137
138    #[test]
139    fn grid_step_returns_none_when_exhausted() {
140        let mut generator = GridStep::new(single_param_space(0.0, 0.0, 1.0));
141        let baseline = ConfigSnapshot::default();
142        let mut visited = HashSet::new();
143        // Only one point: 0.0
144        generator.next(&baseline, &visited).unwrap();
145        visited.insert(Variation {
146            parameter: ParameterKind::Temperature,
147            value: VariationValue::Float(OrderedFloat(0.0)),
148        });
149        assert!(generator.next(&baseline, &visited).is_none());
150    }
151
152    #[test]
153    fn grid_step_multiple_params() {
154        let space = SearchSpace {
155            parameters: vec![
156                ParameterRange {
157                    kind: ParameterKind::Temperature,
158                    min: 0.0,
159                    max: 0.5,
160                    step: Some(0.5),
161                    default: 0.0,
162                },
163                ParameterRange {
164                    kind: ParameterKind::TopP,
165                    min: 0.5,
166                    max: 1.0,
167                    step: Some(0.5),
168                    default: 0.5,
169                },
170            ],
171        };
172        let mut generator = GridStep::new(space);
173        let baseline = ConfigSnapshot::default();
174        let mut visited = HashSet::new();
175        let mut results = vec![];
176        while let Some(v) = generator.next(&baseline, &visited) {
177            visited.insert(v.clone());
178            results.push(v);
179        }
180        // Temperature: 0.0, 0.5 — TopP: 0.5, 1.0
181        assert_eq!(results.len(), 4);
182        let temp_count = results
183            .iter()
184            .filter(|v| v.parameter == ParameterKind::Temperature)
185            .count();
186        let top_p_count = results
187            .iter()
188            .filter(|v| v.parameter == ParameterKind::TopP)
189            .count();
190        assert_eq!(temp_count, 2);
191        assert_eq!(top_p_count, 2);
192    }
193
194    #[test]
195    fn grid_step_quantizes_to_avoid_fp_drift() {
196        // 0.1 * 7 via accumulation = 0.7000000000000001
197        // quantize must snap to 0.7
198        let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.1));
199        let baseline = ConfigSnapshot::default();
200        let mut visited = HashSet::new();
201        let mut values = vec![];
202        while let Some(v) = generator.next(&baseline, &visited) {
203            visited.insert(v.clone());
204            values.push(v.value.as_f64());
205        }
206        // All values should be clean multiples of 0.1
207        for v in &values {
208            let rounded = (v * 10.0).round() / 10.0;
209            assert!(
210                (v - rounded).abs() < 1e-10,
211                "value {v} is not a clean multiple of 0.1"
212            );
213        }
214    }
215
216    #[test]
217    fn grid_step_empty_space_returns_none() {
218        let mut generator = GridStep::new(SearchSpace { parameters: vec![] });
219        let baseline = ConfigSnapshot::default();
220        let visited = HashSet::new();
221        assert!(generator.next(&baseline, &visited).is_none());
222    }
223
224    #[test]
225    fn grid_step_none_step_uses_fallback() {
226        // Parameter with step=None — GridStep falls back to (max-min)/20.0 as step size.
227        let space = SearchSpace {
228            parameters: vec![ParameterRange {
229                kind: ParameterKind::Temperature,
230                min: 0.0,
231                max: 1.0,
232                step: None,
233                default: 0.5,
234            }],
235        };
236        let mut generator = GridStep::new(space);
237        let baseline = ConfigSnapshot::default();
238        let mut visited = HashSet::new();
239        let mut count = 0;
240        while let Some(v) = generator.next(&baseline, &visited) {
241            visited.insert(v.clone());
242            count += 1;
243        }
244        // With step = 1.0/20.0, there should be 21 steps (0.0, 0.05, ..., 1.0)
245        assert_eq!(
246            count, 21,
247            "expected 21 steps for step=None with DEFAULT_STEPS=20"
248        );
249    }
250
251    #[test]
252    fn grid_step_name() {
253        let generator = GridStep::new(SearchSpace::default());
254        assert_eq!(generator.name(), "grid");
255    }
256}