solverforge_scoring/constraint/
bi_incremental.rs

1//! Incremental bi-constraint for self-join evaluation.
2//!
3//! Zero-erasure: all closures are concrete generic types, fully monomorphized.
4//! Uses key-based indexing for O(k) lookups instead of O(n) iteration.
5
6use std::collections::{HashMap, HashSet};
7use std::fmt::Debug;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::analysis::DetailedConstraintMatch;
15use crate::api::constraint_set::IncrementalConstraint;
16
17/// Zero-erasure incremental bi-constraint for self-joins.
18///
19/// All function types are concrete generics - no trait objects, no Arc.
20/// Uses key-based indexing: entities are grouped by join key for O(k) lookups.
21pub struct IncrementalBiConstraint<S, A, K, E, KE, F, W, Sc>
22where
23    Sc: Score,
24{
25    constraint_ref: ConstraintRef,
26    impact_type: ImpactType,
27    extractor: E,
28    key_extractor: KE,
29    filter: F,
30    weight: W,
31    is_hard: bool,
32    /// entity_index -> set of (low_idx, high_idx) pairs involving this entity
33    entity_to_matches: HashMap<usize, HashSet<(usize, usize)>>,
34    /// All matched pairs (low_idx, high_idx) where low_idx < high_idx
35    matches: HashSet<(usize, usize)>,
36    /// Key -> set of entity indices with that key (for O(k) lookup)
37    key_to_indices: HashMap<K, HashSet<usize>>,
38    /// entity_index -> key (for cleanup on retract)
39    index_to_key: HashMap<usize, K>,
40    _phantom: PhantomData<(S, A, Sc)>,
41}
42
43impl<S, A, K, E, KE, F, W, Sc> IncrementalBiConstraint<S, A, K, E, KE, F, W, Sc>
44where
45    S: 'static,
46    A: Clone + 'static,
47    K: Eq + Hash + Clone,
48    E: Fn(&S) -> &[A],
49    KE: Fn(&A) -> K,
50    F: Fn(&A, &A) -> bool,
51    W: Fn(&A, &A) -> Sc,
52    Sc: Score,
53{
54    pub fn new(
55        constraint_ref: ConstraintRef,
56        impact_type: ImpactType,
57        extractor: E,
58        key_extractor: KE,
59        filter: F,
60        weight: W,
61        is_hard: bool,
62    ) -> Self {
63        Self {
64            constraint_ref,
65            impact_type,
66            extractor,
67            key_extractor,
68            filter,
69            weight,
70            is_hard,
71            entity_to_matches: HashMap::new(),
72            matches: HashSet::new(),
73            key_to_indices: HashMap::new(),
74            index_to_key: HashMap::new(),
75            _phantom: PhantomData,
76        }
77    }
78
79    #[inline]
80    fn compute_score(&self, a: &A, b: &A) -> Sc {
81        let base = (self.weight)(a, b);
82        match self.impact_type {
83            ImpactType::Penalty => -base,
84            ImpactType::Reward => base,
85        }
86    }
87
88    /// Insert entity and find matches with other entities sharing the same key.
89    fn insert_entity(&mut self, entities: &[A], index: usize) -> Sc {
90        if index >= entities.len() {
91            return Sc::zero();
92        }
93
94        let entity = &entities[index];
95        let key = (self.key_extractor)(entity);
96
97        // Track this entity's key
98        self.index_to_key.insert(index, key.clone());
99
100        // Add this entity to the key index FIRST
101        self.key_to_indices
102            .entry(key.clone())
103            .or_default()
104            .insert(index);
105
106        // Split borrows to allow simultaneous read of key_to_indices and mutation of matches
107        let key_to_indices = &self.key_to_indices;
108        let matches = &mut self.matches;
109        let entity_to_matches = &mut self.entity_to_matches;
110        let filter = &self.filter;
111        let weight = &self.weight;
112        let impact_type = self.impact_type;
113
114        // Find matches with other entities having the same key (zero allocation)
115        let mut total = Sc::zero();
116        if let Some(others) = key_to_indices.get(&key) {
117            for &other_idx in others {
118                if other_idx == index {
119                    continue;
120                }
121
122                let other = &entities[other_idx];
123
124                // Canonical ordering: (low, high) where low < high
125                let (low_idx, high_idx, low_entity, high_entity) = if index < other_idx {
126                    (index, other_idx, entity, other)
127                } else {
128                    (other_idx, index, other, entity)
129                };
130
131                if filter(low_entity, high_entity) {
132                    let pair = (low_idx, high_idx);
133                    if matches.insert(pair) {
134                        entity_to_matches.entry(low_idx).or_default().insert(pair);
135                        entity_to_matches.entry(high_idx).or_default().insert(pair);
136                        let base = weight(low_entity, high_entity);
137                        let score = match impact_type {
138                            ImpactType::Penalty => -base,
139                            ImpactType::Reward => base,
140                        };
141                        total = total + score;
142                    }
143                }
144            }
145        }
146
147        total
148    }
149
150    /// Retract entity and remove all its matches.
151    fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
152        // Remove from key index
153        if let Some(key) = self.index_to_key.remove(&index) {
154            if let Some(indices) = self.key_to_indices.get_mut(&key) {
155                indices.remove(&index);
156                if indices.is_empty() {
157                    self.key_to_indices.remove(&key);
158                }
159            }
160        }
161
162        // Remove all matches involving this entity
163        let Some(pairs) = self.entity_to_matches.remove(&index) else {
164            return Sc::zero();
165        };
166
167        let mut total = Sc::zero();
168        for pair in pairs {
169            self.matches.remove(&pair);
170
171            // Remove from other entity's match set
172            let other = if pair.0 == index { pair.1 } else { pair.0 };
173            if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
174                other_set.remove(&pair);
175                if other_set.is_empty() {
176                    self.entity_to_matches.remove(&other);
177                }
178            }
179
180            // Compute reverse delta
181            let (low_idx, high_idx) = pair;
182            if low_idx < entities.len() && high_idx < entities.len() {
183                let score = self.compute_score(&entities[low_idx], &entities[high_idx]);
184                total = total - score;
185            }
186        }
187
188        total
189    }
190}
191
192impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
193    for IncrementalBiConstraint<S, A, K, E, KE, F, W, Sc>
194where
195    S: Send + Sync + 'static,
196    A: Clone + Debug + Send + Sync + 'static,
197    K: Eq + Hash + Clone + Send + Sync,
198    E: Fn(&S) -> &[A] + Send + Sync,
199    KE: Fn(&A) -> K + Send + Sync,
200    F: Fn(&A, &A) -> bool + Send + Sync,
201    W: Fn(&A, &A) -> Sc + Send + Sync,
202    Sc: Score,
203{
204    fn evaluate(&self, solution: &S) -> Sc {
205        let entities = (self.extractor)(solution);
206        let mut total = Sc::zero();
207
208        // Build temporary key index for evaluation
209        let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
210        for (i, entity) in entities.iter().enumerate() {
211            let key = (self.key_extractor)(entity);
212            temp_index.entry(key).or_default().push(i);
213        }
214
215        // Evaluate pairs within each key group
216        for indices in temp_index.values() {
217            for i in 0..indices.len() {
218                for j in (i + 1)..indices.len() {
219                    let low = indices[i];
220                    let high = indices[j];
221                    let a = &entities[low];
222                    let b = &entities[high];
223                    if (self.filter)(a, b) {
224                        total = total + self.compute_score(a, b);
225                    }
226                }
227            }
228        }
229
230        total
231    }
232
233    fn match_count(&self, solution: &S) -> usize {
234        let entities = (self.extractor)(solution);
235        let mut count = 0;
236
237        // Build temporary key index
238        let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
239        for (i, entity) in entities.iter().enumerate() {
240            let key = (self.key_extractor)(entity);
241            temp_index.entry(key).or_default().push(i);
242        }
243
244        // Count matches within each key group
245        for indices in temp_index.values() {
246            for i in 0..indices.len() {
247                for j in (i + 1)..indices.len() {
248                    let low = indices[i];
249                    let high = indices[j];
250                    if (self.filter)(&entities[low], &entities[high]) {
251                        count += 1;
252                    }
253                }
254            }
255        }
256
257        count
258    }
259
260    fn initialize(&mut self, solution: &S) -> Sc {
261        self.reset();
262
263        let entities = (self.extractor)(solution);
264        let mut total = Sc::zero();
265        for i in 0..entities.len() {
266            total = total + self.insert_entity(entities, i);
267        }
268        total
269    }
270
271    fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
272        let entities = (self.extractor)(solution);
273        self.insert_entity(entities, entity_index)
274    }
275
276    fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
277        let entities = (self.extractor)(solution);
278        self.retract_entity(entities, entity_index)
279    }
280
281    fn reset(&mut self) {
282        self.entity_to_matches.clear();
283        self.matches.clear();
284        self.key_to_indices.clear();
285        self.index_to_key.clear();
286    }
287
288    fn name(&self) -> &str {
289        &self.constraint_ref.name
290    }
291
292    fn is_hard(&self) -> bool {
293        self.is_hard
294    }
295
296    fn constraint_ref(&self) -> ConstraintRef {
297        self.constraint_ref.clone()
298    }
299
300    fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
301        impl_get_matches_nary!(bi: self, solution)
302    }
303}
304
305impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
306    for IncrementalBiConstraint<S, A, K, E, KE, F, W, Sc>
307{
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        f.debug_struct("IncrementalBiConstraint")
310            .field("name", &self.constraint_ref.name)
311            .field("impact_type", &self.impact_type)
312            .field("match_count", &self.matches.len())
313            .finish()
314    }
315}