Skip to main content

solverforge_scoring/constraint/
exists.rs

1use std::hash::Hash;
2use std::marker::PhantomData;
3use std::slice;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::collection_extract::{ChangeSource, TrackedCollectionExtract};
10use crate::stream::filter::UniFilter;
11use crate::stream::{ExistenceMode, FlattenExtract};
12
13mod key_state;
14
15use key_state::ExistsKeyState;
16#[cfg(test)]
17pub(crate) use key_state::ExistsStorageKind;
18
19#[derive(Debug, Clone)]
20struct ASlot<K, Sc>
21where
22    Sc: Score,
23{
24    key: Option<K>,
25    bucket_pos: usize,
26    contribution: Sc,
27}
28
29impl<K, Sc> Default for ASlot<K, Sc>
30where
31    Sc: Score,
32{
33    fn default() -> Self {
34        Self {
35            key: None,
36            bucket_pos: 0,
37            contribution: Sc::zero(),
38        }
39    }
40}
41
42pub struct IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
43where
44    Sc: Score,
45{
46    constraint_ref: ConstraintRef,
47    impact_type: ImpactType,
48    mode: ExistenceMode,
49    extractor_a: EA,
50    extractor_parent: EP,
51    key_a: KA,
52    key_b: KB,
53    filter_a: FA,
54    filter_parent: FP,
55    flatten: Flatten,
56    weight: W,
57    is_hard: bool,
58    a_source: ChangeSource,
59    parent_source: ChangeSource,
60    a_slots: Vec<ASlot<K, Sc>>,
61    key_state: ExistsKeyState<K>,
62    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> P, fn() -> B)>,
63}
64
65impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
66    IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
67where
68    S: 'static,
69    A: Clone + 'static,
70    P: Clone + 'static,
71    B: Clone + 'static,
72    K: Eq + Hash + Clone + 'static,
73    EA: TrackedCollectionExtract<S, Item = A>,
74    EP: TrackedCollectionExtract<S, Item = P>,
75    KA: Fn(&A) -> K,
76    KB: Fn(&B) -> K,
77    FA: UniFilter<S, A>,
78    FP: UniFilter<S, P>,
79    Flatten: FlattenExtract<P, Item = B>,
80    W: Fn(&A) -> Sc,
81    Sc: Score,
82{
83    #[allow(clippy::too_many_arguments)]
84    pub fn new(
85        constraint_ref: ConstraintRef,
86        impact_type: ImpactType,
87        mode: ExistenceMode,
88        extractor_a: EA,
89        extractor_parent: EP,
90        key_a: KA,
91        key_b: KB,
92        filter_a: FA,
93        filter_parent: FP,
94        flatten: Flatten,
95        weight: W,
96        is_hard: bool,
97    ) -> Self {
98        let a_source = extractor_a.change_source();
99        let parent_source = extractor_parent.change_source();
100        Self {
101            constraint_ref,
102            impact_type,
103            mode,
104            extractor_a,
105            extractor_parent,
106            key_a,
107            key_b,
108            filter_a,
109            filter_parent,
110            flatten,
111            weight,
112            is_hard,
113            a_source,
114            parent_source,
115            a_slots: Vec::new(),
116            key_state: ExistsKeyState::new(),
117            _phantom: PhantomData,
118        }
119    }
120
121    #[cfg(test)]
122    pub(crate) fn storage_kind(&self) -> ExistsStorageKind {
123        self.key_state.storage_kind()
124    }
125
126    #[inline]
127    fn compute_score(&self, a: &A) -> Sc {
128        let base = (self.weight)(a);
129        match self.impact_type {
130            ImpactType::Penalty => -base,
131            ImpactType::Reward => base,
132        }
133    }
134
135    #[inline]
136    fn matches_existence(&self, key: &K) -> bool {
137        self.matches_count(self.key_state.b_count(key))
138    }
139
140    #[inline]
141    fn matches_count(&self, count: usize) -> bool {
142        match self.mode {
143            ExistenceMode::Exists => count > 0,
144            ExistenceMode::NotExists => count == 0,
145        }
146    }
147
148    fn rebuild_b_counts(&mut self, solution: &S) {
149        self.key_state.clear_b_counts();
150        for parent in self.extractor_parent.extract(solution) {
151            if !self.filter_parent.test(solution, parent) {
152                continue;
153            }
154            for item in self.flatten.extract(parent) {
155                let key = (self.key_b)(item);
156                self.key_state.increment_b_count(&key, 1);
157            }
158        }
159    }
160
161    fn remove_a_from_bucket(&mut self, idx: usize, key: &K, bucket_pos: usize) {
162        if let Some(moved) = self.key_state.remove_a_index(key, idx, bucket_pos) {
163            self.a_slots[moved.idx].bucket_pos = moved.bucket_pos;
164        }
165    }
166
167    fn retract_a(&mut self, idx: usize) -> Sc {
168        if idx >= self.a_slots.len() {
169            return Sc::zero();
170        }
171        let slot = self.a_slots[idx].clone();
172        let Some(key) = slot.key.clone() else {
173            return Sc::zero();
174        };
175        self.remove_a_from_bucket(idx, &key, slot.bucket_pos);
176        self.a_slots[idx] = ASlot::default();
177        -slot.contribution
178    }
179
180    fn insert_a(&mut self, solution: &S, idx: usize) -> Sc {
181        let entities_a = self.extractor_a.extract(solution);
182        if idx >= entities_a.len() {
183            return Sc::zero();
184        }
185        if self.a_slots.len() < entities_a.len() {
186            self.a_slots.resize(entities_a.len(), ASlot::default());
187        }
188
189        let a = &entities_a[idx];
190        if !self.filter_a.test(solution, a) {
191            self.a_slots[idx] = ASlot::default();
192            return Sc::zero();
193        }
194
195        let key = (self.key_a)(a);
196        let bucket_pos = self.key_state.insert_a_index(key.clone(), idx);
197
198        let contribution = if self.matches_existence(&key) {
199            self.compute_score(a)
200        } else {
201            Sc::zero()
202        };
203
204        self.a_slots[idx] = ASlot {
205            key: Some(key),
206            bucket_pos,
207            contribution,
208        };
209        contribution
210    }
211
212    fn reevaluate_key(&mut self, solution: &S, key: &K) -> Sc {
213        let indices = self.key_state.a_indices(key);
214        if indices.is_empty() {
215            return Sc::zero();
216        };
217        let entities_a = self.extractor_a.extract(solution);
218        let mut total = Sc::zero();
219        let exists = self.matches_existence(key);
220
221        for idx in indices {
222            let a = &entities_a[idx];
223            let new_contribution = if exists {
224                self.compute_score(a)
225            } else {
226                Sc::zero()
227            };
228            let old_contribution = self.a_slots[idx].contribution;
229            self.a_slots[idx].contribution = new_contribution;
230            total = total + (new_contribution - old_contribution);
231        }
232
233        total
234    }
235
236    fn update_key_counts(&mut self, solution: &S, key_counts: &[(K, usize)], insert: bool) -> Sc {
237        let mut total = Sc::zero();
238
239        for (key, count) in key_counts {
240            if insert {
241                self.key_state.increment_b_count(key, *count);
242            } else {
243                self.key_state.decrement_b_count(key, *count);
244            }
245        }
246
247        for (key, _) in key_counts {
248            total = total + self.reevaluate_key(solution, key);
249        }
250
251        total
252    }
253
254    fn parent_key_counts(&self, solution: &S, idx: usize) -> Vec<(K, usize)> {
255        let parents = self.extractor_parent.extract(solution);
256        if idx >= parents.len() {
257            return Vec::new();
258        }
259        let parent = &parents[idx];
260        if !self.filter_parent.test(solution, parent) {
261            return Vec::new();
262        }
263
264        let mut key_counts = Vec::<(K, usize)>::new();
265        for item in self.flatten.extract(parent) {
266            let key = (self.key_b)(item);
267            if let Some((_, count)) = key_counts
268                .iter_mut()
269                .find(|(existing_key, _)| existing_key == &key)
270            {
271                *count += 1;
272            } else {
273                key_counts.push((key, 1));
274            }
275        }
276        key_counts
277    }
278
279    fn initialize_a_state(&mut self, solution: &S) -> Sc {
280        self.a_slots.clear();
281        self.key_state.clear_a_buckets();
282
283        let len = self.extractor_a.extract(solution).len();
284        self.a_slots.resize(len, ASlot::default());
285
286        let mut total = Sc::zero();
287        for idx in 0..len {
288            total = total + self.insert_a(solution, idx);
289        }
290        total
291    }
292
293    fn build_b_counts(&self, solution: &S) -> ExistsKeyState<K> {
294        let mut key_state = ExistsKeyState::new();
295        for parent in self.extractor_parent.extract(solution) {
296            if !self.filter_parent.test(solution, parent) {
297                continue;
298            }
299            for item in self.flatten.extract(parent) {
300                let key = (self.key_b)(item);
301                key_state.increment_b_count(&key, 1);
302            }
303        }
304        key_state
305    }
306
307    fn full_match_count(&self, solution: &S) -> usize {
308        let key_state = self.build_b_counts(solution);
309
310        self.extractor_a
311            .extract(solution)
312            .iter()
313            .filter(|a| {
314                if !self.filter_a.test(solution, a) {
315                    return false;
316                }
317                let key = (self.key_a)(a);
318                self.matches_count(key_state.b_count(&key))
319            })
320            .count()
321    }
322}
323
324impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc> IncrementalConstraint<S, Sc>
325    for IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
326where
327    S: Send + Sync + 'static,
328    A: Clone + Send + Sync + 'static,
329    P: Clone + Send + Sync + 'static,
330    B: Clone + Send + Sync + 'static,
331    K: Eq + Hash + Clone + Send + Sync + 'static,
332    EA: TrackedCollectionExtract<S, Item = A> + Send + Sync,
333    EP: TrackedCollectionExtract<S, Item = P> + Send + Sync,
334    KA: Fn(&A) -> K + Send + Sync,
335    KB: Fn(&B) -> K + Send + Sync,
336    FA: UniFilter<S, A> + Send + Sync,
337    FP: UniFilter<S, P> + Send + Sync,
338    Flatten: FlattenExtract<P, Item = B> + Send + Sync,
339    W: Fn(&A) -> Sc + Send + Sync,
340    Sc: Score,
341{
342    fn evaluate(&self, solution: &S) -> Sc {
343        let key_state = self.build_b_counts(solution);
344
345        let mut total = Sc::zero();
346        for a in self.extractor_a.extract(solution) {
347            if !self.filter_a.test(solution, a) {
348                continue;
349            }
350            let key = (self.key_a)(a);
351            if self.matches_count(key_state.b_count(&key)) {
352                total = total + self.compute_score(a);
353            }
354        }
355        total
356    }
357
358    fn match_count(&self, solution: &S) -> usize {
359        self.full_match_count(solution)
360    }
361
362    fn initialize(&mut self, solution: &S) -> Sc {
363        self.reset();
364        self.rebuild_b_counts(solution);
365        self.initialize_a_state(solution)
366    }
367
368    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
369        let a_changed =
370            matches!(self.a_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
371        let parent_changed =
372            matches!(self.parent_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
373        let same_source = self.a_source == self.parent_source && a_changed && parent_changed;
374
375        let mut total = Sc::zero();
376        if same_source {
377            let keys = self.parent_key_counts(solution, entity_index);
378            total = total + self.update_key_counts(solution, &keys, true);
379            total = total + self.insert_a(solution, entity_index);
380            return total;
381        }
382
383        if parent_changed {
384            let keys = self.parent_key_counts(solution, entity_index);
385            total = total + self.update_key_counts(solution, &keys, true);
386        }
387        if a_changed {
388            total = total + self.insert_a(solution, entity_index);
389        }
390        total
391    }
392
393    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
394        let a_changed =
395            matches!(self.a_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
396        let parent_changed =
397            matches!(self.parent_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
398        let same_source = self.a_source == self.parent_source && a_changed && parent_changed;
399
400        let mut total = Sc::zero();
401        if same_source {
402            let keys = self.parent_key_counts(solution, entity_index);
403            total = total + self.retract_a(entity_index);
404            total = total + self.update_key_counts(solution, &keys, false);
405            return total;
406        }
407
408        if a_changed {
409            total = total + self.retract_a(entity_index);
410        }
411        if parent_changed {
412            let keys = self.parent_key_counts(solution, entity_index);
413            total = total + self.update_key_counts(solution, &keys, false);
414        }
415        total
416    }
417
418    fn reset(&mut self) {
419        self.a_slots.clear();
420        self.key_state.clear_a_buckets();
421        self.key_state.clear_b_counts();
422    }
423
424    fn name(&self) -> &str {
425        &self.constraint_ref.name
426    }
427
428    fn is_hard(&self) -> bool {
429        self.is_hard
430    }
431
432    fn constraint_ref(&self) -> ConstraintRef {
433        self.constraint_ref.clone()
434    }
435}
436
437#[derive(Debug, Clone, Copy, Default)]
438pub struct SelfFlatten;
439
440impl<T> FlattenExtract<T> for SelfFlatten
441where
442    T: Send + Sync,
443{
444    type Item = T;
445
446    fn extract<'a>(&self, parent: &'a T) -> &'a [T] {
447        slice::from_ref(parent)
448    }
449}