solverforge_solver/builder/selector/
conflict_repair.rs1use std::collections::HashSet;
2
3use solverforge_config::{
4 CompoundConflictRepairMoveSelectorConfig, ConflictRepairMoveSelectorConfig,
5};
6use solverforge_scoring::ConstraintMetadata;
7
8use crate::builder::context::{
9 ConflictRepairEdit, ConflictRepairLimits, ConflictRepairProviderEntry, ScalarVariableContext,
10};
11use crate::heuristic::r#move::{CompoundScalarEdit, CompoundScalarMove};
12use crate::heuristic::selector::move_selector::CandidateStore;
13
14pub struct ConflictRepairSelector<S> {
15 config: ConflictRepairMoveSelectorConfig,
16 scalar_variables: Vec<ScalarVariableContext<S>>,
17 providers: Vec<ConflictRepairProviderEntry<S>>,
18}
19
20impl<S> ConflictRepairSelector<S> {
21 pub fn new(
22 config: ConflictRepairMoveSelectorConfig,
23 scalar_variables: Vec<ScalarVariableContext<S>>,
24 providers: Vec<ConflictRepairProviderEntry<S>>,
25 ) -> Self {
26 Self {
27 config,
28 scalar_variables,
29 providers,
30 }
31 }
32
33 pub fn new_compound(
34 config: CompoundConflictRepairMoveSelectorConfig,
35 scalar_variables: Vec<ScalarVariableContext<S>>,
36 providers: Vec<ConflictRepairProviderEntry<S>>,
37 ) -> Self {
38 Self {
39 config: ConflictRepairMoveSelectorConfig {
40 constraints: config.constraints,
41 max_matches_per_step: config.max_matches_per_step,
42 max_repairs_per_match: config.max_repairs_per_match,
43 max_moves_per_step: config.max_moves_per_step,
44 require_hard_improvement: config.require_hard_improvement,
45 include_soft_matches: config.include_soft_matches,
46 },
47 scalar_variables,
48 providers,
49 }
50 }
51
52 fn limits(&self) -> ConflictRepairLimits {
53 ConflictRepairLimits {
54 max_matches_per_step: self.config.max_matches_per_step,
55 max_repairs_per_match: self.config.max_repairs_per_match,
56 max_moves_per_step: self.config.max_moves_per_step,
57 }
58 }
59
60 fn variable_for_edit(&self, edit: &ConflictRepairEdit) -> Option<ScalarVariableContext<S>> {
61 self.scalar_variables.iter().copied().find(|ctx| {
62 ctx.descriptor_index == edit.descriptor_index && ctx.variable_name == edit.variable_name
63 })
64 }
65
66 fn validate_constraint_hardness<D>(&self, score_director: &D)
67 where
68 S: PlanningSolution,
69 D: solverforge_scoring::Director<S>,
70 {
71 for constraint_name in &self.config.constraints {
72 let metadata = score_director.constraint_metadata();
73 let Some(metadata) = resolve_configured_constraint(&metadata, constraint_name) else {
74 panic!(
75 "conflict_repair_move_selector configured for `{constraint_name}`, but no matching scoring constraint was found"
76 );
77 };
78 assert!(
79 metadata.is_hard || self.config.include_soft_matches,
80 "conflict_repair_move_selector configured for non-hard constraint `{constraint_name}` while include_soft_matches is false"
81 );
82 }
83 }
84}
85
86fn resolve_configured_constraint<'metadata, 'constraint>(
87 metadata: &'metadata [ConstraintMetadata<'constraint>],
88 constraint_name: &str,
89) -> Option<&'metadata ConstraintMetadata<'constraint>> {
90 metadata
91 .iter()
92 .find(|metadata| metadata.full_name() == constraint_name)
93 .or_else(|| {
94 if constraint_name.contains('/') {
95 None
96 } else {
97 metadata.iter().find(|metadata| {
98 metadata.constraint_ref.package.is_empty() && metadata.name() == constraint_name
99 })
100 }
101 })
102}
103
104impl<S> std::fmt::Debug for ConflictRepairSelector<S> {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 f.debug_struct("ConflictRepairSelector")
107 .field("constraints", &self.config.constraints)
108 .field("max_matches_per_step", &self.config.max_matches_per_step)
109 .field("max_repairs_per_match", &self.config.max_repairs_per_match)
110 .field("max_moves_per_step", &self.config.max_moves_per_step)
111 .field(
112 "require_hard_improvement",
113 &self.config.require_hard_improvement,
114 )
115 .finish()
116 }
117}
118
119pub struct ConflictRepairCursor<S>
120where
121 S: PlanningSolution + 'static,
122{
123 store: CandidateStore<S, ScalarMoveUnion<S, usize>>,
124 next_index: usize,
125}
126
127impl<S> ConflictRepairCursor<S>
128where
129 S: PlanningSolution + 'static,
130{
131 fn new(store: CandidateStore<S, ScalarMoveUnion<S, usize>>) -> Self {
132 Self {
133 store,
134 next_index: 0,
135 }
136 }
137}
138
139impl<S> MoveCursor<S, ScalarMoveUnion<S, usize>> for ConflictRepairCursor<S>
140where
141 S: PlanningSolution + 'static,
142{
143 fn next_candidate(&mut self) -> Option<CandidateId> {
144 if self.next_index >= self.store.len() {
145 return None;
146 }
147 let id = CandidateId::new(self.next_index);
148 self.next_index += 1;
149 Some(id)
150 }
151
152 fn candidate(
153 &self,
154 id: CandidateId,
155 ) -> Option<MoveCandidateRef<'_, S, ScalarMoveUnion<S, usize>>> {
156 self.store.candidate(id)
157 }
158
159 fn take_candidate(&mut self, id: CandidateId) -> ScalarMoveUnion<S, usize> {
160 self.store.take_candidate(id)
161 }
162}
163
164impl<S> MoveSelector<S, ScalarMoveUnion<S, usize>> for ConflictRepairSelector<S>
165where
166 S: PlanningSolution + 'static,
167{
168 type Cursor<'a>
169 = ConflictRepairCursor<S>
170 where
171 Self: 'a;
172
173 fn open_cursor<'a, D: solverforge_scoring::Director<S>>(
174 &'a self,
175 score_director: &D,
176 ) -> Self::Cursor<'a> {
177 self.validate_constraint_hardness(score_director);
178 let solution = score_director.working_solution();
179 let limits = self.limits();
180 let mut store = CandidateStore::with_capacity(self.config.max_moves_per_step);
181 let mut seen = HashSet::new();
182
183 for constraint_name in &self.config.constraints {
184 for provider in self
185 .providers
186 .iter()
187 .filter(|provider| provider.constraint_name == constraint_name)
188 {
189 for spec in (provider.provider)(solution, limits)
190 .into_iter()
191 .take(self.config.max_repairs_per_match)
192 {
193 if store.len() >= self.config.max_moves_per_step {
194 return ConflictRepairCursor::new(store);
195 }
196 if spec.edits.is_empty()
197 || spec_has_duplicate_scalar_targets(&spec.edits)
198 || !seen.insert(spec.clone())
199 {
200 continue;
201 }
202 let mut edits = Vec::with_capacity(spec.edits.len());
203 let mut legal = true;
204 for edit in &spec.edits {
205 let Some(ctx) = self.variable_for_edit(edit) else {
206 legal = false;
207 break;
208 };
209 if !ctx.value_is_legal(solution, edit.entity_index, edit.to_value) {
210 legal = false;
211 break;
212 }
213 edits.push(CompoundScalarEdit {
214 descriptor_index: ctx.descriptor_index,
215 entity_index: edit.entity_index,
216 variable_index: ctx.variable_index,
217 variable_name: ctx.variable_name,
218 to_value: edit.to_value,
219 getter: ctx.getter,
220 setter: ctx.setter,
221 value_is_legal: None,
222 });
223 }
224 if legal {
225 let mov = CompoundScalarMove::with_label(
226 spec.reason,
227 "conflict_repair",
228 edits,
229 )
230 .with_require_hard_improvement(self.config.require_hard_improvement);
231 store.push(ScalarMoveUnion::CompoundScalar(mov));
232 }
233 }
234 }
235 }
236
237 ConflictRepairCursor::new(store)
238 }
239
240 fn size<D: solverforge_scoring::Director<S>>(&self, _score_director: &D) -> usize {
241 self.config.max_moves_per_step
242 }
243}
244
245fn spec_has_duplicate_scalar_targets(edits: &[ConflictRepairEdit]) -> bool {
246 let mut targets = HashSet::new();
247 edits
248 .iter()
249 .any(|edit| !targets.insert((edit.descriptor_index, edit.entity_index, edit.variable_name)))
250}