1use std::collections::HashSet;
7
8use ordered_float::OrderedFloat;
9
10use super::generator::VariationGenerator;
11use super::search_space::SearchSpace;
12use super::snapshot::ConfigSnapshot;
13use super::types::{Variation, VariationValue};
14
15pub struct GridStep {
22 search_space: SearchSpace,
23 current_param: usize,
24 current_step: usize,
25}
26
27impl GridStep {
28 #[must_use]
30 pub fn new(search_space: SearchSpace) -> Self {
31 Self {
32 search_space,
33 current_param: 0,
34 current_step: 0,
35 }
36 }
37}
38
39impl VariationGenerator for GridStep {
40 fn next(
41 &mut self,
42 _baseline: &ConfigSnapshot,
43 visited: &HashSet<Variation>,
44 ) -> Option<Variation> {
45 while self.current_param < self.search_space.parameters.len() {
46 let range = &self.search_space.parameters[self.current_param];
47 let step = range.step.unwrap_or_else(|| (range.max - range.min) / 20.0);
48 if step <= 0.0 {
49 self.current_param += 1;
50 self.current_step = 0;
51 continue;
52 }
53
54 #[allow(clippy::cast_precision_loss)]
55 let raw = range.min + step * self.current_step as f64;
56
57 if raw > range.max + f64::EPSILON {
58 self.current_param += 1;
59 self.current_step = 0;
60 continue;
61 }
62
63 self.current_step += 1;
64
65 let value = range.quantize(raw);
67
68 let variation = Variation {
69 parameter: range.kind,
70 value: VariationValue::Float(OrderedFloat(value)),
71 };
72
73 if !visited.contains(&variation) {
74 return Some(variation);
75 }
76 }
77 None
78 }
79
80 fn name(&self) -> &'static str {
81 "grid"
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use std::collections::HashSet;
88
89 use super::super::search_space::ParameterRange;
90 use super::super::types::ParameterKind;
91 use super::*;
92
93 fn single_param_space(min: f64, max: f64, step: f64) -> SearchSpace {
94 SearchSpace {
95 parameters: vec![ParameterRange {
96 kind: ParameterKind::Temperature,
97 min,
98 max,
99 step: Some(step),
100 default: min,
101 }],
102 }
103 }
104
105 #[test]
106 fn grid_step_produces_values_in_range() {
107 let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.5));
108 let baseline = ConfigSnapshot::default();
109 let mut visited = HashSet::new();
110 let mut values = vec![];
111 while let Some(v) = generator.next(&baseline, &visited) {
112 visited.insert(v.clone());
113 values.push(v.value.as_f64());
114 }
115 assert_eq!(values.len(), 3, "0.0, 0.5, 1.0");
116 for v in &values {
117 assert!(*v >= 0.0 && *v <= 1.0);
118 }
119 }
120
121 #[test]
122 fn grid_step_skips_visited() {
123 let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.5));
124 let baseline = ConfigSnapshot::default();
125 let mut visited = HashSet::new();
126 visited.insert(Variation {
127 parameter: ParameterKind::Temperature,
128 value: VariationValue::Float(OrderedFloat(0.0)),
129 });
130 let first = generator.next(&baseline, &visited).unwrap();
131 assert!(
132 (first.value.as_f64() - 0.5).abs() < 1e-10,
133 "expected 0.5, got {}",
134 first.value.as_f64()
135 );
136 }
137
138 #[test]
139 fn grid_step_returns_none_when_exhausted() {
140 let mut generator = GridStep::new(single_param_space(0.0, 0.0, 1.0));
141 let baseline = ConfigSnapshot::default();
142 let mut visited = HashSet::new();
143 generator.next(&baseline, &visited).unwrap();
145 visited.insert(Variation {
146 parameter: ParameterKind::Temperature,
147 value: VariationValue::Float(OrderedFloat(0.0)),
148 });
149 assert!(generator.next(&baseline, &visited).is_none());
150 }
151
152 #[test]
153 fn grid_step_multiple_params() {
154 let space = SearchSpace {
155 parameters: vec![
156 ParameterRange {
157 kind: ParameterKind::Temperature,
158 min: 0.0,
159 max: 0.5,
160 step: Some(0.5),
161 default: 0.0,
162 },
163 ParameterRange {
164 kind: ParameterKind::TopP,
165 min: 0.5,
166 max: 1.0,
167 step: Some(0.5),
168 default: 0.5,
169 },
170 ],
171 };
172 let mut generator = GridStep::new(space);
173 let baseline = ConfigSnapshot::default();
174 let mut visited = HashSet::new();
175 let mut results = vec![];
176 while let Some(v) = generator.next(&baseline, &visited) {
177 visited.insert(v.clone());
178 results.push(v);
179 }
180 assert_eq!(results.len(), 4);
182 let temp_count = results
183 .iter()
184 .filter(|v| v.parameter == ParameterKind::Temperature)
185 .count();
186 let top_p_count = results
187 .iter()
188 .filter(|v| v.parameter == ParameterKind::TopP)
189 .count();
190 assert_eq!(temp_count, 2);
191 assert_eq!(top_p_count, 2);
192 }
193
194 #[test]
195 fn grid_step_quantizes_to_avoid_fp_drift() {
196 let mut generator = GridStep::new(single_param_space(0.0, 1.0, 0.1));
199 let baseline = ConfigSnapshot::default();
200 let mut visited = HashSet::new();
201 let mut values = vec![];
202 while let Some(v) = generator.next(&baseline, &visited) {
203 visited.insert(v.clone());
204 values.push(v.value.as_f64());
205 }
206 for v in &values {
208 let rounded = (v * 10.0).round() / 10.0;
209 assert!(
210 (v - rounded).abs() < 1e-10,
211 "value {v} is not a clean multiple of 0.1"
212 );
213 }
214 }
215
216 #[test]
217 fn grid_step_empty_space_returns_none() {
218 let mut generator = GridStep::new(SearchSpace { parameters: vec![] });
219 let baseline = ConfigSnapshot::default();
220 let visited = HashSet::new();
221 assert!(generator.next(&baseline, &visited).is_none());
222 }
223
224 #[test]
225 fn grid_step_none_step_uses_fallback() {
226 let space = SearchSpace {
228 parameters: vec![ParameterRange {
229 kind: ParameterKind::Temperature,
230 min: 0.0,
231 max: 1.0,
232 step: None,
233 default: 0.5,
234 }],
235 };
236 let mut generator = GridStep::new(space);
237 let baseline = ConfigSnapshot::default();
238 let mut visited = HashSet::new();
239 let mut count = 0;
240 while let Some(v) = generator.next(&baseline, &visited) {
241 visited.insert(v.clone());
242 count += 1;
243 }
244 assert_eq!(
246 count, 21,
247 "expected 21 steps for step=None with DEFAULT_STEPS=20"
248 );
249 }
250
251 #[test]
252 fn grid_step_name() {
253 let generator = GridStep::new(SearchSpace::default());
254 assert_eq!(generator.name(), "grid");
255 }
256}