solverforge_solver/builder/selector/
conflict_repair.rs1use std::collections::HashSet;
2
3use solverforge_config::{
4 CompoundConflictRepairMoveSelectorConfig, ConflictRepairMoveSelectorConfig,
5};
6use solverforge_scoring::ConstraintMetadata;
7
8use crate::builder::context::{
9 ConflictRepairEdit, ConflictRepairLimits, ConflictRepairProviderEntry, ScalarVariableContext,
10};
11use crate::heuristic::r#move::{CompoundScalarEdit, CompoundScalarMove};
12use crate::heuristic::selector::move_selector::CandidateStore;
13
14pub struct ConflictRepairSelector<S> {
15 config: ConflictRepairMoveSelectorConfig,
16 scalar_variables: Vec<ScalarVariableContext<S>>,
17 providers: Vec<ConflictRepairProviderEntry<S>>,
18}
19
20impl<S> ConflictRepairSelector<S> {
21 pub fn new(
22 config: ConflictRepairMoveSelectorConfig,
23 scalar_variables: Vec<ScalarVariableContext<S>>,
24 providers: Vec<ConflictRepairProviderEntry<S>>,
25 ) -> Self {
26 Self {
27 config,
28 scalar_variables,
29 providers,
30 }
31 }
32
33 pub fn new_compound(
34 config: CompoundConflictRepairMoveSelectorConfig,
35 scalar_variables: Vec<ScalarVariableContext<S>>,
36 providers: Vec<ConflictRepairProviderEntry<S>>,
37 ) -> Self {
38 Self {
39 config: ConflictRepairMoveSelectorConfig {
40 constraints: config.constraints,
41 max_matches_per_step: config.max_matches_per_step,
42 max_repairs_per_match: config.max_repairs_per_match,
43 max_moves_per_step: config.max_moves_per_step,
44 require_hard_improvement: config.require_hard_improvement,
45 include_soft_matches: config.include_soft_matches,
46 },
47 scalar_variables,
48 providers,
49 }
50 }
51
52 fn limits(&self) -> ConflictRepairLimits {
53 ConflictRepairLimits {
54 max_matches_per_step: self.config.max_matches_per_step,
55 max_repairs_per_match: self.config.max_repairs_per_match,
56 max_moves_per_step: self.config.max_moves_per_step,
57 }
58 }
59
60 fn variable_for_edit(&self, edit: &ConflictRepairEdit) -> Option<ScalarVariableContext<S>> {
61 self.scalar_variables.iter().copied().find(|ctx| {
62 ctx.descriptor_index == edit.descriptor_index && ctx.variable_name == edit.variable_name
63 })
64 }
65
66 fn validate_constraint_hardness<D>(&self, score_director: &D)
67 where
68 S: PlanningSolution,
69 D: solverforge_scoring::Director<S>,
70 {
71 for constraint_name in &self.config.constraints {
72 let Some(metadata) =
73 resolve_configured_constraint(score_director.constraint_metadata(), constraint_name)
74 else {
75 panic!(
76 "conflict_repair_move_selector configured for `{constraint_name}`, but no matching scoring constraint was found"
77 );
78 };
79 assert!(
80 metadata.is_hard || self.config.include_soft_matches,
81 "conflict_repair_move_selector configured for non-hard constraint `{constraint_name}` while include_soft_matches is false"
82 );
83 }
84 }
85}
86
87fn resolve_configured_constraint<'a>(
88 metadata: &'a [ConstraintMetadata],
89 constraint_name: &str,
90) -> Option<&'a ConstraintMetadata> {
91 metadata
92 .iter()
93 .find(|metadata| metadata.full_name() == constraint_name)
94 .or_else(|| {
95 if constraint_name.contains('/') {
96 None
97 } else {
98 metadata.iter().find(|metadata| {
99 metadata.constraint_ref.package.is_empty() && metadata.name() == constraint_name
100 })
101 }
102 })
103}
104
105impl<S> std::fmt::Debug for ConflictRepairSelector<S> {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 f.debug_struct("ConflictRepairSelector")
108 .field("constraints", &self.config.constraints)
109 .field("max_matches_per_step", &self.config.max_matches_per_step)
110 .field("max_repairs_per_match", &self.config.max_repairs_per_match)
111 .field("max_moves_per_step", &self.config.max_moves_per_step)
112 .field(
113 "require_hard_improvement",
114 &self.config.require_hard_improvement,
115 )
116 .finish()
117 }
118}
119
120pub struct ConflictRepairCursor<S>
121where
122 S: PlanningSolution + 'static,
123{
124 store: CandidateStore<S, ScalarMoveUnion<S, usize>>,
125 next_index: usize,
126}
127
128impl<S> ConflictRepairCursor<S>
129where
130 S: PlanningSolution + 'static,
131{
132 fn new(store: CandidateStore<S, ScalarMoveUnion<S, usize>>) -> Self {
133 Self {
134 store,
135 next_index: 0,
136 }
137 }
138}
139
140impl<S> MoveCursor<S, ScalarMoveUnion<S, usize>> for ConflictRepairCursor<S>
141where
142 S: PlanningSolution + 'static,
143{
144 fn next_candidate(&mut self) -> Option<CandidateId> {
145 if self.next_index >= self.store.len() {
146 return None;
147 }
148 let id = CandidateId::new(self.next_index);
149 self.next_index += 1;
150 Some(id)
151 }
152
153 fn candidate(
154 &self,
155 id: CandidateId,
156 ) -> Option<MoveCandidateRef<'_, S, ScalarMoveUnion<S, usize>>> {
157 self.store.candidate(id)
158 }
159
160 fn take_candidate(&mut self, id: CandidateId) -> ScalarMoveUnion<S, usize> {
161 self.store.take_candidate(id)
162 }
163}
164
165impl<S> MoveSelector<S, ScalarMoveUnion<S, usize>> for ConflictRepairSelector<S>
166where
167 S: PlanningSolution + 'static,
168{
169 type Cursor<'a>
170 = ConflictRepairCursor<S>
171 where
172 Self: 'a;
173
174 fn open_cursor<'a, D: solverforge_scoring::Director<S>>(
175 &'a self,
176 score_director: &D,
177 ) -> Self::Cursor<'a> {
178 self.validate_constraint_hardness(score_director);
179 let solution = score_director.working_solution();
180 let limits = self.limits();
181 let mut store = CandidateStore::with_capacity(self.config.max_moves_per_step);
182 let mut seen = HashSet::new();
183
184 for constraint_name in &self.config.constraints {
185 for provider in self
186 .providers
187 .iter()
188 .filter(|provider| provider.constraint_name == constraint_name)
189 {
190 for spec in (provider.provider)(solution, limits)
191 .into_iter()
192 .take(self.config.max_repairs_per_match)
193 {
194 if store.len() >= self.config.max_moves_per_step {
195 return ConflictRepairCursor::new(store);
196 }
197 if spec.edits.is_empty()
198 || spec_has_duplicate_scalar_targets(&spec.edits)
199 || !seen.insert(spec.clone())
200 {
201 continue;
202 }
203 let mut edits = Vec::with_capacity(spec.edits.len());
204 let mut legal = true;
205 for edit in &spec.edits {
206 let Some(ctx) = self.variable_for_edit(edit) else {
207 legal = false;
208 break;
209 };
210 if !ctx.value_is_legal(solution, edit.entity_index, edit.to_value) {
211 legal = false;
212 break;
213 }
214 edits.push(CompoundScalarEdit {
215 descriptor_index: ctx.descriptor_index,
216 entity_index: edit.entity_index,
217 variable_index: ctx.variable_index,
218 variable_name: ctx.variable_name,
219 to_value: edit.to_value,
220 getter: ctx.getter,
221 setter: ctx.setter,
222 value_is_legal: None,
223 });
224 }
225 if legal {
226 let mov = CompoundScalarMove::with_label(
227 spec.reason,
228 "conflict_repair",
229 edits,
230 )
231 .with_require_hard_improvement(self.config.require_hard_improvement);
232 store.push(ScalarMoveUnion::CompoundScalar(mov));
233 }
234 }
235 }
236 }
237
238 ConflictRepairCursor::new(store)
239 }
240
241 fn size<D: solverforge_scoring::Director<S>>(&self, _score_director: &D) -> usize {
242 self.config.max_moves_per_step
243 }
244}
245
246fn spec_has_duplicate_scalar_targets(edits: &[ConflictRepairEdit]) -> bool {
247 let mut targets = HashSet::new();
248 edits
249 .iter()
250 .any(|edit| !targets.insert((edit.descriptor_index, edit.entity_index, edit.variable_name)))
251}