Skip to main content

solverforge_solver/builder/selector/
conflict_repair.rs

1use std::collections::HashSet;
2
3use solverforge_config::{
4    CompoundConflictRepairMoveSelectorConfig, ConflictRepairMoveSelectorConfig,
5};
6use solverforge_scoring::ConstraintMetadata;
7
8use crate::builder::context::{
9    ConflictRepairEdit, ConflictRepairLimits, ConflictRepairProviderEntry, ScalarVariableContext,
10};
11use crate::heuristic::r#move::{CompoundScalarEdit, CompoundScalarMove};
12use crate::heuristic::selector::move_selector::CandidateStore;
13
14pub struct ConflictRepairSelector<S> {
15    config: ConflictRepairMoveSelectorConfig,
16    scalar_variables: Vec<ScalarVariableContext<S>>,
17    providers: Vec<ConflictRepairProviderEntry<S>>,
18}
19
20impl<S> ConflictRepairSelector<S> {
21    pub fn new(
22        config: ConflictRepairMoveSelectorConfig,
23        scalar_variables: Vec<ScalarVariableContext<S>>,
24        providers: Vec<ConflictRepairProviderEntry<S>>,
25    ) -> Self {
26        Self {
27            config,
28            scalar_variables,
29            providers,
30        }
31    }
32
33    pub fn new_compound(
34        config: CompoundConflictRepairMoveSelectorConfig,
35        scalar_variables: Vec<ScalarVariableContext<S>>,
36        providers: Vec<ConflictRepairProviderEntry<S>>,
37    ) -> Self {
38        Self {
39            config: ConflictRepairMoveSelectorConfig {
40                constraints: config.constraints,
41                max_matches_per_step: config.max_matches_per_step,
42                max_repairs_per_match: config.max_repairs_per_match,
43                max_moves_per_step: config.max_moves_per_step,
44                require_hard_improvement: config.require_hard_improvement,
45                include_soft_matches: config.include_soft_matches,
46            },
47            scalar_variables,
48            providers,
49        }
50    }
51
52    fn limits(&self) -> ConflictRepairLimits {
53        ConflictRepairLimits {
54            max_matches_per_step: self.config.max_matches_per_step,
55            max_repairs_per_match: self.config.max_repairs_per_match,
56            max_moves_per_step: self.config.max_moves_per_step,
57        }
58    }
59
60    fn variable_for_edit(&self, edit: &ConflictRepairEdit) -> Option<ScalarVariableContext<S>> {
61        self.scalar_variables.iter().copied().find(|ctx| {
62            ctx.descriptor_index == edit.descriptor_index && ctx.variable_name == edit.variable_name
63        })
64    }
65
66    fn validate_constraint_hardness<D>(&self, score_director: &D)
67    where
68        S: PlanningSolution,
69        D: solverforge_scoring::Director<S>,
70    {
71        for constraint_name in &self.config.constraints {
72            let metadata = score_director.constraint_metadata();
73            let Some(metadata) = resolve_configured_constraint(&metadata, constraint_name) else {
74                panic!(
75                    "conflict_repair_move_selector configured for `{constraint_name}`, but no matching scoring constraint was found"
76                );
77            };
78            assert!(
79                metadata.is_hard || self.config.include_soft_matches,
80                "conflict_repair_move_selector configured for non-hard constraint `{constraint_name}` while include_soft_matches is false"
81            );
82        }
83    }
84}
85
86fn resolve_configured_constraint<'metadata, 'constraint>(
87    metadata: &'metadata [ConstraintMetadata<'constraint>],
88    constraint_name: &str,
89) -> Option<&'metadata ConstraintMetadata<'constraint>> {
90    metadata
91        .iter()
92        .find(|metadata| metadata.full_name() == constraint_name)
93        .or_else(|| {
94            if constraint_name.contains('/') {
95                None
96            } else {
97                metadata.iter().find(|metadata| {
98                    metadata.constraint_ref.package.is_empty() && metadata.name() == constraint_name
99                })
100            }
101        })
102}
103
104impl<S> std::fmt::Debug for ConflictRepairSelector<S> {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        f.debug_struct("ConflictRepairSelector")
107            .field("constraints", &self.config.constraints)
108            .field("max_matches_per_step", &self.config.max_matches_per_step)
109            .field("max_repairs_per_match", &self.config.max_repairs_per_match)
110            .field("max_moves_per_step", &self.config.max_moves_per_step)
111            .field(
112                "require_hard_improvement",
113                &self.config.require_hard_improvement,
114            )
115            .finish()
116    }
117}
118
119pub struct ConflictRepairCursor<S>
120where
121    S: PlanningSolution + 'static,
122{
123    store: CandidateStore<S, ScalarMoveUnion<S, usize>>,
124    next_index: usize,
125}
126
127impl<S> ConflictRepairCursor<S>
128where
129    S: PlanningSolution + 'static,
130{
131    fn new(store: CandidateStore<S, ScalarMoveUnion<S, usize>>) -> Self {
132        Self {
133            store,
134            next_index: 0,
135        }
136    }
137}
138
139impl<S> MoveCursor<S, ScalarMoveUnion<S, usize>> for ConflictRepairCursor<S>
140where
141    S: PlanningSolution + 'static,
142{
143    fn next_candidate(&mut self) -> Option<CandidateId> {
144        if self.next_index >= self.store.len() {
145            return None;
146        }
147        let id = CandidateId::new(self.next_index);
148        self.next_index += 1;
149        Some(id)
150    }
151
152    fn candidate(
153        &self,
154        id: CandidateId,
155    ) -> Option<MoveCandidateRef<'_, S, ScalarMoveUnion<S, usize>>> {
156        self.store.candidate(id)
157    }
158
159    fn take_candidate(&mut self, id: CandidateId) -> ScalarMoveUnion<S, usize> {
160        self.store.take_candidate(id)
161    }
162}
163
164impl<S> MoveSelector<S, ScalarMoveUnion<S, usize>> for ConflictRepairSelector<S>
165where
166    S: PlanningSolution + 'static,
167{
168    type Cursor<'a>
169        = ConflictRepairCursor<S>
170    where
171        Self: 'a;
172
173    fn open_cursor<'a, D: solverforge_scoring::Director<S>>(
174        &'a self,
175        score_director: &D,
176    ) -> Self::Cursor<'a> {
177        self.validate_constraint_hardness(score_director);
178        let solution = score_director.working_solution();
179        let limits = self.limits();
180        let mut store = CandidateStore::with_capacity(self.config.max_moves_per_step);
181        let mut seen = HashSet::new();
182
183        for constraint_name in &self.config.constraints {
184            for provider in self
185                .providers
186                .iter()
187                .filter(|provider| provider.constraint_name == constraint_name)
188            {
189                for spec in (provider.provider)(solution, limits)
190                    .into_iter()
191                    .take(self.config.max_repairs_per_match)
192                {
193                    if store.len() >= self.config.max_moves_per_step {
194                        return ConflictRepairCursor::new(store);
195                    }
196                    if spec.edits.is_empty()
197                        || spec_has_duplicate_scalar_targets(&spec.edits)
198                        || !seen.insert(spec.clone())
199                    {
200                        continue;
201                    }
202                    let mut edits = Vec::with_capacity(spec.edits.len());
203                    let mut legal = true;
204                    for edit in &spec.edits {
205                        let Some(ctx) = self.variable_for_edit(edit) else {
206                            legal = false;
207                            break;
208                        };
209                        if !ctx.value_is_legal(solution, edit.entity_index, edit.to_value) {
210                            legal = false;
211                            break;
212                        }
213                        edits.push(CompoundScalarEdit {
214                            descriptor_index: ctx.descriptor_index,
215                            entity_index: edit.entity_index,
216                            variable_index: ctx.variable_index,
217                            variable_name: ctx.variable_name,
218                            to_value: edit.to_value,
219                            getter: ctx.getter,
220                            setter: ctx.setter,
221                            value_is_legal: None,
222                        });
223                    }
224                    if legal {
225                        let mov = CompoundScalarMove::with_label(
226                            spec.reason,
227                            "conflict_repair",
228                            edits,
229                        )
230                        .with_require_hard_improvement(self.config.require_hard_improvement);
231                        store.push(ScalarMoveUnion::CompoundScalar(mov));
232                    }
233                }
234            }
235        }
236
237        ConflictRepairCursor::new(store)
238    }
239
240    fn size<D: solverforge_scoring::Director<S>>(&self, _score_director: &D) -> usize {
241        self.config.max_moves_per_step
242    }
243}
244
245fn spec_has_duplicate_scalar_targets(edits: &[ConflictRepairEdit]) -> bool {
246    let mut targets = HashSet::new();
247    edits
248        .iter()
249        .any(|edit| !targets.insert((edit.descriptor_index, edit.entity_index, edit.variable_name)))
250}