solverforge_solver/
solver.rs

1//! Solver and SolverFactory implementations
2
3use std::marker::PhantomData;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6
7use solverforge_config::SolverConfig;
8use solverforge_core::domain::PlanningSolution;
9use solverforge_core::SolverForgeError;
10use solverforge_scoring::ScoreDirector;
11
12use crate::phase::Phase;
13use crate::scope::SolverScope;
14use crate::termination::Termination;
15
16/// Factory for creating Solver instances.
17pub struct SolverFactory<S: PlanningSolution> {
18    config: SolverConfig,
19    _phantom: PhantomData<S>,
20}
21
22impl<S: PlanningSolution> SolverFactory<S> {
23    /// Creates a new SolverFactory from configuration.
24    pub fn create(config: SolverConfig) -> Result<Self, SolverForgeError> {
25        Ok(SolverFactory {
26            config,
27            _phantom: PhantomData,
28        })
29    }
30
31    /// Creates a SolverFactory from a TOML configuration file.
32    pub fn from_toml_file(path: impl AsRef<std::path::Path>) -> Result<Self, SolverForgeError> {
33        let config = SolverConfig::from_toml_file(path)
34            .map_err(|e| SolverForgeError::Config(e.to_string()))?;
35        Self::create(config)
36    }
37
38    /// Creates a SolverFactory from a YAML configuration file.
39    pub fn from_yaml_file(path: impl AsRef<std::path::Path>) -> Result<Self, SolverForgeError> {
40        let config = SolverConfig::from_yaml_file(path)
41            .map_err(|e| SolverForgeError::Config(e.to_string()))?;
42        Self::create(config)
43    }
44
45    /// Builds a new Solver instance.
46    pub fn build_solver(&self) -> Solver<S> {
47        Solver::from_config(self.config.clone())
48    }
49
50    /// Returns a reference to the configuration.
51    pub fn config(&self) -> &SolverConfig {
52        &self.config
53    }
54}
55
56/// The main solver that optimizes planning solutions.
57///
58/// The solver executes phases in sequence, checking termination conditions
59/// between phases and potentially within phases.
60pub struct Solver<S: PlanningSolution> {
61    /// Phases to execute in order.
62    phases: Vec<Box<dyn Phase<S>>>,
63    /// Global termination condition.
64    termination: Option<Box<dyn Termination<S>>>,
65    /// Flag for early termination requests.
66    terminate_early_flag: Arc<AtomicBool>,
67    /// Whether solver is currently running.
68    solving: Arc<AtomicBool>,
69    /// Optional configuration.
70    config: Option<SolverConfig>,
71}
72
73impl<S: PlanningSolution> Solver<S> {
74    /// Creates a new solver with the given phases.
75    pub fn new(phases: Vec<Box<dyn Phase<S>>>) -> Self {
76        Solver {
77            phases,
78            termination: None,
79            terminate_early_flag: Arc::new(AtomicBool::new(false)),
80            solving: Arc::new(AtomicBool::new(false)),
81            config: None,
82        }
83    }
84
85    /// Creates a solver from configuration (phases must be added separately).
86    pub fn from_config(config: SolverConfig) -> Self {
87        Solver {
88            phases: Vec::new(),
89            termination: None,
90            terminate_early_flag: Arc::new(AtomicBool::new(false)),
91            solving: Arc::new(AtomicBool::new(false)),
92            config: Some(config),
93        }
94    }
95
96    /// Adds a phase to the solver.
97    pub fn with_phase(mut self, phase: Box<dyn Phase<S>>) -> Self {
98        self.phases.push(phase);
99        self
100    }
101
102    /// Adds phases to the solver.
103    pub fn with_phases(mut self, phases: Vec<Box<dyn Phase<S>>>) -> Self {
104        self.phases.extend(phases);
105        self
106    }
107
108    /// Sets the termination condition.
109    pub fn with_termination(mut self, termination: Box<dyn Termination<S>>) -> Self {
110        self.termination = Some(termination);
111        self
112    }
113
114    /// Solves using the provided score director.
115    ///
116    /// This is the main solving method that executes all phases.
117    pub fn solve_with_director(&mut self, score_director: Box<dyn ScoreDirector<S>>) -> S {
118        self.solving.store(true, Ordering::SeqCst);
119        self.terminate_early_flag.store(false, Ordering::SeqCst);
120
121        let mut solver_scope = SolverScope::new(score_director);
122        solver_scope.set_terminate_early_flag(self.terminate_early_flag.clone());
123        solver_scope.start_solving();
124
125        // Note: We don't set the initial solution as "best" here because
126        // construction heuristic will create a fully assigned solution
127        // which may have a worse score than the unassigned initial state.
128        // Phases are responsible for updating best solution appropriately.
129
130        // Execute phases
131        let mut phase_index = 0;
132        while phase_index < self.phases.len() {
133            // Check termination before phase
134            if self.check_termination(&solver_scope) {
135                tracing::debug!(
136                    "Terminating before phase {} ({})",
137                    phase_index,
138                    self.phases[phase_index].phase_type_name()
139                );
140                break;
141            }
142
143            tracing::debug!(
144                "Starting phase {} ({})",
145                phase_index,
146                self.phases[phase_index].phase_type_name()
147            );
148
149            self.phases[phase_index].solve(&mut solver_scope);
150
151            tracing::debug!(
152                "Finished phase {} ({}) with score {:?}",
153                phase_index,
154                self.phases[phase_index].phase_type_name(),
155                solver_scope.best_score()
156            );
157
158            phase_index += 1;
159        }
160
161        self.solving.store(false, Ordering::SeqCst);
162
163        // Return the best solution if set, otherwise the working solution
164        // (This handles the case where construction heuristic creates an assigned
165        // solution but local search didn't find any improvements)
166        solver_scope.take_best_or_working_solution()
167    }
168
169    /// Checks if solving should terminate.
170    fn check_termination(&self, solver_scope: &SolverScope<S>) -> bool {
171        // Check early termination request
172        if self.terminate_early_flag.load(Ordering::SeqCst) {
173            return true;
174        }
175
176        // Check termination condition
177        if let Some(ref termination) = self.termination {
178            if termination.is_terminated(solver_scope) {
179                return true;
180            }
181        }
182
183        false
184    }
185
186    /// Requests early termination of the solving process.
187    ///
188    /// This method is thread-safe and can be called from another thread.
189    pub fn terminate_early(&self) -> bool {
190        if self.solving.load(Ordering::SeqCst) {
191            self.terminate_early_flag.store(true, Ordering::SeqCst);
192            true
193        } else {
194            false
195        }
196    }
197
198    /// Returns true if the solver is currently solving.
199    pub fn is_solving(&self) -> bool {
200        self.solving.load(Ordering::SeqCst)
201    }
202
203    /// Returns the configuration if set.
204    pub fn config(&self) -> Option<&SolverConfig> {
205        self.config.as_ref()
206    }
207}
208
209/// Solver status enumeration.
210#[derive(Debug, Clone, Copy, PartialEq, Eq)]
211pub enum SolverStatus {
212    /// Solver is not currently solving.
213    NotSolving,
214    /// Solver is initializing.
215    SolvingScheduled,
216    /// Solver is actively solving.
217    SolvingActive,
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use crate::heuristic::r#move::ChangeMove;
224    use crate::heuristic::selector::ChangeMoveSelector;
225    use crate::manager::SolverPhaseFactory;
226    use crate::termination::StepCountTermination;
227    use solverforge_core::domain::{EntityDescriptor, SolutionDescriptor, TypedEntityExtractor};
228    use solverforge_core::score::SimpleScore;
229    use solverforge_scoring::SimpleScoreDirector;
230    use std::any::TypeId;
231
232    #[derive(Clone, Debug)]
233    struct Queen {
234        column: i32,
235        row: Option<i32>,
236    }
237
238    #[derive(Clone, Debug)]
239    struct NQueensSolution {
240        n: i32,
241        queens: Vec<Queen>,
242        score: Option<SimpleScore>,
243    }
244
245    type NQueensMove = ChangeMove<NQueensSolution, i32>;
246
247    impl PlanningSolution for NQueensSolution {
248        type Score = SimpleScore;
249
250        fn score(&self) -> Option<Self::Score> {
251            self.score
252        }
253
254        fn set_score(&mut self, score: Option<Self::Score>) {
255            self.score = score;
256        }
257    }
258
259    fn get_queens(s: &NQueensSolution) -> &Vec<Queen> {
260        &s.queens
261    }
262
263    fn get_queens_mut(s: &mut NQueensSolution) -> &mut Vec<Queen> {
264        &mut s.queens
265    }
266
267    // Zero-erasure typed getter/setter for solution-level access
268    fn get_queen_row(s: &NQueensSolution, idx: usize) -> Option<i32> {
269        s.queens.get(idx).and_then(|q| q.row)
270    }
271
272    fn set_queen_row(s: &mut NQueensSolution, idx: usize, v: Option<i32>) {
273        if let Some(queen) = s.queens.get_mut(idx) {
274            queen.row = v;
275        }
276    }
277
278    fn calculate_conflicts(solution: &NQueensSolution) -> SimpleScore {
279        let mut conflicts = 0i64;
280
281        for (i, q1) in solution.queens.iter().enumerate() {
282            if let Some(row1) = q1.row {
283                for q2 in solution.queens.iter().skip(i + 1) {
284                    if let Some(row2) = q2.row {
285                        // Same row conflict
286                        if row1 == row2 {
287                            conflicts += 1;
288                        }
289                        // Diagonal conflict
290                        let col_diff = (q2.column - q1.column).abs();
291                        let row_diff = (row2 - row1).abs();
292                        if col_diff == row_diff {
293                            conflicts += 1;
294                        }
295                    }
296                }
297            }
298        }
299
300        SimpleScore::of(-conflicts)
301    }
302
303    fn create_test_director(
304        solution: NQueensSolution,
305    ) -> SimpleScoreDirector<NQueensSolution, impl Fn(&NQueensSolution) -> SimpleScore> {
306        let extractor = Box::new(TypedEntityExtractor::new(
307            "Queen",
308            "queens",
309            get_queens,
310            get_queens_mut,
311        ));
312        let entity_desc = EntityDescriptor::new("Queen", TypeId::of::<Queen>(), "queens")
313            .with_extractor(extractor);
314
315        let descriptor =
316            SolutionDescriptor::new("NQueensSolution", TypeId::of::<NQueensSolution>())
317                .with_entity(entity_desc);
318
319        SimpleScoreDirector::with_calculator(solution, descriptor, calculate_conflicts)
320    }
321
322    #[test]
323    fn test_solver_new() {
324        let solver: Solver<NQueensSolution> = Solver::new(vec![]);
325        assert!(!solver.is_solving());
326    }
327
328    #[test]
329    fn test_solver_with_phases() {
330        use crate::manager::LocalSearchPhaseFactory;
331
332        let values: Vec<i32> = (0..4).collect();
333        let factory =
334            LocalSearchPhaseFactory::<NQueensSolution, NQueensMove, _>::hill_climbing(move || {
335                Box::new(ChangeMoveSelector::<NQueensSolution, i32>::simple(
336                    get_queen_row,
337                    set_queen_row,
338                    0,
339                    "row",
340                    values.clone(),
341                ))
342            })
343            .with_step_limit(10);
344        let local_search = factory.create_phase();
345
346        let solver: Solver<NQueensSolution> = Solver::new(vec![]).with_phase(local_search);
347
348        assert!(!solver.is_solving());
349    }
350
351    #[test]
352    fn test_solver_local_search_only() {
353        use crate::manager::LocalSearchPhaseFactory;
354
355        // Start with an already-assigned solution
356        let n = 4;
357        let queens = (0..n)
358            .map(|col| Queen {
359                column: col,
360                row: Some(0), // All in row 0 - lots of conflicts
361            })
362            .collect();
363
364        let solution = NQueensSolution {
365            n,
366            queens,
367            score: None,
368        };
369
370        let director = create_test_director(solution);
371
372        // Calculate initial conflicts
373        let initial_score = calculate_conflicts(director.working_solution());
374        assert!(initial_score < SimpleScore::of(0)); // Should have conflicts
375
376        let values: Vec<i32> = (0..n).collect();
377        let factory =
378            LocalSearchPhaseFactory::<NQueensSolution, NQueensMove, _>::hill_climbing(move || {
379                Box::new(ChangeMoveSelector::<NQueensSolution, i32>::simple(
380                    get_queen_row,
381                    set_queen_row,
382                    0,
383                    "row",
384                    values.clone(),
385                ))
386            })
387            .with_step_limit(50);
388        let local_search = factory.create_phase();
389
390        let mut solver = Solver::new(vec![local_search]);
391
392        let result = solver.solve_with_director(Box::new(director));
393
394        // Should have improved or stayed the same
395        let final_score = calculate_conflicts(&result);
396        assert!(final_score >= initial_score);
397    }
398
399    #[test]
400    fn test_solver_with_termination() {
401        use crate::manager::LocalSearchPhaseFactory;
402
403        let n = 4;
404        let queens = (0..n)
405            .map(|col| Queen {
406                column: col,
407                row: Some(0),
408            })
409            .collect();
410
411        let solution = NQueensSolution {
412            n,
413            queens,
414            score: None,
415        };
416
417        let director = create_test_director(solution);
418
419        let values: Vec<i32> = (0..n).collect();
420        let factory =
421            LocalSearchPhaseFactory::<NQueensSolution, NQueensMove, _>::hill_climbing(move || {
422                Box::new(ChangeMoveSelector::<NQueensSolution, i32>::simple(
423                    get_queen_row,
424                    set_queen_row,
425                    0,
426                    "row",
427                    values.clone(),
428                ))
429            })
430            .with_step_limit(1000);
431        let local_search = factory.create_phase();
432
433        // Terminate after just 5 steps
434        let termination = StepCountTermination::new(5);
435
436        let mut solver = Solver::new(vec![local_search]).with_termination(Box::new(termination));
437
438        let result = solver.solve_with_director(Box::new(director));
439
440        // Should complete without error
441        assert!(result.score().is_some() || result.queens.iter().any(|q| q.row.is_some()));
442    }
443
444    #[test]
445    fn test_solver_status() {
446        let solver: Solver<NQueensSolution> = Solver::new(vec![]);
447
448        assert!(!solver.is_solving());
449        assert!(!solver.terminate_early()); // Can't terminate when not solving
450    }
451
452    #[test]
453    fn test_solver_multiple_phases() {
454        use crate::manager::LocalSearchPhaseFactory;
455
456        let n = 4;
457        let values: Vec<i32> = (0..n).collect();
458
459        // First local search phase with limited steps
460        let values1 = values.clone();
461        let factory1 =
462            LocalSearchPhaseFactory::<NQueensSolution, NQueensMove, _>::hill_climbing(move || {
463                Box::new(ChangeMoveSelector::<NQueensSolution, i32>::simple(
464                    get_queen_row,
465                    set_queen_row,
466                    0,
467                    "row",
468                    values1.clone(),
469                ))
470            })
471            .with_step_limit(10);
472        let phase1 = factory1.create_phase();
473
474        // Second local search phase
475        let values2 = values.clone();
476        let factory2 =
477            LocalSearchPhaseFactory::<NQueensSolution, NQueensMove, _>::hill_climbing(move || {
478                Box::new(ChangeMoveSelector::<NQueensSolution, i32>::simple(
479                    get_queen_row,
480                    set_queen_row,
481                    0,
482                    "row",
483                    values2.clone(),
484                ))
485            })
486            .with_step_limit(10);
487        let phase2 = factory2.create_phase();
488
489        let mut solver = Solver::new(vec![phase1, phase2]);
490
491        // Need to start with assigned values for local search
492        let n = 4;
493        let queens = (0..n)
494            .map(|col| Queen {
495                column: col,
496                row: Some(col % n), // Assign different rows
497            })
498            .collect();
499
500        let solution = NQueensSolution {
501            n,
502            queens,
503            score: None,
504        };
505        assert_eq!(solution.n, n);
506
507        let director = create_test_director(solution);
508        let result = solver.solve_with_director(Box::new(director));
509
510        // Should complete both phases
511        assert!(result.queens.iter().all(|q| q.row.is_some()));
512    }
513}