1#[macro_export]
26macro_rules! impl_incremental_nary_constraint {
27 (bi, $struct_name:ident) => {
29 pub struct $struct_name<S, A, K, E, KE, F, W, Sc>
34 where
35 Sc: Score,
36 {
37 constraint_ref: ConstraintRef,
38 impact_type: ImpactType,
39 extractor: E,
40 key_extractor: KE,
41 filter: F,
42 weight: W,
43 is_hard: bool,
44 entity_to_matches: HashMap<usize, HashSet<(usize, usize)>>,
45 matches: HashSet<(usize, usize)>,
46 key_to_indices: HashMap<K, HashSet<usize>>,
47 index_to_key: HashMap<usize, K>,
48 _phantom: PhantomData<(S, A, Sc)>,
49 }
50
51 impl<S, A, K, E, KE, F, W, Sc> $struct_name<S, A, K, E, KE, F, W, Sc>
52 where
53 S: 'static,
54 A: Clone + 'static,
55 K: Eq + Hash + Clone,
56 E: Fn(&S) -> &[A],
57 KE: Fn(&A) -> K,
58 F: Fn(&S, &A, &A) -> bool,
59 W: Fn(&A, &A) -> Sc,
60 Sc: Score,
61 {
62 pub fn new(
63 constraint_ref: ConstraintRef,
64 impact_type: ImpactType,
65 extractor: E,
66 key_extractor: KE,
67 filter: F,
68 weight: W,
69 is_hard: bool,
70 ) -> Self {
71 Self {
72 constraint_ref,
73 impact_type,
74 extractor,
75 key_extractor,
76 filter,
77 weight,
78 is_hard,
79 entity_to_matches: HashMap::new(),
80 matches: HashSet::new(),
81 key_to_indices: HashMap::new(),
82 index_to_key: HashMap::new(),
83 _phantom: PhantomData,
84 }
85 }
86
87 #[inline]
88 fn compute_score(&self, a: &A, b: &A) -> Sc {
89 let base = (self.weight)(a, b);
90 match self.impact_type {
91 ImpactType::Penalty => -base,
92 ImpactType::Reward => base,
93 }
94 }
95
96 fn insert_entity(&mut self, solution: &S, entities: &[A], index: usize) -> Sc {
97 if index >= entities.len() {
98 return Sc::zero();
99 }
100
101 let entity = &entities[index];
102 let key = (self.key_extractor)(entity);
103
104 self.index_to_key.insert(index, key.clone());
105 self.key_to_indices
106 .entry(key.clone())
107 .or_default()
108 .insert(index);
109
110 let key_to_indices = &self.key_to_indices;
111 let matches = &mut self.matches;
112 let entity_to_matches = &mut self.entity_to_matches;
113 let filter = &self.filter;
114 let weight = &self.weight;
115 let impact_type = self.impact_type;
116
117 let mut total = Sc::zero();
118 if let Some(others) = key_to_indices.get(&key) {
119 for &other_idx in others {
120 if other_idx == index {
121 continue;
122 }
123
124 let other = &entities[other_idx];
125 let (low_idx, high_idx, low_entity, high_entity) = if index < other_idx {
126 (index, other_idx, entity, other)
127 } else {
128 (other_idx, index, other, entity)
129 };
130
131 if filter(solution, low_entity, high_entity) {
132 let pair = (low_idx, high_idx);
133 if matches.insert(pair) {
134 entity_to_matches.entry(low_idx).or_default().insert(pair);
135 entity_to_matches.entry(high_idx).or_default().insert(pair);
136 let base = weight(low_entity, high_entity);
137 let score = match impact_type {
138 ImpactType::Penalty => -base,
139 ImpactType::Reward => base,
140 };
141 total = total + score;
142 }
143 }
144 }
145 }
146
147 total
148 }
149
150 fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
151 if let Some(key) = self.index_to_key.remove(&index) {
152 if let Some(indices) = self.key_to_indices.get_mut(&key) {
153 indices.remove(&index);
154 if indices.is_empty() {
155 self.key_to_indices.remove(&key);
156 }
157 }
158 }
159
160 let Some(pairs) = self.entity_to_matches.remove(&index) else {
161 return Sc::zero();
162 };
163
164 let mut total = Sc::zero();
165 for pair in pairs {
166 self.matches.remove(&pair);
167
168 let other = if pair.0 == index { pair.1 } else { pair.0 };
169 if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
170 other_set.remove(&pair);
171 if other_set.is_empty() {
172 self.entity_to_matches.remove(&other);
173 }
174 }
175
176 let (low_idx, high_idx) = pair;
177 if low_idx < entities.len() && high_idx < entities.len() {
178 let score = self.compute_score(&entities[low_idx], &entities[high_idx]);
179 total = total - score;
180 }
181 }
182
183 total
184 }
185 }
186
187 impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
188 for $struct_name<S, A, K, E, KE, F, W, Sc>
189 where
190 S: Send + Sync + 'static,
191 A: Clone + Debug + Send + Sync + 'static,
192 K: Eq + Hash + Clone + Send + Sync,
193 E: Fn(&S) -> &[A] + Send + Sync,
194 KE: Fn(&A) -> K + Send + Sync,
195 F: Fn(&S, &A, &A) -> bool + Send + Sync,
196 W: Fn(&A, &A) -> Sc + Send + Sync,
197 Sc: Score,
198 {
199 fn evaluate(&self, solution: &S) -> Sc {
200 let entities = (self.extractor)(solution);
201 let mut total = Sc::zero();
202
203 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
204 for (i, entity) in entities.iter().enumerate() {
205 let key = (self.key_extractor)(entity);
206 temp_index.entry(key).or_default().push(i);
207 }
208
209 for indices in temp_index.values() {
210 for i in 0..indices.len() {
211 for j in (i + 1)..indices.len() {
212 let low = indices[i];
213 let high = indices[j];
214 let a = &entities[low];
215 let b = &entities[high];
216 if (self.filter)(solution, a, b) {
217 total = total + self.compute_score(a, b);
218 }
219 }
220 }
221 }
222
223 total
224 }
225
226 fn match_count(&self, solution: &S) -> usize {
227 let entities = (self.extractor)(solution);
228 let mut count = 0;
229
230 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
231 for (i, entity) in entities.iter().enumerate() {
232 let key = (self.key_extractor)(entity);
233 temp_index.entry(key).or_default().push(i);
234 }
235
236 for indices in temp_index.values() {
237 for i in 0..indices.len() {
238 for j in (i + 1)..indices.len() {
239 let low = indices[i];
240 let high = indices[j];
241 if (self.filter)(solution, &entities[low], &entities[high]) {
242 count += 1;
243 }
244 }
245 }
246 }
247
248 count
249 }
250
251 fn initialize(&mut self, solution: &S) -> Sc {
252 self.reset();
253 let entities = (self.extractor)(solution);
254 let mut total = Sc::zero();
255 for i in 0..entities.len() {
256 total = total + self.insert_entity(solution, entities, i);
257 }
258 total
259 }
260
261 fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
262 let entities = (self.extractor)(solution);
263 self.insert_entity(solution, entities, entity_index)
264 }
265
266 fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
267 let entities = (self.extractor)(solution);
268 self.retract_entity(entities, entity_index)
269 }
270
271 fn reset(&mut self) {
272 self.entity_to_matches.clear();
273 self.matches.clear();
274 self.key_to_indices.clear();
275 self.index_to_key.clear();
276 }
277
278 fn name(&self) -> &str {
279 &self.constraint_ref.name
280 }
281
282 fn is_hard(&self) -> bool {
283 self.is_hard
284 }
285
286 fn constraint_ref(&self) -> ConstraintRef {
287 self.constraint_ref.clone()
288 }
289
290 fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
291 $crate::impl_get_matches_nary!(bi: self, solution)
292 }
293 }
294
295 impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
296 for $struct_name<S, A, K, E, KE, F, W, Sc>
297 {
298 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299 f.debug_struct(stringify!($struct_name))
300 .field("name", &self.constraint_ref.name)
301 .field("impact_type", &self.impact_type)
302 .field("match_count", &self.matches.len())
303 .finish()
304 }
305 }
306 };
307
308 (tri, $struct_name:ident) => {
310 pub struct $struct_name<S, A, K, E, KE, F, W, Sc>
316 where
317 Sc: Score,
318 {
319 constraint_ref: ConstraintRef,
320 impact_type: ImpactType,
321 extractor: E,
322 key_extractor: KE,
323 filter: F,
324 weight: W,
325 is_hard: bool,
326 entity_to_matches: HashMap<usize, HashSet<(usize, usize, usize)>>,
327 matches: HashSet<(usize, usize, usize)>,
328 key_to_indices: HashMap<K, HashSet<usize>>,
329 index_to_key: HashMap<usize, K>,
330 _phantom: PhantomData<(S, A, Sc)>,
331 }
332
333 impl<S, A, K, E, KE, F, W, Sc> $struct_name<S, A, K, E, KE, F, W, Sc>
334 where
335 S: 'static,
336 A: Clone + 'static,
337 K: Eq + Hash + Clone,
338 E: Fn(&S) -> &[A],
339 KE: Fn(&A) -> K,
340 F: Fn(&S, &A, &A, &A) -> bool,
341 W: Fn(&A, &A, &A) -> Sc,
342 Sc: Score,
343 {
344 pub fn new(
345 constraint_ref: ConstraintRef,
346 impact_type: ImpactType,
347 extractor: E,
348 key_extractor: KE,
349 filter: F,
350 weight: W,
351 is_hard: bool,
352 ) -> Self {
353 Self {
354 constraint_ref,
355 impact_type,
356 extractor,
357 key_extractor,
358 filter,
359 weight,
360 is_hard,
361 entity_to_matches: HashMap::new(),
362 matches: HashSet::new(),
363 key_to_indices: HashMap::new(),
364 index_to_key: HashMap::new(),
365 _phantom: PhantomData,
366 }
367 }
368
369 #[inline]
370 fn compute_score(&self, a: &A, b: &A, c: &A) -> Sc {
371 let base = (self.weight)(a, b, c);
372 match self.impact_type {
373 ImpactType::Penalty => -base,
374 ImpactType::Reward => base,
375 }
376 }
377
378 fn insert_entity(&mut self, solution: &S, entities: &[A], index: usize) -> Sc {
379 if index >= entities.len() {
380 return Sc::zero();
381 }
382
383 let entity = &entities[index];
384 let key = (self.key_extractor)(entity);
385
386 self.index_to_key.insert(index, key.clone());
387 self.key_to_indices
388 .entry(key.clone())
389 .or_default()
390 .insert(index);
391
392 let key_to_indices = &self.key_to_indices;
393 let matches = &mut self.matches;
394 let entity_to_matches = &mut self.entity_to_matches;
395 let filter = &self.filter;
396 let weight = &self.weight;
397 let impact_type = self.impact_type;
398
399 let mut total = Sc::zero();
400 if let Some(others) = key_to_indices.get(&key) {
401 for &i in others {
402 if i == index {
403 continue;
404 }
405 for &j in others {
406 if j <= i || j == index {
407 continue;
408 }
409
410 let mut arr = [index, i, j];
411 arr.sort();
412 let [a_idx, b_idx, c_idx] = arr;
413 let triple = (a_idx, b_idx, c_idx);
414
415 if matches.contains(&triple) {
416 continue;
417 }
418
419 let a = &entities[a_idx];
420 let b = &entities[b_idx];
421 let c = &entities[c_idx];
422
423 if filter(solution, a, b, c) && matches.insert(triple) {
424 entity_to_matches.entry(a_idx).or_default().insert(triple);
425 entity_to_matches.entry(b_idx).or_default().insert(triple);
426 entity_to_matches.entry(c_idx).or_default().insert(triple);
427 let base = weight(a, b, c);
428 let score = match impact_type {
429 ImpactType::Penalty => -base,
430 ImpactType::Reward => base,
431 };
432 total = total + score;
433 }
434 }
435 }
436 }
437
438 total
439 }
440
441 fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
442 if let Some(key) = self.index_to_key.remove(&index) {
443 if let Some(indices) = self.key_to_indices.get_mut(&key) {
444 indices.remove(&index);
445 if indices.is_empty() {
446 self.key_to_indices.remove(&key);
447 }
448 }
449 }
450
451 let Some(triples) = self.entity_to_matches.remove(&index) else {
452 return Sc::zero();
453 };
454
455 let mut total = Sc::zero();
456 for triple in triples {
457 self.matches.remove(&triple);
458
459 let (i, j, k) = triple;
460 for &other in &[i, j, k] {
461 if other != index {
462 if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
463 other_set.remove(&triple);
464 if other_set.is_empty() {
465 self.entity_to_matches.remove(&other);
466 }
467 }
468 }
469 }
470
471 if i < entities.len() && j < entities.len() && k < entities.len() {
472 let score =
473 self.compute_score(&entities[i], &entities[j], &entities[k]);
474 total = total - score;
475 }
476 }
477
478 total
479 }
480 }
481
482 impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
483 for $struct_name<S, A, K, E, KE, F, W, Sc>
484 where
485 S: Send + Sync + 'static,
486 A: Clone + Debug + Send + Sync + 'static,
487 K: Eq + Hash + Clone + Send + Sync,
488 E: Fn(&S) -> &[A] + Send + Sync,
489 KE: Fn(&A) -> K + Send + Sync,
490 F: Fn(&S, &A, &A, &A) -> bool + Send + Sync,
491 W: Fn(&A, &A, &A) -> Sc + Send + Sync,
492 Sc: Score,
493 {
494 fn evaluate(&self, solution: &S) -> Sc {
495 let entities = (self.extractor)(solution);
496 let mut total = Sc::zero();
497
498 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
499 for (i, entity) in entities.iter().enumerate() {
500 let key = (self.key_extractor)(entity);
501 temp_index.entry(key).or_default().push(i);
502 }
503
504 for indices in temp_index.values() {
505 for pos_i in 0..indices.len() {
506 for pos_j in (pos_i + 1)..indices.len() {
507 for pos_k in (pos_j + 1)..indices.len() {
508 let i = indices[pos_i];
509 let j = indices[pos_j];
510 let k = indices[pos_k];
511 let a = &entities[i];
512 let b = &entities[j];
513 let c = &entities[k];
514 if (self.filter)(solution, a, b, c) {
515 total = total + self.compute_score(a, b, c);
516 }
517 }
518 }
519 }
520 }
521
522 total
523 }
524
525 fn match_count(&self, solution: &S) -> usize {
526 let entities = (self.extractor)(solution);
527 let mut count = 0;
528
529 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
530 for (i, entity) in entities.iter().enumerate() {
531 let key = (self.key_extractor)(entity);
532 temp_index.entry(key).or_default().push(i);
533 }
534
535 for indices in temp_index.values() {
536 for pos_i in 0..indices.len() {
537 for pos_j in (pos_i + 1)..indices.len() {
538 for pos_k in (pos_j + 1)..indices.len() {
539 let i = indices[pos_i];
540 let j = indices[pos_j];
541 let k = indices[pos_k];
542 if (self.filter)(
543 solution,
544 &entities[i],
545 &entities[j],
546 &entities[k],
547 ) {
548 count += 1;
549 }
550 }
551 }
552 }
553 }
554
555 count
556 }
557
558 fn initialize(&mut self, solution: &S) -> Sc {
559 self.reset();
560 let entities = (self.extractor)(solution);
561 let mut total = Sc::zero();
562 for i in 0..entities.len() {
563 total = total + self.insert_entity(solution, entities, i);
564 }
565 total
566 }
567
568 fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
569 let entities = (self.extractor)(solution);
570 self.insert_entity(solution, entities, entity_index)
571 }
572
573 fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
574 let entities = (self.extractor)(solution);
575 self.retract_entity(entities, entity_index)
576 }
577
578 fn reset(&mut self) {
579 self.entity_to_matches.clear();
580 self.matches.clear();
581 self.key_to_indices.clear();
582 self.index_to_key.clear();
583 }
584
585 fn name(&self) -> &str {
586 &self.constraint_ref.name
587 }
588
589 fn is_hard(&self) -> bool {
590 self.is_hard
591 }
592
593 fn constraint_ref(&self) -> ConstraintRef {
594 self.constraint_ref.clone()
595 }
596
597 fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
598 $crate::impl_get_matches_nary!(tri: self, solution)
599 }
600 }
601
602 impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
603 for $struct_name<S, A, K, E, KE, F, W, Sc>
604 {
605 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
606 f.debug_struct(stringify!($struct_name))
607 .field("name", &self.constraint_ref.name)
608 .field("impact_type", &self.impact_type)
609 .field("match_count", &self.matches.len())
610 .finish()
611 }
612 }
613 };
614
615 (quad, $struct_name:ident) => {
617 pub struct $struct_name<S, A, K, E, KE, F, W, Sc>
623 where
624 Sc: Score,
625 {
626 constraint_ref: ConstraintRef,
627 impact_type: ImpactType,
628 extractor: E,
629 key_extractor: KE,
630 filter: F,
631 weight: W,
632 is_hard: bool,
633 entity_to_matches: HashMap<usize, HashSet<(usize, usize, usize, usize)>>,
634 matches: HashSet<(usize, usize, usize, usize)>,
635 key_to_indices: HashMap<K, HashSet<usize>>,
636 index_to_key: HashMap<usize, K>,
637 _phantom: PhantomData<(S, A, Sc)>,
638 }
639
640 impl<S, A, K, E, KE, F, W, Sc> $struct_name<S, A, K, E, KE, F, W, Sc>
641 where
642 S: 'static,
643 A: Clone + 'static,
644 K: Eq + Hash + Clone,
645 E: Fn(&S) -> &[A],
646 KE: Fn(&A) -> K,
647 F: Fn(&S, &A, &A, &A, &A) -> bool,
648 W: Fn(&A, &A, &A, &A) -> Sc,
649 Sc: Score,
650 {
651 pub fn new(
652 constraint_ref: ConstraintRef,
653 impact_type: ImpactType,
654 extractor: E,
655 key_extractor: KE,
656 filter: F,
657 weight: W,
658 is_hard: bool,
659 ) -> Self {
660 Self {
661 constraint_ref,
662 impact_type,
663 extractor,
664 key_extractor,
665 filter,
666 weight,
667 is_hard,
668 entity_to_matches: HashMap::new(),
669 matches: HashSet::new(),
670 key_to_indices: HashMap::new(),
671 index_to_key: HashMap::new(),
672 _phantom: PhantomData,
673 }
674 }
675
676 #[inline]
677 fn compute_score(&self, a: &A, b: &A, c: &A, d: &A) -> Sc {
678 let base = (self.weight)(a, b, c, d);
679 match self.impact_type {
680 ImpactType::Penalty => -base,
681 ImpactType::Reward => base,
682 }
683 }
684
685 fn insert_entity(&mut self, solution: &S, entities: &[A], index: usize) -> Sc {
686 if index >= entities.len() {
687 return Sc::zero();
688 }
689
690 let entity = &entities[index];
691 let key = (self.key_extractor)(entity);
692
693 self.index_to_key.insert(index, key.clone());
694 self.key_to_indices
695 .entry(key.clone())
696 .or_default()
697 .insert(index);
698
699 let key_to_indices = &self.key_to_indices;
700 let matches = &mut self.matches;
701 let entity_to_matches = &mut self.entity_to_matches;
702 let filter = &self.filter;
703 let weight = &self.weight;
704 let impact_type = self.impact_type;
705
706 let mut total = Sc::zero();
707 if let Some(others) = key_to_indices.get(&key) {
708 for &i in others {
709 if i == index {
710 continue;
711 }
712 for &j in others {
713 if j <= i || j == index {
714 continue;
715 }
716 for &k in others {
717 if k <= j || k == index {
718 continue;
719 }
720
721 let mut arr = [index, i, j, k];
722 arr.sort();
723 let [a_idx, b_idx, c_idx, d_idx] = arr;
724 let quad = (a_idx, b_idx, c_idx, d_idx);
725
726 if matches.contains(&quad) {
727 continue;
728 }
729
730 let a = &entities[a_idx];
731 let b = &entities[b_idx];
732 let c = &entities[c_idx];
733 let d = &entities[d_idx];
734
735 if filter(solution, a, b, c, d) && matches.insert(quad) {
736 entity_to_matches.entry(a_idx).or_default().insert(quad);
737 entity_to_matches.entry(b_idx).or_default().insert(quad);
738 entity_to_matches.entry(c_idx).or_default().insert(quad);
739 entity_to_matches.entry(d_idx).or_default().insert(quad);
740 let base = weight(a, b, c, d);
741 let score = match impact_type {
742 ImpactType::Penalty => -base,
743 ImpactType::Reward => base,
744 };
745 total = total + score;
746 }
747 }
748 }
749 }
750 }
751
752 total
753 }
754
755 fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
756 if let Some(key) = self.index_to_key.remove(&index) {
757 if let Some(indices) = self.key_to_indices.get_mut(&key) {
758 indices.remove(&index);
759 if indices.is_empty() {
760 self.key_to_indices.remove(&key);
761 }
762 }
763 }
764
765 let Some(quads) = self.entity_to_matches.remove(&index) else {
766 return Sc::zero();
767 };
768
769 let mut total = Sc::zero();
770 for quad in quads {
771 self.matches.remove(&quad);
772
773 let (i, j, k, l) = quad;
774 for &other in &[i, j, k, l] {
775 if other != index {
776 if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
777 other_set.remove(&quad);
778 if other_set.is_empty() {
779 self.entity_to_matches.remove(&other);
780 }
781 }
782 }
783 }
784
785 if i < entities.len()
786 && j < entities.len()
787 && k < entities.len()
788 && l < entities.len()
789 {
790 let score = self.compute_score(
791 &entities[i],
792 &entities[j],
793 &entities[k],
794 &entities[l],
795 );
796 total = total - score;
797 }
798 }
799
800 total
801 }
802 }
803
804 impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
805 for $struct_name<S, A, K, E, KE, F, W, Sc>
806 where
807 S: Send + Sync + 'static,
808 A: Clone + Debug + Send + Sync + 'static,
809 K: Eq + Hash + Clone + Send + Sync,
810 E: Fn(&S) -> &[A] + Send + Sync,
811 KE: Fn(&A) -> K + Send + Sync,
812 F: Fn(&S, &A, &A, &A, &A) -> bool + Send + Sync,
813 W: Fn(&A, &A, &A, &A) -> Sc + Send + Sync,
814 Sc: Score,
815 {
816 fn evaluate(&self, solution: &S) -> Sc {
817 let entities = (self.extractor)(solution);
818 let mut total = Sc::zero();
819
820 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
821 for (i, entity) in entities.iter().enumerate() {
822 let key = (self.key_extractor)(entity);
823 temp_index.entry(key).or_default().push(i);
824 }
825
826 for indices in temp_index.values() {
827 for pos_i in 0..indices.len() {
828 for pos_j in (pos_i + 1)..indices.len() {
829 for pos_k in (pos_j + 1)..indices.len() {
830 for pos_l in (pos_k + 1)..indices.len() {
831 let i = indices[pos_i];
832 let j = indices[pos_j];
833 let k = indices[pos_k];
834 let l = indices[pos_l];
835 let a = &entities[i];
836 let b = &entities[j];
837 let c = &entities[k];
838 let d = &entities[l];
839 if (self.filter)(solution, a, b, c, d) {
840 total = total + self.compute_score(a, b, c, d);
841 }
842 }
843 }
844 }
845 }
846 }
847
848 total
849 }
850
851 fn match_count(&self, solution: &S) -> usize {
852 let entities = (self.extractor)(solution);
853 let mut count = 0;
854
855 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
856 for (i, entity) in entities.iter().enumerate() {
857 let key = (self.key_extractor)(entity);
858 temp_index.entry(key).or_default().push(i);
859 }
860
861 for indices in temp_index.values() {
862 for pos_i in 0..indices.len() {
863 for pos_j in (pos_i + 1)..indices.len() {
864 for pos_k in (pos_j + 1)..indices.len() {
865 for pos_l in (pos_k + 1)..indices.len() {
866 let i = indices[pos_i];
867 let j = indices[pos_j];
868 let k = indices[pos_k];
869 let l = indices[pos_l];
870 if (self.filter)(
871 solution,
872 &entities[i],
873 &entities[j],
874 &entities[k],
875 &entities[l],
876 ) {
877 count += 1;
878 }
879 }
880 }
881 }
882 }
883 }
884
885 count
886 }
887
888 fn initialize(&mut self, solution: &S) -> Sc {
889 self.reset();
890 let entities = (self.extractor)(solution);
891 let mut total = Sc::zero();
892 for i in 0..entities.len() {
893 total = total + self.insert_entity(solution, entities, i);
894 }
895 total
896 }
897
898 fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
899 let entities = (self.extractor)(solution);
900 self.insert_entity(solution, entities, entity_index)
901 }
902
903 fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
904 let entities = (self.extractor)(solution);
905 self.retract_entity(entities, entity_index)
906 }
907
908 fn reset(&mut self) {
909 self.entity_to_matches.clear();
910 self.matches.clear();
911 self.key_to_indices.clear();
912 self.index_to_key.clear();
913 }
914
915 fn name(&self) -> &str {
916 &self.constraint_ref.name
917 }
918
919 fn is_hard(&self) -> bool {
920 self.is_hard
921 }
922
923 fn constraint_ref(&self) -> ConstraintRef {
924 self.constraint_ref.clone()
925 }
926
927 fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
928 $crate::impl_get_matches_nary!(quad: self, solution)
929 }
930 }
931
932 impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
933 for $struct_name<S, A, K, E, KE, F, W, Sc>
934 {
935 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
936 f.debug_struct(stringify!($struct_name))
937 .field("name", &self.constraint_ref.name)
938 .field("impact_type", &self.impact_type)
939 .field("match_count", &self.matches.len())
940 .finish()
941 }
942 }
943 };
944
945 (penta, $struct_name:ident) => {
947 pub struct $struct_name<S, A, K, E, KE, F, W, Sc>
953 where
954 Sc: Score,
955 {
956 constraint_ref: ConstraintRef,
957 impact_type: ImpactType,
958 extractor: E,
959 key_extractor: KE,
960 filter: F,
961 weight: W,
962 is_hard: bool,
963 entity_to_matches: HashMap<usize, HashSet<(usize, usize, usize, usize, usize)>>,
964 matches: HashSet<(usize, usize, usize, usize, usize)>,
965 key_to_indices: HashMap<K, HashSet<usize>>,
966 index_to_key: HashMap<usize, K>,
967 _phantom: PhantomData<(S, A, Sc)>,
968 }
969
970 impl<S, A, K, E, KE, F, W, Sc> $struct_name<S, A, K, E, KE, F, W, Sc>
971 where
972 S: 'static,
973 A: Clone + 'static,
974 K: Eq + Hash + Clone,
975 E: Fn(&S) -> &[A],
976 KE: Fn(&A) -> K,
977 F: Fn(&S, &A, &A, &A, &A, &A) -> bool,
978 W: Fn(&A, &A, &A, &A, &A) -> Sc,
979 Sc: Score,
980 {
981 pub fn new(
982 constraint_ref: ConstraintRef,
983 impact_type: ImpactType,
984 extractor: E,
985 key_extractor: KE,
986 filter: F,
987 weight: W,
988 is_hard: bool,
989 ) -> Self {
990 Self {
991 constraint_ref,
992 impact_type,
993 extractor,
994 key_extractor,
995 filter,
996 weight,
997 is_hard,
998 entity_to_matches: HashMap::new(),
999 matches: HashSet::new(),
1000 key_to_indices: HashMap::new(),
1001 index_to_key: HashMap::new(),
1002 _phantom: PhantomData,
1003 }
1004 }
1005
1006 #[inline]
1007 fn compute_score(&self, a: &A, b: &A, c: &A, d: &A, e: &A) -> Sc {
1008 let base = (self.weight)(a, b, c, d, e);
1009 match self.impact_type {
1010 ImpactType::Penalty => -base,
1011 ImpactType::Reward => base,
1012 }
1013 }
1014
1015 fn insert_entity(&mut self, solution: &S, entities: &[A], index: usize) -> Sc {
1016 if index >= entities.len() {
1017 return Sc::zero();
1018 }
1019
1020 let entity = &entities[index];
1021 let key = (self.key_extractor)(entity);
1022
1023 self.index_to_key.insert(index, key.clone());
1024 self.key_to_indices
1025 .entry(key.clone())
1026 .or_default()
1027 .insert(index);
1028
1029 let key_to_indices = &self.key_to_indices;
1030 let matches = &mut self.matches;
1031 let entity_to_matches = &mut self.entity_to_matches;
1032 let filter = &self.filter;
1033 let weight = &self.weight;
1034 let impact_type = self.impact_type;
1035
1036 let mut total = Sc::zero();
1037 if let Some(others) = key_to_indices.get(&key) {
1038 for &i in others {
1039 if i == index {
1040 continue;
1041 }
1042 for &j in others {
1043 if j <= i || j == index {
1044 continue;
1045 }
1046 for &k in others {
1047 if k <= j || k == index {
1048 continue;
1049 }
1050 for &l in others {
1051 if l <= k || l == index {
1052 continue;
1053 }
1054
1055 let mut arr = [index, i, j, k, l];
1056 arr.sort();
1057 let [a_idx, b_idx, c_idx, d_idx, e_idx] = arr;
1058 let penta = (a_idx, b_idx, c_idx, d_idx, e_idx);
1059
1060 if matches.contains(&penta) {
1061 continue;
1062 }
1063
1064 let a = &entities[a_idx];
1065 let b = &entities[b_idx];
1066 let c = &entities[c_idx];
1067 let d = &entities[d_idx];
1068 let e = &entities[e_idx];
1069
1070 if filter(solution, a, b, c, d, e) && matches.insert(penta) {
1071 entity_to_matches.entry(a_idx).or_default().insert(penta);
1072 entity_to_matches.entry(b_idx).or_default().insert(penta);
1073 entity_to_matches.entry(c_idx).or_default().insert(penta);
1074 entity_to_matches.entry(d_idx).or_default().insert(penta);
1075 entity_to_matches.entry(e_idx).or_default().insert(penta);
1076 let base = weight(a, b, c, d, e);
1077 let score = match impact_type {
1078 ImpactType::Penalty => -base,
1079 ImpactType::Reward => base,
1080 };
1081 total = total + score;
1082 }
1083 }
1084 }
1085 }
1086 }
1087 }
1088
1089 total
1090 }
1091
1092 fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
1093 if let Some(key) = self.index_to_key.remove(&index) {
1094 if let Some(indices) = self.key_to_indices.get_mut(&key) {
1095 indices.remove(&index);
1096 if indices.is_empty() {
1097 self.key_to_indices.remove(&key);
1098 }
1099 }
1100 }
1101
1102 let Some(pentas) = self.entity_to_matches.remove(&index) else {
1103 return Sc::zero();
1104 };
1105
1106 let mut total = Sc::zero();
1107 for penta in pentas {
1108 self.matches.remove(&penta);
1109
1110 let (i, j, k, l, m) = penta;
1111 for &other in &[i, j, k, l, m] {
1112 if other != index {
1113 if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
1114 other_set.remove(&penta);
1115 if other_set.is_empty() {
1116 self.entity_to_matches.remove(&other);
1117 }
1118 }
1119 }
1120 }
1121
1122 if i < entities.len()
1123 && j < entities.len()
1124 && k < entities.len()
1125 && l < entities.len()
1126 && m < entities.len()
1127 {
1128 let score = self.compute_score(
1129 &entities[i],
1130 &entities[j],
1131 &entities[k],
1132 &entities[l],
1133 &entities[m],
1134 );
1135 total = total - score;
1136 }
1137 }
1138
1139 total
1140 }
1141 }
1142
1143 impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
1144 for $struct_name<S, A, K, E, KE, F, W, Sc>
1145 where
1146 S: Send + Sync + 'static,
1147 A: Clone + Debug + Send + Sync + 'static,
1148 K: Eq + Hash + Clone + Send + Sync,
1149 E: Fn(&S) -> &[A] + Send + Sync,
1150 KE: Fn(&A) -> K + Send + Sync,
1151 F: Fn(&S, &A, &A, &A, &A, &A) -> bool + Send + Sync,
1152 W: Fn(&A, &A, &A, &A, &A) -> Sc + Send + Sync,
1153 Sc: Score,
1154 {
1155 fn evaluate(&self, solution: &S) -> Sc {
1156 let entities = (self.extractor)(solution);
1157 let mut total = Sc::zero();
1158
1159 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
1160 for (i, entity) in entities.iter().enumerate() {
1161 let key = (self.key_extractor)(entity);
1162 temp_index.entry(key).or_default().push(i);
1163 }
1164
1165 for indices in temp_index.values() {
1166 for pos_i in 0..indices.len() {
1167 for pos_j in (pos_i + 1)..indices.len() {
1168 for pos_k in (pos_j + 1)..indices.len() {
1169 for pos_l in (pos_k + 1)..indices.len() {
1170 for pos_m in (pos_l + 1)..indices.len() {
1171 let i = indices[pos_i];
1172 let j = indices[pos_j];
1173 let k = indices[pos_k];
1174 let l = indices[pos_l];
1175 let m = indices[pos_m];
1176 let a = &entities[i];
1177 let b = &entities[j];
1178 let c = &entities[k];
1179 let d = &entities[l];
1180 let e = &entities[m];
1181 if (self.filter)(solution, a, b, c, d, e) {
1182 total = total + self.compute_score(a, b, c, d, e);
1183 }
1184 }
1185 }
1186 }
1187 }
1188 }
1189 }
1190
1191 total
1192 }
1193
1194 fn match_count(&self, solution: &S) -> usize {
1195 let entities = (self.extractor)(solution);
1196 let mut count = 0;
1197
1198 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
1199 for (i, entity) in entities.iter().enumerate() {
1200 let key = (self.key_extractor)(entity);
1201 temp_index.entry(key).or_default().push(i);
1202 }
1203
1204 for indices in temp_index.values() {
1205 for pos_i in 0..indices.len() {
1206 for pos_j in (pos_i + 1)..indices.len() {
1207 for pos_k in (pos_j + 1)..indices.len() {
1208 for pos_l in (pos_k + 1)..indices.len() {
1209 for pos_m in (pos_l + 1)..indices.len() {
1210 let i = indices[pos_i];
1211 let j = indices[pos_j];
1212 let k = indices[pos_k];
1213 let l = indices[pos_l];
1214 let m = indices[pos_m];
1215 if (self.filter)(
1216 solution,
1217 &entities[i],
1218 &entities[j],
1219 &entities[k],
1220 &entities[l],
1221 &entities[m],
1222 ) {
1223 count += 1;
1224 }
1225 }
1226 }
1227 }
1228 }
1229 }
1230 }
1231
1232 count
1233 }
1234
1235 fn initialize(&mut self, solution: &S) -> Sc {
1236 self.reset();
1237 let entities = (self.extractor)(solution);
1238 let mut total = Sc::zero();
1239 for i in 0..entities.len() {
1240 total = total + self.insert_entity(solution, entities, i);
1241 }
1242 total
1243 }
1244
1245 fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
1246 let entities = (self.extractor)(solution);
1247 self.insert_entity(solution, entities, entity_index)
1248 }
1249
1250 fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
1251 let entities = (self.extractor)(solution);
1252 self.retract_entity(entities, entity_index)
1253 }
1254
1255 fn reset(&mut self) {
1256 self.entity_to_matches.clear();
1257 self.matches.clear();
1258 self.key_to_indices.clear();
1259 self.index_to_key.clear();
1260 }
1261
1262 fn name(&self) -> &str {
1263 &self.constraint_ref.name
1264 }
1265
1266 fn is_hard(&self) -> bool {
1267 self.is_hard
1268 }
1269
1270 fn constraint_ref(&self) -> ConstraintRef {
1271 self.constraint_ref.clone()
1272 }
1273
1274 fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
1275 $crate::impl_get_matches_nary!(penta: self, solution)
1276 }
1277 }
1278
1279 impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
1280 for $struct_name<S, A, K, E, KE, F, W, Sc>
1281 {
1282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1283 f.debug_struct(stringify!($struct_name))
1284 .field("name", &self.constraint_ref.name)
1285 .field("impact_type", &self.impact_type)
1286 .field("match_count", &self.matches.len())
1287 .finish()
1288 }
1289 }
1290 };
1291}
1292
1293pub use impl_incremental_nary_constraint;
1294
1295use std::collections::{HashMap, HashSet};
1297use std::fmt::Debug;
1298use std::hash::Hash;
1299use std::marker::PhantomData;
1300
1301use solverforge_core::score::Score;
1302use solverforge_core::{ConstraintRef, ImpactType};
1303
1304use crate::api::analysis::DetailedConstraintMatch;
1305use crate::api::constraint_set::IncrementalConstraint;
1306
1307impl_incremental_nary_constraint!(bi, IncrementalBiConstraint);
1308impl_incremental_nary_constraint!(tri, IncrementalTriConstraint);
1309impl_incremental_nary_constraint!(quad, IncrementalQuadConstraint);
1310impl_incremental_nary_constraint!(penta, IncrementalPentaConstraint);