solverforge_scoring/constraint/
tri_incremental.rs

1//! Zero-erasure incremental tri-constraint for self-join triple evaluation.
2//!
3//! All function types are concrete generics - no trait objects, no Arc.
4//! Uses key-based indexing: entities are grouped by join key for O(k) lookups.
5
6use std::collections::{HashMap, HashSet};
7use std::fmt::Debug;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::analysis::DetailedConstraintMatch;
15use crate::api::constraint_set::IncrementalConstraint;
16
17/// Zero-erasure incremental tri-constraint for self-joins.
18///
19/// All function types are concrete generics - no Arc, no dyn, fully monomorphized.
20/// Uses key-based indexing: entities are grouped by join key for O(k) lookups
21/// where k is the number of entities sharing the same key.
22///
23/// Triples are ordered as (i, j, k) where i < j < k to avoid duplicates.
24///
25/// # Type Parameters
26///
27/// - `S` - Solution type
28/// - `A` - Entity type
29/// - `K` - Key type for grouping (entities with same key form triples)
30/// - `E` - Extractor function `Fn(&S) -> &[A]`
31/// - `KE` - Key extractor function `Fn(&A) -> K`
32/// - `F` - Filter function `Fn(&A, &A, &A) -> bool`
33/// - `W` - Weight function `Fn(&A, &A, &A) -> Sc`
34/// - `Sc` - Score type
35///
36/// # Example
37///
38/// ```
39/// use solverforge_scoring::constraint::tri_incremental::IncrementalTriConstraint;
40/// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
41/// use solverforge_core::{ConstraintRef, ImpactType};
42/// use solverforge_core::score::SimpleScore;
43///
44/// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
45/// struct Task { team: u32 }
46///
47/// #[derive(Clone)]
48/// struct Solution { tasks: Vec<Task> }
49///
50/// // Penalize when three tasks are on the same team
51/// let constraint = IncrementalTriConstraint::new(
52///     ConstraintRef::new("", "Team clustering"),
53///     ImpactType::Penalty,
54///     |s: &Solution| s.tasks.as_slice(),
55///     |t: &Task| t.team,  // Group by team
56///     |_a: &Task, _b: &Task, _c: &Task| true,  // All triples in same group match
57///     |_a: &Task, _b: &Task, _c: &Task| SimpleScore::of(1),
58///     false,
59/// );
60///
61/// let solution = Solution {
62///     tasks: vec![
63///         Task { team: 1 },
64///         Task { team: 1 },
65///         Task { team: 1 },
66///         Task { team: 2 },
67///     ],
68/// };
69///
70/// // One triple: (0, 1, 2) all on team 1 = -1 penalty
71/// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
72/// ```
73pub struct IncrementalTriConstraint<S, A, K, E, KE, F, W, Sc>
74where
75    Sc: Score,
76{
77    constraint_ref: ConstraintRef,
78    impact_type: ImpactType,
79    extractor: E,
80    key_extractor: KE,
81    filter: F,
82    weight: W,
83    is_hard: bool,
84    /// entity_index -> set of (i, j, k) triples involving this entity
85    entity_to_matches: HashMap<usize, HashSet<(usize, usize, usize)>>,
86    /// All matched triples (i, j, k) where i < j < k
87    matches: HashSet<(usize, usize, usize)>,
88    /// Key -> set of entity indices with that key (for O(k) lookup)
89    key_to_indices: HashMap<K, HashSet<usize>>,
90    /// entity_index -> key (for cleanup on retract)
91    index_to_key: HashMap<usize, K>,
92    _phantom: PhantomData<(S, A, Sc)>,
93}
94
95impl<S, A, K, E, KE, F, W, Sc> IncrementalTriConstraint<S, A, K, E, KE, F, W, Sc>
96where
97    S: 'static,
98    A: Clone + 'static,
99    K: Eq + Hash + Clone,
100    E: Fn(&S) -> &[A],
101    KE: Fn(&A) -> K,
102    F: Fn(&A, &A, &A) -> bool,
103    W: Fn(&A, &A, &A) -> Sc,
104    Sc: Score,
105{
106    /// Creates a new zero-erasure incremental tri-constraint.
107    pub fn new(
108        constraint_ref: ConstraintRef,
109        impact_type: ImpactType,
110        extractor: E,
111        key_extractor: KE,
112        filter: F,
113        weight: W,
114        is_hard: bool,
115    ) -> Self {
116        Self {
117            constraint_ref,
118            impact_type,
119            extractor,
120            key_extractor,
121            filter,
122            weight,
123            is_hard,
124            entity_to_matches: HashMap::new(),
125            matches: HashSet::new(),
126            key_to_indices: HashMap::new(),
127            index_to_key: HashMap::new(),
128            _phantom: PhantomData,
129        }
130    }
131
132    #[inline]
133    fn compute_score(&self, a: &A, b: &A, c: &A) -> Sc {
134        let base = (self.weight)(a, b, c);
135        match self.impact_type {
136            ImpactType::Penalty => -base,
137            ImpactType::Reward => base,
138        }
139    }
140
141    /// Insert entity and find matches with other entity pairs sharing the same key.
142    fn insert_entity(&mut self, entities: &[A], index: usize) -> Sc {
143        if index >= entities.len() {
144            return Sc::zero();
145        }
146
147        let entity = &entities[index];
148        let key = (self.key_extractor)(entity);
149
150        // Track this entity's key
151        self.index_to_key.insert(index, key.clone());
152
153        // Add this entity to the key index FIRST
154        self.key_to_indices
155            .entry(key.clone())
156            .or_default()
157            .insert(index);
158
159        // Split borrows to allow simultaneous read of key_to_indices and mutation of matches
160        let key_to_indices = &self.key_to_indices;
161        let matches = &mut self.matches;
162        let entity_to_matches = &mut self.entity_to_matches;
163        let filter = &self.filter;
164        let weight = &self.weight;
165        let impact_type = self.impact_type;
166
167        // Find matches with all pairs of other entities having the same key (zero allocation)
168        let mut total = Sc::zero();
169        if let Some(others) = key_to_indices.get(&key) {
170            // Iterate over all pairs (i, j) where i < j, excluding current index
171            for &i in others {
172                if i == index {
173                    continue;
174                }
175                for &j in others {
176                    // Ensure i < j and j is not current index
177                    if j <= i || j == index {
178                        continue;
179                    }
180
181                    // Determine canonical ordering for this triple
182                    let mut indices = [index, i, j];
183                    indices.sort();
184                    let [a_idx, b_idx, c_idx] = indices;
185
186                    let triple = (a_idx, b_idx, c_idx);
187
188                    // Skip if already matched
189                    if matches.contains(&triple) {
190                        continue;
191                    }
192
193                    let a = &entities[a_idx];
194                    let b = &entities[b_idx];
195                    let c = &entities[c_idx];
196
197                    if filter(a, b, c) && matches.insert(triple) {
198                        entity_to_matches.entry(a_idx).or_default().insert(triple);
199                        entity_to_matches.entry(b_idx).or_default().insert(triple);
200                        entity_to_matches.entry(c_idx).or_default().insert(triple);
201                        let base = weight(a, b, c);
202                        let score = match impact_type {
203                            ImpactType::Penalty => -base,
204                            ImpactType::Reward => base,
205                        };
206                        total = total + score;
207                    }
208                }
209            }
210        }
211
212        total
213    }
214
215    /// Retract entity and remove all its matches.
216    fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
217        // Remove from key index
218        if let Some(key) = self.index_to_key.remove(&index) {
219            if let Some(indices) = self.key_to_indices.get_mut(&key) {
220                indices.remove(&index);
221                if indices.is_empty() {
222                    self.key_to_indices.remove(&key);
223                }
224            }
225        }
226
227        // Remove all matches involving this entity
228        let Some(triples) = self.entity_to_matches.remove(&index) else {
229            return Sc::zero();
230        };
231
232        let mut total = Sc::zero();
233        for triple in triples {
234            self.matches.remove(&triple);
235
236            // Remove from other entities' match sets
237            let (i, j, k) = triple;
238            for &other in &[i, j, k] {
239                if other != index {
240                    if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
241                        other_set.remove(&triple);
242                        if other_set.is_empty() {
243                            self.entity_to_matches.remove(&other);
244                        }
245                    }
246                }
247            }
248
249            // Compute reverse delta
250            if i < entities.len() && j < entities.len() && k < entities.len() {
251                let score = self.compute_score(&entities[i], &entities[j], &entities[k]);
252                total = total - score;
253            }
254        }
255
256        total
257    }
258}
259
260impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
261    for IncrementalTriConstraint<S, A, K, E, KE, F, W, Sc>
262where
263    S: Send + Sync + 'static,
264    A: Clone + Debug + Send + Sync + 'static,
265    K: Eq + Hash + Clone + Send + Sync,
266    E: Fn(&S) -> &[A] + Send + Sync,
267    KE: Fn(&A) -> K + Send + Sync,
268    F: Fn(&A, &A, &A) -> bool + Send + Sync,
269    W: Fn(&A, &A, &A) -> Sc + Send + Sync,
270    Sc: Score,
271{
272    fn evaluate(&self, solution: &S) -> Sc {
273        let entities = (self.extractor)(solution);
274        let mut total = Sc::zero();
275
276        // Build temporary key index for evaluation
277        let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
278        for (i, entity) in entities.iter().enumerate() {
279            let key = (self.key_extractor)(entity);
280            temp_index.entry(key).or_default().push(i);
281        }
282
283        // Evaluate triples within each key group
284        for indices in temp_index.values() {
285            for pos_i in 0..indices.len() {
286                for pos_j in (pos_i + 1)..indices.len() {
287                    for pos_k in (pos_j + 1)..indices.len() {
288                        let i = indices[pos_i];
289                        let j = indices[pos_j];
290                        let k = indices[pos_k];
291                        let a = &entities[i];
292                        let b = &entities[j];
293                        let c = &entities[k];
294                        if (self.filter)(a, b, c) {
295                            total = total + self.compute_score(a, b, c);
296                        }
297                    }
298                }
299            }
300        }
301
302        total
303    }
304
305    fn match_count(&self, solution: &S) -> usize {
306        let entities = (self.extractor)(solution);
307        let mut count = 0;
308
309        // Build temporary key index
310        let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
311        for (i, entity) in entities.iter().enumerate() {
312            let key = (self.key_extractor)(entity);
313            temp_index.entry(key).or_default().push(i);
314        }
315
316        // Count matches within each key group
317        for indices in temp_index.values() {
318            for pos_i in 0..indices.len() {
319                for pos_j in (pos_i + 1)..indices.len() {
320                    for pos_k in (pos_j + 1)..indices.len() {
321                        let i = indices[pos_i];
322                        let j = indices[pos_j];
323                        let k = indices[pos_k];
324                        if (self.filter)(&entities[i], &entities[j], &entities[k]) {
325                            count += 1;
326                        }
327                    }
328                }
329            }
330        }
331
332        count
333    }
334
335    fn initialize(&mut self, solution: &S) -> Sc {
336        self.reset();
337
338        let entities = (self.extractor)(solution);
339        let mut total = Sc::zero();
340        for i in 0..entities.len() {
341            total = total + self.insert_entity(entities, i);
342        }
343        total
344    }
345
346    fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
347        let entities = (self.extractor)(solution);
348        self.insert_entity(entities, entity_index)
349    }
350
351    fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
352        let entities = (self.extractor)(solution);
353        self.retract_entity(entities, entity_index)
354    }
355
356    fn reset(&mut self) {
357        self.entity_to_matches.clear();
358        self.matches.clear();
359        self.key_to_indices.clear();
360        self.index_to_key.clear();
361    }
362
363    fn name(&self) -> &str {
364        &self.constraint_ref.name
365    }
366
367    fn is_hard(&self) -> bool {
368        self.is_hard
369    }
370
371    fn constraint_ref(&self) -> ConstraintRef {
372        self.constraint_ref.clone()
373    }
374
375    fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
376        impl_get_matches_nary!(tri: self, solution)
377    }
378}
379
380impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
381    for IncrementalTriConstraint<S, A, K, E, KE, F, W, Sc>
382{
383    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384        f.debug_struct("IncrementalTriConstraint")
385            .field("name", &self.constraint_ref.name)
386            .field("impact_type", &self.impact_type)
387            .field("match_count", &self.matches.len())
388            .finish()
389    }
390}