solverforge_scoring/constraint/
nary_incremental.rs

1//! Macro-generated N-ary incremental constraints for self-join evaluation.
2//!
3//! This module provides the `impl_incremental_nary_constraint!` macro that generates
4//! fully monomorphized incremental constraint implementations for bi/tri/quad/penta arities.
5//!
6//! Zero-erasure: all closures are concrete generic types, no trait objects, no Arc.
7
8/// Generates an incremental N-ary constraint struct and implementations.
9///
10/// This macro produces:
11/// - The constraint struct with all fields
12/// - Constructor `new()`
13/// - Private helper methods `compute_score()`, `insert_entity()`, `retract_entity()`
14/// - Full `IncrementalConstraint<S, Sc>` trait implementation
15/// - `Debug` implementation
16///
17/// # Usage
18///
19/// ```text
20/// impl_incremental_nary_constraint!(bi, IncrementalBiConstraint, (usize, usize), 2, a b);
21/// impl_incremental_nary_constraint!(tri, IncrementalTriConstraint, (usize, usize, usize), 3, a b c);
22/// impl_incremental_nary_constraint!(quad, IncrementalQuadConstraint, (usize, usize, usize, usize), 4, a b c d);
23/// impl_incremental_nary_constraint!(penta, IncrementalPentaConstraint, (usize, usize, usize, usize, usize), 5, a b c d e);
24/// ```
25#[macro_export]
26macro_rules! impl_incremental_nary_constraint {
27    // ==================== BI (2-arity) ====================
28    (bi, $struct_name:ident) => {
29        /// Zero-erasure incremental bi-constraint for self-joins.
30        ///
31        /// All function types are concrete generics - no trait objects, no Arc.
32        /// Uses key-based indexing: entities are grouped by join key for O(k) lookups.
33        pub struct $struct_name<S, A, K, E, KE, F, W, Sc>
34        where
35            Sc: Score,
36        {
37            constraint_ref: ConstraintRef,
38            impact_type: ImpactType,
39            extractor: E,
40            key_extractor: KE,
41            filter: F,
42            weight: W,
43            is_hard: bool,
44            entity_to_matches: HashMap<usize, HashSet<(usize, usize)>>,
45            matches: HashSet<(usize, usize)>,
46            key_to_indices: HashMap<K, HashSet<usize>>,
47            index_to_key: HashMap<usize, K>,
48            _phantom: PhantomData<(S, A, Sc)>,
49        }
50
51        impl<S, A, K, E, KE, F, W, Sc> $struct_name<S, A, K, E, KE, F, W, Sc>
52        where
53            S: 'static,
54            A: Clone + 'static,
55            K: Eq + Hash + Clone,
56            E: Fn(&S) -> &[A],
57            KE: Fn(&A) -> K,
58            F: Fn(&S, &A, &A) -> bool,
59            W: Fn(&A, &A) -> Sc,
60            Sc: Score,
61        {
62            pub fn new(
63                constraint_ref: ConstraintRef,
64                impact_type: ImpactType,
65                extractor: E,
66                key_extractor: KE,
67                filter: F,
68                weight: W,
69                is_hard: bool,
70            ) -> Self {
71                Self {
72                    constraint_ref,
73                    impact_type,
74                    extractor,
75                    key_extractor,
76                    filter,
77                    weight,
78                    is_hard,
79                    entity_to_matches: HashMap::new(),
80                    matches: HashSet::new(),
81                    key_to_indices: HashMap::new(),
82                    index_to_key: HashMap::new(),
83                    _phantom: PhantomData,
84                }
85            }
86
87            #[inline]
88            fn compute_score(&self, a: &A, b: &A) -> Sc {
89                let base = (self.weight)(a, b);
90                match self.impact_type {
91                    ImpactType::Penalty => -base,
92                    ImpactType::Reward => base,
93                }
94            }
95
96            fn insert_entity(&mut self, solution: &S, entities: &[A], index: usize) -> Sc {
97                if index >= entities.len() {
98                    return Sc::zero();
99                }
100
101                let entity = &entities[index];
102                let key = (self.key_extractor)(entity);
103
104                self.index_to_key.insert(index, key.clone());
105                self.key_to_indices
106                    .entry(key.clone())
107                    .or_default()
108                    .insert(index);
109
110                let key_to_indices = &self.key_to_indices;
111                let matches = &mut self.matches;
112                let entity_to_matches = &mut self.entity_to_matches;
113                let filter = &self.filter;
114                let weight = &self.weight;
115                let impact_type = self.impact_type;
116
117                let mut total = Sc::zero();
118                if let Some(others) = key_to_indices.get(&key) {
119                    for &other_idx in others {
120                        if other_idx == index {
121                            continue;
122                        }
123
124                        let other = &entities[other_idx];
125                        let (low_idx, high_idx, low_entity, high_entity) = if index < other_idx {
126                            (index, other_idx, entity, other)
127                        } else {
128                            (other_idx, index, other, entity)
129                        };
130
131                        if filter(solution, low_entity, high_entity) {
132                            let pair = (low_idx, high_idx);
133                            if matches.insert(pair) {
134                                entity_to_matches.entry(low_idx).or_default().insert(pair);
135                                entity_to_matches.entry(high_idx).or_default().insert(pair);
136                                let base = weight(low_entity, high_entity);
137                                let score = match impact_type {
138                                    ImpactType::Penalty => -base,
139                                    ImpactType::Reward => base,
140                                };
141                                total = total + score;
142                            }
143                        }
144                    }
145                }
146
147                total
148            }
149
150            fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
151                if let Some(key) = self.index_to_key.remove(&index) {
152                    if let Some(indices) = self.key_to_indices.get_mut(&key) {
153                        indices.remove(&index);
154                        if indices.is_empty() {
155                            self.key_to_indices.remove(&key);
156                        }
157                    }
158                }
159
160                let Some(pairs) = self.entity_to_matches.remove(&index) else {
161                    return Sc::zero();
162                };
163
164                let mut total = Sc::zero();
165                for pair in pairs {
166                    self.matches.remove(&pair);
167
168                    let other = if pair.0 == index { pair.1 } else { pair.0 };
169                    if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
170                        other_set.remove(&pair);
171                        if other_set.is_empty() {
172                            self.entity_to_matches.remove(&other);
173                        }
174                    }
175
176                    let (low_idx, high_idx) = pair;
177                    if low_idx < entities.len() && high_idx < entities.len() {
178                        let score = self.compute_score(&entities[low_idx], &entities[high_idx]);
179                        total = total - score;
180                    }
181                }
182
183                total
184            }
185        }
186
187        impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
188            for $struct_name<S, A, K, E, KE, F, W, Sc>
189        where
190            S: Send + Sync + 'static,
191            A: Clone + Debug + Send + Sync + 'static,
192            K: Eq + Hash + Clone + Send + Sync,
193            E: Fn(&S) -> &[A] + Send + Sync,
194            KE: Fn(&A) -> K + Send + Sync,
195            F: Fn(&S, &A, &A) -> bool + Send + Sync,
196            W: Fn(&A, &A) -> Sc + Send + Sync,
197            Sc: Score,
198        {
199            fn evaluate(&self, solution: &S) -> Sc {
200                let entities = (self.extractor)(solution);
201                let mut total = Sc::zero();
202
203                let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
204                for (i, entity) in entities.iter().enumerate() {
205                    let key = (self.key_extractor)(entity);
206                    temp_index.entry(key).or_default().push(i);
207                }
208
209                for indices in temp_index.values() {
210                    for i in 0..indices.len() {
211                        for j in (i + 1)..indices.len() {
212                            let low = indices[i];
213                            let high = indices[j];
214                            let a = &entities[low];
215                            let b = &entities[high];
216                            if (self.filter)(solution, a, b) {
217                                total = total + self.compute_score(a, b);
218                            }
219                        }
220                    }
221                }
222
223                total
224            }
225
226            fn match_count(&self, solution: &S) -> usize {
227                let entities = (self.extractor)(solution);
228                let mut count = 0;
229
230                let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
231                for (i, entity) in entities.iter().enumerate() {
232                    let key = (self.key_extractor)(entity);
233                    temp_index.entry(key).or_default().push(i);
234                }
235
236                for indices in temp_index.values() {
237                    for i in 0..indices.len() {
238                        for j in (i + 1)..indices.len() {
239                            let low = indices[i];
240                            let high = indices[j];
241                            if (self.filter)(solution, &entities[low], &entities[high]) {
242                                count += 1;
243                            }
244                        }
245                    }
246                }
247
248                count
249            }
250
251            fn initialize(&mut self, solution: &S) -> Sc {
252                self.reset();
253                let entities = (self.extractor)(solution);
254                let mut total = Sc::zero();
255                for i in 0..entities.len() {
256                    total = total + self.insert_entity(solution, entities, i);
257                }
258                total
259            }
260
261            fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
262                let entities = (self.extractor)(solution);
263                self.insert_entity(solution, entities, entity_index)
264            }
265
266            fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
267                let entities = (self.extractor)(solution);
268                self.retract_entity(entities, entity_index)
269            }
270
271            fn reset(&mut self) {
272                self.entity_to_matches.clear();
273                self.matches.clear();
274                self.key_to_indices.clear();
275                self.index_to_key.clear();
276            }
277
278            fn name(&self) -> &str {
279                &self.constraint_ref.name
280            }
281
282            fn is_hard(&self) -> bool {
283                self.is_hard
284            }
285
286            fn constraint_ref(&self) -> ConstraintRef {
287                self.constraint_ref.clone()
288            }
289
290            fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
291                $crate::impl_get_matches_nary!(bi: self, solution)
292            }
293        }
294
295        impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
296            for $struct_name<S, A, K, E, KE, F, W, Sc>
297        {
298            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299                f.debug_struct(stringify!($struct_name))
300                    .field("name", &self.constraint_ref.name)
301                    .field("impact_type", &self.impact_type)
302                    .field("match_count", &self.matches.len())
303                    .finish()
304            }
305        }
306    };
307
308    // ==================== TRI (3-arity) ====================
309    (tri, $struct_name:ident) => {
310        /// Zero-erasure incremental tri-constraint for self-joins.
311        ///
312        /// All function types are concrete generics - no Arc, no dyn, fully monomorphized.
313        /// Uses key-based indexing: entities are grouped by join key for O(k) lookups.
314        /// Triples are ordered as (i, j, k) where i < j < k to avoid duplicates.
315        pub struct $struct_name<S, A, K, E, KE, F, W, Sc>
316        where
317            Sc: Score,
318        {
319            constraint_ref: ConstraintRef,
320            impact_type: ImpactType,
321            extractor: E,
322            key_extractor: KE,
323            filter: F,
324            weight: W,
325            is_hard: bool,
326            entity_to_matches: HashMap<usize, HashSet<(usize, usize, usize)>>,
327            matches: HashSet<(usize, usize, usize)>,
328            key_to_indices: HashMap<K, HashSet<usize>>,
329            index_to_key: HashMap<usize, K>,
330            _phantom: PhantomData<(S, A, Sc)>,
331        }
332
333        impl<S, A, K, E, KE, F, W, Sc> $struct_name<S, A, K, E, KE, F, W, Sc>
334        where
335            S: 'static,
336            A: Clone + 'static,
337            K: Eq + Hash + Clone,
338            E: Fn(&S) -> &[A],
339            KE: Fn(&A) -> K,
340            F: Fn(&S, &A, &A, &A) -> bool,
341            W: Fn(&A, &A, &A) -> Sc,
342            Sc: Score,
343        {
344            pub fn new(
345                constraint_ref: ConstraintRef,
346                impact_type: ImpactType,
347                extractor: E,
348                key_extractor: KE,
349                filter: F,
350                weight: W,
351                is_hard: bool,
352            ) -> Self {
353                Self {
354                    constraint_ref,
355                    impact_type,
356                    extractor,
357                    key_extractor,
358                    filter,
359                    weight,
360                    is_hard,
361                    entity_to_matches: HashMap::new(),
362                    matches: HashSet::new(),
363                    key_to_indices: HashMap::new(),
364                    index_to_key: HashMap::new(),
365                    _phantom: PhantomData,
366                }
367            }
368
369            #[inline]
370            fn compute_score(&self, a: &A, b: &A, c: &A) -> Sc {
371                let base = (self.weight)(a, b, c);
372                match self.impact_type {
373                    ImpactType::Penalty => -base,
374                    ImpactType::Reward => base,
375                }
376            }
377
378            fn insert_entity(&mut self, solution: &S, entities: &[A], index: usize) -> Sc {
379                if index >= entities.len() {
380                    return Sc::zero();
381                }
382
383                let entity = &entities[index];
384                let key = (self.key_extractor)(entity);
385
386                self.index_to_key.insert(index, key.clone());
387                self.key_to_indices
388                    .entry(key.clone())
389                    .or_default()
390                    .insert(index);
391
392                let key_to_indices = &self.key_to_indices;
393                let matches = &mut self.matches;
394                let entity_to_matches = &mut self.entity_to_matches;
395                let filter = &self.filter;
396                let weight = &self.weight;
397                let impact_type = self.impact_type;
398
399                let mut total = Sc::zero();
400                if let Some(others) = key_to_indices.get(&key) {
401                    for &i in others {
402                        if i == index {
403                            continue;
404                        }
405                        for &j in others {
406                            if j <= i || j == index {
407                                continue;
408                            }
409
410                            let mut arr = [index, i, j];
411                            arr.sort();
412                            let [a_idx, b_idx, c_idx] = arr;
413                            let triple = (a_idx, b_idx, c_idx);
414
415                            if matches.contains(&triple) {
416                                continue;
417                            }
418
419                            let a = &entities[a_idx];
420                            let b = &entities[b_idx];
421                            let c = &entities[c_idx];
422
423                            if filter(solution, a, b, c) && matches.insert(triple) {
424                                entity_to_matches.entry(a_idx).or_default().insert(triple);
425                                entity_to_matches.entry(b_idx).or_default().insert(triple);
426                                entity_to_matches.entry(c_idx).or_default().insert(triple);
427                                let base = weight(a, b, c);
428                                let score = match impact_type {
429                                    ImpactType::Penalty => -base,
430                                    ImpactType::Reward => base,
431                                };
432                                total = total + score;
433                            }
434                        }
435                    }
436                }
437
438                total
439            }
440
441            fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
442                if let Some(key) = self.index_to_key.remove(&index) {
443                    if let Some(indices) = self.key_to_indices.get_mut(&key) {
444                        indices.remove(&index);
445                        if indices.is_empty() {
446                            self.key_to_indices.remove(&key);
447                        }
448                    }
449                }
450
451                let Some(triples) = self.entity_to_matches.remove(&index) else {
452                    return Sc::zero();
453                };
454
455                let mut total = Sc::zero();
456                for triple in triples {
457                    self.matches.remove(&triple);
458
459                    let (i, j, k) = triple;
460                    for &other in &[i, j, k] {
461                        if other != index {
462                            if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
463                                other_set.remove(&triple);
464                                if other_set.is_empty() {
465                                    self.entity_to_matches.remove(&other);
466                                }
467                            }
468                        }
469                    }
470
471                    if i < entities.len() && j < entities.len() && k < entities.len() {
472                        let score =
473                            self.compute_score(&entities[i], &entities[j], &entities[k]);
474                        total = total - score;
475                    }
476                }
477
478                total
479            }
480        }
481
482        impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
483            for $struct_name<S, A, K, E, KE, F, W, Sc>
484        where
485            S: Send + Sync + 'static,
486            A: Clone + Debug + Send + Sync + 'static,
487            K: Eq + Hash + Clone + Send + Sync,
488            E: Fn(&S) -> &[A] + Send + Sync,
489            KE: Fn(&A) -> K + Send + Sync,
490            F: Fn(&S, &A, &A, &A) -> bool + Send + Sync,
491            W: Fn(&A, &A, &A) -> Sc + Send + Sync,
492            Sc: Score,
493        {
494            fn evaluate(&self, solution: &S) -> Sc {
495                let entities = (self.extractor)(solution);
496                let mut total = Sc::zero();
497
498                let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
499                for (i, entity) in entities.iter().enumerate() {
500                    let key = (self.key_extractor)(entity);
501                    temp_index.entry(key).or_default().push(i);
502                }
503
504                for indices in temp_index.values() {
505                    for pos_i in 0..indices.len() {
506                        for pos_j in (pos_i + 1)..indices.len() {
507                            for pos_k in (pos_j + 1)..indices.len() {
508                                let i = indices[pos_i];
509                                let j = indices[pos_j];
510                                let k = indices[pos_k];
511                                let a = &entities[i];
512                                let b = &entities[j];
513                                let c = &entities[k];
514                                if (self.filter)(solution, a, b, c) {
515                                    total = total + self.compute_score(a, b, c);
516                                }
517                            }
518                        }
519                    }
520                }
521
522                total
523            }
524
525            fn match_count(&self, solution: &S) -> usize {
526                let entities = (self.extractor)(solution);
527                let mut count = 0;
528
529                let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
530                for (i, entity) in entities.iter().enumerate() {
531                    let key = (self.key_extractor)(entity);
532                    temp_index.entry(key).or_default().push(i);
533                }
534
535                for indices in temp_index.values() {
536                    for pos_i in 0..indices.len() {
537                        for pos_j in (pos_i + 1)..indices.len() {
538                            for pos_k in (pos_j + 1)..indices.len() {
539                                let i = indices[pos_i];
540                                let j = indices[pos_j];
541                                let k = indices[pos_k];
542                                if (self.filter)(
543                                    solution,
544                                    &entities[i],
545                                    &entities[j],
546                                    &entities[k],
547                                ) {
548                                    count += 1;
549                                }
550                            }
551                        }
552                    }
553                }
554
555                count
556            }
557
558            fn initialize(&mut self, solution: &S) -> Sc {
559                self.reset();
560                let entities = (self.extractor)(solution);
561                let mut total = Sc::zero();
562                for i in 0..entities.len() {
563                    total = total + self.insert_entity(solution, entities, i);
564                }
565                total
566            }
567
568            fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
569                let entities = (self.extractor)(solution);
570                self.insert_entity(solution, entities, entity_index)
571            }
572
573            fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
574                let entities = (self.extractor)(solution);
575                self.retract_entity(entities, entity_index)
576            }
577
578            fn reset(&mut self) {
579                self.entity_to_matches.clear();
580                self.matches.clear();
581                self.key_to_indices.clear();
582                self.index_to_key.clear();
583            }
584
585            fn name(&self) -> &str {
586                &self.constraint_ref.name
587            }
588
589            fn is_hard(&self) -> bool {
590                self.is_hard
591            }
592
593            fn constraint_ref(&self) -> ConstraintRef {
594                self.constraint_ref.clone()
595            }
596
597            fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
598                $crate::impl_get_matches_nary!(tri: self, solution)
599            }
600        }
601
602        impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
603            for $struct_name<S, A, K, E, KE, F, W, Sc>
604        {
605            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
606                f.debug_struct(stringify!($struct_name))
607                    .field("name", &self.constraint_ref.name)
608                    .field("impact_type", &self.impact_type)
609                    .field("match_count", &self.matches.len())
610                    .finish()
611            }
612        }
613    };
614
615    // ==================== QUAD (4-arity) ====================
616    (quad, $struct_name:ident) => {
617        /// Zero-erasure incremental quad-constraint for self-joins.
618        ///
619        /// All function types are concrete generics - no Arc, no dyn, fully monomorphized.
620        /// Uses key-based indexing: entities are grouped by join key for O(k) lookups.
621        /// Quadruples are ordered as (i, j, k, l) where i < j < k < l to avoid duplicates.
622        pub struct $struct_name<S, A, K, E, KE, F, W, Sc>
623        where
624            Sc: Score,
625        {
626            constraint_ref: ConstraintRef,
627            impact_type: ImpactType,
628            extractor: E,
629            key_extractor: KE,
630            filter: F,
631            weight: W,
632            is_hard: bool,
633            entity_to_matches: HashMap<usize, HashSet<(usize, usize, usize, usize)>>,
634            matches: HashSet<(usize, usize, usize, usize)>,
635            key_to_indices: HashMap<K, HashSet<usize>>,
636            index_to_key: HashMap<usize, K>,
637            _phantom: PhantomData<(S, A, Sc)>,
638        }
639
640        impl<S, A, K, E, KE, F, W, Sc> $struct_name<S, A, K, E, KE, F, W, Sc>
641        where
642            S: 'static,
643            A: Clone + 'static,
644            K: Eq + Hash + Clone,
645            E: Fn(&S) -> &[A],
646            KE: Fn(&A) -> K,
647            F: Fn(&S, &A, &A, &A, &A) -> bool,
648            W: Fn(&A, &A, &A, &A) -> Sc,
649            Sc: Score,
650        {
651            pub fn new(
652                constraint_ref: ConstraintRef,
653                impact_type: ImpactType,
654                extractor: E,
655                key_extractor: KE,
656                filter: F,
657                weight: W,
658                is_hard: bool,
659            ) -> Self {
660                Self {
661                    constraint_ref,
662                    impact_type,
663                    extractor,
664                    key_extractor,
665                    filter,
666                    weight,
667                    is_hard,
668                    entity_to_matches: HashMap::new(),
669                    matches: HashSet::new(),
670                    key_to_indices: HashMap::new(),
671                    index_to_key: HashMap::new(),
672                    _phantom: PhantomData,
673                }
674            }
675
676            #[inline]
677            fn compute_score(&self, a: &A, b: &A, c: &A, d: &A) -> Sc {
678                let base = (self.weight)(a, b, c, d);
679                match self.impact_type {
680                    ImpactType::Penalty => -base,
681                    ImpactType::Reward => base,
682                }
683            }
684
685            fn insert_entity(&mut self, solution: &S, entities: &[A], index: usize) -> Sc {
686                if index >= entities.len() {
687                    return Sc::zero();
688                }
689
690                let entity = &entities[index];
691                let key = (self.key_extractor)(entity);
692
693                self.index_to_key.insert(index, key.clone());
694                self.key_to_indices
695                    .entry(key.clone())
696                    .or_default()
697                    .insert(index);
698
699                let key_to_indices = &self.key_to_indices;
700                let matches = &mut self.matches;
701                let entity_to_matches = &mut self.entity_to_matches;
702                let filter = &self.filter;
703                let weight = &self.weight;
704                let impact_type = self.impact_type;
705
706                let mut total = Sc::zero();
707                if let Some(others) = key_to_indices.get(&key) {
708                    for &i in others {
709                        if i == index {
710                            continue;
711                        }
712                        for &j in others {
713                            if j <= i || j == index {
714                                continue;
715                            }
716                            for &k in others {
717                                if k <= j || k == index {
718                                    continue;
719                                }
720
721                                let mut arr = [index, i, j, k];
722                                arr.sort();
723                                let [a_idx, b_idx, c_idx, d_idx] = arr;
724                                let quad = (a_idx, b_idx, c_idx, d_idx);
725
726                                if matches.contains(&quad) {
727                                    continue;
728                                }
729
730                                let a = &entities[a_idx];
731                                let b = &entities[b_idx];
732                                let c = &entities[c_idx];
733                                let d = &entities[d_idx];
734
735                                if filter(solution, a, b, c, d) && matches.insert(quad) {
736                                    entity_to_matches.entry(a_idx).or_default().insert(quad);
737                                    entity_to_matches.entry(b_idx).or_default().insert(quad);
738                                    entity_to_matches.entry(c_idx).or_default().insert(quad);
739                                    entity_to_matches.entry(d_idx).or_default().insert(quad);
740                                    let base = weight(a, b, c, d);
741                                    let score = match impact_type {
742                                        ImpactType::Penalty => -base,
743                                        ImpactType::Reward => base,
744                                    };
745                                    total = total + score;
746                                }
747                            }
748                        }
749                    }
750                }
751
752                total
753            }
754
755            fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
756                if let Some(key) = self.index_to_key.remove(&index) {
757                    if let Some(indices) = self.key_to_indices.get_mut(&key) {
758                        indices.remove(&index);
759                        if indices.is_empty() {
760                            self.key_to_indices.remove(&key);
761                        }
762                    }
763                }
764
765                let Some(quads) = self.entity_to_matches.remove(&index) else {
766                    return Sc::zero();
767                };
768
769                let mut total = Sc::zero();
770                for quad in quads {
771                    self.matches.remove(&quad);
772
773                    let (i, j, k, l) = quad;
774                    for &other in &[i, j, k, l] {
775                        if other != index {
776                            if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
777                                other_set.remove(&quad);
778                                if other_set.is_empty() {
779                                    self.entity_to_matches.remove(&other);
780                                }
781                            }
782                        }
783                    }
784
785                    if i < entities.len()
786                        && j < entities.len()
787                        && k < entities.len()
788                        && l < entities.len()
789                    {
790                        let score = self.compute_score(
791                            &entities[i],
792                            &entities[j],
793                            &entities[k],
794                            &entities[l],
795                        );
796                        total = total - score;
797                    }
798                }
799
800                total
801            }
802        }
803
804        impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
805            for $struct_name<S, A, K, E, KE, F, W, Sc>
806        where
807            S: Send + Sync + 'static,
808            A: Clone + Debug + Send + Sync + 'static,
809            K: Eq + Hash + Clone + Send + Sync,
810            E: Fn(&S) -> &[A] + Send + Sync,
811            KE: Fn(&A) -> K + Send + Sync,
812            F: Fn(&S, &A, &A, &A, &A) -> bool + Send + Sync,
813            W: Fn(&A, &A, &A, &A) -> Sc + Send + Sync,
814            Sc: Score,
815        {
816            fn evaluate(&self, solution: &S) -> Sc {
817                let entities = (self.extractor)(solution);
818                let mut total = Sc::zero();
819
820                let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
821                for (i, entity) in entities.iter().enumerate() {
822                    let key = (self.key_extractor)(entity);
823                    temp_index.entry(key).or_default().push(i);
824                }
825
826                for indices in temp_index.values() {
827                    for pos_i in 0..indices.len() {
828                        for pos_j in (pos_i + 1)..indices.len() {
829                            for pos_k in (pos_j + 1)..indices.len() {
830                                for pos_l in (pos_k + 1)..indices.len() {
831                                    let i = indices[pos_i];
832                                    let j = indices[pos_j];
833                                    let k = indices[pos_k];
834                                    let l = indices[pos_l];
835                                    let a = &entities[i];
836                                    let b = &entities[j];
837                                    let c = &entities[k];
838                                    let d = &entities[l];
839                                    if (self.filter)(solution, a, b, c, d) {
840                                        total = total + self.compute_score(a, b, c, d);
841                                    }
842                                }
843                            }
844                        }
845                    }
846                }
847
848                total
849            }
850
851            fn match_count(&self, solution: &S) -> usize {
852                let entities = (self.extractor)(solution);
853                let mut count = 0;
854
855                let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
856                for (i, entity) in entities.iter().enumerate() {
857                    let key = (self.key_extractor)(entity);
858                    temp_index.entry(key).or_default().push(i);
859                }
860
861                for indices in temp_index.values() {
862                    for pos_i in 0..indices.len() {
863                        for pos_j in (pos_i + 1)..indices.len() {
864                            for pos_k in (pos_j + 1)..indices.len() {
865                                for pos_l in (pos_k + 1)..indices.len() {
866                                    let i = indices[pos_i];
867                                    let j = indices[pos_j];
868                                    let k = indices[pos_k];
869                                    let l = indices[pos_l];
870                                    if (self.filter)(
871                                        solution,
872                                        &entities[i],
873                                        &entities[j],
874                                        &entities[k],
875                                        &entities[l],
876                                    ) {
877                                        count += 1;
878                                    }
879                                }
880                            }
881                        }
882                    }
883                }
884
885                count
886            }
887
888            fn initialize(&mut self, solution: &S) -> Sc {
889                self.reset();
890                let entities = (self.extractor)(solution);
891                let mut total = Sc::zero();
892                for i in 0..entities.len() {
893                    total = total + self.insert_entity(solution, entities, i);
894                }
895                total
896            }
897
898            fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
899                let entities = (self.extractor)(solution);
900                self.insert_entity(solution, entities, entity_index)
901            }
902
903            fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
904                let entities = (self.extractor)(solution);
905                self.retract_entity(entities, entity_index)
906            }
907
908            fn reset(&mut self) {
909                self.entity_to_matches.clear();
910                self.matches.clear();
911                self.key_to_indices.clear();
912                self.index_to_key.clear();
913            }
914
915            fn name(&self) -> &str {
916                &self.constraint_ref.name
917            }
918
919            fn is_hard(&self) -> bool {
920                self.is_hard
921            }
922
923            fn constraint_ref(&self) -> ConstraintRef {
924                self.constraint_ref.clone()
925            }
926
927            fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
928                $crate::impl_get_matches_nary!(quad: self, solution)
929            }
930        }
931
932        impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
933            for $struct_name<S, A, K, E, KE, F, W, Sc>
934        {
935            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
936                f.debug_struct(stringify!($struct_name))
937                    .field("name", &self.constraint_ref.name)
938                    .field("impact_type", &self.impact_type)
939                    .field("match_count", &self.matches.len())
940                    .finish()
941            }
942        }
943    };
944
945    // ==================== PENTA (5-arity) ====================
946    (penta, $struct_name:ident) => {
947        /// Zero-erasure incremental penta-constraint for self-joins.
948        ///
949        /// All function types are concrete generics - no Arc, no dyn, fully monomorphized.
950        /// Uses key-based indexing: entities are grouped by join key for O(k) lookups.
951        /// Quintuples are ordered as (i, j, k, l, m) where i < j < k < l < m to avoid duplicates.
952        pub struct $struct_name<S, A, K, E, KE, F, W, Sc>
953        where
954            Sc: Score,
955        {
956            constraint_ref: ConstraintRef,
957            impact_type: ImpactType,
958            extractor: E,
959            key_extractor: KE,
960            filter: F,
961            weight: W,
962            is_hard: bool,
963            entity_to_matches: HashMap<usize, HashSet<(usize, usize, usize, usize, usize)>>,
964            matches: HashSet<(usize, usize, usize, usize, usize)>,
965            key_to_indices: HashMap<K, HashSet<usize>>,
966            index_to_key: HashMap<usize, K>,
967            _phantom: PhantomData<(S, A, Sc)>,
968        }
969
970        impl<S, A, K, E, KE, F, W, Sc> $struct_name<S, A, K, E, KE, F, W, Sc>
971        where
972            S: 'static,
973            A: Clone + 'static,
974            K: Eq + Hash + Clone,
975            E: Fn(&S) -> &[A],
976            KE: Fn(&A) -> K,
977            F: Fn(&S, &A, &A, &A, &A, &A) -> bool,
978            W: Fn(&A, &A, &A, &A, &A) -> Sc,
979            Sc: Score,
980        {
981            pub fn new(
982                constraint_ref: ConstraintRef,
983                impact_type: ImpactType,
984                extractor: E,
985                key_extractor: KE,
986                filter: F,
987                weight: W,
988                is_hard: bool,
989            ) -> Self {
990                Self {
991                    constraint_ref,
992                    impact_type,
993                    extractor,
994                    key_extractor,
995                    filter,
996                    weight,
997                    is_hard,
998                    entity_to_matches: HashMap::new(),
999                    matches: HashSet::new(),
1000                    key_to_indices: HashMap::new(),
1001                    index_to_key: HashMap::new(),
1002                    _phantom: PhantomData,
1003                }
1004            }
1005
1006            #[inline]
1007            fn compute_score(&self, a: &A, b: &A, c: &A, d: &A, e: &A) -> Sc {
1008                let base = (self.weight)(a, b, c, d, e);
1009                match self.impact_type {
1010                    ImpactType::Penalty => -base,
1011                    ImpactType::Reward => base,
1012                }
1013            }
1014
1015            fn insert_entity(&mut self, solution: &S, entities: &[A], index: usize) -> Sc {
1016                if index >= entities.len() {
1017                    return Sc::zero();
1018                }
1019
1020                let entity = &entities[index];
1021                let key = (self.key_extractor)(entity);
1022
1023                self.index_to_key.insert(index, key.clone());
1024                self.key_to_indices
1025                    .entry(key.clone())
1026                    .or_default()
1027                    .insert(index);
1028
1029                let key_to_indices = &self.key_to_indices;
1030                let matches = &mut self.matches;
1031                let entity_to_matches = &mut self.entity_to_matches;
1032                let filter = &self.filter;
1033                let weight = &self.weight;
1034                let impact_type = self.impact_type;
1035
1036                let mut total = Sc::zero();
1037                if let Some(others) = key_to_indices.get(&key) {
1038                    for &i in others {
1039                        if i == index {
1040                            continue;
1041                        }
1042                        for &j in others {
1043                            if j <= i || j == index {
1044                                continue;
1045                            }
1046                            for &k in others {
1047                                if k <= j || k == index {
1048                                    continue;
1049                                }
1050                                for &l in others {
1051                                    if l <= k || l == index {
1052                                        continue;
1053                                    }
1054
1055                                    let mut arr = [index, i, j, k, l];
1056                                    arr.sort();
1057                                    let [a_idx, b_idx, c_idx, d_idx, e_idx] = arr;
1058                                    let penta = (a_idx, b_idx, c_idx, d_idx, e_idx);
1059
1060                                    if matches.contains(&penta) {
1061                                        continue;
1062                                    }
1063
1064                                    let a = &entities[a_idx];
1065                                    let b = &entities[b_idx];
1066                                    let c = &entities[c_idx];
1067                                    let d = &entities[d_idx];
1068                                    let e = &entities[e_idx];
1069
1070                                    if filter(solution, a, b, c, d, e) && matches.insert(penta) {
1071                                        entity_to_matches.entry(a_idx).or_default().insert(penta);
1072                                        entity_to_matches.entry(b_idx).or_default().insert(penta);
1073                                        entity_to_matches.entry(c_idx).or_default().insert(penta);
1074                                        entity_to_matches.entry(d_idx).or_default().insert(penta);
1075                                        entity_to_matches.entry(e_idx).or_default().insert(penta);
1076                                        let base = weight(a, b, c, d, e);
1077                                        let score = match impact_type {
1078                                            ImpactType::Penalty => -base,
1079                                            ImpactType::Reward => base,
1080                                        };
1081                                        total = total + score;
1082                                    }
1083                                }
1084                            }
1085                        }
1086                    }
1087                }
1088
1089                total
1090            }
1091
1092            fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
1093                if let Some(key) = self.index_to_key.remove(&index) {
1094                    if let Some(indices) = self.key_to_indices.get_mut(&key) {
1095                        indices.remove(&index);
1096                        if indices.is_empty() {
1097                            self.key_to_indices.remove(&key);
1098                        }
1099                    }
1100                }
1101
1102                let Some(pentas) = self.entity_to_matches.remove(&index) else {
1103                    return Sc::zero();
1104                };
1105
1106                let mut total = Sc::zero();
1107                for penta in pentas {
1108                    self.matches.remove(&penta);
1109
1110                    let (i, j, k, l, m) = penta;
1111                    for &other in &[i, j, k, l, m] {
1112                        if other != index {
1113                            if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
1114                                other_set.remove(&penta);
1115                                if other_set.is_empty() {
1116                                    self.entity_to_matches.remove(&other);
1117                                }
1118                            }
1119                        }
1120                    }
1121
1122                    if i < entities.len()
1123                        && j < entities.len()
1124                        && k < entities.len()
1125                        && l < entities.len()
1126                        && m < entities.len()
1127                    {
1128                        let score = self.compute_score(
1129                            &entities[i],
1130                            &entities[j],
1131                            &entities[k],
1132                            &entities[l],
1133                            &entities[m],
1134                        );
1135                        total = total - score;
1136                    }
1137                }
1138
1139                total
1140            }
1141        }
1142
1143        impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
1144            for $struct_name<S, A, K, E, KE, F, W, Sc>
1145        where
1146            S: Send + Sync + 'static,
1147            A: Clone + Debug + Send + Sync + 'static,
1148            K: Eq + Hash + Clone + Send + Sync,
1149            E: Fn(&S) -> &[A] + Send + Sync,
1150            KE: Fn(&A) -> K + Send + Sync,
1151            F: Fn(&S, &A, &A, &A, &A, &A) -> bool + Send + Sync,
1152            W: Fn(&A, &A, &A, &A, &A) -> Sc + Send + Sync,
1153            Sc: Score,
1154        {
1155            fn evaluate(&self, solution: &S) -> Sc {
1156                let entities = (self.extractor)(solution);
1157                let mut total = Sc::zero();
1158
1159                let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
1160                for (i, entity) in entities.iter().enumerate() {
1161                    let key = (self.key_extractor)(entity);
1162                    temp_index.entry(key).or_default().push(i);
1163                }
1164
1165                for indices in temp_index.values() {
1166                    for pos_i in 0..indices.len() {
1167                        for pos_j in (pos_i + 1)..indices.len() {
1168                            for pos_k in (pos_j + 1)..indices.len() {
1169                                for pos_l in (pos_k + 1)..indices.len() {
1170                                    for pos_m in (pos_l + 1)..indices.len() {
1171                                        let i = indices[pos_i];
1172                                        let j = indices[pos_j];
1173                                        let k = indices[pos_k];
1174                                        let l = indices[pos_l];
1175                                        let m = indices[pos_m];
1176                                        let a = &entities[i];
1177                                        let b = &entities[j];
1178                                        let c = &entities[k];
1179                                        let d = &entities[l];
1180                                        let e = &entities[m];
1181                                        if (self.filter)(solution, a, b, c, d, e) {
1182                                            total = total + self.compute_score(a, b, c, d, e);
1183                                        }
1184                                    }
1185                                }
1186                            }
1187                        }
1188                    }
1189                }
1190
1191                total
1192            }
1193
1194            fn match_count(&self, solution: &S) -> usize {
1195                let entities = (self.extractor)(solution);
1196                let mut count = 0;
1197
1198                let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
1199                for (i, entity) in entities.iter().enumerate() {
1200                    let key = (self.key_extractor)(entity);
1201                    temp_index.entry(key).or_default().push(i);
1202                }
1203
1204                for indices in temp_index.values() {
1205                    for pos_i in 0..indices.len() {
1206                        for pos_j in (pos_i + 1)..indices.len() {
1207                            for pos_k in (pos_j + 1)..indices.len() {
1208                                for pos_l in (pos_k + 1)..indices.len() {
1209                                    for pos_m in (pos_l + 1)..indices.len() {
1210                                        let i = indices[pos_i];
1211                                        let j = indices[pos_j];
1212                                        let k = indices[pos_k];
1213                                        let l = indices[pos_l];
1214                                        let m = indices[pos_m];
1215                                        if (self.filter)(
1216                                            solution,
1217                                            &entities[i],
1218                                            &entities[j],
1219                                            &entities[k],
1220                                            &entities[l],
1221                                            &entities[m],
1222                                        ) {
1223                                            count += 1;
1224                                        }
1225                                    }
1226                                }
1227                            }
1228                        }
1229                    }
1230                }
1231
1232                count
1233            }
1234
1235            fn initialize(&mut self, solution: &S) -> Sc {
1236                self.reset();
1237                let entities = (self.extractor)(solution);
1238                let mut total = Sc::zero();
1239                for i in 0..entities.len() {
1240                    total = total + self.insert_entity(solution, entities, i);
1241                }
1242                total
1243            }
1244
1245            fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
1246                let entities = (self.extractor)(solution);
1247                self.insert_entity(solution, entities, entity_index)
1248            }
1249
1250            fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
1251                let entities = (self.extractor)(solution);
1252                self.retract_entity(entities, entity_index)
1253            }
1254
1255            fn reset(&mut self) {
1256                self.entity_to_matches.clear();
1257                self.matches.clear();
1258                self.key_to_indices.clear();
1259                self.index_to_key.clear();
1260            }
1261
1262            fn name(&self) -> &str {
1263                &self.constraint_ref.name
1264            }
1265
1266            fn is_hard(&self) -> bool {
1267                self.is_hard
1268            }
1269
1270            fn constraint_ref(&self) -> ConstraintRef {
1271                self.constraint_ref.clone()
1272            }
1273
1274            fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
1275                $crate::impl_get_matches_nary!(penta: self, solution)
1276            }
1277        }
1278
1279        impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
1280            for $struct_name<S, A, K, E, KE, F, W, Sc>
1281        {
1282            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1283                f.debug_struct(stringify!($struct_name))
1284                    .field("name", &self.constraint_ref.name)
1285                    .field("impact_type", &self.impact_type)
1286                    .field("match_count", &self.matches.len())
1287                    .finish()
1288            }
1289        }
1290    };
1291}
1292
1293pub use impl_incremental_nary_constraint;
1294
1295// Generate the N-ary constraint types
1296use std::collections::{HashMap, HashSet};
1297use std::fmt::Debug;
1298use std::hash::Hash;
1299use std::marker::PhantomData;
1300
1301use solverforge_core::score::Score;
1302use solverforge_core::{ConstraintRef, ImpactType};
1303
1304use crate::api::analysis::DetailedConstraintMatch;
1305use crate::api::constraint_set::IncrementalConstraint;
1306
1307impl_incremental_nary_constraint!(bi, IncrementalBiConstraint);
1308impl_incremental_nary_constraint!(tri, IncrementalTriConstraint);
1309impl_incremental_nary_constraint!(quad, IncrementalQuadConstraint);
1310impl_incremental_nary_constraint!(penta, IncrementalPentaConstraint);