Skip to main content

solverforge_solver/
run.rs

1/* Solver entry point. */
2
3use std::fmt;
4use std::marker::PhantomData;
5use std::path::Path;
6use std::time::Duration;
7
8use solverforge_config::SolverConfig;
9use solverforge_core::domain::{PlanningSolution, SolutionDescriptor};
10use solverforge_core::score::{ParseableScore, Score};
11use solverforge_scoring::{ConstraintSet, Director, ScoreDirector};
12use tracing::info;
13
14use crate::manager::{SolverRuntime, SolverTerminalReason};
15use crate::phase::{Phase, PhaseSequence};
16use crate::scope::{ProgressCallback, SolverProgressKind, SolverProgressRef, SolverScope};
17use crate::solver::{NoTermination, Solver};
18use crate::stats::{format_duration, whole_units_per_second};
19use crate::termination::{
20    BestScoreTermination, OrTermination, StepCountTermination, Termination, TimeTermination,
21    UnimprovedStepCountTermination, UnimprovedTimeTermination,
22};
23
24/// Monomorphized termination enum for config-driven solver configurations.
25///
26/// Avoids repeated branching across termination overloads by capturing the
27/// selected termination variant upfront.
28pub enum AnyTermination<S: PlanningSolution, D: Director<S>> {
29    None(NoTermination),
30    Default(OrTermination<(TimeTermination,), S, D>),
31    WithBestScore(OrTermination<(TimeTermination, BestScoreTermination<S::Score>), S, D>),
32    WithStepCount(OrTermination<(TimeTermination, StepCountTermination), S, D>),
33    WithUnimprovedStep(OrTermination<(TimeTermination, UnimprovedStepCountTermination<S>), S, D>),
34    WithUnimprovedTime(OrTermination<(TimeTermination, UnimprovedTimeTermination<S>), S, D>),
35}
36
37#[derive(Clone)]
38pub struct ChannelProgressCallback<S: PlanningSolution> {
39    runtime: SolverRuntime<S>,
40    _phantom: PhantomData<fn() -> S>,
41}
42
43impl<S: PlanningSolution> ChannelProgressCallback<S> {
44    fn new(runtime: SolverRuntime<S>) -> Self {
45        Self {
46            runtime,
47            _phantom: PhantomData,
48        }
49    }
50}
51
52impl<S: PlanningSolution> fmt::Debug for ChannelProgressCallback<S> {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        f.debug_struct("ChannelProgressCallback").finish()
55    }
56}
57
58impl<S: PlanningSolution> ProgressCallback<S> for ChannelProgressCallback<S> {
59    fn invoke(&self, progress: SolverProgressRef<'_, S>) {
60        match progress.kind {
61            SolverProgressKind::Progress => {
62                self.runtime.emit_progress(
63                    progress.current_score.copied(),
64                    progress.best_score.copied(),
65                    progress.telemetry.clone(),
66                );
67            }
68            SolverProgressKind::BestSolution => {
69                if let (Some(solution), Some(score)) = (progress.solution, progress.best_score) {
70                    self.runtime.emit_best_solution(
71                        (*solution).clone(),
72                        progress.current_score.copied(),
73                        *score,
74                        progress.telemetry.clone(),
75                    );
76                }
77            }
78        }
79    }
80}
81
82impl<S: PlanningSolution, D: Director<S>> fmt::Debug for AnyTermination<S, D> {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        match self {
85            Self::None(_) => write!(f, "AnyTermination::None"),
86            Self::Default(_) => write!(f, "AnyTermination::Default"),
87            Self::WithBestScore(_) => write!(f, "AnyTermination::WithBestScore"),
88            Self::WithStepCount(_) => write!(f, "AnyTermination::WithStepCount"),
89            Self::WithUnimprovedStep(_) => write!(f, "AnyTermination::WithUnimprovedStep"),
90            Self::WithUnimprovedTime(_) => write!(f, "AnyTermination::WithUnimprovedTime"),
91        }
92    }
93}
94
95impl<S: PlanningSolution, D: Director<S>, ProgressCb: ProgressCallback<S>>
96    Termination<S, D, ProgressCb> for AnyTermination<S, D>
97where
98    S::Score: Score,
99{
100    fn is_terminated(&self, solver_scope: &SolverScope<S, D, ProgressCb>) -> bool {
101        match self {
102            Self::None(t) => t.is_terminated(solver_scope),
103            Self::Default(t) => t.is_terminated(solver_scope),
104            Self::WithBestScore(t) => t.is_terminated(solver_scope),
105            Self::WithStepCount(t) => t.is_terminated(solver_scope),
106            Self::WithUnimprovedStep(t) => t.is_terminated(solver_scope),
107            Self::WithUnimprovedTime(t) => t.is_terminated(solver_scope),
108        }
109    }
110
111    fn install_inphase_limits(&self, solver_scope: &mut SolverScope<S, D, ProgressCb>) {
112        match self {
113            Self::None(t) => t.install_inphase_limits(solver_scope),
114            Self::Default(t) => t.install_inphase_limits(solver_scope),
115            Self::WithBestScore(t) => t.install_inphase_limits(solver_scope),
116            Self::WithStepCount(t) => t.install_inphase_limits(solver_scope),
117            Self::WithUnimprovedStep(t) => t.install_inphase_limits(solver_scope),
118            Self::WithUnimprovedTime(t) => t.install_inphase_limits(solver_scope),
119        }
120    }
121}
122
123/// Builds a termination from config, returning both the termination and the time limit.
124pub fn build_termination<S, C>(
125    config: &SolverConfig,
126    default_secs: u64,
127) -> (AnyTermination<S, ScoreDirector<S, C>>, Option<Duration>)
128where
129    S: PlanningSolution,
130    S::Score: Score + ParseableScore,
131    C: ConstraintSet<S, S::Score>,
132{
133    let term_config = config.termination.as_ref();
134    let configured_time_limit = term_config.and_then(|c| c.time_limit());
135    let fallback_time_limit = Duration::from_secs(default_secs);
136
137    let best_score_target: Option<S::Score> = term_config
138        .and_then(|c| c.best_score_limit.as_ref())
139        .and_then(|s| S::Score::parse(s).ok());
140
141    let (termination, effective_time_limit) = if let Some(target) = best_score_target {
142        let effective_time_limit = configured_time_limit.unwrap_or(fallback_time_limit);
143        let time = TimeTermination::new(effective_time_limit);
144        (
145            AnyTermination::WithBestScore(OrTermination::new((
146                time,
147                BestScoreTermination::new(target),
148            ))),
149            Some(effective_time_limit),
150        )
151    } else if let Some(step_limit) = term_config.and_then(|c| c.step_count_limit) {
152        let effective_time_limit = configured_time_limit.unwrap_or(fallback_time_limit);
153        let time = TimeTermination::new(effective_time_limit);
154        (
155            AnyTermination::WithStepCount(OrTermination::new((
156                time,
157                StepCountTermination::new(step_limit),
158            ))),
159            Some(effective_time_limit),
160        )
161    } else if let Some(unimproved_step_limit) =
162        term_config.and_then(|c| c.unimproved_step_count_limit)
163    {
164        let effective_time_limit = configured_time_limit.unwrap_or(fallback_time_limit);
165        let time = TimeTermination::new(effective_time_limit);
166        (
167            AnyTermination::WithUnimprovedStep(OrTermination::new((
168                time,
169                UnimprovedStepCountTermination::<S>::new(unimproved_step_limit),
170            ))),
171            Some(effective_time_limit),
172        )
173    } else if let Some(unimproved_time) = term_config.and_then(|c| c.unimproved_time_limit()) {
174        let effective_time_limit = configured_time_limit.unwrap_or(fallback_time_limit);
175        let time = TimeTermination::new(effective_time_limit);
176        (
177            AnyTermination::WithUnimprovedTime(OrTermination::new((
178                time,
179                UnimprovedTimeTermination::<S>::new(unimproved_time),
180            ))),
181            Some(effective_time_limit),
182        )
183    } else if let Some(limit) = configured_time_limit {
184        let time = TimeTermination::new(limit);
185        (
186            AnyTermination::Default(OrTermination::new((time,))),
187            Some(limit),
188        )
189    } else {
190        (AnyTermination::None(NoTermination), None)
191    };
192
193    (termination, effective_time_limit)
194}
195
196pub fn log_solve_start(
197    entity_count: usize,
198    element_count: Option<usize>,
199    candidate_count: Option<usize>,
200) {
201    match (element_count, candidate_count) {
202        (Some(element_count), None) => {
203            info!(
204                event = "solve_start",
205                entity_count = entity_count,
206                element_count = element_count,
207                solve_shape = "list",
208            );
209        }
210        (None, Some(candidate_count)) => {
211            info!(
212                event = "solve_start",
213                entity_count = entity_count,
214                candidate_count = candidate_count,
215                solve_shape = "scalar",
216            );
217        }
218        _ => {
219            panic!("log_solve_start requires exactly one solve scale: list elements or scalar candidates");
220        }
221    }
222}
223
224fn load_solver_config_from(path: impl AsRef<Path>) -> SolverConfig {
225    SolverConfig::load(path).unwrap_or_default()
226}
227
228fn load_solver_config() -> SolverConfig {
229    load_solver_config_from("solver.toml")
230}
231
232#[allow(clippy::too_many_arguments)]
233pub fn run_solver<S, C, P>(
234    solution: S,
235    constraints_fn: fn() -> C,
236    descriptor: fn() -> SolutionDescriptor,
237    entity_count_by_descriptor: fn(&S, usize) -> usize,
238    runtime: SolverRuntime<S>,
239    default_time_limit_secs: u64,
240    is_trivial: fn(&S) -> bool,
241    log_scale: fn(&S),
242    build_phases: fn(&SolverConfig) -> PhaseSequence<P>,
243) -> S
244where
245    S: PlanningSolution,
246    S::Score: Score + ParseableScore,
247    C: ConstraintSet<S, S::Score>,
248    P: Send + std::fmt::Debug,
249    PhaseSequence<P>: Phase<S, ScoreDirector<S, C>, ChannelProgressCallback<S>>,
250{
251    let config = load_solver_config();
252    run_solver_with_config(
253        solution,
254        constraints_fn,
255        descriptor,
256        entity_count_by_descriptor,
257        runtime,
258        config,
259        default_time_limit_secs,
260        is_trivial,
261        log_scale,
262        build_phases,
263    )
264}
265
266#[allow(clippy::too_many_arguments)]
267pub fn run_solver_with_config<S, C, P>(
268    solution: S,
269    constraints_fn: fn() -> C,
270    descriptor: fn() -> SolutionDescriptor,
271    entity_count_by_descriptor: fn(&S, usize) -> usize,
272    runtime: SolverRuntime<S>,
273    config: SolverConfig,
274    default_time_limit_secs: u64,
275    is_trivial: fn(&S) -> bool,
276    log_scale: fn(&S),
277    build_phases: fn(&SolverConfig) -> PhaseSequence<P>,
278) -> S
279where
280    S: PlanningSolution,
281    S::Score: Score + ParseableScore,
282    C: ConstraintSet<S, S::Score>,
283    P: Send + std::fmt::Debug,
284    PhaseSequence<P>: Phase<S, ScoreDirector<S, C>, ChannelProgressCallback<S>>,
285{
286    log_scale(&solution);
287    let trivial = is_trivial(&solution);
288
289    let constraints = constraints_fn();
290    let director = ScoreDirector::with_descriptor(
291        solution,
292        constraints,
293        descriptor(),
294        entity_count_by_descriptor,
295    );
296
297    if trivial {
298        let mut solver_scope = SolverScope::new(director);
299        solver_scope = solver_scope.with_runtime(Some(runtime));
300        if let Some(seed) = config.random_seed {
301            solver_scope = solver_scope.with_seed(seed);
302        }
303        solver_scope.start_solving();
304        let score = solver_scope.calculate_score();
305        let solution = solver_scope.score_director().clone_working_solution();
306        solver_scope.set_best_solution(solution.clone(), score);
307        solver_scope.report_best_solution();
308        solver_scope.pause_if_requested();
309        info!(event = "solve_end", score = %score);
310        let telemetry = solver_scope.stats().snapshot();
311        if runtime.is_cancel_requested() {
312            runtime.emit_cancelled(Some(score), Some(score), telemetry);
313        } else {
314            runtime.emit_completed(
315                solution.clone(),
316                Some(score),
317                score,
318                telemetry,
319                SolverTerminalReason::Completed,
320            );
321        }
322        return solution;
323    }
324
325    let (termination, time_limit) = build_termination::<S, C>(&config, default_time_limit_secs);
326
327    let callback = ChannelProgressCallback::new(runtime);
328
329    let phases = build_phases(&config);
330    let mut solver = Solver::new((phases,))
331        .with_config(config.clone())
332        .with_termination(termination)
333        .with_runtime(runtime)
334        .with_progress_callback(callback);
335    if let Some(time_limit) = time_limit {
336        solver = solver.with_time_limit(time_limit);
337    }
338
339    let result = solver.with_terminate(runtime.cancel_flag()).solve(director);
340
341    let crate::solver::SolveResult {
342        solution,
343        current_score,
344        best_score: final_score,
345        terminal_reason,
346        stats,
347    } = result;
348    let final_telemetry = stats.snapshot();
349    let final_move_speed = whole_units_per_second(stats.moves_evaluated, stats.elapsed());
350    match terminal_reason {
351        SolverTerminalReason::Completed | SolverTerminalReason::TerminatedByConfig => {
352            runtime.emit_completed(
353                solution.clone(),
354                current_score,
355                final_score,
356                final_telemetry,
357                terminal_reason,
358            );
359        }
360        SolverTerminalReason::Cancelled => {
361            runtime.emit_cancelled(current_score, Some(final_score), final_telemetry);
362        }
363        SolverTerminalReason::Failed => unreachable!("solver completion cannot report failure"),
364    }
365
366    info!(
367        event = "solve_end",
368        score = %final_score,
369        steps = stats.step_count,
370        moves_generated = stats.moves_generated,
371        moves_evaluated = stats.moves_evaluated,
372        moves_accepted = stats.moves_accepted,
373        score_calculations = stats.score_calculations,
374        generation_time = %format_duration(stats.generation_time()),
375        evaluation_time = %format_duration(stats.evaluation_time()),
376        moves_speed = final_move_speed,
377        acceptance_rate = format!("{:.1}%", stats.acceptance_rate() * 100.0),
378    );
379    solution
380}
381
382#[cfg(test)]
383#[path = "run_tests.rs"]
384mod tests;