1mod partitioner;
24
25use std::fmt::Debug;
26use std::marker::PhantomData;
27
28use rayon::prelude::*;
29use solverforge_core::domain::PlanningSolution;
30use solverforge_scoring::ScoreDirector;
31
32use crate::phase::Phase;
33use crate::scope::SolverScope;
34
35pub use partitioner::{FunctionalPartitioner, SolutionPartitioner, ThreadCount};
36
37#[derive(Debug, Clone)]
39pub struct PartitionedSearchConfig {
40 pub thread_count: ThreadCount,
42 pub log_progress: bool,
44}
45
46impl Default for PartitionedSearchConfig {
47 fn default() -> Self {
48 Self {
49 thread_count: ThreadCount::Auto,
50 log_progress: false,
51 }
52 }
53}
54
55pub struct PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
75where
76 S: PlanningSolution,
77 D: ScoreDirector<S>,
78 PD: ScoreDirector<S>,
79 Part: SolutionPartitioner<S>,
80 SDF: Fn(S) -> PD + Send + Sync,
81 PF: Fn() -> CP + Send + Sync,
82 CP: ChildPhases<S, PD>,
83{
84 partitioner: Part,
86
87 score_director_factory: SDF,
89
90 phase_factory: PF,
92
93 config: PartitionedSearchConfig,
95
96 _marker: PhantomData<fn(S, D, PD, CP)>,
97}
98
99impl<S, D, PD, Part, SDF, PF, CP> PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
100where
101 S: PlanningSolution,
102 D: ScoreDirector<S>,
103 PD: ScoreDirector<S>,
104 Part: SolutionPartitioner<S>,
105 SDF: Fn(S) -> PD + Send + Sync,
106 PF: Fn() -> CP + Send + Sync,
107 CP: ChildPhases<S, PD>,
108{
109 pub fn new(partitioner: Part, score_director_factory: SDF, phase_factory: PF) -> Self {
111 Self {
112 partitioner,
113 score_director_factory,
114 phase_factory,
115 config: PartitionedSearchConfig::default(),
116 _marker: PhantomData,
117 }
118 }
119
120 pub fn with_config(
122 partitioner: Part,
123 score_director_factory: SDF,
124 phase_factory: PF,
125 config: PartitionedSearchConfig,
126 ) -> Self {
127 Self {
128 partitioner,
129 score_director_factory,
130 phase_factory,
131 config,
132 _marker: PhantomData,
133 }
134 }
135}
136
137impl<S, D, PD, Part, SDF, PF, CP> Debug for PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
138where
139 S: PlanningSolution,
140 D: ScoreDirector<S>,
141 PD: ScoreDirector<S>,
142 Part: SolutionPartitioner<S> + Debug,
143 SDF: Fn(S) -> PD + Send + Sync,
144 PF: Fn() -> CP + Send + Sync,
145 CP: ChildPhases<S, PD>,
146{
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 f.debug_struct("PartitionedSearchPhase")
149 .field("partitioner", &self.partitioner)
150 .field("config", &self.config)
151 .finish()
152 }
153}
154
155impl<S, D, PD, Part, SDF, PF, CP> Phase<S, D>
156 for PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
157where
158 S: PlanningSolution + 'static,
159 D: ScoreDirector<S>,
160 PD: ScoreDirector<S> + 'static,
161 Part: SolutionPartitioner<S>,
162 SDF: Fn(S) -> PD + Send + Sync,
163 PF: Fn() -> CP + Send + Sync,
164 CP: ChildPhases<S, PD> + Send,
165{
166 fn solve(&mut self, solver_scope: &mut SolverScope<S, D>) {
167 let solution = solver_scope.score_director().working_solution().clone();
169
170 let partitions = self.partitioner.partition(&solution);
172 let partition_count = partitions.len();
173
174 if partition_count == 0 {
175 return;
176 }
177
178 let thread_count = self.config.thread_count.resolve(partition_count);
180
181 if self.config.log_progress {
182 tracing::info!(event = "phase_start", phase = "PartitionedSearch",);
183 }
184
185 let solved_partitions: Vec<S> = if thread_count == 1 || partition_count == 1 {
187 partitions
188 .into_iter()
189 .map(|p| self.solve_partition(p))
190 .collect()
191 } else {
192 partitions
193 .into_par_iter()
194 .map(|partition| {
195 let director = (self.score_director_factory)(partition);
196 let mut solver_scope = SolverScope::new(director);
197 let mut phases = (self.phase_factory)();
198 phases.solve_all(&mut solver_scope);
199 solver_scope.take_best_or_working_solution()
200 })
201 .collect()
202 };
203
204 let merged = self.partitioner.merge(&solution, solved_partitions);
206
207 let director = solver_scope.score_director_mut();
209
210 let working = director.working_solution_mut();
212 *working = merged;
213
214 solver_scope.calculate_score();
216
217 solver_scope.update_best_solution();
219
220 if self.config.log_progress {
221 if let Some(score) = solver_scope.best_score() {
222 tracing::info!(
223 event = "phase_end",
224 phase = "PartitionedSearch",
225 score = %format!("{:?}", score),
226 );
227 }
228 }
229 }
230
231 fn phase_type_name(&self) -> &'static str {
232 "PartitionedSearch"
233 }
234}
235
236impl<S, D, PD, Part, SDF, PF, CP> PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
237where
238 S: PlanningSolution,
239 D: ScoreDirector<S>,
240 PD: ScoreDirector<S>,
241 Part: SolutionPartitioner<S>,
242 SDF: Fn(S) -> PD + Send + Sync,
243 PF: Fn() -> CP + Send + Sync,
244 CP: ChildPhases<S, PD>,
245{
246 fn solve_partition(&self, partition: S) -> S {
248 let director = (self.score_director_factory)(partition);
250
251 let mut solver_scope = SolverScope::new(director);
253
254 let mut phases = (self.phase_factory)();
256 phases.solve_all(&mut solver_scope);
257
258 solver_scope.take_best_or_working_solution()
260 }
261}
262
263pub trait ChildPhases<S, D>
267where
268 S: PlanningSolution,
269 D: ScoreDirector<S>,
270{
271 fn solve_all(&mut self, solver_scope: &mut SolverScope<S, D>);
273}
274
275macro_rules! impl_child_phases_tuple {
277 ($($idx:tt: $P:ident),+) => {
278 impl<S, D, $($P),+> ChildPhases<S, D> for ($($P,)+)
279 where
280 S: PlanningSolution,
281 D: ScoreDirector<S>,
282 $($P: Phase<S, D>,)+
283 {
284 fn solve_all(&mut self, solver_scope: &mut SolverScope<S, D>) {
285 $(
286 self.$idx.solve(solver_scope);
287 )+
288 }
289 }
290 };
291}
292
293impl_child_phases_tuple!(0: P0);
294impl_child_phases_tuple!(0: P0, 1: P1);
295impl_child_phases_tuple!(0: P0, 1: P1, 2: P2);
296impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3);
297impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3, 4: P4);
298impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3, 4: P4, 5: P5);
299impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3, 4: P4, 5: P5, 6: P6);
300impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3, 4: P4, 5: P5, 6: P6, 7: P7);
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_config_default() {
308 let config = PartitionedSearchConfig::default();
309 assert_eq!(config.thread_count, ThreadCount::Auto);
310 assert!(!config.log_progress);
311 }
312}