Skip to main content

solverforge_solver/builder/selector/
conflict_repair.rs

1use std::collections::HashSet;
2
3use solverforge_config::{
4    CompoundConflictRepairMoveSelectorConfig, ConflictRepairMoveSelectorConfig,
5};
6use solverforge_scoring::ConstraintMetadata;
7
8use crate::builder::context::{
9    ConflictRepairEdit, ConflictRepairLimits, ConflictRepairProviderEntry, ScalarVariableContext,
10};
11use crate::heuristic::r#move::{CompoundScalarEdit, CompoundScalarMove};
12use crate::heuristic::selector::move_selector::CandidateStore;
13
14pub struct ConflictRepairSelector<S> {
15    config: ConflictRepairMoveSelectorConfig,
16    scalar_variables: Vec<ScalarVariableContext<S>>,
17    providers: Vec<ConflictRepairProviderEntry<S>>,
18}
19
20impl<S> ConflictRepairSelector<S> {
21    pub fn new(
22        config: ConflictRepairMoveSelectorConfig,
23        scalar_variables: Vec<ScalarVariableContext<S>>,
24        providers: Vec<ConflictRepairProviderEntry<S>>,
25    ) -> Self {
26        Self {
27            config,
28            scalar_variables,
29            providers,
30        }
31    }
32
33    pub fn new_compound(
34        config: CompoundConflictRepairMoveSelectorConfig,
35        scalar_variables: Vec<ScalarVariableContext<S>>,
36        providers: Vec<ConflictRepairProviderEntry<S>>,
37    ) -> Self {
38        Self {
39            config: ConflictRepairMoveSelectorConfig {
40                constraints: config.constraints,
41                max_matches_per_step: config.max_matches_per_step,
42                max_repairs_per_match: config.max_repairs_per_match,
43                max_moves_per_step: config.max_moves_per_step,
44                require_hard_improvement: config.require_hard_improvement,
45                include_soft_matches: config.include_soft_matches,
46            },
47            scalar_variables,
48            providers,
49        }
50    }
51
52    fn limits(&self) -> ConflictRepairLimits {
53        ConflictRepairLimits {
54            max_matches_per_step: self.config.max_matches_per_step,
55            max_repairs_per_match: self.config.max_repairs_per_match,
56            max_moves_per_step: self.config.max_moves_per_step,
57        }
58    }
59
60    fn variable_for_edit(&self, edit: &ConflictRepairEdit) -> Option<ScalarVariableContext<S>> {
61        self.scalar_variables.iter().copied().find(|ctx| {
62            ctx.descriptor_index == edit.descriptor_index && ctx.variable_name == edit.variable_name
63        })
64    }
65
66    fn validate_constraint_hardness<D>(&self, score_director: &D)
67    where
68        S: PlanningSolution,
69        D: solverforge_scoring::Director<S>,
70    {
71        for constraint_name in &self.config.constraints {
72            let Some(metadata) =
73                resolve_configured_constraint(score_director.constraint_metadata(), constraint_name)
74            else {
75                panic!(
76                    "conflict_repair_move_selector configured for `{constraint_name}`, but no matching scoring constraint was found"
77                );
78            };
79            assert!(
80                metadata.is_hard || self.config.include_soft_matches,
81                "conflict_repair_move_selector configured for non-hard constraint `{constraint_name}` while include_soft_matches is false"
82            );
83        }
84    }
85}
86
87fn resolve_configured_constraint<'a>(
88    metadata: &'a [ConstraintMetadata],
89    constraint_name: &str,
90) -> Option<&'a ConstraintMetadata> {
91    metadata
92        .iter()
93        .find(|metadata| metadata.full_name() == constraint_name)
94        .or_else(|| {
95            if constraint_name.contains('/') {
96                None
97            } else {
98                metadata.iter().find(|metadata| {
99                    metadata.constraint_ref.package.is_empty() && metadata.name() == constraint_name
100                })
101            }
102        })
103}
104
105impl<S> std::fmt::Debug for ConflictRepairSelector<S> {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        f.debug_struct("ConflictRepairSelector")
108            .field("constraints", &self.config.constraints)
109            .field("max_matches_per_step", &self.config.max_matches_per_step)
110            .field("max_repairs_per_match", &self.config.max_repairs_per_match)
111            .field("max_moves_per_step", &self.config.max_moves_per_step)
112            .field(
113                "require_hard_improvement",
114                &self.config.require_hard_improvement,
115            )
116            .finish()
117    }
118}
119
120pub struct ConflictRepairCursor<S>
121where
122    S: PlanningSolution + 'static,
123{
124    store: CandidateStore<S, ScalarMoveUnion<S, usize>>,
125    next_index: usize,
126}
127
128impl<S> ConflictRepairCursor<S>
129where
130    S: PlanningSolution + 'static,
131{
132    fn new(store: CandidateStore<S, ScalarMoveUnion<S, usize>>) -> Self {
133        Self {
134            store,
135            next_index: 0,
136        }
137    }
138}
139
140impl<S> MoveCursor<S, ScalarMoveUnion<S, usize>> for ConflictRepairCursor<S>
141where
142    S: PlanningSolution + 'static,
143{
144    fn next_candidate(&mut self) -> Option<CandidateId> {
145        if self.next_index >= self.store.len() {
146            return None;
147        }
148        let id = CandidateId::new(self.next_index);
149        self.next_index += 1;
150        Some(id)
151    }
152
153    fn candidate(
154        &self,
155        id: CandidateId,
156    ) -> Option<MoveCandidateRef<'_, S, ScalarMoveUnion<S, usize>>> {
157        self.store.candidate(id)
158    }
159
160    fn take_candidate(&mut self, id: CandidateId) -> ScalarMoveUnion<S, usize> {
161        self.store.take_candidate(id)
162    }
163}
164
165impl<S> MoveSelector<S, ScalarMoveUnion<S, usize>> for ConflictRepairSelector<S>
166where
167    S: PlanningSolution + 'static,
168{
169    type Cursor<'a>
170        = ConflictRepairCursor<S>
171    where
172        Self: 'a;
173
174    fn open_cursor<'a, D: solverforge_scoring::Director<S>>(
175        &'a self,
176        score_director: &D,
177    ) -> Self::Cursor<'a> {
178        self.validate_constraint_hardness(score_director);
179        let solution = score_director.working_solution();
180        let limits = self.limits();
181        let mut store = CandidateStore::with_capacity(self.config.max_moves_per_step);
182        let mut seen = HashSet::new();
183
184        for constraint_name in &self.config.constraints {
185            for provider in self
186                .providers
187                .iter()
188                .filter(|provider| provider.constraint_name == constraint_name)
189            {
190                for spec in (provider.provider)(solution, limits)
191                    .into_iter()
192                    .take(self.config.max_repairs_per_match)
193                {
194                    if store.len() >= self.config.max_moves_per_step {
195                        return ConflictRepairCursor::new(store);
196                    }
197                    if spec.edits.is_empty()
198                        || spec_has_duplicate_scalar_targets(&spec.edits)
199                        || !seen.insert(spec.clone())
200                    {
201                        continue;
202                    }
203                    let mut edits = Vec::with_capacity(spec.edits.len());
204                    let mut legal = true;
205                    for edit in &spec.edits {
206                        let Some(ctx) = self.variable_for_edit(edit) else {
207                            legal = false;
208                            break;
209                        };
210                        if !ctx.value_is_legal(solution, edit.entity_index, edit.to_value) {
211                            legal = false;
212                            break;
213                        }
214                        edits.push(CompoundScalarEdit {
215                            descriptor_index: ctx.descriptor_index,
216                            entity_index: edit.entity_index,
217                            variable_index: ctx.variable_index,
218                            variable_name: ctx.variable_name,
219                            to_value: edit.to_value,
220                            getter: ctx.getter,
221                            setter: ctx.setter,
222                            value_is_legal: None,
223                        });
224                    }
225                    if legal {
226                        let mov = CompoundScalarMove::with_label(
227                            spec.reason,
228                            "conflict_repair",
229                            edits,
230                        )
231                        .with_require_hard_improvement(self.config.require_hard_improvement);
232                        store.push(ScalarMoveUnion::CompoundScalar(mov));
233                    }
234                }
235            }
236        }
237
238        ConflictRepairCursor::new(store)
239    }
240
241    fn size<D: solverforge_scoring::Director<S>>(&self, _score_director: &D) -> usize {
242        self.config.max_moves_per_step
243    }
244}
245
246fn spec_has_duplicate_scalar_targets(edits: &[ConflictRepairEdit]) -> bool {
247    let mut targets = HashSet::new();
248    edits
249        .iter()
250        .any(|edit| !targets.insert((edit.descriptor_index, edit.entity_index, edit.variable_name)))
251}