solverforge_scoring/constraint/
bi_incremental.rs1use std::collections::{HashMap, HashSet};
7use std::fmt::Debug;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::analysis::DetailedConstraintMatch;
15use crate::api::constraint_set::IncrementalConstraint;
16
17pub struct IncrementalBiConstraint<S, A, K, E, KE, F, W, Sc>
22where
23 Sc: Score,
24{
25 constraint_ref: ConstraintRef,
26 impact_type: ImpactType,
27 extractor: E,
28 key_extractor: KE,
29 filter: F,
30 weight: W,
31 is_hard: bool,
32 entity_to_matches: HashMap<usize, HashSet<(usize, usize)>>,
34 matches: HashSet<(usize, usize)>,
36 key_to_indices: HashMap<K, HashSet<usize>>,
38 index_to_key: HashMap<usize, K>,
40 _phantom: PhantomData<(S, A, Sc)>,
41}
42
43impl<S, A, K, E, KE, F, W, Sc> IncrementalBiConstraint<S, A, K, E, KE, F, W, Sc>
44where
45 S: 'static,
46 A: Clone + 'static,
47 K: Eq + Hash + Clone,
48 E: Fn(&S) -> &[A],
49 KE: Fn(&A) -> K,
50 F: Fn(&A, &A) -> bool,
51 W: Fn(&A, &A) -> Sc,
52 Sc: Score,
53{
54 pub fn new(
55 constraint_ref: ConstraintRef,
56 impact_type: ImpactType,
57 extractor: E,
58 key_extractor: KE,
59 filter: F,
60 weight: W,
61 is_hard: bool,
62 ) -> Self {
63 Self {
64 constraint_ref,
65 impact_type,
66 extractor,
67 key_extractor,
68 filter,
69 weight,
70 is_hard,
71 entity_to_matches: HashMap::new(),
72 matches: HashSet::new(),
73 key_to_indices: HashMap::new(),
74 index_to_key: HashMap::new(),
75 _phantom: PhantomData,
76 }
77 }
78
79 #[inline]
80 fn compute_score(&self, a: &A, b: &A) -> Sc {
81 let base = (self.weight)(a, b);
82 match self.impact_type {
83 ImpactType::Penalty => -base,
84 ImpactType::Reward => base,
85 }
86 }
87
88 fn insert_entity(&mut self, entities: &[A], index: usize) -> Sc {
90 if index >= entities.len() {
91 return Sc::zero();
92 }
93
94 let entity = &entities[index];
95 let key = (self.key_extractor)(entity);
96
97 self.index_to_key.insert(index, key.clone());
99
100 self.key_to_indices
102 .entry(key.clone())
103 .or_default()
104 .insert(index);
105
106 let key_to_indices = &self.key_to_indices;
108 let matches = &mut self.matches;
109 let entity_to_matches = &mut self.entity_to_matches;
110 let filter = &self.filter;
111 let weight = &self.weight;
112 let impact_type = self.impact_type;
113
114 let mut total = Sc::zero();
116 if let Some(others) = key_to_indices.get(&key) {
117 for &other_idx in others {
118 if other_idx == index {
119 continue;
120 }
121
122 let other = &entities[other_idx];
123
124 let (low_idx, high_idx, low_entity, high_entity) = if index < other_idx {
126 (index, other_idx, entity, other)
127 } else {
128 (other_idx, index, other, entity)
129 };
130
131 if filter(low_entity, high_entity) {
132 let pair = (low_idx, high_idx);
133 if matches.insert(pair) {
134 entity_to_matches.entry(low_idx).or_default().insert(pair);
135 entity_to_matches.entry(high_idx).or_default().insert(pair);
136 let base = weight(low_entity, high_entity);
137 let score = match impact_type {
138 ImpactType::Penalty => -base,
139 ImpactType::Reward => base,
140 };
141 total = total + score;
142 }
143 }
144 }
145 }
146
147 total
148 }
149
150 fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
152 if let Some(key) = self.index_to_key.remove(&index) {
154 if let Some(indices) = self.key_to_indices.get_mut(&key) {
155 indices.remove(&index);
156 if indices.is_empty() {
157 self.key_to_indices.remove(&key);
158 }
159 }
160 }
161
162 let Some(pairs) = self.entity_to_matches.remove(&index) else {
164 return Sc::zero();
165 };
166
167 let mut total = Sc::zero();
168 for pair in pairs {
169 self.matches.remove(&pair);
170
171 let other = if pair.0 == index { pair.1 } else { pair.0 };
173 if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
174 other_set.remove(&pair);
175 if other_set.is_empty() {
176 self.entity_to_matches.remove(&other);
177 }
178 }
179
180 let (low_idx, high_idx) = pair;
182 if low_idx < entities.len() && high_idx < entities.len() {
183 let score = self.compute_score(&entities[low_idx], &entities[high_idx]);
184 total = total - score;
185 }
186 }
187
188 total
189 }
190}
191
192impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
193 for IncrementalBiConstraint<S, A, K, E, KE, F, W, Sc>
194where
195 S: Send + Sync + 'static,
196 A: Clone + Debug + Send + Sync + 'static,
197 K: Eq + Hash + Clone + Send + Sync,
198 E: Fn(&S) -> &[A] + Send + Sync,
199 KE: Fn(&A) -> K + Send + Sync,
200 F: Fn(&A, &A) -> bool + Send + Sync,
201 W: Fn(&A, &A) -> Sc + Send + Sync,
202 Sc: Score,
203{
204 fn evaluate(&self, solution: &S) -> Sc {
205 let entities = (self.extractor)(solution);
206 let mut total = Sc::zero();
207
208 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
210 for (i, entity) in entities.iter().enumerate() {
211 let key = (self.key_extractor)(entity);
212 temp_index.entry(key).or_default().push(i);
213 }
214
215 for indices in temp_index.values() {
217 for i in 0..indices.len() {
218 for j in (i + 1)..indices.len() {
219 let low = indices[i];
220 let high = indices[j];
221 let a = &entities[low];
222 let b = &entities[high];
223 if (self.filter)(a, b) {
224 total = total + self.compute_score(a, b);
225 }
226 }
227 }
228 }
229
230 total
231 }
232
233 fn match_count(&self, solution: &S) -> usize {
234 let entities = (self.extractor)(solution);
235 let mut count = 0;
236
237 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
239 for (i, entity) in entities.iter().enumerate() {
240 let key = (self.key_extractor)(entity);
241 temp_index.entry(key).or_default().push(i);
242 }
243
244 for indices in temp_index.values() {
246 for i in 0..indices.len() {
247 for j in (i + 1)..indices.len() {
248 let low = indices[i];
249 let high = indices[j];
250 if (self.filter)(&entities[low], &entities[high]) {
251 count += 1;
252 }
253 }
254 }
255 }
256
257 count
258 }
259
260 fn initialize(&mut self, solution: &S) -> Sc {
261 self.reset();
262
263 let entities = (self.extractor)(solution);
264 let mut total = Sc::zero();
265 for i in 0..entities.len() {
266 total = total + self.insert_entity(entities, i);
267 }
268 total
269 }
270
271 fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
272 let entities = (self.extractor)(solution);
273 self.insert_entity(entities, entity_index)
274 }
275
276 fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
277 let entities = (self.extractor)(solution);
278 self.retract_entity(entities, entity_index)
279 }
280
281 fn reset(&mut self) {
282 self.entity_to_matches.clear();
283 self.matches.clear();
284 self.key_to_indices.clear();
285 self.index_to_key.clear();
286 }
287
288 fn name(&self) -> &str {
289 &self.constraint_ref.name
290 }
291
292 fn is_hard(&self) -> bool {
293 self.is_hard
294 }
295
296 fn constraint_ref(&self) -> ConstraintRef {
297 self.constraint_ref.clone()
298 }
299
300 fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
301 impl_get_matches_nary!(bi: self, solution)
302 }
303}
304
305impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
306 for IncrementalBiConstraint<S, A, K, E, KE, F, W, Sc>
307{
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 f.debug_struct("IncrementalBiConstraint")
310 .field("name", &self.constraint_ref.name)
311 .field("impact_type", &self.impact_type)
312 .field("match_count", &self.matches.len())
313 .finish()
314 }
315}