Skip to main content

solverforge_solver/
run.rs

1/* Unified solver entry point. */
2
3use std::fmt;
4use std::marker::PhantomData;
5use std::time::Duration;
6
7use solverforge_config::SolverConfig;
8use solverforge_core::domain::{PlanningSolution, SolutionDescriptor};
9use solverforge_core::score::{ParseableScore, Score};
10use solverforge_scoring::{ConstraintSet, Director, ScoreDirector};
11use tracing::info;
12
13use crate::manager::{SolverRuntime, SolverTerminalReason};
14use crate::phase::{Phase, PhaseSequence};
15use crate::scope::{ProgressCallback, SolverProgressKind, SolverProgressRef, SolverScope};
16use crate::solver::Solver;
17use crate::termination::{
18    BestScoreTermination, OrTermination, StepCountTermination, Termination, TimeTermination,
19    UnimprovedStepCountTermination, UnimprovedTimeTermination,
20};
21
22/// Monomorphized termination enum for config-driven solver configurations.
23///
24/// Avoids repeated branching across termination overloads by capturing the
25/// selected termination variant upfront.
26pub enum AnyTermination<S: PlanningSolution, D: Director<S>> {
27    Default(OrTermination<(TimeTermination,), S, D>),
28    WithBestScore(OrTermination<(TimeTermination, BestScoreTermination<S::Score>), S, D>),
29    WithStepCount(OrTermination<(TimeTermination, StepCountTermination), S, D>),
30    WithUnimprovedStep(OrTermination<(TimeTermination, UnimprovedStepCountTermination<S>), S, D>),
31    WithUnimprovedTime(OrTermination<(TimeTermination, UnimprovedTimeTermination<S>), S, D>),
32}
33
34#[derive(Clone)]
35pub struct ChannelProgressCallback<S: PlanningSolution> {
36    runtime: SolverRuntime<S>,
37    _phantom: PhantomData<fn() -> S>,
38}
39
40impl<S: PlanningSolution> ChannelProgressCallback<S> {
41    fn new(runtime: SolverRuntime<S>) -> Self {
42        Self {
43            runtime,
44            _phantom: PhantomData,
45        }
46    }
47}
48
49impl<S: PlanningSolution> fmt::Debug for ChannelProgressCallback<S> {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        f.debug_struct("ChannelProgressCallback").finish()
52    }
53}
54
55impl<S: PlanningSolution> ProgressCallback<S> for ChannelProgressCallback<S> {
56    fn invoke(&self, progress: SolverProgressRef<'_, S>) {
57        match progress.kind {
58            SolverProgressKind::Progress => {
59                self.runtime.emit_progress(
60                    progress.current_score.copied(),
61                    progress.best_score.copied(),
62                    progress.telemetry,
63                );
64            }
65            SolverProgressKind::BestSolution => {
66                if let (Some(solution), Some(score)) = (progress.solution, progress.best_score) {
67                    self.runtime.emit_best_solution(
68                        (*solution).clone(),
69                        progress.current_score.copied(),
70                        *score,
71                        progress.telemetry,
72                    );
73                }
74            }
75        }
76    }
77}
78
79impl<S: PlanningSolution, D: Director<S>> fmt::Debug for AnyTermination<S, D> {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        match self {
82            Self::Default(_) => write!(f, "AnyTermination::Default"),
83            Self::WithBestScore(_) => write!(f, "AnyTermination::WithBestScore"),
84            Self::WithStepCount(_) => write!(f, "AnyTermination::WithStepCount"),
85            Self::WithUnimprovedStep(_) => write!(f, "AnyTermination::WithUnimprovedStep"),
86            Self::WithUnimprovedTime(_) => write!(f, "AnyTermination::WithUnimprovedTime"),
87        }
88    }
89}
90
91impl<S: PlanningSolution, D: Director<S>, ProgressCb: ProgressCallback<S>>
92    Termination<S, D, ProgressCb> for AnyTermination<S, D>
93where
94    S::Score: Score,
95{
96    fn is_terminated(&self, solver_scope: &SolverScope<S, D, ProgressCb>) -> bool {
97        match self {
98            Self::Default(t) => t.is_terminated(solver_scope),
99            Self::WithBestScore(t) => t.is_terminated(solver_scope),
100            Self::WithStepCount(t) => t.is_terminated(solver_scope),
101            Self::WithUnimprovedStep(t) => t.is_terminated(solver_scope),
102            Self::WithUnimprovedTime(t) => t.is_terminated(solver_scope),
103        }
104    }
105
106    fn install_inphase_limits(&self, solver_scope: &mut SolverScope<S, D, ProgressCb>) {
107        match self {
108            Self::Default(t) => t.install_inphase_limits(solver_scope),
109            Self::WithBestScore(t) => t.install_inphase_limits(solver_scope),
110            Self::WithStepCount(t) => t.install_inphase_limits(solver_scope),
111            Self::WithUnimprovedStep(t) => t.install_inphase_limits(solver_scope),
112            Self::WithUnimprovedTime(t) => t.install_inphase_limits(solver_scope),
113        }
114    }
115}
116
117/// Builds a termination from config, returning both the termination and the time limit.
118pub fn build_termination<S, C>(
119    config: &SolverConfig,
120    default_secs: u64,
121) -> (AnyTermination<S, ScoreDirector<S, C>>, Duration)
122where
123    S: PlanningSolution,
124    S::Score: Score + ParseableScore,
125    C: ConstraintSet<S, S::Score>,
126{
127    let term_config = config.termination.as_ref();
128    let time_limit = term_config
129        .and_then(|c| c.time_limit())
130        .unwrap_or(Duration::from_secs(default_secs));
131    let time = TimeTermination::new(time_limit);
132
133    let best_score_target: Option<S::Score> = term_config
134        .and_then(|c| c.best_score_limit.as_ref())
135        .and_then(|s| S::Score::parse(s).ok());
136
137    let termination = if let Some(target) = best_score_target {
138        AnyTermination::WithBestScore(OrTermination::new((
139            time,
140            BestScoreTermination::new(target),
141        )))
142    } else if let Some(step_limit) = term_config.and_then(|c| c.step_count_limit) {
143        AnyTermination::WithStepCount(OrTermination::new((
144            time,
145            StepCountTermination::new(step_limit),
146        )))
147    } else if let Some(unimproved_step_limit) =
148        term_config.and_then(|c| c.unimproved_step_count_limit)
149    {
150        AnyTermination::WithUnimprovedStep(OrTermination::new((
151            time,
152            UnimprovedStepCountTermination::<S>::new(unimproved_step_limit),
153        )))
154    } else if let Some(unimproved_time) = term_config.and_then(|c| c.unimproved_time_limit()) {
155        AnyTermination::WithUnimprovedTime(OrTermination::new((
156            time,
157            UnimprovedTimeTermination::<S>::new(unimproved_time),
158        )))
159    } else {
160        AnyTermination::Default(OrTermination::new((time,)))
161    };
162
163    (termination, time_limit)
164}
165
166pub fn log_solve_start(
167    entity_count: usize,
168    element_count: Option<usize>,
169    has_standard: Option<bool>,
170    variable_count: Option<usize>,
171) {
172    info!(
173        event = "solve_start",
174        entity_count = entity_count,
175        value_count = element_count.or(variable_count).unwrap_or(0),
176        solve_shape = if element_count.is_some() {
177            "list"
178        } else if has_standard.unwrap_or(false) {
179            "standard"
180        } else {
181            "solution"
182        },
183    );
184}
185
186#[allow(clippy::too_many_arguments)]
187pub fn run_solver<S, C, P>(
188    solution: S,
189    constraints_fn: fn() -> C,
190    descriptor: fn() -> SolutionDescriptor,
191    entity_count_by_descriptor: fn(&S, usize) -> usize,
192    runtime: SolverRuntime<S>,
193    default_time_limit_secs: u64,
194    is_trivial: fn(&S) -> bool,
195    log_scale: fn(&S),
196    build_phases: fn(&SolverConfig) -> PhaseSequence<P>,
197) -> S
198where
199    S: PlanningSolution,
200    S::Score: Score + ParseableScore,
201    C: ConstraintSet<S, S::Score>,
202    P: Send + std::fmt::Debug,
203    PhaseSequence<P>: Phase<S, ScoreDirector<S, C>, ChannelProgressCallback<S>>,
204{
205    let config = SolverConfig::load("solver.toml").unwrap_or_default();
206
207    log_scale(&solution);
208    let trivial = is_trivial(&solution);
209
210    let constraints = constraints_fn();
211    let director = ScoreDirector::with_descriptor(
212        solution,
213        constraints,
214        descriptor(),
215        entity_count_by_descriptor,
216    );
217
218    if trivial {
219        let mut solver_scope = SolverScope::new(director);
220        solver_scope = solver_scope.with_runtime(Some(runtime));
221        solver_scope.start_solving();
222        let score = solver_scope.calculate_score();
223        let solution = solver_scope.score_director().clone_working_solution();
224        solver_scope.set_best_solution(solution.clone(), score);
225        solver_scope.report_best_solution();
226        solver_scope.pause_if_requested();
227        info!(event = "solve_end", score = %score);
228        let telemetry = solver_scope.stats().snapshot();
229        if runtime.is_cancel_requested() {
230            runtime.emit_cancelled(Some(score), Some(score), telemetry);
231        } else {
232            runtime.emit_completed(
233                solution.clone(),
234                Some(score),
235                score,
236                telemetry,
237                SolverTerminalReason::Completed,
238            );
239        }
240        return solution;
241    }
242
243    let (termination, time_limit) = build_termination::<S, C>(&config, default_time_limit_secs);
244
245    let callback = ChannelProgressCallback::new(runtime);
246
247    let phases = build_phases(&config);
248    let solver = Solver::new((phases,))
249        .with_termination(termination)
250        .with_time_limit(time_limit)
251        .with_runtime(runtime)
252        .with_progress_callback(callback);
253
254    let result = solver.with_terminate(runtime.cancel_flag()).solve(director);
255
256    let crate::solver::SolveResult {
257        solution,
258        current_score,
259        best_score: final_score,
260        terminal_reason,
261        stats,
262    } = result;
263    let final_telemetry = stats.snapshot();
264    match terminal_reason {
265        SolverTerminalReason::Completed | SolverTerminalReason::TerminatedByConfig => {
266            runtime.emit_completed(
267                solution.clone(),
268                current_score,
269                final_score,
270                final_telemetry,
271                terminal_reason,
272            );
273        }
274        SolverTerminalReason::Cancelled => {
275            runtime.emit_cancelled(current_score, Some(final_score), final_telemetry);
276        }
277        SolverTerminalReason::Failed => unreachable!("solver completion cannot report failure"),
278    }
279
280    info!(
281        event = "solve_end",
282        score = %final_score,
283        steps = stats.step_count,
284        moves_evaluated = stats.moves_evaluated,
285        moves_accepted = stats.moves_accepted,
286        score_calculations = stats.score_calculations,
287        moves_speed = final_telemetry.moves_per_second,
288        acceptance_rate = format!("{:.1}%", stats.acceptance_rate() * 100.0),
289    );
290    solution
291}