Skip to main content

solverforge_solver/builder/selector/
conflict_repair.rs

1use std::collections::HashSet;
2
3use solverforge_config::{
4    CompoundConflictRepairMoveSelectorConfig, ConflictRepairMoveSelectorConfig,
5};
6use solverforge_scoring::ConstraintMetadata;
7
8use crate::builder::context::{ConflictRepair, RepairLimits, ScalarVariableSlot};
9use crate::heuristic::r#move::{CompoundScalarEdit, CompoundScalarMove};
10use crate::heuristic::selector::move_selector::CandidateStore;
11use crate::planning::ScalarEdit;
12
13pub struct ConflictRepairSelector<S> {
14    config: ConflictRepairMoveSelectorConfig,
15    scalar_variables: Vec<ScalarVariableSlot<S>>,
16    repairs: Vec<ConflictRepair<S>>,
17}
18
19impl<S> ConflictRepairSelector<S> {
20    pub fn new(
21        config: ConflictRepairMoveSelectorConfig,
22        scalar_variables: Vec<ScalarVariableSlot<S>>,
23        repairs: Vec<ConflictRepair<S>>,
24    ) -> Self {
25        Self {
26            config,
27            scalar_variables,
28            repairs,
29        }
30    }
31
32    pub fn new_compound(
33        config: CompoundConflictRepairMoveSelectorConfig,
34        scalar_variables: Vec<ScalarVariableSlot<S>>,
35        repairs: Vec<ConflictRepair<S>>,
36    ) -> Self {
37        Self {
38            config: ConflictRepairMoveSelectorConfig {
39                constraints: config.constraints,
40                max_matches_per_step: config.max_matches_per_step,
41                max_repairs_per_match: config.max_repairs_per_match,
42                max_moves_per_step: config.max_moves_per_step,
43                require_hard_improvement: config.require_hard_improvement,
44                include_soft_matches: config.include_soft_matches,
45            },
46            scalar_variables,
47            repairs,
48        }
49    }
50
51    fn limits(&self) -> RepairLimits {
52        RepairLimits {
53            max_matches_per_step: self.config.max_matches_per_step,
54            max_repairs_per_match: self.config.max_repairs_per_match,
55            max_moves_per_step: self.config.max_moves_per_step,
56        }
57    }
58
59    fn variable_for_edit(&self, edit: &ScalarEdit<S>) -> Option<ScalarVariableSlot<S>> {
60        self.scalar_variables.iter().copied().find(|ctx| {
61            ctx.descriptor_index == edit.descriptor_index()
62                && ctx.variable_name == edit.variable_name()
63        })
64    }
65
66    fn validate_constraint_hardness<D>(&self, score_director: &D)
67    where
68        S: PlanningSolution,
69        D: solverforge_scoring::Director<S>,
70    {
71        for constraint_name in &self.config.constraints {
72            let metadata = score_director.constraint_metadata();
73            let Some(metadata) = resolve_configured_constraint(&metadata, constraint_name) else {
74                panic!(
75                    "conflict_repair_move_selector configured for `{constraint_name}`, but no matching scoring constraint was found"
76                );
77            };
78            assert!(
79                metadata.is_hard || self.config.include_soft_matches,
80                "conflict_repair_move_selector configured for non-hard constraint `{constraint_name}` while include_soft_matches is false"
81            );
82        }
83    }
84}
85
86fn resolve_configured_constraint<'metadata, 'constraint>(
87    metadata: &'metadata [ConstraintMetadata<'constraint>],
88    constraint_name: &str,
89) -> Option<&'metadata ConstraintMetadata<'constraint>> {
90    metadata
91        .iter()
92        .find(|metadata| metadata.full_name() == constraint_name)
93        .or_else(|| {
94            if constraint_name.contains('/') {
95                None
96            } else {
97                metadata.iter().find(|metadata| {
98                    metadata.constraint_ref.package.is_empty() && metadata.name() == constraint_name
99                })
100            }
101        })
102}
103
104impl<S> std::fmt::Debug for ConflictRepairSelector<S> {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        f.debug_struct("ConflictRepairSelector")
107            .field("constraints", &self.config.constraints)
108            .field("max_matches_per_step", &self.config.max_matches_per_step)
109            .field("max_repairs_per_match", &self.config.max_repairs_per_match)
110            .field("max_moves_per_step", &self.config.max_moves_per_step)
111            .field(
112                "require_hard_improvement",
113                &self.config.require_hard_improvement,
114            )
115            .finish()
116    }
117}
118
119pub struct ConflictRepairCursor<S>
120where
121    S: PlanningSolution + 'static,
122{
123    store: CandidateStore<S, ScalarMoveUnion<S, usize>>,
124    next_index: usize,
125}
126
127impl<S> ConflictRepairCursor<S>
128where
129    S: PlanningSolution + 'static,
130{
131    fn new(store: CandidateStore<S, ScalarMoveUnion<S, usize>>) -> Self {
132        Self {
133            store,
134            next_index: 0,
135        }
136    }
137}
138
139impl<S> MoveCursor<S, ScalarMoveUnion<S, usize>> for ConflictRepairCursor<S>
140where
141    S: PlanningSolution + 'static,
142{
143    fn next_candidate(&mut self) -> Option<CandidateId> {
144        if self.next_index >= self.store.len() {
145            return None;
146        }
147        let id = CandidateId::new(self.next_index);
148        self.next_index += 1;
149        Some(id)
150    }
151
152    fn candidate(
153        &self,
154        id: CandidateId,
155    ) -> Option<MoveCandidateRef<'_, S, ScalarMoveUnion<S, usize>>> {
156        self.store.candidate(id)
157    }
158
159    fn take_candidate(&mut self, id: CandidateId) -> ScalarMoveUnion<S, usize> {
160        self.store.take_candidate(id)
161    }
162}
163
164impl<S> MoveSelector<S, ScalarMoveUnion<S, usize>> for ConflictRepairSelector<S>
165where
166    S: PlanningSolution + 'static,
167{
168    type Cursor<'a>
169        = ConflictRepairCursor<S>
170    where
171        Self: 'a;
172
173    fn open_cursor<'a, D: solverforge_scoring::Director<S>>(
174        &'a self,
175        score_director: &D,
176    ) -> Self::Cursor<'a> {
177        self.open_cursor_with_context(score_director, MoveStreamContext::default())
178    }
179
180    fn open_cursor_with_context<'a, D: solverforge_scoring::Director<S>>(
181        &'a self,
182        score_director: &D,
183        context: MoveStreamContext,
184    ) -> Self::Cursor<'a> {
185        self.validate_constraint_hardness(score_director);
186        let solution = score_director.working_solution();
187        let limits = self.limits();
188        if limits.max_moves_per_step == 0
189            || limits.max_matches_per_step == 0
190            || limits.max_repairs_per_match == 0
191        {
192            return ConflictRepairCursor::new(CandidateStore::new());
193        }
194
195        let mut store = CandidateStore::with_capacity(self.config.max_moves_per_step);
196        let mut seen = HashSet::new();
197        let mut provider_invocations = 0usize;
198        let mut constraint_indices = (0..self.config.constraints.len()).collect::<Vec<_>>();
199        let constraint_offset = context.start_offset(
200            constraint_indices.len(),
201            0xC0AF_11C7_0000_0001 ^ self.config.max_moves_per_step as u64,
202        );
203        constraint_indices.rotate_left(constraint_offset);
204
205        for constraint_index in constraint_indices {
206            let constraint_name = &self.config.constraints[constraint_index];
207            let mut repair_indices = self
208                .repairs
209                .iter()
210                .enumerate()
211                .filter_map(|(index, repair)| {
212                    (repair.constraint_name() == constraint_name).then_some(index)
213                })
214                .collect::<Vec<_>>();
215            let repair_offset = context.start_offset(
216                repair_indices.len(),
217                0xC0AF_11C7_0000_0002 ^ constraint_index as u64,
218            );
219            repair_indices.rotate_left(repair_offset);
220            for repair_index in repair_indices {
221                if store.len() >= self.config.max_moves_per_step
222                    || provider_invocations >= self.config.max_matches_per_step
223                {
224                    return ConflictRepairCursor::new(store);
225                }
226                provider_invocations += 1;
227                let repair = &self.repairs[repair_index];
228                let mut specs = (repair.provider())(solution, limits);
229                let spec_offset = context.start_offset(
230                    specs.len(),
231                    0xC0AF_11C7_0000_0003 ^ repair_index as u64,
232                );
233                specs.rotate_left(spec_offset);
234                for spec in specs.into_iter().take(self.config.max_repairs_per_match) {
235                    if store.len() >= self.config.max_moves_per_step {
236                        return ConflictRepairCursor::new(store);
237                    }
238                    if spec.edits().is_empty()
239                        || spec_has_duplicate_scalar_targets(spec.edits())
240                        || !seen.insert(spec.clone())
241                    {
242                        continue;
243                    }
244                    let mut edits = Vec::with_capacity(spec.edits().len());
245                    let mut legal = true;
246                    for edit in spec.edits() {
247                        let Some(ctx) = self.variable_for_edit(edit) else {
248                            legal = false;
249                            break;
250                        };
251                        if !ctx.value_is_legal(solution, edit.entity_index(), edit.to_value()) {
252                            legal = false;
253                            break;
254                        }
255                        edits.push(CompoundScalarEdit {
256                            descriptor_index: ctx.descriptor_index,
257                            entity_index: edit.entity_index(),
258                            variable_index: ctx.variable_index,
259                            variable_name: ctx.variable_name,
260                            to_value: edit.to_value(),
261                            getter: ctx.getter,
262                            setter: ctx.setter,
263                            value_is_legal: None,
264                        });
265                    }
266                    if legal {
267                        let mov = CompoundScalarMove::with_label(
268                            spec.reason(),
269                            "conflict_repair",
270                            edits,
271                        )
272                        .with_require_hard_improvement(self.config.require_hard_improvement);
273                        store.push(ScalarMoveUnion::CompoundScalar(mov));
274                    }
275                }
276            }
277        }
278
279        ConflictRepairCursor::new(store)
280    }
281
282    fn size<D: solverforge_scoring::Director<S>>(&self, _score_director: &D) -> usize {
283        self.config.max_moves_per_step
284    }
285}
286
287fn spec_has_duplicate_scalar_targets<S>(edits: &[ScalarEdit<S>]) -> bool {
288    let mut targets = HashSet::new();
289    edits.iter().any(|edit| {
290        !targets.insert((
291            edit.descriptor_index(),
292            edit.entity_index(),
293            edit.variable_name(),
294        ))
295    })
296}