Skip to main content

use_grid_search/
lib.rs

1#![forbid(unsafe_code)]
2//! Primitive one-dimensional and two-dimensional grid search helpers.
3//!
4//! # Examples
5//!
6//! ```rust
7//! use use_grid_search::{grid_search_1d, grid_search_2d};
8//! use use_objective::ObjectiveDirection;
9//!
10//! let one_dimensional = grid_search_1d(
11//!     &[0.0, 1.0, 2.0, 3.0],
12//!     |value| -(value - 2.0) * (value - 2.0),
13//!     ObjectiveDirection::Maximize,
14//! )
15//! .unwrap();
16//! assert_eq!(one_dimensional.best_value, 2.0);
17//!
18//! let two_dimensional = grid_search_2d(
19//!     &[0.0, 1.0, 2.0],
20//!     &[-2.0, -1.0, 0.0],
21//!     |x, y| -((x - 1.0) * (x - 1.0) + (y + 1.0) * (y + 1.0)),
22//!     ObjectiveDirection::Maximize,
23//! )
24//! .unwrap();
25//! assert_eq!((two_dimensional.best_x, two_dimensional.best_y), (1.0, -1.0));
26//! ```
27
28use use_objective::{ObjectiveDirection, is_better};
29
30#[derive(Debug, Clone, Copy, PartialEq)]
31pub struct GridSearchResult1D {
32    pub best_value: f64,
33    pub best_score: f64,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq)]
37pub struct GridSearchResult2D {
38    pub best_x: f64,
39    pub best_y: f64,
40    pub best_score: f64,
41}
42
43pub fn grid_search_1d<F>(
44    values: &[f64],
45    objective: F,
46    direction: ObjectiveDirection,
47) -> Option<GridSearchResult1D>
48where
49    F: Fn(f64) -> f64,
50{
51    if values.is_empty() || values.iter().any(|value| !value.is_finite()) {
52        return None;
53    }
54
55    let mut best: Option<GridSearchResult1D> = None;
56    for value in values.iter().copied() {
57        let score = objective(value);
58        if !score.is_finite() {
59            return None;
60        }
61
62        let candidate = GridSearchResult1D {
63            best_value: value,
64            best_score: score,
65        };
66
67        if best.is_none_or(|current| is_better(candidate.best_score, current.best_score, direction))
68        {
69            best = Some(candidate);
70        }
71    }
72
73    best
74}
75
76pub fn grid_search_2d<F>(
77    x_values: &[f64],
78    y_values: &[f64],
79    objective: F,
80    direction: ObjectiveDirection,
81) -> Option<GridSearchResult2D>
82where
83    F: Fn(f64, f64) -> f64,
84{
85    if x_values.is_empty()
86        || y_values.is_empty()
87        || x_values
88            .iter()
89            .chain(y_values.iter())
90            .any(|value| !value.is_finite())
91    {
92        return None;
93    }
94
95    let mut best: Option<GridSearchResult2D> = None;
96    for x_value in x_values.iter().copied() {
97        for y_value in y_values.iter().copied() {
98            let score = objective(x_value, y_value);
99            if !score.is_finite() {
100                return None;
101            }
102
103            let candidate = GridSearchResult2D {
104                best_x: x_value,
105                best_y: y_value,
106                best_score: score,
107            };
108
109            if best.is_none_or(|current| {
110                is_better(candidate.best_score, current.best_score, direction)
111            }) {
112                best = Some(candidate);
113            }
114        }
115    }
116
117    best
118}
119
120#[cfg(test)]
121mod tests {
122    use super::{grid_search_1d, grid_search_2d};
123    use use_objective::ObjectiveDirection;
124
125    #[test]
126    fn finds_best_one_dimensional_candidate() {
127        let result = grid_search_1d(
128            &[0.0, 1.0, 2.0, 3.0],
129            |value| -(value - 2.0) * (value - 2.0),
130            ObjectiveDirection::Maximize,
131        )
132        .unwrap();
133
134        assert_eq!(result.best_value, 2.0);
135        assert_eq!(result.best_score, 0.0);
136    }
137
138    #[test]
139    fn finds_best_two_dimensional_candidate() {
140        let result = grid_search_2d(
141            &[0.0, 1.0, 2.0],
142            &[-2.0, -1.0, 0.0],
143            |x, y| -((x - 1.0) * (x - 1.0) + (y + 1.0) * (y + 1.0)),
144            ObjectiveDirection::Maximize,
145        )
146        .unwrap();
147
148        assert_eq!(result.best_x, 1.0);
149        assert_eq!(result.best_y, -1.0);
150        assert_eq!(result.best_score, 0.0);
151    }
152
153    #[test]
154    fn supports_minimization() {
155        let result = grid_search_1d(
156            &[0.0, 1.0, 2.0, 3.0],
157            |value| (value - 1.0) * (value - 1.0),
158            ObjectiveDirection::Minimize,
159        )
160        .unwrap();
161
162        assert_eq!(result.best_value, 1.0);
163    }
164
165    #[test]
166    fn returns_none_for_invalid_inputs() {
167        assert_eq!(
168            grid_search_1d(&[], |value| value, ObjectiveDirection::Maximize),
169            None
170        );
171        assert_eq!(
172            grid_search_1d(
173                &[1.0, f64::NAN],
174                |value| value,
175                ObjectiveDirection::Maximize
176            ),
177            None
178        );
179        assert_eq!(
180            grid_search_2d(
181                &[1.0],
182                &[2.0],
183                |_x, _y| f64::NAN,
184                ObjectiveDirection::Maximize,
185            ),
186            None
187        );
188    }
189}