Skip to main content

solverforge_solver/builder/selector/
conflict_repair.rs

1use std::collections::HashSet;
2
3use solverforge_config::{
4    CompoundConflictRepairMoveSelectorConfig, ConflictRepairMoveSelectorConfig,
5};
6use solverforge_scoring::ConstraintMetadata;
7
8use crate::builder::context::{ConflictRepair, RepairLimits, ScalarVariableSlot};
9use crate::heuristic::r#move::{CompoundScalarEdit, CompoundScalarMove};
10use crate::heuristic::selector::move_selector::CandidateStore;
11use crate::planning::ScalarEdit;
12
13pub struct ConflictRepairSelector<S> {
14    config: ConflictRepairMoveSelectorConfig,
15    scalar_variables: Vec<ScalarVariableSlot<S>>,
16    repairs: Vec<ConflictRepair<S>>,
17}
18
19impl<S> ConflictRepairSelector<S> {
20    pub fn new(
21        config: ConflictRepairMoveSelectorConfig,
22        scalar_variables: Vec<ScalarVariableSlot<S>>,
23        repairs: Vec<ConflictRepair<S>>,
24    ) -> Self {
25        Self {
26            config,
27            scalar_variables,
28            repairs,
29        }
30    }
31
32    pub fn new_compound(
33        config: CompoundConflictRepairMoveSelectorConfig,
34        scalar_variables: Vec<ScalarVariableSlot<S>>,
35        repairs: Vec<ConflictRepair<S>>,
36    ) -> Self {
37        Self {
38            config: ConflictRepairMoveSelectorConfig {
39                constraints: config.constraints,
40                max_matches_per_step: config.max_matches_per_step,
41                max_repairs_per_match: config.max_repairs_per_match,
42                max_moves_per_step: config.max_moves_per_step,
43                require_hard_improvement: config.require_hard_improvement,
44                include_soft_matches: config.include_soft_matches,
45            },
46            scalar_variables,
47            repairs,
48        }
49    }
50
51    fn limits(&self) -> RepairLimits {
52        RepairLimits {
53            max_matches_per_step: self.config.max_matches_per_step,
54            max_repairs_per_match: self.config.max_repairs_per_match,
55            max_moves_per_step: self.config.max_moves_per_step,
56        }
57    }
58
59    fn variable_for_edit(&self, edit: &ScalarEdit<S>) -> Option<ScalarVariableSlot<S>> {
60        self.scalar_variables.iter().copied().find(|ctx| {
61            ctx.descriptor_index == edit.descriptor_index()
62                && ctx.variable_name == edit.variable_name()
63        })
64    }
65
66    fn validate_constraint_hardness<D>(&self, score_director: &D)
67    where
68        S: PlanningSolution,
69        D: solverforge_scoring::Director<S>,
70    {
71        for constraint_name in &self.config.constraints {
72            let metadata = score_director.constraint_metadata();
73            let Some(metadata) = resolve_configured_constraint(&metadata, constraint_name) else {
74                panic!(
75                    "conflict_repair_move_selector configured for `{constraint_name}`, but no matching scoring constraint was found"
76                );
77            };
78            assert!(
79                metadata.is_hard || self.config.include_soft_matches,
80                "conflict_repair_move_selector configured for non-hard constraint `{constraint_name}` while include_soft_matches is false"
81            );
82        }
83    }
84}
85
86fn resolve_configured_constraint<'metadata, 'constraint>(
87    metadata: &'metadata [ConstraintMetadata<'constraint>],
88    constraint_name: &str,
89) -> Option<&'metadata ConstraintMetadata<'constraint>> {
90    metadata
91        .iter()
92        .find(|metadata| metadata.full_name() == constraint_name)
93        .or_else(|| {
94            if constraint_name.contains('/') {
95                None
96            } else {
97                metadata.iter().find(|metadata| {
98                    metadata.constraint_ref.package.is_empty() && metadata.name() == constraint_name
99                })
100            }
101        })
102}
103
104impl<S> std::fmt::Debug for ConflictRepairSelector<S> {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        f.debug_struct("ConflictRepairSelector")
107            .field("constraints", &self.config.constraints)
108            .field("max_matches_per_step", &self.config.max_matches_per_step)
109            .field("max_repairs_per_match", &self.config.max_repairs_per_match)
110            .field("max_moves_per_step", &self.config.max_moves_per_step)
111            .field(
112                "require_hard_improvement",
113                &self.config.require_hard_improvement,
114            )
115            .finish()
116    }
117}
118
119pub struct ConflictRepairCursor<S>
120where
121    S: PlanningSolution + 'static,
122{
123    store: CandidateStore<S, ScalarMoveUnion<S, usize>>,
124    next_index: usize,
125}
126
127impl<S> ConflictRepairCursor<S>
128where
129    S: PlanningSolution + 'static,
130{
131    fn new(store: CandidateStore<S, ScalarMoveUnion<S, usize>>) -> Self {
132        Self {
133            store,
134            next_index: 0,
135        }
136    }
137}
138
139impl<S> MoveCursor<S, ScalarMoveUnion<S, usize>> for ConflictRepairCursor<S>
140where
141    S: PlanningSolution + 'static,
142{
143    fn next_candidate(&mut self) -> Option<CandidateId> {
144        if self.next_index >= self.store.len() {
145            return None;
146        }
147        let id = CandidateId::new(self.next_index);
148        self.next_index += 1;
149        Some(id)
150    }
151
152    fn candidate(
153        &self,
154        id: CandidateId,
155    ) -> Option<MoveCandidateRef<'_, S, ScalarMoveUnion<S, usize>>> {
156        self.store.candidate(id)
157    }
158
159    fn take_candidate(&mut self, id: CandidateId) -> ScalarMoveUnion<S, usize> {
160        self.store.take_candidate(id)
161    }
162}
163
164impl<S> MoveSelector<S, ScalarMoveUnion<S, usize>> for ConflictRepairSelector<S>
165where
166    S: PlanningSolution + 'static,
167{
168    type Cursor<'a>
169        = ConflictRepairCursor<S>
170    where
171        Self: 'a;
172
173    fn open_cursor<'a, D: solverforge_scoring::Director<S>>(
174        &'a self,
175        score_director: &D,
176    ) -> Self::Cursor<'a> {
177        self.validate_constraint_hardness(score_director);
178        let solution = score_director.working_solution();
179        let limits = self.limits();
180        let mut store = CandidateStore::with_capacity(self.config.max_moves_per_step);
181        let mut seen = HashSet::new();
182
183        for constraint_name in &self.config.constraints {
184            for repair in self
185                .repairs
186                .iter()
187                .filter(|repair| repair.constraint_name() == constraint_name)
188            {
189                for spec in (repair.provider())(solution, limits)
190                    .into_iter()
191                    .take(self.config.max_repairs_per_match)
192                {
193                    if store.len() >= self.config.max_moves_per_step {
194                        return ConflictRepairCursor::new(store);
195                    }
196                    if spec.edits().is_empty()
197                        || spec_has_duplicate_scalar_targets(spec.edits())
198                        || !seen.insert(spec.clone())
199                    {
200                        continue;
201                    }
202                    let mut edits = Vec::with_capacity(spec.edits().len());
203                    let mut legal = true;
204                    for edit in spec.edits() {
205                        let Some(ctx) = self.variable_for_edit(edit) else {
206                            legal = false;
207                            break;
208                        };
209                        if !ctx.value_is_legal(solution, edit.entity_index(), edit.to_value()) {
210                            legal = false;
211                            break;
212                        }
213                        edits.push(CompoundScalarEdit {
214                            descriptor_index: ctx.descriptor_index,
215                            entity_index: edit.entity_index(),
216                            variable_index: ctx.variable_index,
217                            variable_name: ctx.variable_name,
218                            to_value: edit.to_value(),
219                            getter: ctx.getter,
220                            setter: ctx.setter,
221                            value_is_legal: None,
222                        });
223                    }
224                    if legal {
225                        let mov = CompoundScalarMove::with_label(
226                            spec.reason(),
227                            "conflict_repair",
228                            edits,
229                        )
230                        .with_require_hard_improvement(self.config.require_hard_improvement);
231                        store.push(ScalarMoveUnion::CompoundScalar(mov));
232                    }
233                }
234            }
235        }
236
237        ConflictRepairCursor::new(store)
238    }
239
240    fn size<D: solverforge_scoring::Director<S>>(&self, _score_director: &D) -> usize {
241        self.config.max_moves_per_step
242    }
243}
244
245fn spec_has_duplicate_scalar_targets<S>(edits: &[ScalarEdit<S>]) -> bool {
246    let mut targets = HashSet::new();
247    edits.iter().any(|edit| {
248        !targets.insert((
249            edit.descriptor_index(),
250            edit.entity_index(),
251            edit.variable_name(),
252        ))
253    })
254}