solverforge_scoring/constraint/
penta_incremental.rs

1//! Zero-erasure incremental penta-constraint for self-join quintuple evaluation.
2//!
3//! All function types are concrete generics - no trait objects, no Arc.
4//! Uses key-based indexing: entities are grouped by join key for O(k) lookups.
5
6use std::collections::{HashMap, HashSet};
7use std::fmt::Debug;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::analysis::DetailedConstraintMatch;
15use crate::api::constraint_set::IncrementalConstraint;
16
17/// Ordered quintuple indices (i, j, k, l, m) where i < j < k < l < m.
18type Quintuple = (usize, usize, usize, usize, usize);
19
20/// Zero-erasure incremental penta-constraint for self-joins.
21///
22/// All function types are concrete generics - no Arc, no dyn, fully monomorphized.
23/// Uses key-based indexing: entities are grouped by join key for O(k) lookups
24/// where k is the number of entities sharing the same key.
25///
26/// Quintuples are ordered as (i, j, k, l, m) where i < j < k < l < m to avoid duplicates.
27///
28/// # Type Parameters
29///
30/// - `S` - Solution type
31/// - `A` - Entity type
32/// - `K` - Key type for grouping (entities with same key form quintuples)
33/// - `E` - Extractor function `Fn(&S) -> &[A]`
34/// - `KE` - Key extractor function `Fn(&A) -> K`
35/// - `F` - Filter function `Fn(&A, &A, &A, &A, &A) -> bool`
36/// - `W` - Weight function `Fn(&A, &A, &A, &A, &A) -> Sc`
37/// - `Sc` - Score type
38///
39/// # Example
40///
41/// ```
42/// use solverforge_scoring::constraint::penta_incremental::IncrementalPentaConstraint;
43/// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
44/// use solverforge_core::{ConstraintRef, ImpactType};
45/// use solverforge_core::score::SimpleScore;
46///
47/// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
48/// struct Task { team: u32 }
49///
50/// #[derive(Clone)]
51/// struct Solution { tasks: Vec<Task> }
52///
53/// // Penalize when five tasks are on the same team
54/// let constraint = IncrementalPentaConstraint::new(
55///     ConstraintRef::new("", "Team clustering"),
56///     ImpactType::Penalty,
57///     |s: &Solution| s.tasks.as_slice(),
58///     |t: &Task| t.team,  // Group by team
59///     |_a: &Task, _b: &Task, _c: &Task, _d: &Task, _e: &Task| true,
60///     |_a: &Task, _b: &Task, _c: &Task, _d: &Task, _e: &Task| SimpleScore::of(1),
61///     false,
62/// );
63///
64/// let solution = Solution {
65///     tasks: vec![
66///         Task { team: 1 },
67///         Task { team: 1 },
68///         Task { team: 1 },
69///         Task { team: 1 },
70///         Task { team: 1 },
71///         Task { team: 2 },
72///     ],
73/// };
74///
75/// // One quintuple: (0, 1, 2, 3, 4) all on team 1 = -1 penalty
76/// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
77/// ```
78pub struct IncrementalPentaConstraint<S, A, K, E, KE, F, W, Sc>
79where
80    Sc: Score,
81{
82    constraint_ref: ConstraintRef,
83    impact_type: ImpactType,
84    extractor: E,
85    key_extractor: KE,
86    filter: F,
87    weight: W,
88    is_hard: bool,
89    /// entity_index -> set of quintuples involving this entity
90    entity_to_matches: HashMap<usize, HashSet<Quintuple>>,
91    /// All matched quintuples (i, j, k, l, m) where i < j < k < l < m
92    matches: HashSet<Quintuple>,
93    /// Key -> set of entity indices with that key (for O(k) lookup)
94    key_to_indices: HashMap<K, HashSet<usize>>,
95    /// entity_index -> key (for cleanup on retract)
96    index_to_key: HashMap<usize, K>,
97    _phantom: PhantomData<(S, A, Sc)>,
98}
99
100impl<S, A, K, E, KE, F, W, Sc> IncrementalPentaConstraint<S, A, K, E, KE, F, W, Sc>
101where
102    S: 'static,
103    A: Clone + 'static,
104    K: Eq + Hash + Clone,
105    E: Fn(&S) -> &[A],
106    KE: Fn(&A) -> K,
107    F: Fn(&A, &A, &A, &A, &A) -> bool,
108    W: Fn(&A, &A, &A, &A, &A) -> Sc,
109    Sc: Score,
110{
111    /// Creates a new zero-erasure incremental penta-constraint.
112    pub fn new(
113        constraint_ref: ConstraintRef,
114        impact_type: ImpactType,
115        extractor: E,
116        key_extractor: KE,
117        filter: F,
118        weight: W,
119        is_hard: bool,
120    ) -> Self {
121        Self {
122            constraint_ref,
123            impact_type,
124            extractor,
125            key_extractor,
126            filter,
127            weight,
128            is_hard,
129            entity_to_matches: HashMap::new(),
130            matches: HashSet::new(),
131            key_to_indices: HashMap::new(),
132            index_to_key: HashMap::new(),
133            _phantom: PhantomData,
134        }
135    }
136
137    #[inline]
138    fn compute_score(&self, a: &A, b: &A, c: &A, d: &A, e: &A) -> Sc {
139        let base = (self.weight)(a, b, c, d, e);
140        match self.impact_type {
141            ImpactType::Penalty => -base,
142            ImpactType::Reward => base,
143        }
144    }
145
146    /// Insert entity and find matches with other entity quads sharing the same key.
147    fn insert_entity(&mut self, entities: &[A], index: usize) -> Sc {
148        if index >= entities.len() {
149            return Sc::zero();
150        }
151
152        let entity = &entities[index];
153        let key = (self.key_extractor)(entity);
154
155        // Track this entity's key
156        self.index_to_key.insert(index, key.clone());
157
158        // Add this entity to the key index FIRST
159        self.key_to_indices
160            .entry(key.clone())
161            .or_default()
162            .insert(index);
163
164        // Split borrows to allow simultaneous read of key_to_indices and mutation of matches
165        let key_to_indices = &self.key_to_indices;
166        let matches = &mut self.matches;
167        let entity_to_matches = &mut self.entity_to_matches;
168        let filter = &self.filter;
169        let weight = &self.weight;
170        let impact_type = self.impact_type;
171
172        // Find matches with all quads of other entities having the same key (zero allocation)
173        let mut total = Sc::zero();
174        if let Some(others) = key_to_indices.get(&key) {
175            // Iterate over all quads (i, j, k, l) where i < j < k < l, excluding current index
176            for &i in others {
177                if i == index {
178                    continue;
179                }
180                for &j in others {
181                    if j <= i || j == index {
182                        continue;
183                    }
184                    for &k in others {
185                        if k <= j || k == index {
186                            continue;
187                        }
188                        for &l in others {
189                            if l <= k || l == index {
190                                continue;
191                            }
192
193                            // Determine canonical ordering for this quintuple
194                            let mut indices = [index, i, j, k, l];
195                            indices.sort();
196                            let [a_idx, b_idx, c_idx, d_idx, e_idx] = indices;
197
198                            let penta = (a_idx, b_idx, c_idx, d_idx, e_idx);
199
200                            // Skip if already matched
201                            if matches.contains(&penta) {
202                                continue;
203                            }
204
205                            let a = &entities[a_idx];
206                            let b = &entities[b_idx];
207                            let c = &entities[c_idx];
208                            let d = &entities[d_idx];
209                            let e = &entities[e_idx];
210
211                            if filter(a, b, c, d, e) && matches.insert(penta) {
212                                entity_to_matches.entry(a_idx).or_default().insert(penta);
213                                entity_to_matches.entry(b_idx).or_default().insert(penta);
214                                entity_to_matches.entry(c_idx).or_default().insert(penta);
215                                entity_to_matches.entry(d_idx).or_default().insert(penta);
216                                entity_to_matches.entry(e_idx).or_default().insert(penta);
217                                let base = weight(a, b, c, d, e);
218                                let score = match impact_type {
219                                    ImpactType::Penalty => -base,
220                                    ImpactType::Reward => base,
221                                };
222                                total = total + score;
223                            }
224                        }
225                    }
226                }
227            }
228        }
229
230        total
231    }
232
233    /// Retract entity and remove all its matches.
234    fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
235        // Remove from key index
236        if let Some(key) = self.index_to_key.remove(&index) {
237            if let Some(indices) = self.key_to_indices.get_mut(&key) {
238                indices.remove(&index);
239                if indices.is_empty() {
240                    self.key_to_indices.remove(&key);
241                }
242            }
243        }
244
245        // Remove all matches involving this entity
246        let Some(pentas) = self.entity_to_matches.remove(&index) else {
247            return Sc::zero();
248        };
249
250        let mut total = Sc::zero();
251        for penta in pentas {
252            self.matches.remove(&penta);
253
254            // Remove from other entities' match sets
255            let (i, j, k, l, m) = penta;
256            for &other in &[i, j, k, l, m] {
257                if other != index {
258                    if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
259                        other_set.remove(&penta);
260                        if other_set.is_empty() {
261                            self.entity_to_matches.remove(&other);
262                        }
263                    }
264                }
265            }
266
267            // Compute reverse delta
268            if i < entities.len()
269                && j < entities.len()
270                && k < entities.len()
271                && l < entities.len()
272                && m < entities.len()
273            {
274                let score = self.compute_score(
275                    &entities[i],
276                    &entities[j],
277                    &entities[k],
278                    &entities[l],
279                    &entities[m],
280                );
281                total = total - score;
282            }
283        }
284
285        total
286    }
287}
288
289impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
290    for IncrementalPentaConstraint<S, A, K, E, KE, F, W, Sc>
291where
292    S: Send + Sync + 'static,
293    A: Clone + Debug + Send + Sync + 'static,
294    K: Eq + Hash + Clone + Send + Sync,
295    E: Fn(&S) -> &[A] + Send + Sync,
296    KE: Fn(&A) -> K + Send + Sync,
297    F: Fn(&A, &A, &A, &A, &A) -> bool + Send + Sync,
298    W: Fn(&A, &A, &A, &A, &A) -> Sc + Send + Sync,
299    Sc: Score,
300{
301    fn evaluate(&self, solution: &S) -> Sc {
302        let entities = (self.extractor)(solution);
303        let mut total = Sc::zero();
304
305        // Build temporary key index for evaluation
306        let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
307        for (i, entity) in entities.iter().enumerate() {
308            let key = (self.key_extractor)(entity);
309            temp_index.entry(key).or_default().push(i);
310        }
311
312        // Evaluate quintuples within each key group
313        for indices in temp_index.values() {
314            for pos_i in 0..indices.len() {
315                for pos_j in (pos_i + 1)..indices.len() {
316                    for pos_k in (pos_j + 1)..indices.len() {
317                        for pos_l in (pos_k + 1)..indices.len() {
318                            for pos_m in (pos_l + 1)..indices.len() {
319                                let i = indices[pos_i];
320                                let j = indices[pos_j];
321                                let k = indices[pos_k];
322                                let l = indices[pos_l];
323                                let m = indices[pos_m];
324                                let a = &entities[i];
325                                let b = &entities[j];
326                                let c = &entities[k];
327                                let d = &entities[l];
328                                let e = &entities[m];
329                                if (self.filter)(a, b, c, d, e) {
330                                    total = total + self.compute_score(a, b, c, d, e);
331                                }
332                            }
333                        }
334                    }
335                }
336            }
337        }
338
339        total
340    }
341
342    fn match_count(&self, solution: &S) -> usize {
343        let entities = (self.extractor)(solution);
344        let mut count = 0;
345
346        // Build temporary key index
347        let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
348        for (i, entity) in entities.iter().enumerate() {
349            let key = (self.key_extractor)(entity);
350            temp_index.entry(key).or_default().push(i);
351        }
352
353        // Count matches within each key group
354        for indices in temp_index.values() {
355            for pos_i in 0..indices.len() {
356                for pos_j in (pos_i + 1)..indices.len() {
357                    for pos_k in (pos_j + 1)..indices.len() {
358                        for pos_l in (pos_k + 1)..indices.len() {
359                            for pos_m in (pos_l + 1)..indices.len() {
360                                let i = indices[pos_i];
361                                let j = indices[pos_j];
362                                let k = indices[pos_k];
363                                let l = indices[pos_l];
364                                let m = indices[pos_m];
365                                if (self.filter)(
366                                    &entities[i],
367                                    &entities[j],
368                                    &entities[k],
369                                    &entities[l],
370                                    &entities[m],
371                                ) {
372                                    count += 1;
373                                }
374                            }
375                        }
376                    }
377                }
378            }
379        }
380
381        count
382    }
383
384    fn initialize(&mut self, solution: &S) -> Sc {
385        self.reset();
386
387        let entities = (self.extractor)(solution);
388        let mut total = Sc::zero();
389        for i in 0..entities.len() {
390            total = total + self.insert_entity(entities, i);
391        }
392        total
393    }
394
395    fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
396        let entities = (self.extractor)(solution);
397        self.insert_entity(entities, entity_index)
398    }
399
400    fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
401        let entities = (self.extractor)(solution);
402        self.retract_entity(entities, entity_index)
403    }
404
405    fn reset(&mut self) {
406        self.entity_to_matches.clear();
407        self.matches.clear();
408        self.key_to_indices.clear();
409        self.index_to_key.clear();
410    }
411
412    fn name(&self) -> &str {
413        &self.constraint_ref.name
414    }
415
416    fn is_hard(&self) -> bool {
417        self.is_hard
418    }
419
420    fn constraint_ref(&self) -> ConstraintRef {
421        self.constraint_ref.clone()
422    }
423
424    fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
425        impl_get_matches_nary!(penta: self, solution)
426    }
427}
428
429impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
430    for IncrementalPentaConstraint<S, A, K, E, KE, F, W, Sc>
431{
432    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433        f.debug_struct("IncrementalPentaConstraint")
434            .field("name", &self.constraint_ref.name)
435            .field("impact_type", &self.impact_type)
436            .field("match_count", &self.matches.len())
437            .finish()
438    }
439}