Skip to main content

solverforge_solver/
run.rs

1/* Unified solver entry point. */
2
3use std::fmt;
4use std::marker::PhantomData;
5use std::path::Path;
6use std::time::Duration;
7
8use solverforge_config::SolverConfig;
9use solverforge_core::domain::{PlanningSolution, SolutionDescriptor};
10use solverforge_core::score::{ParseableScore, Score};
11use solverforge_scoring::{ConstraintSet, Director, ScoreDirector};
12use tracing::info;
13
14use crate::manager::{SolverRuntime, SolverTerminalReason};
15use crate::phase::{Phase, PhaseSequence};
16use crate::scope::{ProgressCallback, SolverProgressKind, SolverProgressRef, SolverScope};
17use crate::solver::Solver;
18use crate::termination::{
19    BestScoreTermination, OrTermination, StepCountTermination, Termination, TimeTermination,
20    UnimprovedStepCountTermination, UnimprovedTimeTermination,
21};
22
23/// Monomorphized termination enum for config-driven solver configurations.
24///
25/// Avoids repeated branching across termination overloads by capturing the
26/// selected termination variant upfront.
27pub enum AnyTermination<S: PlanningSolution, D: Director<S>> {
28    Default(OrTermination<(TimeTermination,), S, D>),
29    WithBestScore(OrTermination<(TimeTermination, BestScoreTermination<S::Score>), S, D>),
30    WithStepCount(OrTermination<(TimeTermination, StepCountTermination), S, D>),
31    WithUnimprovedStep(OrTermination<(TimeTermination, UnimprovedStepCountTermination<S>), S, D>),
32    WithUnimprovedTime(OrTermination<(TimeTermination, UnimprovedTimeTermination<S>), S, D>),
33}
34
35#[derive(Clone)]
36pub struct ChannelProgressCallback<S: PlanningSolution> {
37    runtime: SolverRuntime<S>,
38    _phantom: PhantomData<fn() -> S>,
39}
40
41impl<S: PlanningSolution> ChannelProgressCallback<S> {
42    fn new(runtime: SolverRuntime<S>) -> Self {
43        Self {
44            runtime,
45            _phantom: PhantomData,
46        }
47    }
48}
49
50impl<S: PlanningSolution> fmt::Debug for ChannelProgressCallback<S> {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        f.debug_struct("ChannelProgressCallback").finish()
53    }
54}
55
56impl<S: PlanningSolution> ProgressCallback<S> for ChannelProgressCallback<S> {
57    fn invoke(&self, progress: SolverProgressRef<'_, S>) {
58        match progress.kind {
59            SolverProgressKind::Progress => {
60                self.runtime.emit_progress(
61                    progress.current_score.copied(),
62                    progress.best_score.copied(),
63                    progress.telemetry,
64                );
65            }
66            SolverProgressKind::BestSolution => {
67                if let (Some(solution), Some(score)) = (progress.solution, progress.best_score) {
68                    self.runtime.emit_best_solution(
69                        (*solution).clone(),
70                        progress.current_score.copied(),
71                        *score,
72                        progress.telemetry,
73                    );
74                }
75            }
76        }
77    }
78}
79
80impl<S: PlanningSolution, D: Director<S>> fmt::Debug for AnyTermination<S, D> {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        match self {
83            Self::Default(_) => write!(f, "AnyTermination::Default"),
84            Self::WithBestScore(_) => write!(f, "AnyTermination::WithBestScore"),
85            Self::WithStepCount(_) => write!(f, "AnyTermination::WithStepCount"),
86            Self::WithUnimprovedStep(_) => write!(f, "AnyTermination::WithUnimprovedStep"),
87            Self::WithUnimprovedTime(_) => write!(f, "AnyTermination::WithUnimprovedTime"),
88        }
89    }
90}
91
92impl<S: PlanningSolution, D: Director<S>, ProgressCb: ProgressCallback<S>>
93    Termination<S, D, ProgressCb> for AnyTermination<S, D>
94where
95    S::Score: Score,
96{
97    fn is_terminated(&self, solver_scope: &SolverScope<S, D, ProgressCb>) -> bool {
98        match self {
99            Self::Default(t) => t.is_terminated(solver_scope),
100            Self::WithBestScore(t) => t.is_terminated(solver_scope),
101            Self::WithStepCount(t) => t.is_terminated(solver_scope),
102            Self::WithUnimprovedStep(t) => t.is_terminated(solver_scope),
103            Self::WithUnimprovedTime(t) => t.is_terminated(solver_scope),
104        }
105    }
106
107    fn install_inphase_limits(&self, solver_scope: &mut SolverScope<S, D, ProgressCb>) {
108        match self {
109            Self::Default(t) => t.install_inphase_limits(solver_scope),
110            Self::WithBestScore(t) => t.install_inphase_limits(solver_scope),
111            Self::WithStepCount(t) => t.install_inphase_limits(solver_scope),
112            Self::WithUnimprovedStep(t) => t.install_inphase_limits(solver_scope),
113            Self::WithUnimprovedTime(t) => t.install_inphase_limits(solver_scope),
114        }
115    }
116}
117
118/// Builds a termination from config, returning both the termination and the time limit.
119pub fn build_termination<S, C>(
120    config: &SolverConfig,
121    default_secs: u64,
122) -> (AnyTermination<S, ScoreDirector<S, C>>, Duration)
123where
124    S: PlanningSolution,
125    S::Score: Score + ParseableScore,
126    C: ConstraintSet<S, S::Score>,
127{
128    let term_config = config.termination.as_ref();
129    let time_limit = term_config
130        .and_then(|c| c.time_limit())
131        .unwrap_or(Duration::from_secs(default_secs));
132    let time = TimeTermination::new(time_limit);
133
134    let best_score_target: Option<S::Score> = term_config
135        .and_then(|c| c.best_score_limit.as_ref())
136        .and_then(|s| S::Score::parse(s).ok());
137
138    let termination = if let Some(target) = best_score_target {
139        AnyTermination::WithBestScore(OrTermination::new((
140            time,
141            BestScoreTermination::new(target),
142        )))
143    } else if let Some(step_limit) = term_config.and_then(|c| c.step_count_limit) {
144        AnyTermination::WithStepCount(OrTermination::new((
145            time,
146            StepCountTermination::new(step_limit),
147        )))
148    } else if let Some(unimproved_step_limit) =
149        term_config.and_then(|c| c.unimproved_step_count_limit)
150    {
151        AnyTermination::WithUnimprovedStep(OrTermination::new((
152            time,
153            UnimprovedStepCountTermination::<S>::new(unimproved_step_limit),
154        )))
155    } else if let Some(unimproved_time) = term_config.and_then(|c| c.unimproved_time_limit()) {
156        AnyTermination::WithUnimprovedTime(OrTermination::new((
157            time,
158            UnimprovedTimeTermination::<S>::new(unimproved_time),
159        )))
160    } else {
161        AnyTermination::Default(OrTermination::new((time,)))
162    };
163
164    (termination, time_limit)
165}
166
167pub fn log_solve_start(
168    entity_count: usize,
169    element_count: Option<usize>,
170    has_standard: Option<bool>,
171    variable_count: Option<usize>,
172) {
173    info!(
174        event = "solve_start",
175        entity_count = entity_count,
176        value_count = element_count.or(variable_count).unwrap_or(0),
177        solve_shape = if element_count.is_some() {
178            "list"
179        } else if has_standard.unwrap_or(false) {
180            "standard"
181        } else {
182            "solution"
183        },
184    );
185}
186
187fn load_solver_config_from(path: impl AsRef<Path>) -> SolverConfig {
188    SolverConfig::load(path).unwrap_or_default()
189}
190
191fn load_solver_config() -> SolverConfig {
192    load_solver_config_from("solver.toml")
193}
194
195#[allow(clippy::too_many_arguments)]
196pub fn run_solver<S, C, P>(
197    solution: S,
198    constraints_fn: fn() -> C,
199    descriptor: fn() -> SolutionDescriptor,
200    entity_count_by_descriptor: fn(&S, usize) -> usize,
201    runtime: SolverRuntime<S>,
202    default_time_limit_secs: u64,
203    is_trivial: fn(&S) -> bool,
204    log_scale: fn(&S),
205    build_phases: fn(&SolverConfig) -> PhaseSequence<P>,
206) -> S
207where
208    S: PlanningSolution,
209    S::Score: Score + ParseableScore,
210    C: ConstraintSet<S, S::Score>,
211    P: Send + std::fmt::Debug,
212    PhaseSequence<P>: Phase<S, ScoreDirector<S, C>, ChannelProgressCallback<S>>,
213{
214    let config = load_solver_config();
215    run_solver_with_config(
216        solution,
217        constraints_fn,
218        descriptor,
219        entity_count_by_descriptor,
220        runtime,
221        config,
222        default_time_limit_secs,
223        is_trivial,
224        log_scale,
225        build_phases,
226    )
227}
228
229#[allow(clippy::too_many_arguments)]
230pub fn run_solver_with_config<S, C, P>(
231    solution: S,
232    constraints_fn: fn() -> C,
233    descriptor: fn() -> SolutionDescriptor,
234    entity_count_by_descriptor: fn(&S, usize) -> usize,
235    runtime: SolverRuntime<S>,
236    config: SolverConfig,
237    default_time_limit_secs: u64,
238    is_trivial: fn(&S) -> bool,
239    log_scale: fn(&S),
240    build_phases: fn(&SolverConfig) -> PhaseSequence<P>,
241) -> S
242where
243    S: PlanningSolution,
244    S::Score: Score + ParseableScore,
245    C: ConstraintSet<S, S::Score>,
246    P: Send + std::fmt::Debug,
247    PhaseSequence<P>: Phase<S, ScoreDirector<S, C>, ChannelProgressCallback<S>>,
248{
249    log_scale(&solution);
250    let trivial = is_trivial(&solution);
251
252    let constraints = constraints_fn();
253    let director = ScoreDirector::with_descriptor(
254        solution,
255        constraints,
256        descriptor(),
257        entity_count_by_descriptor,
258    );
259
260    if trivial {
261        let mut solver_scope = SolverScope::new(director);
262        solver_scope = solver_scope.with_runtime(Some(runtime));
263        if let Some(seed) = config.random_seed {
264            solver_scope = solver_scope.with_seed(seed);
265        }
266        solver_scope.start_solving();
267        let score = solver_scope.calculate_score();
268        let solution = solver_scope.score_director().clone_working_solution();
269        solver_scope.set_best_solution(solution.clone(), score);
270        solver_scope.report_best_solution();
271        solver_scope.pause_if_requested();
272        info!(event = "solve_end", score = %score);
273        let telemetry = solver_scope.stats().snapshot();
274        if runtime.is_cancel_requested() {
275            runtime.emit_cancelled(Some(score), Some(score), telemetry);
276        } else {
277            runtime.emit_completed(
278                solution.clone(),
279                Some(score),
280                score,
281                telemetry,
282                SolverTerminalReason::Completed,
283            );
284        }
285        return solution;
286    }
287
288    let (termination, time_limit) = build_termination::<S, C>(&config, default_time_limit_secs);
289
290    let callback = ChannelProgressCallback::new(runtime);
291
292    let phases = build_phases(&config);
293    let solver = Solver::new((phases,))
294        .with_config(config.clone())
295        .with_termination(termination)
296        .with_time_limit(time_limit)
297        .with_runtime(runtime)
298        .with_progress_callback(callback);
299
300    let result = solver.with_terminate(runtime.cancel_flag()).solve(director);
301
302    let crate::solver::SolveResult {
303        solution,
304        current_score,
305        best_score: final_score,
306        terminal_reason,
307        stats,
308    } = result;
309    let final_telemetry = stats.snapshot();
310    match terminal_reason {
311        SolverTerminalReason::Completed | SolverTerminalReason::TerminatedByConfig => {
312            runtime.emit_completed(
313                solution.clone(),
314                current_score,
315                final_score,
316                final_telemetry,
317                terminal_reason,
318            );
319        }
320        SolverTerminalReason::Cancelled => {
321            runtime.emit_cancelled(current_score, Some(final_score), final_telemetry);
322        }
323        SolverTerminalReason::Failed => unreachable!("solver completion cannot report failure"),
324    }
325
326    info!(
327        event = "solve_end",
328        score = %final_score,
329        steps = stats.step_count,
330        moves_evaluated = stats.moves_evaluated,
331        moves_accepted = stats.moves_accepted,
332        score_calculations = stats.score_calculations,
333        moves_speed = final_telemetry.moves_per_second,
334        acceptance_rate = format!("{:.1}%", stats.acceptance_rate() * 100.0),
335    );
336    solution
337}
338
339#[cfg(test)]
340mod tests {
341    use super::load_solver_config_from;
342    use std::fs;
343    use std::time::{Duration, SystemTime, UNIX_EPOCH};
344
345    fn temp_config_path() -> std::path::PathBuf {
346        let suffix = SystemTime::now()
347            .duration_since(UNIX_EPOCH)
348            .expect("system time should be after epoch")
349            .as_nanos();
350        std::env::temp_dir()
351            .join(format!(
352                "solverforge-run-tests-{}-{suffix}",
353                std::process::id()
354            ))
355            .join("solver.toml")
356    }
357
358    #[test]
359    fn load_solver_config_from_preserves_file_settings() {
360        let path = temp_config_path();
361        let parent = path.parent().expect("temp file should have a parent");
362        fs::create_dir_all(parent).expect("temp directory should be created");
363        fs::write(
364            &path,
365            r#"
366random_seed = 41
367
368[termination]
369seconds_spent_limit = 5
370
371[[phases]]
372type = "construction_heuristic"
373construction_heuristic_type = "first_fit"
374"#,
375        )
376        .expect("solver.toml should be written");
377
378        let config = load_solver_config_from(&path);
379
380        assert_eq!(config.random_seed, Some(41));
381        assert_eq!(config.time_limit(), Some(Duration::from_secs(5)));
382        assert_eq!(config.phases.len(), 1);
383
384        fs::remove_dir_all(parent).expect("temp directory should be removed");
385    }
386}