1#![forbid(unsafe_code)]
2use use_objective::{ObjectiveDirection, is_better};
29
30#[derive(Debug, Clone, Copy, PartialEq)]
31pub struct GridSearchResult1D {
32 pub best_value: f64,
33 pub best_score: f64,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq)]
37pub struct GridSearchResult2D {
38 pub best_x: f64,
39 pub best_y: f64,
40 pub best_score: f64,
41}
42
43pub fn grid_search_1d<F>(
44 values: &[f64],
45 objective: F,
46 direction: ObjectiveDirection,
47) -> Option<GridSearchResult1D>
48where
49 F: Fn(f64) -> f64,
50{
51 if values.is_empty() || values.iter().any(|value| !value.is_finite()) {
52 return None;
53 }
54
55 let mut best: Option<GridSearchResult1D> = None;
56 for value in values.iter().copied() {
57 let score = objective(value);
58 if !score.is_finite() {
59 return None;
60 }
61
62 let candidate = GridSearchResult1D {
63 best_value: value,
64 best_score: score,
65 };
66
67 if best.is_none_or(|current| is_better(candidate.best_score, current.best_score, direction))
68 {
69 best = Some(candidate);
70 }
71 }
72
73 best
74}
75
76pub fn grid_search_2d<F>(
77 x_values: &[f64],
78 y_values: &[f64],
79 objective: F,
80 direction: ObjectiveDirection,
81) -> Option<GridSearchResult2D>
82where
83 F: Fn(f64, f64) -> f64,
84{
85 if x_values.is_empty()
86 || y_values.is_empty()
87 || x_values
88 .iter()
89 .chain(y_values.iter())
90 .any(|value| !value.is_finite())
91 {
92 return None;
93 }
94
95 let mut best: Option<GridSearchResult2D> = None;
96 for x_value in x_values.iter().copied() {
97 for y_value in y_values.iter().copied() {
98 let score = objective(x_value, y_value);
99 if !score.is_finite() {
100 return None;
101 }
102
103 let candidate = GridSearchResult2D {
104 best_x: x_value,
105 best_y: y_value,
106 best_score: score,
107 };
108
109 if best.is_none_or(|current| {
110 is_better(candidate.best_score, current.best_score, direction)
111 }) {
112 best = Some(candidate);
113 }
114 }
115 }
116
117 best
118}
119
120#[cfg(test)]
121mod tests {
122 use super::{grid_search_1d, grid_search_2d};
123 use use_objective::ObjectiveDirection;
124
125 #[test]
126 fn finds_best_one_dimensional_candidate() {
127 let result = grid_search_1d(
128 &[0.0, 1.0, 2.0, 3.0],
129 |value| -(value - 2.0) * (value - 2.0),
130 ObjectiveDirection::Maximize,
131 )
132 .unwrap();
133
134 assert_eq!(result.best_value, 2.0);
135 assert_eq!(result.best_score, 0.0);
136 }
137
138 #[test]
139 fn finds_best_two_dimensional_candidate() {
140 let result = grid_search_2d(
141 &[0.0, 1.0, 2.0],
142 &[-2.0, -1.0, 0.0],
143 |x, y| -((x - 1.0) * (x - 1.0) + (y + 1.0) * (y + 1.0)),
144 ObjectiveDirection::Maximize,
145 )
146 .unwrap();
147
148 assert_eq!(result.best_x, 1.0);
149 assert_eq!(result.best_y, -1.0);
150 assert_eq!(result.best_score, 0.0);
151 }
152
153 #[test]
154 fn supports_minimization() {
155 let result = grid_search_1d(
156 &[0.0, 1.0, 2.0, 3.0],
157 |value| (value - 1.0) * (value - 1.0),
158 ObjectiveDirection::Minimize,
159 )
160 .unwrap();
161
162 assert_eq!(result.best_value, 1.0);
163 }
164
165 #[test]
166 fn returns_none_for_invalid_inputs() {
167 assert_eq!(
168 grid_search_1d(&[], |value| value, ObjectiveDirection::Maximize),
169 None
170 );
171 assert_eq!(
172 grid_search_1d(
173 &[1.0, f64::NAN],
174 |value| value,
175 ObjectiveDirection::Maximize
176 ),
177 None
178 );
179 assert_eq!(
180 grid_search_2d(
181 &[1.0],
182 &[2.0],
183 |_x, _y| f64::NAN,
184 ObjectiveDirection::Maximize,
185 ),
186 None
187 );
188 }
189}