solverforge_solver/phase/vnd/
mod.rs1use std::fmt::Debug;
21use std::marker::PhantomData;
22
23use solverforge_core::domain::PlanningSolution;
24use solverforge_scoring::{RecordingScoreDirector, ScoreDirector};
25
26use crate::heuristic::r#move::{Move, MoveArena};
27use crate::heuristic::selector::MoveSelector;
28use crate::phase::Phase;
29use crate::scope::{PhaseScope, SolverScope, StepScope};
30
31pub struct VndPhase<S, M>
66where
67 S: PlanningSolution,
68 M: Move<S>,
69{
70 neighborhoods: Vec<Box<dyn MoveSelector<S, M>>>,
72 arena: MoveArena<M>,
74 step_limit_per_neighborhood: Option<u64>,
76 _phantom: PhantomData<M>,
77}
78
79impl<S, M> Debug for VndPhase<S, M>
80where
81 S: PlanningSolution,
82 M: Move<S>,
83{
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 f.debug_struct("VndPhase")
86 .field("neighborhood_count", &self.neighborhoods.len())
87 .field(
88 "step_limit_per_neighborhood",
89 &self.step_limit_per_neighborhood,
90 )
91 .finish()
92 }
93}
94
95impl<S, M> VndPhase<S, M>
96where
97 S: PlanningSolution,
98 M: Move<S> + 'static,
99{
100 pub fn new(neighborhoods: Vec<Box<dyn MoveSelector<S, M>>>) -> Self {
105 Self {
106 neighborhoods,
107 arena: MoveArena::new(),
108 step_limit_per_neighborhood: None,
109 _phantom: PhantomData,
110 }
111 }
112
113 pub fn with_step_limit(mut self, limit: u64) -> Self {
115 self.step_limit_per_neighborhood = Some(limit);
116 self
117 }
118
119 pub fn neighborhood_count(&self) -> usize {
121 self.neighborhoods.len()
122 }
123}
124
125impl<S, M> Phase<S> for VndPhase<S, M>
126where
127 S: PlanningSolution,
128 M: Move<S>,
129{
130 fn solve(&mut self, solver_scope: &mut SolverScope<S>) {
131 if self.neighborhoods.is_empty() {
132 return;
133 }
134
135 let mut phase_scope = PhaseScope::new(solver_scope, 0);
136 let mut current_score = phase_scope.calculate_score();
137
138 let mut neighborhood_idx = 0;
139 let mut steps_in_neighborhood = 0u64;
140
141 while neighborhood_idx < self.neighborhoods.len() {
142 if let Some(limit) = self.step_limit_per_neighborhood {
144 if steps_in_neighborhood >= limit {
145 neighborhood_idx += 1;
147 steps_in_neighborhood = 0;
148 continue;
149 }
150 }
151
152 let mut step_scope = StepScope::new(&mut phase_scope);
153
154 self.arena.reset();
156 self.arena.extend(
157 self.neighborhoods[neighborhood_idx].iter_moves(step_scope.score_director()),
158 );
159
160 let mut best_move: Option<(M, S::Score)> = None;
162
163 for i in 0..self.arena.len() {
164 let m = self.arena.get(i).unwrap();
165
166 if !m.is_doable(step_scope.score_director()) {
167 continue;
168 }
169
170 let mut recording = RecordingScoreDirector::new(step_scope.score_director_mut());
171 m.do_move(&mut recording);
172 let move_score = recording.calculate_score();
173 recording.undo_changes();
174
175 if move_score > current_score {
177 match &best_move {
178 Some((_, best_score)) if move_score > *best_score => {
179 best_move = Some((m.clone(), move_score));
180 }
181 None => {
182 best_move = Some((m.clone(), move_score));
183 }
184 _ => {}
185 }
186 }
187 }
188
189 if let Some((selected_move, selected_score)) = best_move {
190 selected_move.do_move(step_scope.score_director_mut());
192 step_scope.set_step_score(selected_score.clone());
193 current_score = selected_score;
194
195 step_scope.phase_scope_mut().update_best_solution();
197
198 neighborhood_idx = 0;
200 steps_in_neighborhood = 0;
201 } else {
202 neighborhood_idx += 1;
204 steps_in_neighborhood = 0;
205 }
206
207 steps_in_neighborhood += 1;
208 step_scope.complete();
209 }
210 }
211
212 fn phase_type_name(&self) -> &'static str {
213 "VariableNeighborhoodDescent"
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use crate::heuristic::r#move::ChangeMove;
221 use crate::heuristic::selector::ChangeMoveSelector;
222 use solverforge_core::domain::{EntityDescriptor, SolutionDescriptor, TypedEntityExtractor};
223 use solverforge_core::score::SimpleScore;
224 use solverforge_scoring::SimpleScoreDirector;
225 use std::any::TypeId;
226
227 #[derive(Clone, Debug)]
228 struct Queen {
229 column: i32,
230 row: Option<i32>,
231 }
232
233 #[derive(Clone, Debug)]
234 struct NQueensSolution {
235 queens: Vec<Queen>,
236 score: Option<SimpleScore>,
237 }
238
239 impl PlanningSolution for NQueensSolution {
240 type Score = SimpleScore;
241
242 fn score(&self) -> Option<Self::Score> {
243 self.score
244 }
245
246 fn set_score(&mut self, score: Option<Self::Score>) {
247 self.score = score;
248 }
249 }
250
251 fn get_queens(s: &NQueensSolution) -> &Vec<Queen> {
252 &s.queens
253 }
254 fn get_queens_mut(s: &mut NQueensSolution) -> &mut Vec<Queen> {
255 &mut s.queens
256 }
257
258 fn get_queen_row(s: &NQueensSolution, idx: usize) -> Option<i32> {
259 s.queens.get(idx).and_then(|q| q.row)
260 }
261
262 fn set_queen_row(s: &mut NQueensSolution, idx: usize, v: Option<i32>) {
263 if let Some(queen) = s.queens.get_mut(idx) {
264 queen.row = v;
265 }
266 }
267
268 fn calculate_conflicts(solution: &NQueensSolution) -> SimpleScore {
269 let mut conflicts = 0i64;
270
271 for (i, q1) in solution.queens.iter().enumerate() {
272 if let Some(row1) = q1.row {
273 for q2 in solution.queens.iter().skip(i + 1) {
274 if let Some(row2) = q2.row {
275 if row1 == row2 {
276 conflicts += 1;
277 }
278 let col_diff = (q2.column - q1.column).abs();
279 let row_diff = (row2 - row1).abs();
280 if col_diff == row_diff {
281 conflicts += 1;
282 }
283 }
284 }
285 }
286 }
287
288 SimpleScore::of(-conflicts)
289 }
290
291 fn create_director(
292 rows: &[i32],
293 ) -> SimpleScoreDirector<NQueensSolution, impl Fn(&NQueensSolution) -> SimpleScore> {
294 let queens: Vec<_> = rows
295 .iter()
296 .enumerate()
297 .map(|(col, &row)| Queen {
298 column: col as i32,
299 row: Some(row),
300 })
301 .collect();
302
303 let solution = NQueensSolution {
304 queens,
305 score: None,
306 };
307
308 let extractor = Box::new(TypedEntityExtractor::new(
309 "Queen",
310 "queens",
311 get_queens,
312 get_queens_mut,
313 ));
314 let entity_desc = EntityDescriptor::new("Queen", TypeId::of::<Queen>(), "queens")
315 .with_extractor(extractor);
316
317 let descriptor =
318 SolutionDescriptor::new("NQueensSolution", TypeId::of::<NQueensSolution>())
319 .with_entity(entity_desc);
320
321 SimpleScoreDirector::with_calculator(solution, descriptor, calculate_conflicts)
322 }
323
324 type NQueensMove = ChangeMove<NQueensSolution, i32>;
325
326 fn create_move_selector(
327 values: Vec<i32>,
328 ) -> Box<dyn MoveSelector<NQueensSolution, NQueensMove>> {
329 Box::new(ChangeMoveSelector::<NQueensSolution, i32>::simple(
330 get_queen_row,
331 set_queen_row,
332 0,
333 "row",
334 values,
335 ))
336 }
337
338 #[test]
339 fn test_vnd_improves_solution() {
340 let director = create_director(&[0, 0, 0, 0]);
341 let mut solver_scope = SolverScope::new(Box::new(director));
342
343 let initial_score = solver_scope.calculate_score();
344 assert!(initial_score < SimpleScore::of(0));
345
346 let values: Vec<i32> = (0..4).collect();
347 let mut phase = VndPhase::<NQueensSolution, NQueensMove>::new(vec![
348 create_move_selector(values.clone()),
349 create_move_selector(values.clone()),
350 ]);
351
352 phase.solve(&mut solver_scope);
353
354 let final_score = solver_scope.best_score().cloned().unwrap_or(initial_score);
355 assert!(final_score >= initial_score);
356 }
357
358 #[test]
359 fn test_vnd_empty_neighborhoods() {
360 let director = create_director(&[0, 1, 2, 3]);
361 let mut solver_scope = SolverScope::new(Box::new(director));
362
363 let mut phase = VndPhase::<NQueensSolution, NQueensMove>::new(vec![]);
364 phase.solve(&mut solver_scope);
365
366 }
368
369 #[test]
370 fn test_vnd_single_neighborhood() {
371 let director = create_director(&[0, 0, 0, 0]);
372 let mut solver_scope = SolverScope::new(Box::new(director));
373
374 let initial_score = solver_scope.calculate_score();
375
376 let values: Vec<i32> = (0..4).collect();
377 let mut phase =
378 VndPhase::<NQueensSolution, NQueensMove>::new(vec![create_move_selector(values)]);
379
380 phase.solve(&mut solver_scope);
381
382 let final_score = solver_scope.best_score().cloned().unwrap_or(initial_score);
383 assert!(final_score >= initial_score);
384 }
385}