1use std::marker::PhantomData;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6
7use solverforge_config::SolverConfig;
8use solverforge_core::domain::PlanningSolution;
9use solverforge_core::SolverForgeError;
10use solverforge_scoring::ScoreDirector;
11
12use crate::phase::Phase;
13use crate::scope::SolverScope;
14use crate::termination::Termination;
15
16pub struct SolverFactory<S: PlanningSolution> {
18 config: SolverConfig,
19 _phantom: PhantomData<S>,
20}
21
22impl<S: PlanningSolution> SolverFactory<S> {
23 pub fn create(config: SolverConfig) -> Result<Self, SolverForgeError> {
25 Ok(SolverFactory {
26 config,
27 _phantom: PhantomData,
28 })
29 }
30
31 pub fn from_toml_file(path: impl AsRef<std::path::Path>) -> Result<Self, SolverForgeError> {
33 let config = SolverConfig::from_toml_file(path)
34 .map_err(|e| SolverForgeError::Config(e.to_string()))?;
35 Self::create(config)
36 }
37
38 pub fn from_yaml_file(path: impl AsRef<std::path::Path>) -> Result<Self, SolverForgeError> {
40 let config = SolverConfig::from_yaml_file(path)
41 .map_err(|e| SolverForgeError::Config(e.to_string()))?;
42 Self::create(config)
43 }
44
45 pub fn build_solver(&self) -> Solver<S> {
47 Solver::from_config(self.config.clone())
48 }
49
50 pub fn config(&self) -> &SolverConfig {
52 &self.config
53 }
54}
55
56pub struct Solver<S: PlanningSolution> {
61 phases: Vec<Box<dyn Phase<S>>>,
63 termination: Option<Box<dyn Termination<S>>>,
65 terminate_early_flag: Arc<AtomicBool>,
67 solving: Arc<AtomicBool>,
69 config: Option<SolverConfig>,
71}
72
73impl<S: PlanningSolution> Solver<S> {
74 pub fn new(phases: Vec<Box<dyn Phase<S>>>) -> Self {
76 Solver {
77 phases,
78 termination: None,
79 terminate_early_flag: Arc::new(AtomicBool::new(false)),
80 solving: Arc::new(AtomicBool::new(false)),
81 config: None,
82 }
83 }
84
85 pub fn from_config(config: SolverConfig) -> Self {
87 Solver {
88 phases: Vec::new(),
89 termination: None,
90 terminate_early_flag: Arc::new(AtomicBool::new(false)),
91 solving: Arc::new(AtomicBool::new(false)),
92 config: Some(config),
93 }
94 }
95
96 pub fn with_phase(mut self, phase: Box<dyn Phase<S>>) -> Self {
98 self.phases.push(phase);
99 self
100 }
101
102 pub fn with_phases(mut self, phases: Vec<Box<dyn Phase<S>>>) -> Self {
104 self.phases.extend(phases);
105 self
106 }
107
108 pub fn with_termination(mut self, termination: Box<dyn Termination<S>>) -> Self {
110 self.termination = Some(termination);
111 self
112 }
113
114 pub fn solve_with_director(&mut self, score_director: Box<dyn ScoreDirector<S>>) -> S {
118 self.solving.store(true, Ordering::SeqCst);
119 self.terminate_early_flag.store(false, Ordering::SeqCst);
120
121 let mut solver_scope = SolverScope::new(score_director);
122 solver_scope.set_terminate_early_flag(self.terminate_early_flag.clone());
123 solver_scope.start_solving();
124
125 let mut phase_index = 0;
132 while phase_index < self.phases.len() {
133 if self.check_termination(&solver_scope) {
135 tracing::debug!(
136 "Terminating before phase {} ({})",
137 phase_index,
138 self.phases[phase_index].phase_type_name()
139 );
140 break;
141 }
142
143 tracing::debug!(
144 "Starting phase {} ({})",
145 phase_index,
146 self.phases[phase_index].phase_type_name()
147 );
148
149 self.phases[phase_index].solve(&mut solver_scope);
150
151 tracing::debug!(
152 "Finished phase {} ({}) with score {:?}",
153 phase_index,
154 self.phases[phase_index].phase_type_name(),
155 solver_scope.best_score()
156 );
157
158 phase_index += 1;
159 }
160
161 self.solving.store(false, Ordering::SeqCst);
162
163 solver_scope.take_best_or_working_solution()
167 }
168
169 fn check_termination(&self, solver_scope: &SolverScope<S>) -> bool {
171 if self.terminate_early_flag.load(Ordering::SeqCst) {
173 return true;
174 }
175
176 if let Some(ref termination) = self.termination {
178 if termination.is_terminated(solver_scope) {
179 return true;
180 }
181 }
182
183 false
184 }
185
186 pub fn terminate_early(&self) -> bool {
190 if self.solving.load(Ordering::SeqCst) {
191 self.terminate_early_flag.store(true, Ordering::SeqCst);
192 true
193 } else {
194 false
195 }
196 }
197
198 pub fn is_solving(&self) -> bool {
200 self.solving.load(Ordering::SeqCst)
201 }
202
203 pub fn config(&self) -> Option<&SolverConfig> {
205 self.config.as_ref()
206 }
207}
208
209#[derive(Debug, Clone, Copy, PartialEq, Eq)]
211pub enum SolverStatus {
212 NotSolving,
214 SolvingScheduled,
216 SolvingActive,
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use crate::heuristic::r#move::ChangeMove;
224 use crate::heuristic::selector::ChangeMoveSelector;
225 use crate::manager::SolverPhaseFactory;
226 use crate::termination::StepCountTermination;
227 use solverforge_core::domain::{EntityDescriptor, SolutionDescriptor, TypedEntityExtractor};
228 use solverforge_core::score::SimpleScore;
229 use solverforge_scoring::SimpleScoreDirector;
230 use std::any::TypeId;
231
232 #[derive(Clone, Debug)]
233 struct Queen {
234 column: i32,
235 row: Option<i32>,
236 }
237
238 #[derive(Clone, Debug)]
239 struct NQueensSolution {
240 n: i32,
241 queens: Vec<Queen>,
242 score: Option<SimpleScore>,
243 }
244
245 type NQueensMove = ChangeMove<NQueensSolution, i32>;
246
247 impl PlanningSolution for NQueensSolution {
248 type Score = SimpleScore;
249
250 fn score(&self) -> Option<Self::Score> {
251 self.score
252 }
253
254 fn set_score(&mut self, score: Option<Self::Score>) {
255 self.score = score;
256 }
257 }
258
259 fn get_queens(s: &NQueensSolution) -> &Vec<Queen> {
260 &s.queens
261 }
262
263 fn get_queens_mut(s: &mut NQueensSolution) -> &mut Vec<Queen> {
264 &mut s.queens
265 }
266
267 fn get_queen_row(s: &NQueensSolution, idx: usize) -> Option<i32> {
269 s.queens.get(idx).and_then(|q| q.row)
270 }
271
272 fn set_queen_row(s: &mut NQueensSolution, idx: usize, v: Option<i32>) {
273 if let Some(queen) = s.queens.get_mut(idx) {
274 queen.row = v;
275 }
276 }
277
278 fn calculate_conflicts(solution: &NQueensSolution) -> SimpleScore {
279 let mut conflicts = 0i64;
280
281 for (i, q1) in solution.queens.iter().enumerate() {
282 if let Some(row1) = q1.row {
283 for q2 in solution.queens.iter().skip(i + 1) {
284 if let Some(row2) = q2.row {
285 if row1 == row2 {
287 conflicts += 1;
288 }
289 let col_diff = (q2.column - q1.column).abs();
291 let row_diff = (row2 - row1).abs();
292 if col_diff == row_diff {
293 conflicts += 1;
294 }
295 }
296 }
297 }
298 }
299
300 SimpleScore::of(-conflicts)
301 }
302
303 fn create_test_director(
304 solution: NQueensSolution,
305 ) -> SimpleScoreDirector<NQueensSolution, impl Fn(&NQueensSolution) -> SimpleScore> {
306 let extractor = Box::new(TypedEntityExtractor::new(
307 "Queen",
308 "queens",
309 get_queens,
310 get_queens_mut,
311 ));
312 let entity_desc = EntityDescriptor::new("Queen", TypeId::of::<Queen>(), "queens")
313 .with_extractor(extractor);
314
315 let descriptor =
316 SolutionDescriptor::new("NQueensSolution", TypeId::of::<NQueensSolution>())
317 .with_entity(entity_desc);
318
319 SimpleScoreDirector::with_calculator(solution, descriptor, calculate_conflicts)
320 }
321
322 #[test]
323 fn test_solver_new() {
324 let solver: Solver<NQueensSolution> = Solver::new(vec![]);
325 assert!(!solver.is_solving());
326 }
327
328 #[test]
329 fn test_solver_with_phases() {
330 use crate::manager::LocalSearchPhaseFactory;
331
332 let values: Vec<i32> = (0..4).collect();
333 let factory =
334 LocalSearchPhaseFactory::<NQueensSolution, NQueensMove, _>::hill_climbing(move || {
335 Box::new(ChangeMoveSelector::<NQueensSolution, i32>::simple(
336 get_queen_row,
337 set_queen_row,
338 0,
339 "row",
340 values.clone(),
341 ))
342 })
343 .with_step_limit(10);
344 let local_search = factory.create_phase();
345
346 let solver: Solver<NQueensSolution> = Solver::new(vec![]).with_phase(local_search);
347
348 assert!(!solver.is_solving());
349 }
350
351 #[test]
352 fn test_solver_local_search_only() {
353 use crate::manager::LocalSearchPhaseFactory;
354
355 let n = 4;
357 let queens = (0..n)
358 .map(|col| Queen {
359 column: col,
360 row: Some(0), })
362 .collect();
363
364 let solution = NQueensSolution {
365 n,
366 queens,
367 score: None,
368 };
369
370 let director = create_test_director(solution);
371
372 let initial_score = calculate_conflicts(director.working_solution());
374 assert!(initial_score < SimpleScore::of(0)); let values: Vec<i32> = (0..n).collect();
377 let factory =
378 LocalSearchPhaseFactory::<NQueensSolution, NQueensMove, _>::hill_climbing(move || {
379 Box::new(ChangeMoveSelector::<NQueensSolution, i32>::simple(
380 get_queen_row,
381 set_queen_row,
382 0,
383 "row",
384 values.clone(),
385 ))
386 })
387 .with_step_limit(50);
388 let local_search = factory.create_phase();
389
390 let mut solver = Solver::new(vec![local_search]);
391
392 let result = solver.solve_with_director(Box::new(director));
393
394 let final_score = calculate_conflicts(&result);
396 assert!(final_score >= initial_score);
397 }
398
399 #[test]
400 fn test_solver_with_termination() {
401 use crate::manager::LocalSearchPhaseFactory;
402
403 let n = 4;
404 let queens = (0..n)
405 .map(|col| Queen {
406 column: col,
407 row: Some(0),
408 })
409 .collect();
410
411 let solution = NQueensSolution {
412 n,
413 queens,
414 score: None,
415 };
416
417 let director = create_test_director(solution);
418
419 let values: Vec<i32> = (0..n).collect();
420 let factory =
421 LocalSearchPhaseFactory::<NQueensSolution, NQueensMove, _>::hill_climbing(move || {
422 Box::new(ChangeMoveSelector::<NQueensSolution, i32>::simple(
423 get_queen_row,
424 set_queen_row,
425 0,
426 "row",
427 values.clone(),
428 ))
429 })
430 .with_step_limit(1000);
431 let local_search = factory.create_phase();
432
433 let termination = StepCountTermination::new(5);
435
436 let mut solver = Solver::new(vec![local_search]).with_termination(Box::new(termination));
437
438 let result = solver.solve_with_director(Box::new(director));
439
440 assert!(result.score().is_some() || result.queens.iter().any(|q| q.row.is_some()));
442 }
443
444 #[test]
445 fn test_solver_status() {
446 let solver: Solver<NQueensSolution> = Solver::new(vec![]);
447
448 assert!(!solver.is_solving());
449 assert!(!solver.terminate_early()); }
451
452 #[test]
453 fn test_solver_multiple_phases() {
454 use crate::manager::LocalSearchPhaseFactory;
455
456 let n = 4;
457 let values: Vec<i32> = (0..n).collect();
458
459 let values1 = values.clone();
461 let factory1 =
462 LocalSearchPhaseFactory::<NQueensSolution, NQueensMove, _>::hill_climbing(move || {
463 Box::new(ChangeMoveSelector::<NQueensSolution, i32>::simple(
464 get_queen_row,
465 set_queen_row,
466 0,
467 "row",
468 values1.clone(),
469 ))
470 })
471 .with_step_limit(10);
472 let phase1 = factory1.create_phase();
473
474 let values2 = values.clone();
476 let factory2 =
477 LocalSearchPhaseFactory::<NQueensSolution, NQueensMove, _>::hill_climbing(move || {
478 Box::new(ChangeMoveSelector::<NQueensSolution, i32>::simple(
479 get_queen_row,
480 set_queen_row,
481 0,
482 "row",
483 values2.clone(),
484 ))
485 })
486 .with_step_limit(10);
487 let phase2 = factory2.create_phase();
488
489 let mut solver = Solver::new(vec![phase1, phase2]);
490
491 let n = 4;
493 let queens = (0..n)
494 .map(|col| Queen {
495 column: col,
496 row: Some(col % n), })
498 .collect();
499
500 let solution = NQueensSolution {
501 n,
502 queens,
503 score: None,
504 };
505 assert_eq!(solution.n, n);
506
507 let director = create_test_director(solution);
508 let result = solver.solve_with_director(Box::new(director));
509
510 assert!(result.queens.iter().all(|q| q.row.is_some()));
512 }
513}