solverforge_scoring/constraint/
quad_incremental.rs

1//! Zero-erasure incremental quad-constraint for self-join quadruple evaluation.
2//!
3//! All function types are concrete generics - no trait objects, no Arc.
4//! Uses key-based indexing: entities are grouped by join key for O(k) lookups.
5
6use std::collections::{HashMap, HashSet};
7use std::fmt::Debug;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::analysis::DetailedConstraintMatch;
15use crate::api::constraint_set::IncrementalConstraint;
16
17/// Zero-erasure incremental quad-constraint for self-joins.
18///
19/// All function types are concrete generics - no Arc, no dyn, fully monomorphized.
20/// Uses key-based indexing: entities are grouped by join key for O(k) lookups
21/// where k is the number of entities sharing the same key.
22///
23/// Quadruples are ordered as (i, j, k, l) where i < j < k < l to avoid duplicates.
24///
25/// # Type Parameters
26///
27/// - `S` - Solution type
28/// - `A` - Entity type
29/// - `K` - Key type for grouping (entities with same key form quadruples)
30/// - `E` - Extractor function `Fn(&S) -> &[A]`
31/// - `KE` - Key extractor function `Fn(&A) -> K`
32/// - `F` - Filter function `Fn(&A, &A, &A, &A) -> bool`
33/// - `W` - Weight function `Fn(&A, &A, &A, &A) -> Sc`
34/// - `Sc` - Score type
35///
36/// # Example
37///
38/// ```
39/// use solverforge_scoring::constraint::quad_incremental::IncrementalQuadConstraint;
40/// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
41/// use solverforge_core::{ConstraintRef, ImpactType};
42/// use solverforge_core::score::SimpleScore;
43///
44/// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
45/// struct Task { team: u32 }
46///
47/// #[derive(Clone)]
48/// struct Solution { tasks: Vec<Task> }
49///
50/// // Penalize when four tasks are on the same team
51/// let constraint = IncrementalQuadConstraint::new(
52///     ConstraintRef::new("", "Team clustering"),
53///     ImpactType::Penalty,
54///     |s: &Solution| s.tasks.as_slice(),
55///     |t: &Task| t.team,  // Group by team
56///     |_a: &Task, _b: &Task, _c: &Task, _d: &Task| true,  // All quads in same group match
57///     |_a: &Task, _b: &Task, _c: &Task, _d: &Task| SimpleScore::of(1),
58///     false,
59/// );
60///
61/// let solution = Solution {
62///     tasks: vec![
63///         Task { team: 1 },
64///         Task { team: 1 },
65///         Task { team: 1 },
66///         Task { team: 1 },
67///         Task { team: 2 },
68///     ],
69/// };
70///
71/// // One quadruple: (0, 1, 2, 3) all on team 1 = -1 penalty
72/// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
73/// ```
74pub struct IncrementalQuadConstraint<S, A, K, E, KE, F, W, Sc>
75where
76    Sc: Score,
77{
78    constraint_ref: ConstraintRef,
79    impact_type: ImpactType,
80    extractor: E,
81    key_extractor: KE,
82    filter: F,
83    weight: W,
84    is_hard: bool,
85    /// entity_index -> set of (i, j, k, l) quadruples involving this entity
86    entity_to_matches: HashMap<usize, HashSet<(usize, usize, usize, usize)>>,
87    /// All matched quadruples (i, j, k, l) where i < j < k < l
88    matches: HashSet<(usize, usize, usize, usize)>,
89    /// Key -> set of entity indices with that key (for O(k) lookup)
90    key_to_indices: HashMap<K, HashSet<usize>>,
91    /// entity_index -> key (for cleanup on retract)
92    index_to_key: HashMap<usize, K>,
93    _phantom: PhantomData<(S, A, Sc)>,
94}
95
96impl<S, A, K, E, KE, F, W, Sc> IncrementalQuadConstraint<S, A, K, E, KE, F, W, Sc>
97where
98    S: 'static,
99    A: Clone + 'static,
100    K: Eq + Hash + Clone,
101    E: Fn(&S) -> &[A],
102    KE: Fn(&A) -> K,
103    F: Fn(&A, &A, &A, &A) -> bool,
104    W: Fn(&A, &A, &A, &A) -> Sc,
105    Sc: Score,
106{
107    /// Creates a new zero-erasure incremental quad-constraint.
108    pub fn new(
109        constraint_ref: ConstraintRef,
110        impact_type: ImpactType,
111        extractor: E,
112        key_extractor: KE,
113        filter: F,
114        weight: W,
115        is_hard: bool,
116    ) -> Self {
117        Self {
118            constraint_ref,
119            impact_type,
120            extractor,
121            key_extractor,
122            filter,
123            weight,
124            is_hard,
125            entity_to_matches: HashMap::new(),
126            matches: HashSet::new(),
127            key_to_indices: HashMap::new(),
128            index_to_key: HashMap::new(),
129            _phantom: PhantomData,
130        }
131    }
132
133    #[inline]
134    fn compute_score(&self, a: &A, b: &A, c: &A, d: &A) -> Sc {
135        let base = (self.weight)(a, b, c, d);
136        match self.impact_type {
137            ImpactType::Penalty => -base,
138            ImpactType::Reward => base,
139        }
140    }
141
142    /// Insert entity and find matches with other entity triples sharing the same key.
143    fn insert_entity(&mut self, entities: &[A], index: usize) -> Sc {
144        if index >= entities.len() {
145            return Sc::zero();
146        }
147
148        let entity = &entities[index];
149        let key = (self.key_extractor)(entity);
150
151        // Track this entity's key
152        self.index_to_key.insert(index, key.clone());
153
154        // Add this entity to the key index FIRST
155        self.key_to_indices
156            .entry(key.clone())
157            .or_default()
158            .insert(index);
159
160        // Split borrows to allow simultaneous read of key_to_indices and mutation of matches
161        let key_to_indices = &self.key_to_indices;
162        let matches = &mut self.matches;
163        let entity_to_matches = &mut self.entity_to_matches;
164        let filter = &self.filter;
165        let weight = &self.weight;
166        let impact_type = self.impact_type;
167
168        // Find matches with all triples of other entities having the same key (zero allocation)
169        let mut total = Sc::zero();
170        if let Some(others) = key_to_indices.get(&key) {
171            // Iterate over all triples (i, j, k) where i < j < k, excluding current index
172            for &i in others {
173                if i == index {
174                    continue;
175                }
176                for &j in others {
177                    if j <= i || j == index {
178                        continue;
179                    }
180                    for &k in others {
181                        if k <= j || k == index {
182                            continue;
183                        }
184
185                        // Determine canonical ordering for this quadruple
186                        let mut indices = [index, i, j, k];
187                        indices.sort();
188                        let [a_idx, b_idx, c_idx, d_idx] = indices;
189
190                        let quad = (a_idx, b_idx, c_idx, d_idx);
191
192                        // Skip if already matched
193                        if matches.contains(&quad) {
194                            continue;
195                        }
196
197                        let a = &entities[a_idx];
198                        let b = &entities[b_idx];
199                        let c = &entities[c_idx];
200                        let d = &entities[d_idx];
201
202                        if filter(a, b, c, d) && matches.insert(quad) {
203                            entity_to_matches.entry(a_idx).or_default().insert(quad);
204                            entity_to_matches.entry(b_idx).or_default().insert(quad);
205                            entity_to_matches.entry(c_idx).or_default().insert(quad);
206                            entity_to_matches.entry(d_idx).or_default().insert(quad);
207                            let base = weight(a, b, c, d);
208                            let score = match impact_type {
209                                ImpactType::Penalty => -base,
210                                ImpactType::Reward => base,
211                            };
212                            total = total + score;
213                        }
214                    }
215                }
216            }
217        }
218
219        total
220    }
221
222    /// Retract entity and remove all its matches.
223    fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
224        // Remove from key index
225        if let Some(key) = self.index_to_key.remove(&index) {
226            if let Some(indices) = self.key_to_indices.get_mut(&key) {
227                indices.remove(&index);
228                if indices.is_empty() {
229                    self.key_to_indices.remove(&key);
230                }
231            }
232        }
233
234        // Remove all matches involving this entity
235        let Some(quads) = self.entity_to_matches.remove(&index) else {
236            return Sc::zero();
237        };
238
239        let mut total = Sc::zero();
240        for quad in quads {
241            self.matches.remove(&quad);
242
243            // Remove from other entities' match sets
244            let (i, j, k, l) = quad;
245            for &other in &[i, j, k, l] {
246                if other != index {
247                    if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
248                        other_set.remove(&quad);
249                        if other_set.is_empty() {
250                            self.entity_to_matches.remove(&other);
251                        }
252                    }
253                }
254            }
255
256            // Compute reverse delta
257            if i < entities.len() && j < entities.len() && k < entities.len() && l < entities.len()
258            {
259                let score =
260                    self.compute_score(&entities[i], &entities[j], &entities[k], &entities[l]);
261                total = total - score;
262            }
263        }
264
265        total
266    }
267}
268
269impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
270    for IncrementalQuadConstraint<S, A, K, E, KE, F, W, Sc>
271where
272    S: Send + Sync + 'static,
273    A: Clone + Debug + Send + Sync + 'static,
274    K: Eq + Hash + Clone + Send + Sync,
275    E: Fn(&S) -> &[A] + Send + Sync,
276    KE: Fn(&A) -> K + Send + Sync,
277    F: Fn(&A, &A, &A, &A) -> bool + Send + Sync,
278    W: Fn(&A, &A, &A, &A) -> Sc + Send + Sync,
279    Sc: Score,
280{
281    fn evaluate(&self, solution: &S) -> Sc {
282        let entities = (self.extractor)(solution);
283        let mut total = Sc::zero();
284
285        // Build temporary key index for evaluation
286        let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
287        for (i, entity) in entities.iter().enumerate() {
288            let key = (self.key_extractor)(entity);
289            temp_index.entry(key).or_default().push(i);
290        }
291
292        // Evaluate quadruples within each key group
293        for indices in temp_index.values() {
294            for pos_i in 0..indices.len() {
295                for pos_j in (pos_i + 1)..indices.len() {
296                    for pos_k in (pos_j + 1)..indices.len() {
297                        for pos_l in (pos_k + 1)..indices.len() {
298                            let i = indices[pos_i];
299                            let j = indices[pos_j];
300                            let k = indices[pos_k];
301                            let l = indices[pos_l];
302                            let a = &entities[i];
303                            let b = &entities[j];
304                            let c = &entities[k];
305                            let d = &entities[l];
306                            if (self.filter)(a, b, c, d) {
307                                total = total + self.compute_score(a, b, c, d);
308                            }
309                        }
310                    }
311                }
312            }
313        }
314
315        total
316    }
317
318    fn match_count(&self, solution: &S) -> usize {
319        let entities = (self.extractor)(solution);
320        let mut count = 0;
321
322        // Build temporary key index
323        let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
324        for (i, entity) in entities.iter().enumerate() {
325            let key = (self.key_extractor)(entity);
326            temp_index.entry(key).or_default().push(i);
327        }
328
329        // Count matches within each key group
330        for indices in temp_index.values() {
331            for pos_i in 0..indices.len() {
332                for pos_j in (pos_i + 1)..indices.len() {
333                    for pos_k in (pos_j + 1)..indices.len() {
334                        for pos_l in (pos_k + 1)..indices.len() {
335                            let i = indices[pos_i];
336                            let j = indices[pos_j];
337                            let k = indices[pos_k];
338                            let l = indices[pos_l];
339                            if (self.filter)(&entities[i], &entities[j], &entities[k], &entities[l])
340                            {
341                                count += 1;
342                            }
343                        }
344                    }
345                }
346            }
347        }
348
349        count
350    }
351
352    fn initialize(&mut self, solution: &S) -> Sc {
353        self.reset();
354
355        let entities = (self.extractor)(solution);
356        let mut total = Sc::zero();
357        for i in 0..entities.len() {
358            total = total + self.insert_entity(entities, i);
359        }
360        total
361    }
362
363    fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
364        let entities = (self.extractor)(solution);
365        self.insert_entity(entities, entity_index)
366    }
367
368    fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
369        let entities = (self.extractor)(solution);
370        self.retract_entity(entities, entity_index)
371    }
372
373    fn reset(&mut self) {
374        self.entity_to_matches.clear();
375        self.matches.clear();
376        self.key_to_indices.clear();
377        self.index_to_key.clear();
378    }
379
380    fn name(&self) -> &str {
381        &self.constraint_ref.name
382    }
383
384    fn is_hard(&self) -> bool {
385        self.is_hard
386    }
387
388    fn constraint_ref(&self) -> ConstraintRef {
389        self.constraint_ref.clone()
390    }
391
392    fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
393        impl_get_matches_nary!(quad: self, solution)
394    }
395}
396
397impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
398    for IncrementalQuadConstraint<S, A, K, E, KE, F, W, Sc>
399{
400    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401        f.debug_struct("IncrementalQuadConstraint")
402            .field("name", &self.constraint_ref.name)
403            .field("impact_type", &self.impact_type)
404            .field("match_count", &self.matches.len())
405            .finish()
406    }
407}