Skip to main content

solverforge_solver/phase/partitioned/
phase.rs

1// PartitionedSearchPhase implementation.
2
3use std::fmt::Debug;
4use std::marker::PhantomData;
5
6use rand::RngExt;
7use rayon::prelude::*;
8use rayon::ThreadPoolBuilder;
9use solverforge_core::domain::PlanningSolution;
10use solverforge_scoring::Director;
11
12use crate::manager::SolverTerminalReason;
13use crate::phase::Phase;
14use crate::scope::ProgressCallback;
15use crate::scope::{PendingControl, SolverScope, SolverScopeChildConfig};
16
17use super::child_phases::ChildPhases;
18use super::config::PartitionedSearchConfig;
19use super::partitioner::SolutionPartitioner;
20
21enum PartitionOutcome<S> {
22    Complete(S),
23    Pause,
24    Cancelled,
25    Terminated,
26}
27
28/// Partitioned search phase that solves partitions in parallel.
29///
30/// This phase:
31/// 1. Partitions the solution using the provided partitioner
32/// 2. Creates a solver for each partition
33/// 3. Runs child phases on each partition in parallel
34/// 4. Merges the solved partitions back together
35///
36/// Each partition runs independently with its own solver scope.
37///
38/// # Type Parameters
39///
40/// * `S` - The solution type
41/// * `PD` - The score director type for partition solvers
42/// * `Part` - The partitioner type (implements `SolutionPartitioner<S>`)
43/// * `SDF` - The score director factory function type
44/// * `PF` - The phase factory function type
45/// * `CP` - The child phases type (tuple of phases)
46pub struct PartitionedSearchPhase<S, PD, Part, SDF, PF, CP>
47where
48    S: PlanningSolution,
49    PD: Director<S>,
50    Part: SolutionPartitioner<S>,
51    SDF: Fn(S) -> PD + Send + Sync,
52    PF: Fn() -> CP + Send + Sync,
53    CP: ChildPhases<S, PD>,
54{
55    // The partitioner that splits and merges solutions.
56    partitioner: Part,
57
58    // Factory for creating score directors for each partition.
59    score_director_factory: SDF,
60
61    // Factory for creating child phases for each partition.
62    phase_factory: PF,
63
64    // Configuration for this phase.
65    config: PartitionedSearchConfig,
66
67    _marker: PhantomData<(fn() -> S, fn() -> PD, fn() -> CP)>,
68}
69
70impl<S, PD, Part, SDF, PF, CP> PartitionedSearchPhase<S, PD, Part, SDF, PF, CP>
71where
72    S: PlanningSolution,
73    PD: Director<S>,
74    Part: SolutionPartitioner<S>,
75    SDF: Fn(S) -> PD + Send + Sync,
76    PF: Fn() -> CP + Send + Sync,
77    CP: ChildPhases<S, PD>,
78{
79    pub fn new(partitioner: Part, score_director_factory: SDF, phase_factory: PF) -> Self {
80        Self {
81            partitioner,
82            score_director_factory,
83            phase_factory,
84            config: PartitionedSearchConfig::default(),
85            _marker: PhantomData,
86        }
87    }
88
89    pub fn with_config(
90        partitioner: Part,
91        score_director_factory: SDF,
92        phase_factory: PF,
93        config: PartitionedSearchConfig,
94    ) -> Self {
95        Self {
96            partitioner,
97            score_director_factory,
98            phase_factory,
99            config,
100            _marker: PhantomData,
101        }
102    }
103}
104
105impl<S, PD, Part, SDF, PF, CP> Debug for PartitionedSearchPhase<S, PD, Part, SDF, PF, CP>
106where
107    S: PlanningSolution,
108    PD: Director<S>,
109    Part: SolutionPartitioner<S> + Debug,
110    SDF: Fn(S) -> PD + Send + Sync,
111    PF: Fn() -> CP + Send + Sync,
112    CP: ChildPhases<S, PD>,
113{
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("PartitionedSearchPhase")
116            .field("partitioner", &self.partitioner)
117            .field("config", &self.config)
118            .finish()
119    }
120}
121
122impl<S, D, BestCb, PD, Part, SDF, PF, CP> Phase<S, D, BestCb>
123    for PartitionedSearchPhase<S, PD, Part, SDF, PF, CP>
124where
125    S: PlanningSolution + 'static,
126    D: Director<S>,
127    BestCb: ProgressCallback<S>,
128    PD: Director<S> + 'static,
129    Part: SolutionPartitioner<S>,
130    SDF: Fn(S) -> PD + Send + Sync,
131    PF: Fn() -> CP + Send + Sync,
132    CP: ChildPhases<S, PD> + Send,
133{
134    fn solve(&mut self, solver_scope: &mut SolverScope<S, D, BestCb>) {
135        'partitioning: loop {
136            if solver_scope.should_terminate() {
137                return;
138            }
139
140            let solution = solver_scope.score_director().working_solution().clone();
141            let partitions = self.partitioner.partition(&solution);
142            let partition_count = partitions.len();
143
144            if partition_count == 0 {
145                return;
146            }
147
148            let thread_count = self.config.thread_count.resolve(partition_count);
149
150            if self.config.log_progress {
151                tracing::info!(event = "phase_start", phase = "PartitionedSearch",);
152            }
153
154            let child_seeds: Vec<u64> = (0..partition_count)
155                .map(|_| solver_scope.rng().random())
156                .collect();
157            let phase_budget = solver_scope.child_phase_budget();
158            let child_config = solver_scope.child_config(Some(&phase_budget));
159            let outcomes =
160                self.solve_partitions(partitions, thread_count, child_config, child_seeds);
161
162            let mut solved_partitions = Vec::with_capacity(outcomes.len());
163            for outcome in outcomes {
164                match outcome {
165                    PartitionOutcome::Complete(partition) => solved_partitions.push(partition),
166                    PartitionOutcome::Pause => {
167                        solver_scope.pause_if_requested();
168                        continue 'partitioning;
169                    }
170                    PartitionOutcome::Cancelled => {
171                        solver_scope.mark_cancelled();
172                        return;
173                    }
174                    PartitionOutcome::Terminated => {
175                        solver_scope.mark_terminated_by_config();
176                        return;
177                    }
178                }
179            }
180
181            if solver_scope.should_terminate() {
182                return;
183            }
184
185            let merged = self.partitioner.merge(&solution, solved_partitions);
186            solver_scope.replace_working_solution_and_reinitialize(merged);
187            solver_scope.update_best_solution();
188
189            if self.config.log_progress {
190                if let Some(score) = solver_scope.best_score() {
191                    tracing::info!(
192                        event = "phase_end",
193                        phase = "PartitionedSearch",
194                        score = %format!("{:?}", score),
195                    );
196                }
197            }
198
199            return;
200        }
201    }
202
203    fn phase_type_name(&self) -> &'static str {
204        "PartitionedSearch"
205    }
206}
207
208impl<S, PD, Part, SDF, PF, CP> PartitionedSearchPhase<S, PD, Part, SDF, PF, CP>
209where
210    S: PlanningSolution,
211    PD: Director<S>,
212    Part: SolutionPartitioner<S>,
213    SDF: Fn(S) -> PD + Send + Sync,
214    PF: Fn() -> CP + Send + Sync,
215    CP: ChildPhases<S, PD>,
216{
217    // Solves a single partition and returns the solved solution.
218    fn solve_partition<'t>(
219        &self,
220        partition: S,
221        child_config: SolverScopeChildConfig<'t, S>,
222        seed: u64,
223    ) -> PartitionOutcome<S> {
224        // Create score director for this partition
225        let director = (self.score_director_factory)(partition);
226
227        // Create solver scope
228        let mut solver_scope = child_config.build_scope(director, seed);
229        if solver_scope.should_terminate() {
230            return PartitionOutcome::Terminated;
231        }
232        solver_scope.initialize_working_solution_as_best();
233
234        // Create and run child phases
235        let mut phases = (self.phase_factory)();
236        phases.solve_all(&mut solver_scope);
237
238        match solver_scope.pending_control() {
239            PendingControl::PauseRequested => return PartitionOutcome::Pause,
240            PendingControl::CancelRequested => return PartitionOutcome::Cancelled,
241            PendingControl::ConfigTerminationRequested => return PartitionOutcome::Terminated,
242            PendingControl::Continue => {}
243        }
244        if solver_scope.yielded_to_parent() {
245            return PartitionOutcome::Pause;
246        }
247        match solver_scope.terminal_reason() {
248            SolverTerminalReason::Cancelled => return PartitionOutcome::Cancelled,
249            SolverTerminalReason::TerminatedByConfig => return PartitionOutcome::Terminated,
250            SolverTerminalReason::Completed | SolverTerminalReason::Failed => {}
251        }
252
253        PartitionOutcome::Complete(solver_scope.take_best_or_working_solution())
254    }
255
256    fn solve_partitions<'t>(
257        &self,
258        partitions: Vec<S>,
259        thread_count: usize,
260        child_config: SolverScopeChildConfig<'t, S>,
261        child_seeds: Vec<u64>,
262    ) -> Vec<PartitionOutcome<S>> {
263        if thread_count <= 1 || partitions.len() <= 1 {
264            return partitions
265                .into_iter()
266                .zip(child_seeds)
267                .map(|(partition, seed)| {
268                    self.solve_partition(partition, child_config.clone(), seed)
269                })
270                .collect();
271        }
272
273        ThreadPoolBuilder::new()
274            .num_threads(thread_count)
275            .build()
276            .expect("failed to build partitioned search rayon pool")
277            .install(|| {
278                partitions
279                    .into_par_iter()
280                    .zip(child_seeds.into_par_iter())
281                    .map(|(partition, seed)| {
282                        self.solve_partition(partition, child_config.clone(), seed)
283                    })
284                    .collect()
285            })
286    }
287}
288
289#[cfg(test)]
290#[path = "phase_tests.rs"]
291mod tests;