1use std::fmt;
4use std::marker::PhantomData;
5use std::path::Path;
6use std::time::Duration;
7
8use solverforge_config::SolverConfig;
9use solverforge_core::domain::{PlanningSolution, SolutionDescriptor};
10use solverforge_core::score::{ParseableScore, Score};
11use solverforge_scoring::{ConstraintSet, Director, ScoreDirector};
12use tracing::info;
13
14use crate::manager::{SolverRuntime, SolverTerminalReason};
15use crate::phase::{Phase, PhaseSequence};
16use crate::scope::{ProgressCallback, SolverProgressKind, SolverProgressRef, SolverScope};
17use crate::solver::{NoTermination, Solver};
18use crate::stats::{format_duration, whole_units_per_second};
19use crate::termination::{
20 BestScoreTermination, OrTermination, StepCountTermination, Termination, TimeTermination,
21 UnimprovedStepCountTermination, UnimprovedTimeTermination,
22};
23
24pub enum AnyTermination<S: PlanningSolution, D: Director<S>> {
29 None(NoTermination),
30 Default(OrTermination<(TimeTermination,), S, D>),
31 WithBestScore(OrTermination<(TimeTermination, BestScoreTermination<S::Score>), S, D>),
32 WithStepCount(OrTermination<(TimeTermination, StepCountTermination), S, D>),
33 WithUnimprovedStep(OrTermination<(TimeTermination, UnimprovedStepCountTermination<S>), S, D>),
34 WithUnimprovedTime(OrTermination<(TimeTermination, UnimprovedTimeTermination<S>), S, D>),
35}
36
37#[derive(Clone)]
38pub struct ChannelProgressCallback<S: PlanningSolution> {
39 runtime: SolverRuntime<S>,
40 _phantom: PhantomData<fn() -> S>,
41}
42
43impl<S: PlanningSolution> ChannelProgressCallback<S> {
44 fn new(runtime: SolverRuntime<S>) -> Self {
45 Self {
46 runtime,
47 _phantom: PhantomData,
48 }
49 }
50}
51
52impl<S: PlanningSolution> fmt::Debug for ChannelProgressCallback<S> {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 f.debug_struct("ChannelProgressCallback").finish()
55 }
56}
57
58impl<S: PlanningSolution> ProgressCallback<S> for ChannelProgressCallback<S> {
59 fn invoke(&self, progress: SolverProgressRef<'_, S>) {
60 match progress.kind {
61 SolverProgressKind::Progress => {
62 self.runtime.emit_progress(
63 progress.current_score.copied(),
64 progress.best_score.copied(),
65 progress.telemetry.clone(),
66 );
67 }
68 SolverProgressKind::BestSolution => {
69 if let (Some(solution), Some(score)) = (progress.solution, progress.best_score) {
70 self.runtime.emit_best_solution(
71 (*solution).clone(),
72 progress.current_score.copied(),
73 *score,
74 progress.telemetry.clone(),
75 );
76 }
77 }
78 }
79 }
80}
81
82impl<S: PlanningSolution, D: Director<S>> fmt::Debug for AnyTermination<S, D> {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 match self {
85 Self::None(_) => write!(f, "AnyTermination::None"),
86 Self::Default(_) => write!(f, "AnyTermination::Default"),
87 Self::WithBestScore(_) => write!(f, "AnyTermination::WithBestScore"),
88 Self::WithStepCount(_) => write!(f, "AnyTermination::WithStepCount"),
89 Self::WithUnimprovedStep(_) => write!(f, "AnyTermination::WithUnimprovedStep"),
90 Self::WithUnimprovedTime(_) => write!(f, "AnyTermination::WithUnimprovedTime"),
91 }
92 }
93}
94
95impl<S: PlanningSolution, D: Director<S>, ProgressCb: ProgressCallback<S>>
96 Termination<S, D, ProgressCb> for AnyTermination<S, D>
97where
98 S::Score: Score,
99{
100 fn is_terminated(&self, solver_scope: &SolverScope<S, D, ProgressCb>) -> bool {
101 match self {
102 Self::None(t) => t.is_terminated(solver_scope),
103 Self::Default(t) => t.is_terminated(solver_scope),
104 Self::WithBestScore(t) => t.is_terminated(solver_scope),
105 Self::WithStepCount(t) => t.is_terminated(solver_scope),
106 Self::WithUnimprovedStep(t) => t.is_terminated(solver_scope),
107 Self::WithUnimprovedTime(t) => t.is_terminated(solver_scope),
108 }
109 }
110
111 fn install_inphase_limits(&self, solver_scope: &mut SolverScope<S, D, ProgressCb>) {
112 match self {
113 Self::None(t) => t.install_inphase_limits(solver_scope),
114 Self::Default(t) => t.install_inphase_limits(solver_scope),
115 Self::WithBestScore(t) => t.install_inphase_limits(solver_scope),
116 Self::WithStepCount(t) => t.install_inphase_limits(solver_scope),
117 Self::WithUnimprovedStep(t) => t.install_inphase_limits(solver_scope),
118 Self::WithUnimprovedTime(t) => t.install_inphase_limits(solver_scope),
119 }
120 }
121}
122
123pub fn build_termination<S, C>(
125 config: &SolverConfig,
126 default_secs: u64,
127) -> (AnyTermination<S, ScoreDirector<S, C>>, Option<Duration>)
128where
129 S: PlanningSolution,
130 S::Score: Score + ParseableScore,
131 C: ConstraintSet<S, S::Score>,
132{
133 let term_config = config.termination.as_ref();
134 let configured_time_limit = term_config.and_then(|c| c.time_limit());
135 let fallback_time_limit = Duration::from_secs(default_secs);
136
137 let best_score_target: Option<S::Score> = term_config
138 .and_then(|c| c.best_score_limit.as_ref())
139 .and_then(|s| S::Score::parse(s).ok());
140
141 let (termination, effective_time_limit) = if let Some(target) = best_score_target {
142 let effective_time_limit = configured_time_limit.unwrap_or(fallback_time_limit);
143 let time = TimeTermination::new(effective_time_limit);
144 (
145 AnyTermination::WithBestScore(OrTermination::new((
146 time,
147 BestScoreTermination::new(target),
148 ))),
149 Some(effective_time_limit),
150 )
151 } else if let Some(step_limit) = term_config.and_then(|c| c.step_count_limit) {
152 let effective_time_limit = configured_time_limit.unwrap_or(fallback_time_limit);
153 let time = TimeTermination::new(effective_time_limit);
154 (
155 AnyTermination::WithStepCount(OrTermination::new((
156 time,
157 StepCountTermination::new(step_limit),
158 ))),
159 Some(effective_time_limit),
160 )
161 } else if let Some(unimproved_step_limit) =
162 term_config.and_then(|c| c.unimproved_step_count_limit)
163 {
164 let effective_time_limit = configured_time_limit.unwrap_or(fallback_time_limit);
165 let time = TimeTermination::new(effective_time_limit);
166 (
167 AnyTermination::WithUnimprovedStep(OrTermination::new((
168 time,
169 UnimprovedStepCountTermination::<S>::new(unimproved_step_limit),
170 ))),
171 Some(effective_time_limit),
172 )
173 } else if let Some(unimproved_time) = term_config.and_then(|c| c.unimproved_time_limit()) {
174 let effective_time_limit = configured_time_limit.unwrap_or(fallback_time_limit);
175 let time = TimeTermination::new(effective_time_limit);
176 (
177 AnyTermination::WithUnimprovedTime(OrTermination::new((
178 time,
179 UnimprovedTimeTermination::<S>::new(unimproved_time),
180 ))),
181 Some(effective_time_limit),
182 )
183 } else if let Some(limit) = configured_time_limit {
184 let time = TimeTermination::new(limit);
185 (
186 AnyTermination::Default(OrTermination::new((time,))),
187 Some(limit),
188 )
189 } else {
190 (AnyTermination::None(NoTermination), None)
191 };
192
193 (termination, effective_time_limit)
194}
195
196pub fn log_solve_start(
197 entity_count: usize,
198 element_count: Option<usize>,
199 candidate_count: Option<usize>,
200) {
201 match (element_count, candidate_count) {
202 (Some(element_count), None) => {
203 info!(
204 event = "solve_start",
205 entity_count = entity_count,
206 element_count = element_count,
207 solve_shape = "list",
208 );
209 }
210 (None, Some(candidate_count)) => {
211 info!(
212 event = "solve_start",
213 entity_count = entity_count,
214 candidate_count = candidate_count,
215 solve_shape = "scalar",
216 );
217 }
218 _ => {
219 panic!("log_solve_start requires exactly one solve scale: list elements or scalar candidates");
220 }
221 }
222}
223
224fn load_solver_config_from(path: impl AsRef<Path>) -> SolverConfig {
225 SolverConfig::load(path).unwrap_or_default()
226}
227
228fn load_solver_config() -> SolverConfig {
229 load_solver_config_from("solver.toml")
230}
231
232#[allow(clippy::too_many_arguments)]
233pub fn run_solver<S, C, P>(
234 solution: S,
235 constraints_fn: fn() -> C,
236 descriptor: fn() -> SolutionDescriptor,
237 entity_count_by_descriptor: fn(&S, usize) -> usize,
238 runtime: SolverRuntime<S>,
239 default_time_limit_secs: u64,
240 is_trivial: fn(&S) -> bool,
241 log_scale: fn(&S),
242 build_phases: fn(&SolverConfig) -> PhaseSequence<P>,
243) -> S
244where
245 S: PlanningSolution,
246 S::Score: Score + ParseableScore,
247 C: ConstraintSet<S, S::Score>,
248 P: Send + std::fmt::Debug,
249 PhaseSequence<P>: Phase<S, ScoreDirector<S, C>, ChannelProgressCallback<S>>,
250{
251 let config = load_solver_config();
252 run_solver_with_config(
253 solution,
254 constraints_fn,
255 descriptor,
256 entity_count_by_descriptor,
257 runtime,
258 config,
259 default_time_limit_secs,
260 is_trivial,
261 log_scale,
262 build_phases,
263 )
264}
265
266#[allow(clippy::too_many_arguments)]
267pub fn run_solver_with_config<S, C, P>(
268 solution: S,
269 constraints_fn: fn() -> C,
270 descriptor: fn() -> SolutionDescriptor,
271 entity_count_by_descriptor: fn(&S, usize) -> usize,
272 runtime: SolverRuntime<S>,
273 config: SolverConfig,
274 default_time_limit_secs: u64,
275 is_trivial: fn(&S) -> bool,
276 log_scale: fn(&S),
277 build_phases: fn(&SolverConfig) -> PhaseSequence<P>,
278) -> S
279where
280 S: PlanningSolution,
281 S::Score: Score + ParseableScore,
282 C: ConstraintSet<S, S::Score>,
283 P: Send + std::fmt::Debug,
284 PhaseSequence<P>: Phase<S, ScoreDirector<S, C>, ChannelProgressCallback<S>>,
285{
286 log_scale(&solution);
287 let trivial = is_trivial(&solution);
288
289 let constraints = constraints_fn();
290 let director = ScoreDirector::with_descriptor(
291 solution,
292 constraints,
293 descriptor(),
294 entity_count_by_descriptor,
295 );
296
297 if trivial {
298 let mut solver_scope = SolverScope::new(director);
299 solver_scope = solver_scope.with_runtime(Some(runtime));
300 if let Some(seed) = config.random_seed {
301 solver_scope = solver_scope.with_seed(seed);
302 }
303 solver_scope.start_solving();
304 let score = solver_scope.calculate_score();
305 let solution = solver_scope.score_director().clone_working_solution();
306 solver_scope.set_best_solution(solution.clone(), score);
307 solver_scope.report_best_solution();
308 solver_scope.pause_if_requested();
309 info!(event = "solve_end", score = %score);
310 let telemetry = solver_scope.stats().snapshot();
311 if runtime.is_cancel_requested() {
312 runtime.emit_cancelled(Some(score), Some(score), telemetry);
313 } else {
314 runtime.emit_completed(
315 solution.clone(),
316 Some(score),
317 score,
318 telemetry,
319 SolverTerminalReason::Completed,
320 );
321 }
322 return solution;
323 }
324
325 let (termination, time_limit) = build_termination::<S, C>(&config, default_time_limit_secs);
326
327 let callback = ChannelProgressCallback::new(runtime);
328
329 let phases = build_phases(&config);
330 let mut solver = Solver::new((phases,))
331 .with_config(config.clone())
332 .with_termination(termination)
333 .with_runtime(runtime)
334 .with_progress_callback(callback);
335 if let Some(time_limit) = time_limit {
336 solver = solver.with_time_limit(time_limit);
337 }
338
339 let result = solver.with_terminate(runtime.cancel_flag()).solve(director);
340
341 let crate::solver::SolveResult {
342 solution,
343 current_score,
344 best_score: final_score,
345 terminal_reason,
346 stats,
347 } = result;
348 let final_telemetry = stats.snapshot();
349 let final_move_speed = whole_units_per_second(stats.moves_evaluated, stats.elapsed());
350 match terminal_reason {
351 SolverTerminalReason::Completed | SolverTerminalReason::TerminatedByConfig => {
352 runtime.emit_completed(
353 solution.clone(),
354 current_score,
355 final_score,
356 final_telemetry,
357 terminal_reason,
358 );
359 }
360 SolverTerminalReason::Cancelled => {
361 runtime.emit_cancelled(current_score, Some(final_score), final_telemetry);
362 }
363 SolverTerminalReason::Failed => unreachable!("solver completion cannot report failure"),
364 }
365
366 info!(
367 event = "solve_end",
368 score = %final_score,
369 steps = stats.step_count,
370 moves_generated = stats.moves_generated,
371 moves_evaluated = stats.moves_evaluated,
372 moves_accepted = stats.moves_accepted,
373 score_calculations = stats.score_calculations,
374 generation_time = %format_duration(stats.generation_time()),
375 evaluation_time = %format_duration(stats.evaluation_time()),
376 moves_speed = final_move_speed,
377 acceptance_rate = format!("{:.1}%", stats.acceptance_rate() * 100.0),
378 );
379 solution
380}
381
382#[cfg(test)]
383#[path = "run_tests.rs"]
384mod tests;