solverforge_solver/phase/localsearch/
phase.rs1use std::fmt::Debug;
4use std::marker::PhantomData;
5
6use solverforge_core::domain::PlanningSolution;
7use solverforge_scoring::{RecordingScoreDirector, ScoreDirector};
8
9use crate::heuristic::r#move::{Move, MoveArena};
10use crate::heuristic::selector::MoveSelector;
11use crate::phase::localsearch::{Acceptor, LocalSearchForager};
12use crate::phase::Phase;
13use crate::scope::{PhaseScope, SolverScope, StepScope};
14
15pub struct LocalSearchPhase<S, M, MS, A, Fo>
36where
37 S: PlanningSolution,
38 M: Move<S>,
39 MS: MoveSelector<S, M>,
40 A: Acceptor<S>,
41 Fo: LocalSearchForager<S, M>,
42{
43 move_selector: MS,
44 acceptor: A,
45 forager: Fo,
46 arena: MoveArena<M>,
47 step_limit: Option<u64>,
48 _phantom: PhantomData<fn() -> (S, M)>,
49}
50
51impl<S, M, MS, A, Fo> LocalSearchPhase<S, M, MS, A, Fo>
52where
53 S: PlanningSolution,
54 M: Move<S> + 'static,
55 MS: MoveSelector<S, M>,
56 A: Acceptor<S>,
57 Fo: LocalSearchForager<S, M>,
58{
59 pub fn new(move_selector: MS, acceptor: A, forager: Fo, step_limit: Option<u64>) -> Self {
61 Self {
62 move_selector,
63 acceptor,
64 forager,
65 arena: MoveArena::new(),
66 step_limit,
67 _phantom: PhantomData,
68 }
69 }
70}
71
72impl<S, M, MS, A, Fo> Debug for LocalSearchPhase<S, M, MS, A, Fo>
73where
74 S: PlanningSolution,
75 M: Move<S>,
76 MS: MoveSelector<S, M> + Debug,
77 A: Acceptor<S> + Debug,
78 Fo: LocalSearchForager<S, M> + Debug,
79{
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct("LocalSearchPhase")
82 .field("move_selector", &self.move_selector)
83 .field("acceptor", &self.acceptor)
84 .field("forager", &self.forager)
85 .field("arena", &self.arena)
86 .field("step_limit", &self.step_limit)
87 .finish()
88 }
89}
90
91impl<S, D, M, MS, A, Fo> Phase<S, D> for LocalSearchPhase<S, M, MS, A, Fo>
92where
93 S: PlanningSolution,
94 D: ScoreDirector<S>,
95 M: Move<S>,
96 MS: MoveSelector<S, M>,
97 A: Acceptor<S>,
98 Fo: LocalSearchForager<S, M>,
99{
100 fn solve(&mut self, solver_scope: &mut SolverScope<S, D>) {
101 let mut phase_scope = PhaseScope::new(solver_scope, 0);
102
103 let mut last_step_score = phase_scope.calculate_score();
105
106 self.acceptor.phase_started(&last_step_score);
108
109 loop {
110 if phase_scope.solver_scope().is_terminate_early() {
112 break;
113 }
114
115 if let Some(limit) = self.step_limit {
117 if phase_scope.step_count() >= limit {
118 break;
119 }
120 }
121
122 let mut step_scope = StepScope::new(&mut phase_scope);
123
124 self.forager.step_started();
126
127 self.arena.reset();
129 self.arena
130 .extend(self.move_selector.iter_moves(step_scope.score_director()));
131
132 for i in 0..self.arena.len() {
134 let m = self.arena.get(i).unwrap();
135
136 if !m.is_doable(step_scope.score_director()) {
137 continue;
138 }
139
140 let move_score = {
142 let mut recording =
143 RecordingScoreDirector::new(step_scope.score_director_mut());
144
145 m.do_move(&mut recording);
147
148 let score = recording.calculate_score();
150
151 recording.undo_changes();
153
154 score
155 };
156
157 let accepted = self.acceptor.is_accepted(&last_step_score, &move_score);
159
160 if accepted {
162 self.forager.add_move_index(i, move_score);
163 }
164
165 if self.forager.is_quit_early() {
167 break;
168 }
169 }
170
171 if let Some((selected_index, selected_score)) = self.forager.pick_move_index() {
173 let selected_move = self.arena.take(selected_index);
175
176 selected_move.do_move(step_scope.score_director_mut());
178 step_scope.set_step_score(selected_score);
179
180 last_step_score = selected_score;
182
183 step_scope.phase_scope_mut().update_best_solution();
185 } else {
186 break;
188 }
189
190 step_scope.complete();
191 }
192
193 self.acceptor.phase_ended();
195 }
196
197 fn phase_type_name(&self) -> &'static str {
198 "LocalSearch"
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::heuristic::selector::ChangeMoveSelector;
206 use crate::phase::localsearch::{AcceptedCountForager, HillClimbingAcceptor};
207 use solverforge_core::domain::{EntityDescriptor, SolutionDescriptor, TypedEntityExtractor};
208 use solverforge_core::score::SimpleScore;
209 use solverforge_scoring::SimpleScoreDirector;
210 use std::any::TypeId;
211
212 #[derive(Clone, Debug)]
213 struct Queen {
214 column: i32,
215 row: Option<i32>,
216 }
217
218 #[derive(Clone, Debug)]
219 struct NQueensSolution {
220 queens: Vec<Queen>,
221 score: Option<SimpleScore>,
222 }
223
224 impl PlanningSolution for NQueensSolution {
225 type Score = SimpleScore;
226
227 fn score(&self) -> Option<Self::Score> {
228 self.score
229 }
230
231 fn set_score(&mut self, score: Option<Self::Score>) {
232 self.score = score;
233 }
234 }
235
236 fn get_queens(s: &NQueensSolution) -> &Vec<Queen> {
237 &s.queens
238 }
239
240 fn get_queens_mut(s: &mut NQueensSolution) -> &mut Vec<Queen> {
241 &mut s.queens
242 }
243
244 fn get_queen_row(s: &NQueensSolution, idx: usize) -> Option<i32> {
245 s.queens.get(idx).and_then(|q| q.row)
246 }
247
248 fn set_queen_row(s: &mut NQueensSolution, idx: usize, v: Option<i32>) {
249 if let Some(queen) = s.queens.get_mut(idx) {
250 queen.row = v;
251 }
252 }
253
254 fn calculate_conflicts(solution: &NQueensSolution) -> SimpleScore {
255 let mut conflicts = 0i64;
256
257 for (i, q1) in solution.queens.iter().enumerate() {
258 if let Some(row1) = q1.row {
259 for q2 in solution.queens.iter().skip(i + 1) {
260 if let Some(row2) = q2.row {
261 if row1 == row2 {
262 conflicts += 1;
263 }
264 let col_diff = (q2.column - q1.column).abs();
265 let row_diff = (row2 - row1).abs();
266 if col_diff == row_diff {
267 conflicts += 1;
268 }
269 }
270 }
271 }
272 }
273
274 SimpleScore::of(-conflicts)
275 }
276
277 fn create_test_director(
278 rows: &[i32],
279 ) -> SimpleScoreDirector<NQueensSolution, impl Fn(&NQueensSolution) -> SimpleScore> {
280 let queens: Vec<_> = rows
281 .iter()
282 .enumerate()
283 .map(|(col, &row)| Queen {
284 column: col as i32,
285 row: Some(row),
286 })
287 .collect();
288
289 let solution = NQueensSolution {
290 queens,
291 score: None,
292 };
293
294 let extractor = Box::new(TypedEntityExtractor::new(
295 "Queen",
296 "queens",
297 get_queens,
298 get_queens_mut,
299 ));
300 let entity_desc = EntityDescriptor::new("Queen", TypeId::of::<Queen>(), "queens")
301 .with_extractor(extractor);
302
303 let descriptor =
304 SolutionDescriptor::new("NQueensSolution", TypeId::of::<NQueensSolution>())
305 .with_entity(entity_desc);
306
307 SimpleScoreDirector::with_calculator(solution, descriptor, calculate_conflicts)
308 }
309
310 type NQueensMove = crate::heuristic::r#move::ChangeMove<NQueensSolution, i32>;
311
312 fn create_move_selector(
313 values: Vec<i32>,
314 ) -> ChangeMoveSelector<
315 NQueensSolution,
316 i32,
317 crate::heuristic::selector::FromSolutionEntitySelector,
318 crate::heuristic::selector::StaticTypedValueSelector<NQueensSolution, i32>,
319 > {
320 ChangeMoveSelector::simple(get_queen_row, set_queen_row, 0, "row", values)
321 }
322
323 #[test]
324 fn test_local_search_hill_climbing() {
325 let director = create_test_director(&[0, 0, 0, 0]);
326 let mut solver_scope = SolverScope::new(director);
327
328 let initial_score = solver_scope.calculate_score();
329 assert!(initial_score < SimpleScore::of(0));
330
331 let values: Vec<i32> = (0..4).collect();
332 let move_selector = create_move_selector(values);
333 let acceptor = HillClimbingAcceptor::new();
334 let forager: AcceptedCountForager<_> = AcceptedCountForager::new(1);
335 let mut phase: LocalSearchPhase<_, NQueensMove, _, _, _> =
336 LocalSearchPhase::new(move_selector, acceptor, forager, Some(100));
337
338 phase.solve(&mut solver_scope);
339
340 let final_score = solver_scope.best_score().copied().unwrap_or(initial_score);
341 assert!(final_score >= initial_score);
342 }
343
344 #[test]
345 fn test_local_search_reaches_optimal() {
346 let director = create_test_director(&[0, 2, 1, 3]);
347 let mut solver_scope = SolverScope::new(director);
348
349 let initial_score = solver_scope.calculate_score();
350
351 let values: Vec<i32> = (0..4).collect();
352 let move_selector = create_move_selector(values);
353 let acceptor = HillClimbingAcceptor::new();
354 let forager: AcceptedCountForager<_> = AcceptedCountForager::new(1);
355 let mut phase: LocalSearchPhase<_, NQueensMove, _, _, _> =
356 LocalSearchPhase::new(move_selector, acceptor, forager, Some(50));
357
358 phase.solve(&mut solver_scope);
359
360 let final_score = solver_scope.best_score().copied().unwrap_or(initial_score);
361 assert!(final_score >= initial_score);
362 }
363
364 #[test]
365 fn test_local_search_step_limit() {
366 let director = create_test_director(&[0, 0, 0, 0]);
367 let mut solver_scope = SolverScope::new(director);
368
369 let values: Vec<i32> = (0..4).collect();
370 let move_selector = create_move_selector(values);
371 let acceptor = HillClimbingAcceptor::new();
372 let forager: AcceptedCountForager<_> = AcceptedCountForager::new(1);
373 let mut phase: LocalSearchPhase<_, NQueensMove, _, _, _> =
374 LocalSearchPhase::new(move_selector, acceptor, forager, Some(3));
375
376 phase.solve(&mut solver_scope);
377 }
378}