1use std::fmt;
4use std::marker::PhantomData;
5use std::path::Path;
6use std::time::Duration;
7
8use solverforge_config::SolverConfig;
9use solverforge_core::domain::{PlanningSolution, SolutionDescriptor};
10use solverforge_core::score::{ParseableScore, Score};
11use solverforge_scoring::{ConstraintSet, Director, ScoreDirector};
12use tracing::info;
13
14use crate::manager::{SolverRuntime, SolverTerminalReason};
15use crate::phase::{Phase, PhaseSequence};
16use crate::scope::{ProgressCallback, SolverProgressKind, SolverProgressRef, SolverScope};
17use crate::solver::Solver;
18use crate::termination::{
19 BestScoreTermination, OrTermination, StepCountTermination, Termination, TimeTermination,
20 UnimprovedStepCountTermination, UnimprovedTimeTermination,
21};
22
23pub enum AnyTermination<S: PlanningSolution, D: Director<S>> {
28 Default(OrTermination<(TimeTermination,), S, D>),
29 WithBestScore(OrTermination<(TimeTermination, BestScoreTermination<S::Score>), S, D>),
30 WithStepCount(OrTermination<(TimeTermination, StepCountTermination), S, D>),
31 WithUnimprovedStep(OrTermination<(TimeTermination, UnimprovedStepCountTermination<S>), S, D>),
32 WithUnimprovedTime(OrTermination<(TimeTermination, UnimprovedTimeTermination<S>), S, D>),
33}
34
35#[derive(Clone)]
36pub struct ChannelProgressCallback<S: PlanningSolution> {
37 runtime: SolverRuntime<S>,
38 _phantom: PhantomData<fn() -> S>,
39}
40
41impl<S: PlanningSolution> ChannelProgressCallback<S> {
42 fn new(runtime: SolverRuntime<S>) -> Self {
43 Self {
44 runtime,
45 _phantom: PhantomData,
46 }
47 }
48}
49
50impl<S: PlanningSolution> fmt::Debug for ChannelProgressCallback<S> {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.debug_struct("ChannelProgressCallback").finish()
53 }
54}
55
56impl<S: PlanningSolution> ProgressCallback<S> for ChannelProgressCallback<S> {
57 fn invoke(&self, progress: SolverProgressRef<'_, S>) {
58 match progress.kind {
59 SolverProgressKind::Progress => {
60 self.runtime.emit_progress(
61 progress.current_score.copied(),
62 progress.best_score.copied(),
63 progress.telemetry,
64 );
65 }
66 SolverProgressKind::BestSolution => {
67 if let (Some(solution), Some(score)) = (progress.solution, progress.best_score) {
68 self.runtime.emit_best_solution(
69 (*solution).clone(),
70 progress.current_score.copied(),
71 *score,
72 progress.telemetry,
73 );
74 }
75 }
76 }
77 }
78}
79
80impl<S: PlanningSolution, D: Director<S>> fmt::Debug for AnyTermination<S, D> {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 match self {
83 Self::Default(_) => write!(f, "AnyTermination::Default"),
84 Self::WithBestScore(_) => write!(f, "AnyTermination::WithBestScore"),
85 Self::WithStepCount(_) => write!(f, "AnyTermination::WithStepCount"),
86 Self::WithUnimprovedStep(_) => write!(f, "AnyTermination::WithUnimprovedStep"),
87 Self::WithUnimprovedTime(_) => write!(f, "AnyTermination::WithUnimprovedTime"),
88 }
89 }
90}
91
92impl<S: PlanningSolution, D: Director<S>, ProgressCb: ProgressCallback<S>>
93 Termination<S, D, ProgressCb> for AnyTermination<S, D>
94where
95 S::Score: Score,
96{
97 fn is_terminated(&self, solver_scope: &SolverScope<S, D, ProgressCb>) -> bool {
98 match self {
99 Self::Default(t) => t.is_terminated(solver_scope),
100 Self::WithBestScore(t) => t.is_terminated(solver_scope),
101 Self::WithStepCount(t) => t.is_terminated(solver_scope),
102 Self::WithUnimprovedStep(t) => t.is_terminated(solver_scope),
103 Self::WithUnimprovedTime(t) => t.is_terminated(solver_scope),
104 }
105 }
106
107 fn install_inphase_limits(&self, solver_scope: &mut SolverScope<S, D, ProgressCb>) {
108 match self {
109 Self::Default(t) => t.install_inphase_limits(solver_scope),
110 Self::WithBestScore(t) => t.install_inphase_limits(solver_scope),
111 Self::WithStepCount(t) => t.install_inphase_limits(solver_scope),
112 Self::WithUnimprovedStep(t) => t.install_inphase_limits(solver_scope),
113 Self::WithUnimprovedTime(t) => t.install_inphase_limits(solver_scope),
114 }
115 }
116}
117
118pub fn build_termination<S, C>(
120 config: &SolverConfig,
121 default_secs: u64,
122) -> (AnyTermination<S, ScoreDirector<S, C>>, Duration)
123where
124 S: PlanningSolution,
125 S::Score: Score + ParseableScore,
126 C: ConstraintSet<S, S::Score>,
127{
128 let term_config = config.termination.as_ref();
129 let time_limit = term_config
130 .and_then(|c| c.time_limit())
131 .unwrap_or(Duration::from_secs(default_secs));
132 let time = TimeTermination::new(time_limit);
133
134 let best_score_target: Option<S::Score> = term_config
135 .and_then(|c| c.best_score_limit.as_ref())
136 .and_then(|s| S::Score::parse(s).ok());
137
138 let termination = if let Some(target) = best_score_target {
139 AnyTermination::WithBestScore(OrTermination::new((
140 time,
141 BestScoreTermination::new(target),
142 )))
143 } else if let Some(step_limit) = term_config.and_then(|c| c.step_count_limit) {
144 AnyTermination::WithStepCount(OrTermination::new((
145 time,
146 StepCountTermination::new(step_limit),
147 )))
148 } else if let Some(unimproved_step_limit) =
149 term_config.and_then(|c| c.unimproved_step_count_limit)
150 {
151 AnyTermination::WithUnimprovedStep(OrTermination::new((
152 time,
153 UnimprovedStepCountTermination::<S>::new(unimproved_step_limit),
154 )))
155 } else if let Some(unimproved_time) = term_config.and_then(|c| c.unimproved_time_limit()) {
156 AnyTermination::WithUnimprovedTime(OrTermination::new((
157 time,
158 UnimprovedTimeTermination::<S>::new(unimproved_time),
159 )))
160 } else {
161 AnyTermination::Default(OrTermination::new((time,)))
162 };
163
164 (termination, time_limit)
165}
166
167pub fn log_solve_start(
168 entity_count: usize,
169 element_count: Option<usize>,
170 has_standard: Option<bool>,
171 variable_count: Option<usize>,
172) {
173 info!(
174 event = "solve_start",
175 entity_count = entity_count,
176 value_count = element_count.or(variable_count).unwrap_or(0),
177 solve_shape = if element_count.is_some() {
178 "list"
179 } else if has_standard.unwrap_or(false) {
180 "standard"
181 } else {
182 "solution"
183 },
184 );
185}
186
187fn load_solver_config_from(path: impl AsRef<Path>) -> SolverConfig {
188 SolverConfig::load(path).unwrap_or_default()
189}
190
191fn load_solver_config() -> SolverConfig {
192 load_solver_config_from("solver.toml")
193}
194
195#[allow(clippy::too_many_arguments)]
196pub fn run_solver<S, C, P>(
197 solution: S,
198 constraints_fn: fn() -> C,
199 descriptor: fn() -> SolutionDescriptor,
200 entity_count_by_descriptor: fn(&S, usize) -> usize,
201 runtime: SolverRuntime<S>,
202 default_time_limit_secs: u64,
203 is_trivial: fn(&S) -> bool,
204 log_scale: fn(&S),
205 build_phases: fn(&SolverConfig) -> PhaseSequence<P>,
206) -> S
207where
208 S: PlanningSolution,
209 S::Score: Score + ParseableScore,
210 C: ConstraintSet<S, S::Score>,
211 P: Send + std::fmt::Debug,
212 PhaseSequence<P>: Phase<S, ScoreDirector<S, C>, ChannelProgressCallback<S>>,
213{
214 let config = load_solver_config();
215 run_solver_with_config(
216 solution,
217 constraints_fn,
218 descriptor,
219 entity_count_by_descriptor,
220 runtime,
221 config,
222 default_time_limit_secs,
223 is_trivial,
224 log_scale,
225 build_phases,
226 )
227}
228
229#[allow(clippy::too_many_arguments)]
230pub fn run_solver_with_config<S, C, P>(
231 solution: S,
232 constraints_fn: fn() -> C,
233 descriptor: fn() -> SolutionDescriptor,
234 entity_count_by_descriptor: fn(&S, usize) -> usize,
235 runtime: SolverRuntime<S>,
236 config: SolverConfig,
237 default_time_limit_secs: u64,
238 is_trivial: fn(&S) -> bool,
239 log_scale: fn(&S),
240 build_phases: fn(&SolverConfig) -> PhaseSequence<P>,
241) -> S
242where
243 S: PlanningSolution,
244 S::Score: Score + ParseableScore,
245 C: ConstraintSet<S, S::Score>,
246 P: Send + std::fmt::Debug,
247 PhaseSequence<P>: Phase<S, ScoreDirector<S, C>, ChannelProgressCallback<S>>,
248{
249 log_scale(&solution);
250 let trivial = is_trivial(&solution);
251
252 let constraints = constraints_fn();
253 let director = ScoreDirector::with_descriptor(
254 solution,
255 constraints,
256 descriptor(),
257 entity_count_by_descriptor,
258 );
259
260 if trivial {
261 let mut solver_scope = SolverScope::new(director);
262 solver_scope = solver_scope.with_runtime(Some(runtime));
263 if let Some(seed) = config.random_seed {
264 solver_scope = solver_scope.with_seed(seed);
265 }
266 solver_scope.start_solving();
267 let score = solver_scope.calculate_score();
268 let solution = solver_scope.score_director().clone_working_solution();
269 solver_scope.set_best_solution(solution.clone(), score);
270 solver_scope.report_best_solution();
271 solver_scope.pause_if_requested();
272 info!(event = "solve_end", score = %score);
273 let telemetry = solver_scope.stats().snapshot();
274 if runtime.is_cancel_requested() {
275 runtime.emit_cancelled(Some(score), Some(score), telemetry);
276 } else {
277 runtime.emit_completed(
278 solution.clone(),
279 Some(score),
280 score,
281 telemetry,
282 SolverTerminalReason::Completed,
283 );
284 }
285 return solution;
286 }
287
288 let (termination, time_limit) = build_termination::<S, C>(&config, default_time_limit_secs);
289
290 let callback = ChannelProgressCallback::new(runtime);
291
292 let phases = build_phases(&config);
293 let solver = Solver::new((phases,))
294 .with_config(config.clone())
295 .with_termination(termination)
296 .with_time_limit(time_limit)
297 .with_runtime(runtime)
298 .with_progress_callback(callback);
299
300 let result = solver.with_terminate(runtime.cancel_flag()).solve(director);
301
302 let crate::solver::SolveResult {
303 solution,
304 current_score,
305 best_score: final_score,
306 terminal_reason,
307 stats,
308 } = result;
309 let final_telemetry = stats.snapshot();
310 match terminal_reason {
311 SolverTerminalReason::Completed | SolverTerminalReason::TerminatedByConfig => {
312 runtime.emit_completed(
313 solution.clone(),
314 current_score,
315 final_score,
316 final_telemetry,
317 terminal_reason,
318 );
319 }
320 SolverTerminalReason::Cancelled => {
321 runtime.emit_cancelled(current_score, Some(final_score), final_telemetry);
322 }
323 SolverTerminalReason::Failed => unreachable!("solver completion cannot report failure"),
324 }
325
326 info!(
327 event = "solve_end",
328 score = %final_score,
329 steps = stats.step_count,
330 moves_evaluated = stats.moves_evaluated,
331 moves_accepted = stats.moves_accepted,
332 score_calculations = stats.score_calculations,
333 moves_speed = final_telemetry.moves_per_second,
334 acceptance_rate = format!("{:.1}%", stats.acceptance_rate() * 100.0),
335 );
336 solution
337}
338
339#[cfg(test)]
340#[path = "run_tests.rs"]
341mod tests;