1mod partitioner;
24
25use std::fmt::Debug;
26use std::marker::PhantomData;
27
28use rayon::prelude::*;
29use solverforge_core::domain::PlanningSolution;
30use solverforge_scoring::ScoreDirector;
31
32use crate::phase::Phase;
33use crate::scope::SolverScope;
34
35pub use partitioner::{FunctionalPartitioner, SolutionPartitioner, ThreadCount};
36
37#[derive(Debug, Clone)]
39pub struct PartitionedSearchConfig {
40 pub thread_count: ThreadCount,
42 pub log_progress: bool,
44}
45
46impl Default for PartitionedSearchConfig {
47 fn default() -> Self {
48 Self {
49 thread_count: ThreadCount::Auto,
50 log_progress: false,
51 }
52 }
53}
54
55pub struct PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
75where
76 S: PlanningSolution,
77 D: ScoreDirector<S>,
78 PD: ScoreDirector<S>,
79 Part: SolutionPartitioner<S>,
80 SDF: Fn(S) -> PD + Send + Sync,
81 PF: Fn() -> CP + Send + Sync,
82 CP: ChildPhases<S, PD>,
83{
84 partitioner: Part,
86
87 score_director_factory: SDF,
89
90 phase_factory: PF,
92
93 config: PartitionedSearchConfig,
95
96 _marker: PhantomData<fn(S, D, PD, CP)>,
97}
98
99impl<S, D, PD, Part, SDF, PF, CP> PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
100where
101 S: PlanningSolution,
102 D: ScoreDirector<S>,
103 PD: ScoreDirector<S>,
104 Part: SolutionPartitioner<S>,
105 SDF: Fn(S) -> PD + Send + Sync,
106 PF: Fn() -> CP + Send + Sync,
107 CP: ChildPhases<S, PD>,
108{
109 pub fn new(partitioner: Part, score_director_factory: SDF, phase_factory: PF) -> Self {
111 Self {
112 partitioner,
113 score_director_factory,
114 phase_factory,
115 config: PartitionedSearchConfig::default(),
116 _marker: PhantomData,
117 }
118 }
119
120 pub fn with_config(
122 partitioner: Part,
123 score_director_factory: SDF,
124 phase_factory: PF,
125 config: PartitionedSearchConfig,
126 ) -> Self {
127 Self {
128 partitioner,
129 score_director_factory,
130 phase_factory,
131 config,
132 _marker: PhantomData,
133 }
134 }
135}
136
137impl<S, D, PD, Part, SDF, PF, CP> Debug for PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
138where
139 S: PlanningSolution,
140 D: ScoreDirector<S>,
141 PD: ScoreDirector<S>,
142 Part: SolutionPartitioner<S> + Debug,
143 SDF: Fn(S) -> PD + Send + Sync,
144 PF: Fn() -> CP + Send + Sync,
145 CP: ChildPhases<S, PD>,
146{
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 f.debug_struct("PartitionedSearchPhase")
149 .field("partitioner", &self.partitioner)
150 .field("config", &self.config)
151 .finish()
152 }
153}
154
155impl<S, D, PD, Part, SDF, PF, CP> Phase<S, D>
156 for PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
157where
158 S: PlanningSolution + 'static,
159 D: ScoreDirector<S>,
160 PD: ScoreDirector<S> + 'static,
161 Part: SolutionPartitioner<S>,
162 SDF: Fn(S) -> PD + Send + Sync,
163 PF: Fn() -> CP + Send + Sync,
164 CP: ChildPhases<S, PD> + Send,
165{
166 fn solve(&mut self, solver_scope: &mut SolverScope<S, D>) {
167 let solution = solver_scope.score_director().working_solution().clone();
169
170 let partitions = self.partitioner.partition(&solution);
172 let partition_count = partitions.len();
173
174 if partition_count == 0 {
175 return;
176 }
177
178 let thread_count = self.config.thread_count.resolve(partition_count);
180
181 if self.config.log_progress {
182 println!(
183 "[PartitionedSearch] Solving {} partitions with {} threads",
184 partition_count, thread_count
185 );
186 }
187
188 let solved_partitions: Vec<S> = if thread_count == 1 || partition_count == 1 {
190 partitions
191 .into_iter()
192 .map(|p| self.solve_partition(p))
193 .collect()
194 } else {
195 partitions
196 .into_par_iter()
197 .map(|partition| {
198 let director = (self.score_director_factory)(partition);
199 let mut solver_scope = SolverScope::new(director);
200 let mut phases = (self.phase_factory)();
201 phases.solve_all(&mut solver_scope);
202 solver_scope.take_best_or_working_solution()
203 })
204 .collect()
205 };
206
207 let merged = self.partitioner.merge(&solution, solved_partitions);
209
210 let director = solver_scope.score_director_mut();
212
213 let working = director.working_solution_mut();
215 *working = merged;
216
217 solver_scope.calculate_score();
219
220 solver_scope.update_best_solution();
222
223 if self.config.log_progress {
224 if let Some(score) = solver_scope.best_score() {
225 println!(
226 "[PartitionedSearch] Completed with merged score: {:?}",
227 score
228 );
229 }
230 }
231 }
232
233 fn phase_type_name(&self) -> &'static str {
234 "PartitionedSearch"
235 }
236}
237
238impl<S, D, PD, Part, SDF, PF, CP> PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
239where
240 S: PlanningSolution,
241 D: ScoreDirector<S>,
242 PD: ScoreDirector<S>,
243 Part: SolutionPartitioner<S>,
244 SDF: Fn(S) -> PD + Send + Sync,
245 PF: Fn() -> CP + Send + Sync,
246 CP: ChildPhases<S, PD>,
247{
248 fn solve_partition(&self, partition: S) -> S {
250 let director = (self.score_director_factory)(partition);
252
253 let mut solver_scope = SolverScope::new(director);
255
256 let mut phases = (self.phase_factory)();
258 phases.solve_all(&mut solver_scope);
259
260 solver_scope.take_best_or_working_solution()
262 }
263}
264
265pub trait ChildPhases<S, D>
269where
270 S: PlanningSolution,
271 D: ScoreDirector<S>,
272{
273 fn solve_all(&mut self, solver_scope: &mut SolverScope<S, D>);
275}
276
277macro_rules! impl_child_phases_tuple {
279 ($($idx:tt: $P:ident),+) => {
280 impl<S, D, $($P),+> ChildPhases<S, D> for ($($P,)+)
281 where
282 S: PlanningSolution,
283 D: ScoreDirector<S>,
284 $($P: Phase<S, D>,)+
285 {
286 fn solve_all(&mut self, solver_scope: &mut SolverScope<S, D>) {
287 $(
288 self.$idx.solve(solver_scope);
289 )+
290 }
291 }
292 };
293}
294
295impl_child_phases_tuple!(0: P0);
296impl_child_phases_tuple!(0: P0, 1: P1);
297impl_child_phases_tuple!(0: P0, 1: P1, 2: P2);
298impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3);
299impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3, 4: P4);
300impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3, 4: P4, 5: P5);
301impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3, 4: P4, 5: P5, 6: P6);
302impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3, 4: P4, 5: P5, 6: P6, 7: P7);
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_config_default() {
310 let config = PartitionedSearchConfig::default();
311 assert_eq!(config.thread_count, ThreadCount::Auto);
312 assert!(!config.log_progress);
313 }
314}