1use std::fmt;
4use std::marker::PhantomData;
5use std::sync::atomic::AtomicBool;
6use std::time::Duration;
7
8use solverforge_config::SolverConfig;
9use solverforge_core::domain::{PlanningSolution, SolutionDescriptor};
10use solverforge_core::score::{ParseableScore, Score};
11use solverforge_scoring::{ConstraintSet, Director, ScoreDirector};
12use tokio::sync::mpsc;
13use tracing::info;
14
15use crate::manager::SolverEvent;
16use crate::phase::{Phase, PhaseSequence};
17use crate::scope::{ProgressCallback, SolverProgressKind, SolverProgressRef, SolverScope};
18use crate::solver::Solver;
19use crate::termination::{
20 BestScoreTermination, OrTermination, StepCountTermination, Termination, TimeTermination,
21 UnimprovedStepCountTermination, UnimprovedTimeTermination,
22};
23
24pub enum AnyTermination<S: PlanningSolution, D: Director<S>> {
29 Default(OrTermination<(TimeTermination,), S, D>),
30 WithBestScore(OrTermination<(TimeTermination, BestScoreTermination<S::Score>), S, D>),
31 WithStepCount(OrTermination<(TimeTermination, StepCountTermination), S, D>),
32 WithUnimprovedStep(OrTermination<(TimeTermination, UnimprovedStepCountTermination<S>), S, D>),
33 WithUnimprovedTime(OrTermination<(TimeTermination, UnimprovedTimeTermination<S>), S, D>),
34}
35
36#[derive(Clone)]
37pub struct ChannelProgressCallback<S: PlanningSolution> {
38 sender: mpsc::UnboundedSender<SolverEvent<S>>,
39 _phantom: PhantomData<fn() -> S>,
40}
41
42impl<S: PlanningSolution> ChannelProgressCallback<S> {
43 fn new(sender: mpsc::UnboundedSender<SolverEvent<S>>) -> Self {
44 Self {
45 sender,
46 _phantom: PhantomData,
47 }
48 }
49}
50
51impl<S: PlanningSolution> fmt::Debug for ChannelProgressCallback<S> {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 f.debug_struct("ChannelProgressCallback").finish()
54 }
55}
56
57impl<S: PlanningSolution> ProgressCallback<S> for ChannelProgressCallback<S> {
58 fn invoke(&self, progress: SolverProgressRef<'_, S>) {
59 match progress.kind {
60 SolverProgressKind::Progress => {
61 let _ = self.sender.send(SolverEvent::Progress {
62 current_score: progress.current_score.cloned(),
63 best_score: progress.best_score.cloned(),
64 telemetry: progress.telemetry,
65 });
66 }
67 SolverProgressKind::BestSolution => {
68 if let (Some(solution), Some(score)) = (progress.solution, progress.best_score) {
69 let _ = self.sender.send(SolverEvent::BestSolution {
70 solution: (*solution).clone(),
71 score: *score,
72 telemetry: progress.telemetry,
73 });
74 }
75 }
76 }
77 }
78}
79
80impl<S: PlanningSolution, D: Director<S>> fmt::Debug for AnyTermination<S, D> {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 match self {
83 Self::Default(_) => write!(f, "AnyTermination::Default"),
84 Self::WithBestScore(_) => write!(f, "AnyTermination::WithBestScore"),
85 Self::WithStepCount(_) => write!(f, "AnyTermination::WithStepCount"),
86 Self::WithUnimprovedStep(_) => write!(f, "AnyTermination::WithUnimprovedStep"),
87 Self::WithUnimprovedTime(_) => write!(f, "AnyTermination::WithUnimprovedTime"),
88 }
89 }
90}
91
92impl<S: PlanningSolution, D: Director<S>, ProgressCb: ProgressCallback<S>>
93 Termination<S, D, ProgressCb> for AnyTermination<S, D>
94where
95 S::Score: Score,
96{
97 fn is_terminated(&self, solver_scope: &SolverScope<S, D, ProgressCb>) -> bool {
98 match self {
99 Self::Default(t) => t.is_terminated(solver_scope),
100 Self::WithBestScore(t) => t.is_terminated(solver_scope),
101 Self::WithStepCount(t) => t.is_terminated(solver_scope),
102 Self::WithUnimprovedStep(t) => t.is_terminated(solver_scope),
103 Self::WithUnimprovedTime(t) => t.is_terminated(solver_scope),
104 }
105 }
106
107 fn install_inphase_limits(&self, solver_scope: &mut SolverScope<S, D, ProgressCb>) {
108 match self {
109 Self::Default(t) => t.install_inphase_limits(solver_scope),
110 Self::WithBestScore(t) => t.install_inphase_limits(solver_scope),
111 Self::WithStepCount(t) => t.install_inphase_limits(solver_scope),
112 Self::WithUnimprovedStep(t) => t.install_inphase_limits(solver_scope),
113 Self::WithUnimprovedTime(t) => t.install_inphase_limits(solver_scope),
114 }
115 }
116}
117
118pub fn build_termination<S, C>(
120 config: &SolverConfig,
121 default_secs: u64,
122) -> (AnyTermination<S, ScoreDirector<S, C>>, Duration)
123where
124 S: PlanningSolution,
125 S::Score: Score + ParseableScore,
126 C: ConstraintSet<S, S::Score>,
127{
128 let term_config = config.termination.as_ref();
129 let time_limit = term_config
130 .and_then(|c| c.time_limit())
131 .unwrap_or(Duration::from_secs(default_secs));
132 let time = TimeTermination::new(time_limit);
133
134 let best_score_target: Option<S::Score> = term_config
135 .and_then(|c| c.best_score_limit.as_ref())
136 .and_then(|s| S::Score::parse(s).ok());
137
138 let termination = if let Some(target) = best_score_target {
139 AnyTermination::WithBestScore(OrTermination::new((
140 time,
141 BestScoreTermination::new(target),
142 )))
143 } else if let Some(step_limit) = term_config.and_then(|c| c.step_count_limit) {
144 AnyTermination::WithStepCount(OrTermination::new((
145 time,
146 StepCountTermination::new(step_limit),
147 )))
148 } else if let Some(unimproved_step_limit) =
149 term_config.and_then(|c| c.unimproved_step_count_limit)
150 {
151 AnyTermination::WithUnimprovedStep(OrTermination::new((
152 time,
153 UnimprovedStepCountTermination::<S>::new(unimproved_step_limit),
154 )))
155 } else if let Some(unimproved_time) = term_config.and_then(|c| c.unimproved_time_limit()) {
156 AnyTermination::WithUnimprovedTime(OrTermination::new((
157 time,
158 UnimprovedTimeTermination::<S>::new(unimproved_time),
159 )))
160 } else {
161 AnyTermination::Default(OrTermination::new((time,)))
162 };
163
164 (termination, time_limit)
165}
166
167pub fn log_solve_start(
168 entity_count: usize,
169 element_count: Option<usize>,
170 has_standard: Option<bool>,
171 variable_count: Option<usize>,
172) {
173 info!(
174 event = "solve_start",
175 entity_count = entity_count,
176 value_count = element_count.or(variable_count).unwrap_or(0),
177 solve_shape = if element_count.is_some() {
178 "list"
179 } else if has_standard.unwrap_or(false) {
180 "standard"
181 } else {
182 "solution"
183 },
184 );
185}
186
187#[allow(clippy::too_many_arguments)]
188pub fn run_solver<S, C, P>(
189 solution: S,
190 constraints_fn: fn() -> C,
191 descriptor: fn() -> SolutionDescriptor,
192 entity_count_by_descriptor: fn(&S, usize) -> usize,
193 terminate: Option<&AtomicBool>,
194 sender: mpsc::UnboundedSender<SolverEvent<S>>,
195 default_time_limit_secs: u64,
196 is_trivial: fn(&S) -> bool,
197 log_scale: fn(&S),
198 build_phases: fn(&SolverConfig) -> PhaseSequence<P>,
199) -> S
200where
201 S: PlanningSolution,
202 S::Score: Score + ParseableScore,
203 C: ConstraintSet<S, S::Score>,
204 P: Send + std::fmt::Debug,
205 PhaseSequence<P>: Phase<S, ScoreDirector<S, C>, ChannelProgressCallback<S>>,
206{
207 let config = SolverConfig::load("solver.toml").unwrap_or_default();
208
209 log_scale(&solution);
210 let trivial = is_trivial(&solution);
211
212 let constraints = constraints_fn();
213 let director = ScoreDirector::with_descriptor(
214 solution,
215 constraints,
216 descriptor(),
217 entity_count_by_descriptor,
218 );
219
220 if trivial {
221 let mut solver_scope = SolverScope::new(director);
222 solver_scope.start_solving();
223 let score = solver_scope.calculate_score();
224 let solution = solver_scope.score_director().clone_working_solution();
225 solver_scope.set_best_solution(solution.clone(), score);
226 info!(event = "solve_end", score = %score);
227 let telemetry = solver_scope.stats().snapshot();
228 let _ = sender.send(SolverEvent::Finished {
229 solution: solution.clone(),
230 score,
231 telemetry,
232 });
233 return solution;
234 }
235
236 let (termination, time_limit) = build_termination::<S, C>(&config, default_time_limit_secs);
237
238 let callback = ChannelProgressCallback::new(sender.clone());
239
240 let phases = build_phases(&config);
241 let solver = Solver::new((phases,))
242 .with_termination(termination)
243 .with_time_limit(time_limit)
244 .with_progress_callback(callback);
245
246 let result = if let Some(flag) = terminate {
247 solver.with_terminate(flag).solve(director)
248 } else {
249 solver.solve(director)
250 };
251
252 let crate::solver::SolveResult {
253 solution,
254 best_score: final_score,
255 stats,
256 } = result;
257 let final_telemetry = stats.snapshot();
258 let _ = sender.send(SolverEvent::Finished {
259 solution: solution.clone(),
260 score: final_score,
261 telemetry: final_telemetry,
262 });
263
264 info!(
265 event = "solve_end",
266 score = %final_score,
267 steps = stats.step_count,
268 moves_evaluated = stats.moves_evaluated,
269 moves_accepted = stats.moves_accepted,
270 score_calculations = stats.score_calculations,
271 moves_speed = final_telemetry.moves_per_second,
272 acceptance_rate = format!("{:.1}%", stats.acceptance_rate() * 100.0),
273 );
274 solution
275}