1use std::fmt;
4use std::marker::PhantomData;
5use std::path::Path;
6use std::time::Duration;
7
8use solverforge_config::SolverConfig;
9use solverforge_core::domain::{PlanningSolution, SolutionDescriptor};
10use solverforge_core::score::{ParseableScore, Score};
11use solverforge_scoring::{ConstraintSet, Director, ScoreDirector};
12use tracing::info;
13
14use crate::manager::{SolverRuntime, SolverTerminalReason};
15use crate::phase::{Phase, PhaseSequence};
16use crate::scope::{ProgressCallback, SolverProgressKind, SolverProgressRef, SolverScope};
17use crate::solver::Solver;
18use crate::stats::{format_duration, whole_units_per_second};
19use crate::termination::{
20 BestScoreTermination, OrTermination, StepCountTermination, Termination, TimeTermination,
21 UnimprovedStepCountTermination, UnimprovedTimeTermination,
22};
23
24pub enum AnyTermination<S: PlanningSolution, D: Director<S>> {
29 Default(OrTermination<(TimeTermination,), S, D>),
30 WithBestScore(OrTermination<(TimeTermination, BestScoreTermination<S::Score>), S, D>),
31 WithStepCount(OrTermination<(TimeTermination, StepCountTermination), S, D>),
32 WithUnimprovedStep(OrTermination<(TimeTermination, UnimprovedStepCountTermination<S>), S, D>),
33 WithUnimprovedTime(OrTermination<(TimeTermination, UnimprovedTimeTermination<S>), S, D>),
34}
35
36#[derive(Clone)]
37pub struct ChannelProgressCallback<S: PlanningSolution> {
38 runtime: SolverRuntime<S>,
39 _phantom: PhantomData<fn() -> S>,
40}
41
42impl<S: PlanningSolution> ChannelProgressCallback<S> {
43 fn new(runtime: SolverRuntime<S>) -> Self {
44 Self {
45 runtime,
46 _phantom: PhantomData,
47 }
48 }
49}
50
51impl<S: PlanningSolution> fmt::Debug for ChannelProgressCallback<S> {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 f.debug_struct("ChannelProgressCallback").finish()
54 }
55}
56
57impl<S: PlanningSolution> ProgressCallback<S> for ChannelProgressCallback<S> {
58 fn invoke(&self, progress: SolverProgressRef<'_, S>) {
59 match progress.kind {
60 SolverProgressKind::Progress => {
61 self.runtime.emit_progress(
62 progress.current_score.copied(),
63 progress.best_score.copied(),
64 progress.telemetry,
65 );
66 }
67 SolverProgressKind::BestSolution => {
68 if let (Some(solution), Some(score)) = (progress.solution, progress.best_score) {
69 self.runtime.emit_best_solution(
70 (*solution).clone(),
71 progress.current_score.copied(),
72 *score,
73 progress.telemetry,
74 );
75 }
76 }
77 }
78 }
79}
80
81impl<S: PlanningSolution, D: Director<S>> fmt::Debug for AnyTermination<S, D> {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 match self {
84 Self::Default(_) => write!(f, "AnyTermination::Default"),
85 Self::WithBestScore(_) => write!(f, "AnyTermination::WithBestScore"),
86 Self::WithStepCount(_) => write!(f, "AnyTermination::WithStepCount"),
87 Self::WithUnimprovedStep(_) => write!(f, "AnyTermination::WithUnimprovedStep"),
88 Self::WithUnimprovedTime(_) => write!(f, "AnyTermination::WithUnimprovedTime"),
89 }
90 }
91}
92
93impl<S: PlanningSolution, D: Director<S>, ProgressCb: ProgressCallback<S>>
94 Termination<S, D, ProgressCb> for AnyTermination<S, D>
95where
96 S::Score: Score,
97{
98 fn is_terminated(&self, solver_scope: &SolverScope<S, D, ProgressCb>) -> bool {
99 match self {
100 Self::Default(t) => t.is_terminated(solver_scope),
101 Self::WithBestScore(t) => t.is_terminated(solver_scope),
102 Self::WithStepCount(t) => t.is_terminated(solver_scope),
103 Self::WithUnimprovedStep(t) => t.is_terminated(solver_scope),
104 Self::WithUnimprovedTime(t) => t.is_terminated(solver_scope),
105 }
106 }
107
108 fn install_inphase_limits(&self, solver_scope: &mut SolverScope<S, D, ProgressCb>) {
109 match self {
110 Self::Default(t) => t.install_inphase_limits(solver_scope),
111 Self::WithBestScore(t) => t.install_inphase_limits(solver_scope),
112 Self::WithStepCount(t) => t.install_inphase_limits(solver_scope),
113 Self::WithUnimprovedStep(t) => t.install_inphase_limits(solver_scope),
114 Self::WithUnimprovedTime(t) => t.install_inphase_limits(solver_scope),
115 }
116 }
117}
118
119pub fn build_termination<S, C>(
121 config: &SolverConfig,
122 default_secs: u64,
123) -> (AnyTermination<S, ScoreDirector<S, C>>, Duration)
124where
125 S: PlanningSolution,
126 S::Score: Score + ParseableScore,
127 C: ConstraintSet<S, S::Score>,
128{
129 let term_config = config.termination.as_ref();
130 let time_limit = term_config
131 .and_then(|c| c.time_limit())
132 .unwrap_or(Duration::from_secs(default_secs));
133 let time = TimeTermination::new(time_limit);
134
135 let best_score_target: Option<S::Score> = term_config
136 .and_then(|c| c.best_score_limit.as_ref())
137 .and_then(|s| S::Score::parse(s).ok());
138
139 let termination = if let Some(target) = best_score_target {
140 AnyTermination::WithBestScore(OrTermination::new((
141 time,
142 BestScoreTermination::new(target),
143 )))
144 } else if let Some(step_limit) = term_config.and_then(|c| c.step_count_limit) {
145 AnyTermination::WithStepCount(OrTermination::new((
146 time,
147 StepCountTermination::new(step_limit),
148 )))
149 } else if let Some(unimproved_step_limit) =
150 term_config.and_then(|c| c.unimproved_step_count_limit)
151 {
152 AnyTermination::WithUnimprovedStep(OrTermination::new((
153 time,
154 UnimprovedStepCountTermination::<S>::new(unimproved_step_limit),
155 )))
156 } else if let Some(unimproved_time) = term_config.and_then(|c| c.unimproved_time_limit()) {
157 AnyTermination::WithUnimprovedTime(OrTermination::new((
158 time,
159 UnimprovedTimeTermination::<S>::new(unimproved_time),
160 )))
161 } else {
162 AnyTermination::Default(OrTermination::new((time,)))
163 };
164
165 (termination, time_limit)
166}
167
168pub fn log_solve_start(
169 entity_count: usize,
170 element_count: Option<usize>,
171 candidate_count: Option<usize>,
172) {
173 match (element_count, candidate_count) {
174 (Some(element_count), None) => {
175 info!(
176 event = "solve_start",
177 entity_count = entity_count,
178 element_count = element_count,
179 solve_shape = "list",
180 );
181 }
182 (None, Some(candidate_count)) => {
183 info!(
184 event = "solve_start",
185 entity_count = entity_count,
186 candidate_count = candidate_count,
187 solve_shape = "scalar",
188 );
189 }
190 _ => {
191 panic!("log_solve_start requires exactly one solve scale: list elements or scalar candidates");
192 }
193 }
194}
195
196fn load_solver_config_from(path: impl AsRef<Path>) -> SolverConfig {
197 SolverConfig::load(path).unwrap_or_default()
198}
199
200fn load_solver_config() -> SolverConfig {
201 load_solver_config_from("solver.toml")
202}
203
204#[allow(clippy::too_many_arguments)]
205pub fn run_solver<S, C, P>(
206 solution: S,
207 constraints_fn: fn() -> C,
208 descriptor: fn() -> SolutionDescriptor,
209 entity_count_by_descriptor: fn(&S, usize) -> usize,
210 runtime: SolverRuntime<S>,
211 default_time_limit_secs: u64,
212 is_trivial: fn(&S) -> bool,
213 log_scale: fn(&S),
214 build_phases: fn(&SolverConfig) -> PhaseSequence<P>,
215) -> S
216where
217 S: PlanningSolution,
218 S::Score: Score + ParseableScore,
219 C: ConstraintSet<S, S::Score>,
220 P: Send + std::fmt::Debug,
221 PhaseSequence<P>: Phase<S, ScoreDirector<S, C>, ChannelProgressCallback<S>>,
222{
223 let config = load_solver_config();
224 run_solver_with_config(
225 solution,
226 constraints_fn,
227 descriptor,
228 entity_count_by_descriptor,
229 runtime,
230 config,
231 default_time_limit_secs,
232 is_trivial,
233 log_scale,
234 build_phases,
235 )
236}
237
238#[allow(clippy::too_many_arguments)]
239pub fn run_solver_with_config<S, C, P>(
240 solution: S,
241 constraints_fn: fn() -> C,
242 descriptor: fn() -> SolutionDescriptor,
243 entity_count_by_descriptor: fn(&S, usize) -> usize,
244 runtime: SolverRuntime<S>,
245 config: SolverConfig,
246 default_time_limit_secs: u64,
247 is_trivial: fn(&S) -> bool,
248 log_scale: fn(&S),
249 build_phases: fn(&SolverConfig) -> PhaseSequence<P>,
250) -> S
251where
252 S: PlanningSolution,
253 S::Score: Score + ParseableScore,
254 C: ConstraintSet<S, S::Score>,
255 P: Send + std::fmt::Debug,
256 PhaseSequence<P>: Phase<S, ScoreDirector<S, C>, ChannelProgressCallback<S>>,
257{
258 log_scale(&solution);
259 let trivial = is_trivial(&solution);
260
261 let constraints = constraints_fn();
262 let director = ScoreDirector::with_descriptor(
263 solution,
264 constraints,
265 descriptor(),
266 entity_count_by_descriptor,
267 );
268
269 if trivial {
270 let mut solver_scope = SolverScope::new(director);
271 solver_scope = solver_scope.with_runtime(Some(runtime));
272 if let Some(seed) = config.random_seed {
273 solver_scope = solver_scope.with_seed(seed);
274 }
275 solver_scope.start_solving();
276 let score = solver_scope.calculate_score();
277 let solution = solver_scope.score_director().clone_working_solution();
278 solver_scope.set_best_solution(solution.clone(), score);
279 solver_scope.report_best_solution();
280 solver_scope.pause_if_requested();
281 info!(event = "solve_end", score = %score);
282 let telemetry = solver_scope.stats().snapshot();
283 if runtime.is_cancel_requested() {
284 runtime.emit_cancelled(Some(score), Some(score), telemetry);
285 } else {
286 runtime.emit_completed(
287 solution.clone(),
288 Some(score),
289 score,
290 telemetry,
291 SolverTerminalReason::Completed,
292 );
293 }
294 return solution;
295 }
296
297 let (termination, time_limit) = build_termination::<S, C>(&config, default_time_limit_secs);
298
299 let callback = ChannelProgressCallback::new(runtime);
300
301 let phases = build_phases(&config);
302 let solver = Solver::new((phases,))
303 .with_config(config.clone())
304 .with_termination(termination)
305 .with_time_limit(time_limit)
306 .with_runtime(runtime)
307 .with_progress_callback(callback);
308
309 let result = solver.with_terminate(runtime.cancel_flag()).solve(director);
310
311 let crate::solver::SolveResult {
312 solution,
313 current_score,
314 best_score: final_score,
315 terminal_reason,
316 stats,
317 } = result;
318 let final_telemetry = stats.snapshot();
319 let final_move_speed = whole_units_per_second(stats.moves_evaluated, stats.elapsed());
320 match terminal_reason {
321 SolverTerminalReason::Completed | SolverTerminalReason::TerminatedByConfig => {
322 runtime.emit_completed(
323 solution.clone(),
324 current_score,
325 final_score,
326 final_telemetry,
327 terminal_reason,
328 );
329 }
330 SolverTerminalReason::Cancelled => {
331 runtime.emit_cancelled(current_score, Some(final_score), final_telemetry);
332 }
333 SolverTerminalReason::Failed => unreachable!("solver completion cannot report failure"),
334 }
335
336 info!(
337 event = "solve_end",
338 score = %final_score,
339 steps = stats.step_count,
340 moves_generated = stats.moves_generated,
341 moves_evaluated = stats.moves_evaluated,
342 moves_accepted = stats.moves_accepted,
343 score_calculations = stats.score_calculations,
344 generation_time = %format_duration(stats.generation_time()),
345 evaluation_time = %format_duration(stats.evaluation_time()),
346 moves_speed = final_move_speed,
347 acceptance_rate = format!("{:.1}%", stats.acceptance_rate() * 100.0),
348 );
349 solution
350}
351
352#[cfg(test)]
353#[path = "run_tests.rs"]
354mod tests;