solverforge_solver/builder/selector/
conflict_repair.rs1use std::collections::HashSet;
2
3use solverforge_config::{
4 CompoundConflictRepairMoveSelectorConfig, ConflictRepairMoveSelectorConfig,
5};
6use solverforge_scoring::ConstraintMetadata;
7
8use crate::builder::context::{ConflictRepair, RepairLimits, ScalarVariableSlot};
9use crate::heuristic::r#move::{CompoundScalarEdit, CompoundScalarMove};
10use crate::heuristic::selector::move_selector::CandidateStore;
11use crate::planning::ScalarEdit;
12
13pub struct ConflictRepairSelector<S> {
14 config: ConflictRepairMoveSelectorConfig,
15 scalar_variables: Vec<ScalarVariableSlot<S>>,
16 repairs: Vec<ConflictRepair<S>>,
17}
18
19impl<S> ConflictRepairSelector<S> {
20 pub fn new(
21 config: ConflictRepairMoveSelectorConfig,
22 scalar_variables: Vec<ScalarVariableSlot<S>>,
23 repairs: Vec<ConflictRepair<S>>,
24 ) -> Self {
25 Self {
26 config,
27 scalar_variables,
28 repairs,
29 }
30 }
31
32 pub fn new_compound(
33 config: CompoundConflictRepairMoveSelectorConfig,
34 scalar_variables: Vec<ScalarVariableSlot<S>>,
35 repairs: Vec<ConflictRepair<S>>,
36 ) -> Self {
37 Self {
38 config: ConflictRepairMoveSelectorConfig {
39 constraints: config.constraints,
40 max_matches_per_step: config.max_matches_per_step,
41 max_repairs_per_match: config.max_repairs_per_match,
42 max_moves_per_step: config.max_moves_per_step,
43 require_hard_improvement: config.require_hard_improvement,
44 include_soft_matches: config.include_soft_matches,
45 },
46 scalar_variables,
47 repairs,
48 }
49 }
50
51 fn limits(&self) -> RepairLimits {
52 RepairLimits {
53 max_matches_per_step: self.config.max_matches_per_step,
54 max_repairs_per_match: self.config.max_repairs_per_match,
55 max_moves_per_step: self.config.max_moves_per_step,
56 }
57 }
58
59 fn variable_for_edit(&self, edit: &ScalarEdit<S>) -> Option<ScalarVariableSlot<S>> {
60 self.scalar_variables.iter().copied().find(|ctx| {
61 ctx.descriptor_index == edit.descriptor_index()
62 && ctx.variable_name == edit.variable_name()
63 })
64 }
65
66 fn validate_constraint_hardness<D>(&self, score_director: &D)
67 where
68 S: PlanningSolution,
69 D: solverforge_scoring::Director<S>,
70 {
71 for constraint_name in &self.config.constraints {
72 let metadata = score_director.constraint_metadata();
73 let Some(metadata) = resolve_configured_constraint(&metadata, constraint_name) else {
74 panic!(
75 "conflict_repair_move_selector configured for `{constraint_name}`, but no matching scoring constraint was found"
76 );
77 };
78 assert!(
79 metadata.is_hard || self.config.include_soft_matches,
80 "conflict_repair_move_selector configured for non-hard constraint `{constraint_name}` while include_soft_matches is false"
81 );
82 }
83 }
84}
85
86fn resolve_configured_constraint<'metadata, 'constraint>(
87 metadata: &'metadata [ConstraintMetadata<'constraint>],
88 constraint_name: &str,
89) -> Option<&'metadata ConstraintMetadata<'constraint>> {
90 metadata
91 .iter()
92 .find(|metadata| metadata.full_name() == constraint_name)
93 .or_else(|| {
94 if constraint_name.contains('/') {
95 None
96 } else {
97 metadata.iter().find(|metadata| {
98 metadata.constraint_ref.package.is_empty() && metadata.name() == constraint_name
99 })
100 }
101 })
102}
103
104impl<S> std::fmt::Debug for ConflictRepairSelector<S> {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 f.debug_struct("ConflictRepairSelector")
107 .field("constraints", &self.config.constraints)
108 .field("max_matches_per_step", &self.config.max_matches_per_step)
109 .field("max_repairs_per_match", &self.config.max_repairs_per_match)
110 .field("max_moves_per_step", &self.config.max_moves_per_step)
111 .field(
112 "require_hard_improvement",
113 &self.config.require_hard_improvement,
114 )
115 .finish()
116 }
117}
118
119pub struct ConflictRepairCursor<S>
120where
121 S: PlanningSolution + 'static,
122{
123 store: CandidateStore<S, ScalarMoveUnion<S, usize>>,
124 next_index: usize,
125}
126
127impl<S> ConflictRepairCursor<S>
128where
129 S: PlanningSolution + 'static,
130{
131 fn new(store: CandidateStore<S, ScalarMoveUnion<S, usize>>) -> Self {
132 Self {
133 store,
134 next_index: 0,
135 }
136 }
137}
138
139impl<S> MoveCursor<S, ScalarMoveUnion<S, usize>> for ConflictRepairCursor<S>
140where
141 S: PlanningSolution + 'static,
142{
143 fn next_candidate(&mut self) -> Option<CandidateId> {
144 if self.next_index >= self.store.len() {
145 return None;
146 }
147 let id = CandidateId::new(self.next_index);
148 self.next_index += 1;
149 Some(id)
150 }
151
152 fn candidate(
153 &self,
154 id: CandidateId,
155 ) -> Option<MoveCandidateRef<'_, S, ScalarMoveUnion<S, usize>>> {
156 self.store.candidate(id)
157 }
158
159 fn take_candidate(&mut self, id: CandidateId) -> ScalarMoveUnion<S, usize> {
160 self.store.take_candidate(id)
161 }
162}
163
164impl<S> MoveSelector<S, ScalarMoveUnion<S, usize>> for ConflictRepairSelector<S>
165where
166 S: PlanningSolution + 'static,
167{
168 type Cursor<'a>
169 = ConflictRepairCursor<S>
170 where
171 Self: 'a;
172
173 fn open_cursor<'a, D: solverforge_scoring::Director<S>>(
174 &'a self,
175 score_director: &D,
176 ) -> Self::Cursor<'a> {
177 self.validate_constraint_hardness(score_director);
178 let solution = score_director.working_solution();
179 let limits = self.limits();
180 let mut store = CandidateStore::with_capacity(self.config.max_moves_per_step);
181 let mut seen = HashSet::new();
182
183 for constraint_name in &self.config.constraints {
184 for repair in self
185 .repairs
186 .iter()
187 .filter(|repair| repair.constraint_name() == constraint_name)
188 {
189 for spec in (repair.provider())(solution, limits)
190 .into_iter()
191 .take(self.config.max_repairs_per_match)
192 {
193 if store.len() >= self.config.max_moves_per_step {
194 return ConflictRepairCursor::new(store);
195 }
196 if spec.edits().is_empty()
197 || spec_has_duplicate_scalar_targets(spec.edits())
198 || !seen.insert(spec.clone())
199 {
200 continue;
201 }
202 let mut edits = Vec::with_capacity(spec.edits().len());
203 let mut legal = true;
204 for edit in spec.edits() {
205 let Some(ctx) = self.variable_for_edit(edit) else {
206 legal = false;
207 break;
208 };
209 if !ctx.value_is_legal(solution, edit.entity_index(), edit.to_value()) {
210 legal = false;
211 break;
212 }
213 edits.push(CompoundScalarEdit {
214 descriptor_index: ctx.descriptor_index,
215 entity_index: edit.entity_index(),
216 variable_index: ctx.variable_index,
217 variable_name: ctx.variable_name,
218 to_value: edit.to_value(),
219 getter: ctx.getter,
220 setter: ctx.setter,
221 value_is_legal: None,
222 });
223 }
224 if legal {
225 let mov = CompoundScalarMove::with_label(
226 spec.reason(),
227 "conflict_repair",
228 edits,
229 )
230 .with_require_hard_improvement(self.config.require_hard_improvement);
231 store.push(ScalarMoveUnion::CompoundScalar(mov));
232 }
233 }
234 }
235 }
236
237 ConflictRepairCursor::new(store)
238 }
239
240 fn size<D: solverforge_scoring::Director<S>>(&self, _score_director: &D) -> usize {
241 self.config.max_moves_per_step
242 }
243}
244
245fn spec_has_duplicate_scalar_targets<S>(edits: &[ScalarEdit<S>]) -> bool {
246 let mut targets = HashSet::new();
247 edits.iter().any(|edit| {
248 !targets.insert((
249 edit.descriptor_index(),
250 edit.entity_index(),
251 edit.variable_name(),
252 ))
253 })
254}