Skip to main content

solverforge_solver/
run.rs

1/* Solver entry point. */
2
3use std::fmt;
4use std::marker::PhantomData;
5use std::path::Path;
6use std::time::Duration;
7
8use solverforge_config::SolverConfig;
9use solverforge_core::domain::{PlanningSolution, SolutionDescriptor};
10use solverforge_core::score::{ParseableScore, Score};
11use solverforge_scoring::{ConstraintSet, Director, ScoreDirector};
12use tracing::info;
13
14use crate::manager::{SolverRuntime, SolverTerminalReason};
15use crate::phase::{Phase, PhaseSequence};
16use crate::scope::{ProgressCallback, SolverProgressKind, SolverProgressRef, SolverScope};
17use crate::solver::Solver;
18use crate::stats::{format_duration, whole_units_per_second};
19use crate::termination::{
20    BestScoreTermination, OrTermination, StepCountTermination, Termination, TimeTermination,
21    UnimprovedStepCountTermination, UnimprovedTimeTermination,
22};
23
24/// Monomorphized termination enum for config-driven solver configurations.
25///
26/// Avoids repeated branching across termination overloads by capturing the
27/// selected termination variant upfront.
28pub enum AnyTermination<S: PlanningSolution, D: Director<S>> {
29    Default(OrTermination<(TimeTermination,), S, D>),
30    WithBestScore(OrTermination<(TimeTermination, BestScoreTermination<S::Score>), S, D>),
31    WithStepCount(OrTermination<(TimeTermination, StepCountTermination), S, D>),
32    WithUnimprovedStep(OrTermination<(TimeTermination, UnimprovedStepCountTermination<S>), S, D>),
33    WithUnimprovedTime(OrTermination<(TimeTermination, UnimprovedTimeTermination<S>), S, D>),
34}
35
36#[derive(Clone)]
37pub struct ChannelProgressCallback<S: PlanningSolution> {
38    runtime: SolverRuntime<S>,
39    _phantom: PhantomData<fn() -> S>,
40}
41
42impl<S: PlanningSolution> ChannelProgressCallback<S> {
43    fn new(runtime: SolverRuntime<S>) -> Self {
44        Self {
45            runtime,
46            _phantom: PhantomData,
47        }
48    }
49}
50
51impl<S: PlanningSolution> fmt::Debug for ChannelProgressCallback<S> {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        f.debug_struct("ChannelProgressCallback").finish()
54    }
55}
56
57impl<S: PlanningSolution> ProgressCallback<S> for ChannelProgressCallback<S> {
58    fn invoke(&self, progress: SolverProgressRef<'_, S>) {
59        match progress.kind {
60            SolverProgressKind::Progress => {
61                self.runtime.emit_progress(
62                    progress.current_score.copied(),
63                    progress.best_score.copied(),
64                    progress.telemetry,
65                );
66            }
67            SolverProgressKind::BestSolution => {
68                if let (Some(solution), Some(score)) = (progress.solution, progress.best_score) {
69                    self.runtime.emit_best_solution(
70                        (*solution).clone(),
71                        progress.current_score.copied(),
72                        *score,
73                        progress.telemetry,
74                    );
75                }
76            }
77        }
78    }
79}
80
81impl<S: PlanningSolution, D: Director<S>> fmt::Debug for AnyTermination<S, D> {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        match self {
84            Self::Default(_) => write!(f, "AnyTermination::Default"),
85            Self::WithBestScore(_) => write!(f, "AnyTermination::WithBestScore"),
86            Self::WithStepCount(_) => write!(f, "AnyTermination::WithStepCount"),
87            Self::WithUnimprovedStep(_) => write!(f, "AnyTermination::WithUnimprovedStep"),
88            Self::WithUnimprovedTime(_) => write!(f, "AnyTermination::WithUnimprovedTime"),
89        }
90    }
91}
92
93impl<S: PlanningSolution, D: Director<S>, ProgressCb: ProgressCallback<S>>
94    Termination<S, D, ProgressCb> for AnyTermination<S, D>
95where
96    S::Score: Score,
97{
98    fn is_terminated(&self, solver_scope: &SolverScope<S, D, ProgressCb>) -> bool {
99        match self {
100            Self::Default(t) => t.is_terminated(solver_scope),
101            Self::WithBestScore(t) => t.is_terminated(solver_scope),
102            Self::WithStepCount(t) => t.is_terminated(solver_scope),
103            Self::WithUnimprovedStep(t) => t.is_terminated(solver_scope),
104            Self::WithUnimprovedTime(t) => t.is_terminated(solver_scope),
105        }
106    }
107
108    fn install_inphase_limits(&self, solver_scope: &mut SolverScope<S, D, ProgressCb>) {
109        match self {
110            Self::Default(t) => t.install_inphase_limits(solver_scope),
111            Self::WithBestScore(t) => t.install_inphase_limits(solver_scope),
112            Self::WithStepCount(t) => t.install_inphase_limits(solver_scope),
113            Self::WithUnimprovedStep(t) => t.install_inphase_limits(solver_scope),
114            Self::WithUnimprovedTime(t) => t.install_inphase_limits(solver_scope),
115        }
116    }
117}
118
119/// Builds a termination from config, returning both the termination and the time limit.
120pub fn build_termination<S, C>(
121    config: &SolverConfig,
122    default_secs: u64,
123) -> (AnyTermination<S, ScoreDirector<S, C>>, Duration)
124where
125    S: PlanningSolution,
126    S::Score: Score + ParseableScore,
127    C: ConstraintSet<S, S::Score>,
128{
129    let term_config = config.termination.as_ref();
130    let time_limit = term_config
131        .and_then(|c| c.time_limit())
132        .unwrap_or(Duration::from_secs(default_secs));
133    let time = TimeTermination::new(time_limit);
134
135    let best_score_target: Option<S::Score> = term_config
136        .and_then(|c| c.best_score_limit.as_ref())
137        .and_then(|s| S::Score::parse(s).ok());
138
139    let termination = if let Some(target) = best_score_target {
140        AnyTermination::WithBestScore(OrTermination::new((
141            time,
142            BestScoreTermination::new(target),
143        )))
144    } else if let Some(step_limit) = term_config.and_then(|c| c.step_count_limit) {
145        AnyTermination::WithStepCount(OrTermination::new((
146            time,
147            StepCountTermination::new(step_limit),
148        )))
149    } else if let Some(unimproved_step_limit) =
150        term_config.and_then(|c| c.unimproved_step_count_limit)
151    {
152        AnyTermination::WithUnimprovedStep(OrTermination::new((
153            time,
154            UnimprovedStepCountTermination::<S>::new(unimproved_step_limit),
155        )))
156    } else if let Some(unimproved_time) = term_config.and_then(|c| c.unimproved_time_limit()) {
157        AnyTermination::WithUnimprovedTime(OrTermination::new((
158            time,
159            UnimprovedTimeTermination::<S>::new(unimproved_time),
160        )))
161    } else {
162        AnyTermination::Default(OrTermination::new((time,)))
163    };
164
165    (termination, time_limit)
166}
167
168pub fn log_solve_start(
169    entity_count: usize,
170    element_count: Option<usize>,
171    has_standard: Option<bool>,
172    candidate_count: Option<usize>,
173) {
174    if let Some(element_count) = element_count {
175        info!(
176            event = "solve_start",
177            entity_count = entity_count,
178            element_count = element_count,
179            solve_shape = "list",
180        );
181    } else if has_standard.unwrap_or(candidate_count.is_some()) {
182        info!(
183            event = "solve_start",
184            entity_count = entity_count,
185            candidate_count = candidate_count.unwrap_or(0),
186            solve_shape = "standard",
187        );
188    } else {
189        info!(
190            event = "solve_start",
191            entity_count = entity_count,
192            value_count = candidate_count.unwrap_or(0),
193            solve_shape = "solution",
194        );
195    }
196}
197
198fn load_solver_config_from(path: impl AsRef<Path>) -> SolverConfig {
199    SolverConfig::load(path).unwrap_or_default()
200}
201
202fn load_solver_config() -> SolverConfig {
203    load_solver_config_from("solver.toml")
204}
205
206#[allow(clippy::too_many_arguments)]
207pub fn run_solver<S, C, P>(
208    solution: S,
209    constraints_fn: fn() -> C,
210    descriptor: fn() -> SolutionDescriptor,
211    entity_count_by_descriptor: fn(&S, usize) -> usize,
212    runtime: SolverRuntime<S>,
213    default_time_limit_secs: u64,
214    is_trivial: fn(&S) -> bool,
215    log_scale: fn(&S),
216    build_phases: fn(&SolverConfig) -> PhaseSequence<P>,
217) -> S
218where
219    S: PlanningSolution,
220    S::Score: Score + ParseableScore,
221    C: ConstraintSet<S, S::Score>,
222    P: Send + std::fmt::Debug,
223    PhaseSequence<P>: Phase<S, ScoreDirector<S, C>, ChannelProgressCallback<S>>,
224{
225    let config = load_solver_config();
226    run_solver_with_config(
227        solution,
228        constraints_fn,
229        descriptor,
230        entity_count_by_descriptor,
231        runtime,
232        config,
233        default_time_limit_secs,
234        is_trivial,
235        log_scale,
236        build_phases,
237    )
238}
239
240#[allow(clippy::too_many_arguments)]
241pub fn run_solver_with_config<S, C, P>(
242    solution: S,
243    constraints_fn: fn() -> C,
244    descriptor: fn() -> SolutionDescriptor,
245    entity_count_by_descriptor: fn(&S, usize) -> usize,
246    runtime: SolverRuntime<S>,
247    config: SolverConfig,
248    default_time_limit_secs: u64,
249    is_trivial: fn(&S) -> bool,
250    log_scale: fn(&S),
251    build_phases: fn(&SolverConfig) -> PhaseSequence<P>,
252) -> S
253where
254    S: PlanningSolution,
255    S::Score: Score + ParseableScore,
256    C: ConstraintSet<S, S::Score>,
257    P: Send + std::fmt::Debug,
258    PhaseSequence<P>: Phase<S, ScoreDirector<S, C>, ChannelProgressCallback<S>>,
259{
260    log_scale(&solution);
261    let trivial = is_trivial(&solution);
262
263    let constraints = constraints_fn();
264    let director = ScoreDirector::with_descriptor(
265        solution,
266        constraints,
267        descriptor(),
268        entity_count_by_descriptor,
269    );
270
271    if trivial {
272        let mut solver_scope = SolverScope::new(director);
273        solver_scope = solver_scope.with_runtime(Some(runtime));
274        if let Some(seed) = config.random_seed {
275            solver_scope = solver_scope.with_seed(seed);
276        }
277        solver_scope.start_solving();
278        let score = solver_scope.calculate_score();
279        let solution = solver_scope.score_director().clone_working_solution();
280        solver_scope.set_best_solution(solution.clone(), score);
281        solver_scope.report_best_solution();
282        solver_scope.pause_if_requested();
283        info!(event = "solve_end", score = %score);
284        let telemetry = solver_scope.stats().snapshot();
285        if runtime.is_cancel_requested() {
286            runtime.emit_cancelled(Some(score), Some(score), telemetry);
287        } else {
288            runtime.emit_completed(
289                solution.clone(),
290                Some(score),
291                score,
292                telemetry,
293                SolverTerminalReason::Completed,
294            );
295        }
296        return solution;
297    }
298
299    let (termination, time_limit) = build_termination::<S, C>(&config, default_time_limit_secs);
300
301    let callback = ChannelProgressCallback::new(runtime);
302
303    let phases = build_phases(&config);
304    let solver = Solver::new((phases,))
305        .with_config(config.clone())
306        .with_termination(termination)
307        .with_time_limit(time_limit)
308        .with_runtime(runtime)
309        .with_progress_callback(callback);
310
311    let result = solver.with_terminate(runtime.cancel_flag()).solve(director);
312
313    let crate::solver::SolveResult {
314        solution,
315        current_score,
316        best_score: final_score,
317        terminal_reason,
318        stats,
319    } = result;
320    let final_telemetry = stats.snapshot();
321    let final_move_speed = whole_units_per_second(stats.moves_evaluated, stats.elapsed());
322    match terminal_reason {
323        SolverTerminalReason::Completed | SolverTerminalReason::TerminatedByConfig => {
324            runtime.emit_completed(
325                solution.clone(),
326                current_score,
327                final_score,
328                final_telemetry,
329                terminal_reason,
330            );
331        }
332        SolverTerminalReason::Cancelled => {
333            runtime.emit_cancelled(current_score, Some(final_score), final_telemetry);
334        }
335        SolverTerminalReason::Failed => unreachable!("solver completion cannot report failure"),
336    }
337
338    info!(
339        event = "solve_end",
340        score = %final_score,
341        steps = stats.step_count,
342        moves_generated = stats.moves_generated,
343        moves_evaluated = stats.moves_evaluated,
344        moves_accepted = stats.moves_accepted,
345        score_calculations = stats.score_calculations,
346        generation_time = %format_duration(stats.generation_time()),
347        evaluation_time = %format_duration(stats.evaluation_time()),
348        moves_speed = final_move_speed,
349        acceptance_rate = format!("{:.1}%", stats.acceptance_rate() * 100.0),
350    );
351    solution
352}
353
354#[cfg(test)]
355#[path = "run_tests.rs"]
356mod tests;