1use std::collections::{HashMap, HashSet};
7use std::fmt::Debug;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::analysis::DetailedConstraintMatch;
15use crate::api::constraint_set::IncrementalConstraint;
16
17type Quintuple = (usize, usize, usize, usize, usize);
19
20pub struct IncrementalPentaConstraint<S, A, K, E, KE, F, W, Sc>
79where
80 Sc: Score,
81{
82 constraint_ref: ConstraintRef,
83 impact_type: ImpactType,
84 extractor: E,
85 key_extractor: KE,
86 filter: F,
87 weight: W,
88 is_hard: bool,
89 entity_to_matches: HashMap<usize, HashSet<Quintuple>>,
91 matches: HashSet<Quintuple>,
93 key_to_indices: HashMap<K, HashSet<usize>>,
95 index_to_key: HashMap<usize, K>,
97 _phantom: PhantomData<(S, A, Sc)>,
98}
99
100impl<S, A, K, E, KE, F, W, Sc> IncrementalPentaConstraint<S, A, K, E, KE, F, W, Sc>
101where
102 S: 'static,
103 A: Clone + 'static,
104 K: Eq + Hash + Clone,
105 E: Fn(&S) -> &[A],
106 KE: Fn(&A) -> K,
107 F: Fn(&A, &A, &A, &A, &A) -> bool,
108 W: Fn(&A, &A, &A, &A, &A) -> Sc,
109 Sc: Score,
110{
111 pub fn new(
113 constraint_ref: ConstraintRef,
114 impact_type: ImpactType,
115 extractor: E,
116 key_extractor: KE,
117 filter: F,
118 weight: W,
119 is_hard: bool,
120 ) -> Self {
121 Self {
122 constraint_ref,
123 impact_type,
124 extractor,
125 key_extractor,
126 filter,
127 weight,
128 is_hard,
129 entity_to_matches: HashMap::new(),
130 matches: HashSet::new(),
131 key_to_indices: HashMap::new(),
132 index_to_key: HashMap::new(),
133 _phantom: PhantomData,
134 }
135 }
136
137 #[inline]
138 fn compute_score(&self, a: &A, b: &A, c: &A, d: &A, e: &A) -> Sc {
139 let base = (self.weight)(a, b, c, d, e);
140 match self.impact_type {
141 ImpactType::Penalty => -base,
142 ImpactType::Reward => base,
143 }
144 }
145
146 fn insert_entity(&mut self, entities: &[A], index: usize) -> Sc {
148 if index >= entities.len() {
149 return Sc::zero();
150 }
151
152 let entity = &entities[index];
153 let key = (self.key_extractor)(entity);
154
155 self.index_to_key.insert(index, key.clone());
157
158 self.key_to_indices
160 .entry(key.clone())
161 .or_default()
162 .insert(index);
163
164 let key_to_indices = &self.key_to_indices;
166 let matches = &mut self.matches;
167 let entity_to_matches = &mut self.entity_to_matches;
168 let filter = &self.filter;
169 let weight = &self.weight;
170 let impact_type = self.impact_type;
171
172 let mut total = Sc::zero();
174 if let Some(others) = key_to_indices.get(&key) {
175 for &i in others {
177 if i == index {
178 continue;
179 }
180 for &j in others {
181 if j <= i || j == index {
182 continue;
183 }
184 for &k in others {
185 if k <= j || k == index {
186 continue;
187 }
188 for &l in others {
189 if l <= k || l == index {
190 continue;
191 }
192
193 let mut indices = [index, i, j, k, l];
195 indices.sort();
196 let [a_idx, b_idx, c_idx, d_idx, e_idx] = indices;
197
198 let penta = (a_idx, b_idx, c_idx, d_idx, e_idx);
199
200 if matches.contains(&penta) {
202 continue;
203 }
204
205 let a = &entities[a_idx];
206 let b = &entities[b_idx];
207 let c = &entities[c_idx];
208 let d = &entities[d_idx];
209 let e = &entities[e_idx];
210
211 if filter(a, b, c, d, e) && matches.insert(penta) {
212 entity_to_matches.entry(a_idx).or_default().insert(penta);
213 entity_to_matches.entry(b_idx).or_default().insert(penta);
214 entity_to_matches.entry(c_idx).or_default().insert(penta);
215 entity_to_matches.entry(d_idx).or_default().insert(penta);
216 entity_to_matches.entry(e_idx).or_default().insert(penta);
217 let base = weight(a, b, c, d, e);
218 let score = match impact_type {
219 ImpactType::Penalty => -base,
220 ImpactType::Reward => base,
221 };
222 total = total + score;
223 }
224 }
225 }
226 }
227 }
228 }
229
230 total
231 }
232
233 fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
235 if let Some(key) = self.index_to_key.remove(&index) {
237 if let Some(indices) = self.key_to_indices.get_mut(&key) {
238 indices.remove(&index);
239 if indices.is_empty() {
240 self.key_to_indices.remove(&key);
241 }
242 }
243 }
244
245 let Some(pentas) = self.entity_to_matches.remove(&index) else {
247 return Sc::zero();
248 };
249
250 let mut total = Sc::zero();
251 for penta in pentas {
252 self.matches.remove(&penta);
253
254 let (i, j, k, l, m) = penta;
256 for &other in &[i, j, k, l, m] {
257 if other != index {
258 if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
259 other_set.remove(&penta);
260 if other_set.is_empty() {
261 self.entity_to_matches.remove(&other);
262 }
263 }
264 }
265 }
266
267 if i < entities.len()
269 && j < entities.len()
270 && k < entities.len()
271 && l < entities.len()
272 && m < entities.len()
273 {
274 let score = self.compute_score(
275 &entities[i],
276 &entities[j],
277 &entities[k],
278 &entities[l],
279 &entities[m],
280 );
281 total = total - score;
282 }
283 }
284
285 total
286 }
287}
288
289impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
290 for IncrementalPentaConstraint<S, A, K, E, KE, F, W, Sc>
291where
292 S: Send + Sync + 'static,
293 A: Clone + Debug + Send + Sync + 'static,
294 K: Eq + Hash + Clone + Send + Sync,
295 E: Fn(&S) -> &[A] + Send + Sync,
296 KE: Fn(&A) -> K + Send + Sync,
297 F: Fn(&A, &A, &A, &A, &A) -> bool + Send + Sync,
298 W: Fn(&A, &A, &A, &A, &A) -> Sc + Send + Sync,
299 Sc: Score,
300{
301 fn evaluate(&self, solution: &S) -> Sc {
302 let entities = (self.extractor)(solution);
303 let mut total = Sc::zero();
304
305 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
307 for (i, entity) in entities.iter().enumerate() {
308 let key = (self.key_extractor)(entity);
309 temp_index.entry(key).or_default().push(i);
310 }
311
312 for indices in temp_index.values() {
314 for pos_i in 0..indices.len() {
315 for pos_j in (pos_i + 1)..indices.len() {
316 for pos_k in (pos_j + 1)..indices.len() {
317 for pos_l in (pos_k + 1)..indices.len() {
318 for pos_m in (pos_l + 1)..indices.len() {
319 let i = indices[pos_i];
320 let j = indices[pos_j];
321 let k = indices[pos_k];
322 let l = indices[pos_l];
323 let m = indices[pos_m];
324 let a = &entities[i];
325 let b = &entities[j];
326 let c = &entities[k];
327 let d = &entities[l];
328 let e = &entities[m];
329 if (self.filter)(a, b, c, d, e) {
330 total = total + self.compute_score(a, b, c, d, e);
331 }
332 }
333 }
334 }
335 }
336 }
337 }
338
339 total
340 }
341
342 fn match_count(&self, solution: &S) -> usize {
343 let entities = (self.extractor)(solution);
344 let mut count = 0;
345
346 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
348 for (i, entity) in entities.iter().enumerate() {
349 let key = (self.key_extractor)(entity);
350 temp_index.entry(key).or_default().push(i);
351 }
352
353 for indices in temp_index.values() {
355 for pos_i in 0..indices.len() {
356 for pos_j in (pos_i + 1)..indices.len() {
357 for pos_k in (pos_j + 1)..indices.len() {
358 for pos_l in (pos_k + 1)..indices.len() {
359 for pos_m in (pos_l + 1)..indices.len() {
360 let i = indices[pos_i];
361 let j = indices[pos_j];
362 let k = indices[pos_k];
363 let l = indices[pos_l];
364 let m = indices[pos_m];
365 if (self.filter)(
366 &entities[i],
367 &entities[j],
368 &entities[k],
369 &entities[l],
370 &entities[m],
371 ) {
372 count += 1;
373 }
374 }
375 }
376 }
377 }
378 }
379 }
380
381 count
382 }
383
384 fn initialize(&mut self, solution: &S) -> Sc {
385 self.reset();
386
387 let entities = (self.extractor)(solution);
388 let mut total = Sc::zero();
389 for i in 0..entities.len() {
390 total = total + self.insert_entity(entities, i);
391 }
392 total
393 }
394
395 fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
396 let entities = (self.extractor)(solution);
397 self.insert_entity(entities, entity_index)
398 }
399
400 fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
401 let entities = (self.extractor)(solution);
402 self.retract_entity(entities, entity_index)
403 }
404
405 fn reset(&mut self) {
406 self.entity_to_matches.clear();
407 self.matches.clear();
408 self.key_to_indices.clear();
409 self.index_to_key.clear();
410 }
411
412 fn name(&self) -> &str {
413 &self.constraint_ref.name
414 }
415
416 fn is_hard(&self) -> bool {
417 self.is_hard
418 }
419
420 fn constraint_ref(&self) -> ConstraintRef {
421 self.constraint_ref.clone()
422 }
423
424 fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
425 impl_get_matches_nary!(penta: self, solution)
426 }
427}
428
429impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
430 for IncrementalPentaConstraint<S, A, K, E, KE, F, W, Sc>
431{
432 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433 f.debug_struct("IncrementalPentaConstraint")
434 .field("name", &self.constraint_ref.name)
435 .field("impact_type", &self.impact_type)
436 .field("match_count", &self.matches.len())
437 .finish()
438 }
439}