1mod partitioner;
24
25use std::fmt::Debug;
26use std::marker::PhantomData;
27
28use rayon::prelude::*;
29use solverforge_core::domain::PlanningSolution;
30use solverforge_scoring::Director;
31
32use crate::phase::Phase;
33use crate::scope::BestSolutionCallback;
34use crate::scope::SolverScope;
35
36pub use partitioner::{FunctionalPartitioner, SolutionPartitioner, ThreadCount};
37
38#[derive(Debug, Clone)]
40pub struct PartitionedSearchConfig {
41 pub thread_count: ThreadCount,
43 pub log_progress: bool,
45}
46
47impl Default for PartitionedSearchConfig {
48 fn default() -> Self {
49 Self {
50 thread_count: ThreadCount::Auto,
51 log_progress: false,
52 }
53 }
54}
55
56pub struct PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
76where
77 S: PlanningSolution,
78 D: Director<S>,
79 PD: Director<S>,
80 Part: SolutionPartitioner<S>,
81 SDF: Fn(S) -> PD + Send + Sync,
82 PF: Fn() -> CP + Send + Sync,
83 CP: ChildPhases<S, PD>,
84{
85 partitioner: Part,
87
88 score_director_factory: SDF,
90
91 phase_factory: PF,
93
94 config: PartitionedSearchConfig,
96
97 _marker: PhantomData<fn(S, D, PD, CP)>,
98}
99
100impl<S, D, PD, Part, SDF, PF, CP> PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
101where
102 S: PlanningSolution,
103 D: Director<S>,
104 PD: Director<S>,
105 Part: SolutionPartitioner<S>,
106 SDF: Fn(S) -> PD + Send + Sync,
107 PF: Fn() -> CP + Send + Sync,
108 CP: ChildPhases<S, PD>,
109{
110 pub fn new(partitioner: Part, score_director_factory: SDF, phase_factory: PF) -> Self {
112 Self {
113 partitioner,
114 score_director_factory,
115 phase_factory,
116 config: PartitionedSearchConfig::default(),
117 _marker: PhantomData,
118 }
119 }
120
121 pub fn with_config(
123 partitioner: Part,
124 score_director_factory: SDF,
125 phase_factory: PF,
126 config: PartitionedSearchConfig,
127 ) -> Self {
128 Self {
129 partitioner,
130 score_director_factory,
131 phase_factory,
132 config,
133 _marker: PhantomData,
134 }
135 }
136}
137
138impl<S, D, PD, Part, SDF, PF, CP> Debug for PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
139where
140 S: PlanningSolution,
141 D: Director<S>,
142 PD: Director<S>,
143 Part: SolutionPartitioner<S> + Debug,
144 SDF: Fn(S) -> PD + Send + Sync,
145 PF: Fn() -> CP + Send + Sync,
146 CP: ChildPhases<S, PD>,
147{
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 f.debug_struct("PartitionedSearchPhase")
150 .field("partitioner", &self.partitioner)
151 .field("config", &self.config)
152 .finish()
153 }
154}
155
156impl<S, D, BestCb, PD, Part, SDF, PF, CP> Phase<S, D, BestCb>
157 for PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
158where
159 S: PlanningSolution + 'static,
160 D: Director<S>,
161 BestCb: BestSolutionCallback<S>,
162 PD: Director<S> + 'static,
163 Part: SolutionPartitioner<S>,
164 SDF: Fn(S) -> PD + Send + Sync,
165 PF: Fn() -> CP + Send + Sync,
166 CP: ChildPhases<S, PD> + Send,
167{
168 fn solve(&mut self, solver_scope: &mut SolverScope<S, D, BestCb>) {
169 let solution = solver_scope.score_director().working_solution().clone();
171
172 let partitions = self.partitioner.partition(&solution);
174 let partition_count = partitions.len();
175
176 if partition_count == 0 {
177 return;
178 }
179
180 let thread_count = self.config.thread_count.resolve(partition_count);
182
183 if self.config.log_progress {
184 tracing::info!(event = "phase_start", phase = "PartitionedSearch",);
185 }
186
187 let solved_partitions: Vec<S> = if thread_count == 1 || partition_count == 1 {
189 partitions
190 .into_iter()
191 .map(|p| self.solve_partition(p))
192 .collect()
193 } else {
194 partitions
195 .into_par_iter()
196 .map(|partition| {
197 let director = (self.score_director_factory)(partition);
198 let mut solver_scope = SolverScope::new(director);
199 let mut phases = (self.phase_factory)();
200 phases.solve_all(&mut solver_scope);
201 solver_scope.take_best_or_working_solution()
202 })
203 .collect()
204 };
205
206 let merged = self.partitioner.merge(&solution, solved_partitions);
208
209 let director = solver_scope.score_director_mut();
211
212 let working = director.working_solution_mut();
214 *working = merged;
215
216 solver_scope.calculate_score();
218
219 solver_scope.update_best_solution();
221
222 if self.config.log_progress {
223 if let Some(score) = solver_scope.best_score() {
224 tracing::info!(
225 event = "phase_end",
226 phase = "PartitionedSearch",
227 score = %format!("{:?}", score),
228 );
229 }
230 }
231 }
232
233 fn phase_type_name(&self) -> &'static str {
234 "PartitionedSearch"
235 }
236}
237
238impl<S, D, PD, Part, SDF, PF, CP> PartitionedSearchPhase<S, D, PD, Part, SDF, PF, CP>
239where
240 S: PlanningSolution,
241 D: Director<S>,
242 PD: Director<S>,
243 Part: SolutionPartitioner<S>,
244 SDF: Fn(S) -> PD + Send + Sync,
245 PF: Fn() -> CP + Send + Sync,
246 CP: ChildPhases<S, PD>,
247{
248 fn solve_partition(&self, partition: S) -> S {
250 let director = (self.score_director_factory)(partition);
252
253 let mut solver_scope = SolverScope::new(director);
255
256 let mut phases = (self.phase_factory)();
258 phases.solve_all(&mut solver_scope);
259
260 solver_scope.take_best_or_working_solution()
262 }
263}
264
265pub trait ChildPhases<S, D>
269where
270 S: PlanningSolution,
271 D: Director<S>,
272{
273 fn solve_all(&mut self, solver_scope: &mut SolverScope<S, D>);
275}
276
277macro_rules! impl_child_phases_tuple {
279 ($($idx:tt: $P:ident),+) => {
280 impl<S, D, $($P),+> ChildPhases<S, D> for ($($P,)+)
281 where
282 S: PlanningSolution,
283 D: Director<S>,
284 $($P: Phase<S, D>,)+
285 {
286 fn solve_all(&mut self, solver_scope: &mut SolverScope<S, D>) {
287 $(
288 self.$idx.solve(solver_scope);
289 )+
290 }
291 }
292 };
293}
294
295impl_child_phases_tuple!(0: P0);
296impl_child_phases_tuple!(0: P0, 1: P1);
297impl_child_phases_tuple!(0: P0, 1: P1, 2: P2);
298impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3);
299impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3, 4: P4);
300impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3, 4: P4, 5: P5);
301impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3, 4: P4, 5: P5, 6: P6);
302impl_child_phases_tuple!(0: P0, 1: P1, 2: P2, 3: P3, 4: P4, 5: P5, 6: P6, 7: P7);
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_config_default() {
310 let config = PartitionedSearchConfig::default();
311 assert_eq!(config.thread_count, ThreadCount::Auto);
312 assert!(!config.log_progress);
313 }
314}