1use std::collections::HashSet;
15
16use ordered_float::OrderedFloat;
17
18use super::generator::VariationGenerator;
19use super::search_space::SearchSpace;
20use super::snapshot::ConfigSnapshot;
21use super::types::{Variation, VariationValue};
22
23pub struct GridStep {
63 search_space: SearchSpace,
64 current_param: usize,
65 current_step: usize,
66}
67
68impl GridStep {
69 #[must_use]
80 pub fn new(search_space: SearchSpace) -> Self {
81 Self {
82 search_space,
83 current_param: 0,
84 current_step: 0,
85 }
86 }
87}
88
89impl VariationGenerator for GridStep {
90 fn next(
91 &mut self,
92 _baseline: &ConfigSnapshot,
93 visited: &HashSet<Variation>,
94 ) -> Option<Variation> {
95 while self.current_param < self.search_space.parameters.len() {
96 let range = &self.search_space.parameters[self.current_param];
97 let step = range.step.unwrap_or_else(|| (range.max - range.min) / 20.0);
98 if step <= 0.0 {
99 self.current_param += 1;
100 self.current_step = 0;
101 continue;
102 }
103
104 #[allow(clippy::cast_precision_loss)]
105 let raw = range.min + step * self.current_step as f64;
106
107 if raw > range.max + f64::EPSILON {
108 self.current_param += 1;
109 self.current_step = 0;
110 continue;
111 }
112
113 self.current_step += 1;
114
115 let value = range.quantize(raw);
117
118 let variation = Variation {
119 parameter: range.kind,
120 value: VariationValue::Float(OrderedFloat(value)),
121 };
122
123 if !visited.contains(&variation) {
124 return Some(variation);
125 }
126 }
127 None
128 }
129
130 fn name(&self) -> &'static str {
131 "grid"
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use std::collections::HashSet;
138
139 use super::super::search_space::ParameterRange;
140 use super::super::types::ParameterKind;
141 use super::*;
142
143 fn single_param_space(min: f64, max: f64, step: f64) -> SearchSpace {
144 SearchSpace {
145 parameters: vec![ParameterRange {
146 kind: ParameterKind::Temperature,
147 min,
148 max,
149 step: Some(step),
150 default: min,
151 }],
152 }
153 }
154
155 #[test]
156 fn grid_step_produces_values_in_range() {
157 let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.5));
158 let baseline = ConfigSnapshot::default();
159 let mut visited = HashSet::new();
160 let mut values = vec![];
161 while let Some(v) = generator.next(&baseline, &visited) {
162 visited.insert(v.clone());
163 values.push(v.value.as_f64());
164 }
165 assert_eq!(values.len(), 3, "0.0, 0.5, 1.0");
166 for v in &values {
167 assert!(*v >= 0.0 && *v <= 1.0);
168 }
169 }
170
171 #[test]
172 fn grid_step_skips_visited() {
173 let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.5));
174 let baseline = ConfigSnapshot::default();
175 let mut visited = HashSet::new();
176 visited.insert(Variation {
177 parameter: ParameterKind::Temperature,
178 value: VariationValue::Float(OrderedFloat(0.0)),
179 });
180 let first = generator.next(&baseline, &visited).unwrap();
181 assert!(
182 (first.value.as_f64() - 0.5).abs() < 1e-10,
183 "expected 0.5, got {}",
184 first.value.as_f64()
185 );
186 }
187
188 #[test]
189 fn grid_step_returns_none_when_exhausted() {
190 let mut generator = GridStep::new(single_param_space(0.0, 0.0, 1.0));
191 let baseline = ConfigSnapshot::default();
192 let mut visited = HashSet::new();
193 generator.next(&baseline, &visited).unwrap();
195 visited.insert(Variation {
196 parameter: ParameterKind::Temperature,
197 value: VariationValue::Float(OrderedFloat(0.0)),
198 });
199 assert!(generator.next(&baseline, &visited).is_none());
200 }
201
202 #[test]
203 fn grid_step_multiple_params() {
204 let space = SearchSpace {
205 parameters: vec![
206 ParameterRange {
207 kind: ParameterKind::Temperature,
208 min: 0.0,
209 max: 0.5,
210 step: Some(0.5),
211 default: 0.0,
212 },
213 ParameterRange {
214 kind: ParameterKind::TopP,
215 min: 0.5,
216 max: 1.0,
217 step: Some(0.5),
218 default: 0.5,
219 },
220 ],
221 };
222 let mut generator = GridStep::new(space);
223 let baseline = ConfigSnapshot::default();
224 let mut visited = HashSet::new();
225 let mut results = vec![];
226 while let Some(v) = generator.next(&baseline, &visited) {
227 visited.insert(v.clone());
228 results.push(v);
229 }
230 assert_eq!(results.len(), 4);
232 let temp_count = results
233 .iter()
234 .filter(|v| v.parameter == ParameterKind::Temperature)
235 .count();
236 let top_p_count = results
237 .iter()
238 .filter(|v| v.parameter == ParameterKind::TopP)
239 .count();
240 assert_eq!(temp_count, 2);
241 assert_eq!(top_p_count, 2);
242 }
243
244 #[test]
245 fn grid_step_quantizes_to_avoid_fp_drift() {
246 let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.1));
249 let baseline = ConfigSnapshot::default();
250 let mut visited = HashSet::new();
251 let mut values = vec![];
252 while let Some(v) = generator.next(&baseline, &visited) {
253 visited.insert(v.clone());
254 values.push(v.value.as_f64());
255 }
256 for v in &values {
258 let rounded = (v * 10.0).round() / 10.0;
259 assert!(
260 (v - rounded).abs() < 1e-10,
261 "value {v} is not a clean multiple of 0.1"
262 );
263 }
264 }
265
266 #[test]
267 fn grid_step_empty_space_returns_none() {
268 let mut generator = GridStep::new(SearchSpace { parameters: vec![] });
269 let baseline = ConfigSnapshot::default();
270 let visited = HashSet::new();
271 assert!(generator.next(&baseline, &visited).is_none());
272 }
273
274 #[test]
275 fn grid_step_none_step_uses_fallback() {
276 let space = SearchSpace {
278 parameters: vec![ParameterRange {
279 kind: ParameterKind::Temperature,
280 min: 0.0,
281 max: 1.0,
282 step: None,
283 default: 0.5,
284 }],
285 };
286 let mut generator = GridStep::new(space);
287 let baseline = ConfigSnapshot::default();
288 let mut visited = HashSet::new();
289 let mut count = 0;
290 while let Some(v) = generator.next(&baseline, &visited) {
291 visited.insert(v.clone());
292 count += 1;
293 }
294 assert_eq!(
296 count, 21,
297 "expected 21 steps for step=None with DEFAULT_STEPS=20"
298 );
299 }
300
301 #[test]
302 fn grid_step_name() {
303 let generator = GridStep::new(SearchSpace::default());
304 assert_eq!(generator.name(), "grid");
305 }
306}