Skip to main content

solverforge_scoring/constraint/
exists.rs

1use std::hash::Hash;
2use std::marker::PhantomData;
3use std::slice;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::collection_extract::{ChangeSource, CollectionExtract};
10use crate::stream::filter::UniFilter;
11use crate::stream::{ExistenceMode, FlattenExtract};
12
13mod key_state;
14
15use key_state::ExistsKeyState;
16#[cfg(test)]
17pub(crate) use key_state::ExistsStorageKind;
18
19#[derive(Debug, Clone)]
20struct ASlot<K, Sc>
21where
22    Sc: Score,
23{
24    key: Option<K>,
25    bucket_pos: usize,
26    score: Sc,
27}
28
29impl<K, Sc> Default for ASlot<K, Sc>
30where
31    Sc: Score,
32{
33    fn default() -> Self {
34        Self {
35            key: None,
36            bucket_pos: 0,
37            score: Sc::zero(),
38        }
39    }
40}
41
42pub struct IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
43where
44    Sc: Score,
45{
46    constraint_ref: ConstraintRef,
47    impact_type: ImpactType,
48    mode: ExistenceMode,
49    extractor_a: EA,
50    extractor_parent: EP,
51    key_a: KA,
52    key_b: KB,
53    filter_a: FA,
54    filter_parent: FP,
55    flatten: Flatten,
56    weight: W,
57    is_hard: bool,
58    a_source: ChangeSource,
59    parent_source: ChangeSource,
60    a_slots: Vec<ASlot<K, Sc>>,
61    key_state: ExistsKeyState<K, Sc>,
62    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> P, fn() -> B)>,
63}
64
65impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
66    IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
67where
68    S: 'static,
69    A: Clone + 'static,
70    P: Clone + 'static,
71    B: Clone + 'static,
72    K: Eq + Hash + Clone + 'static,
73    EA: CollectionExtract<S, Item = A>,
74    EP: CollectionExtract<S, Item = P>,
75    KA: Fn(&A) -> K,
76    KB: Fn(&B) -> K,
77    FA: UniFilter<S, A>,
78    FP: UniFilter<S, P>,
79    Flatten: FlattenExtract<P, Item = B>,
80    W: Fn(&A) -> Sc,
81    Sc: Score,
82{
83    #[allow(clippy::too_many_arguments)]
84    pub fn new(
85        constraint_ref: ConstraintRef,
86        impact_type: ImpactType,
87        mode: ExistenceMode,
88        extractor_a: EA,
89        extractor_parent: EP,
90        key_a: KA,
91        key_b: KB,
92        filter_a: FA,
93        filter_parent: FP,
94        flatten: Flatten,
95        weight: W,
96        is_hard: bool,
97    ) -> Self {
98        let a_source = extractor_a.change_source();
99        let parent_source = extractor_parent.change_source();
100        Self {
101            constraint_ref,
102            impact_type,
103            mode,
104            extractor_a,
105            extractor_parent,
106            key_a,
107            key_b,
108            filter_a,
109            filter_parent,
110            flatten,
111            weight,
112            is_hard,
113            a_source,
114            parent_source,
115            a_slots: Vec::new(),
116            key_state: ExistsKeyState::new(),
117            _phantom: PhantomData,
118        }
119    }
120
121    #[cfg(test)]
122    pub(crate) fn storage_kind(&self) -> ExistsStorageKind {
123        self.key_state.storage_kind()
124    }
125
126    #[inline]
127    fn compute_score(&self, a: &A) -> Sc {
128        let base = (self.weight)(a);
129        match self.impact_type {
130            ImpactType::Penalty => -base,
131            ImpactType::Reward => base,
132        }
133    }
134
135    #[inline]
136    fn matches_existence(&self, key: &K) -> bool {
137        self.matches_count(self.key_state.b_count(key))
138    }
139
140    #[inline]
141    fn matches_count(&self, count: usize) -> bool {
142        match self.mode {
143            ExistenceMode::Exists => count > 0,
144            ExistenceMode::NotExists => count == 0,
145        }
146    }
147
148    fn rebuild_b_counts(&mut self, solution: &S) {
149        self.key_state.clear_b_counts();
150        for parent in self.extractor_parent.extract(solution) {
151            if !self.filter_parent.test(solution, parent) {
152                continue;
153            }
154            for item in self.flatten.extract(parent) {
155                let key = (self.key_b)(item);
156                self.key_state.increment_b_count(&key, 1);
157            }
158        }
159    }
160
161    fn remove_a_from_bucket(&mut self, idx: usize, key: &K, bucket_pos: usize) {
162        if let Some(moved) = self.key_state.remove_a_index(key, idx, bucket_pos) {
163            self.a_slots[moved.idx].bucket_pos = moved.bucket_pos;
164        }
165    }
166
167    fn retract_a(&mut self, idx: usize) -> Sc {
168        if idx >= self.a_slots.len() {
169            return Sc::zero();
170        }
171        let bucket_pos = self.a_slots[idx].bucket_pos;
172        let score = self.a_slots[idx].score;
173        let Some(key) = self.a_slots[idx].key.take() else {
174            return Sc::zero();
175        };
176        let contribution = if self.matches_existence(&key) {
177            score
178        } else {
179            Sc::zero()
180        };
181        self.remove_a_from_bucket(idx, &key, bucket_pos);
182        self.key_state.subtract_a_score(&key, score);
183        self.a_slots[idx] = ASlot::default();
184        -contribution
185    }
186
187    fn insert_a(&mut self, solution: &S, idx: usize) -> Sc {
188        let entities_a = self.extractor_a.extract(solution);
189        if idx >= entities_a.len() {
190            return Sc::zero();
191        }
192        if self.a_slots.len() < entities_a.len() {
193            self.a_slots.resize(entities_a.len(), ASlot::default());
194        }
195
196        let a = &entities_a[idx];
197        if !self.filter_a.test(solution, a) {
198            self.a_slots[idx] = ASlot::default();
199            return Sc::zero();
200        }
201
202        let key = (self.key_a)(a);
203        let bucket_pos = self.key_state.insert_a_index(key.clone(), idx);
204        let score = self.compute_score(a);
205        self.key_state.add_a_score(&key, score);
206
207        let contribution = if self.matches_existence(&key) {
208            score
209        } else {
210            Sc::zero()
211        };
212
213        self.a_slots[idx] = ASlot {
214            key: Some(key),
215            bucket_pos,
216            score,
217        };
218        contribution
219    }
220
221    fn key_existence_delta(&self, key: &K, old_count: usize, new_count: usize) -> Sc {
222        let old_matches = self.matches_count(old_count);
223        let new_matches = self.matches_count(new_count);
224        if old_matches == new_matches {
225            Sc::zero()
226        } else if new_matches {
227            self.key_state.a_score_total(key)
228        } else {
229            -self.key_state.a_score_total(key)
230        }
231    }
232
233    fn update_key_counts(&mut self, key_counts: &[(K, usize)], insert: bool) -> Sc {
234        let mut total = Sc::zero();
235
236        for (key, count) in key_counts {
237            let old_count = self.key_state.b_count(key);
238            if insert {
239                self.key_state.increment_b_count(key, *count);
240            } else {
241                self.key_state.decrement_b_count(key, *count);
242            }
243            total = total + self.key_existence_delta(key, old_count, self.key_state.b_count(key));
244        }
245
246        total
247    }
248
249    fn parent_key_counts(&self, solution: &S, idx: usize) -> Vec<(K, usize)> {
250        let parents = self.extractor_parent.extract(solution);
251        if idx >= parents.len() {
252            return Vec::new();
253        }
254        let parent = &parents[idx];
255        if !self.filter_parent.test(solution, parent) {
256            return Vec::new();
257        }
258
259        let mut key_counts = Vec::<(K, usize)>::new();
260        for item in self.flatten.extract(parent) {
261            let key = (self.key_b)(item);
262            if let Some((_, count)) = key_counts
263                .iter_mut()
264                .find(|(existing_key, _)| existing_key == &key)
265            {
266                *count += 1;
267            } else {
268                key_counts.push((key, 1));
269            }
270        }
271        key_counts
272    }
273
274    fn initialize_a_state(&mut self, solution: &S) -> Sc {
275        self.a_slots.clear();
276        self.key_state.clear_a_buckets();
277
278        let len = self.extractor_a.extract(solution).len();
279        self.a_slots.resize(len, ASlot::default());
280
281        let mut total = Sc::zero();
282        for idx in 0..len {
283            total = total + self.insert_a(solution, idx);
284        }
285        total
286    }
287
288    fn build_b_counts(&self, solution: &S) -> ExistsKeyState<K, Sc> {
289        let mut key_state = ExistsKeyState::new();
290        for parent in self.extractor_parent.extract(solution) {
291            if !self.filter_parent.test(solution, parent) {
292                continue;
293            }
294            for item in self.flatten.extract(parent) {
295                let key = (self.key_b)(item);
296                key_state.increment_b_count(&key, 1);
297            }
298        }
299        key_state
300    }
301
302    fn full_match_count(&self, solution: &S) -> usize {
303        let key_state = self.build_b_counts(solution);
304
305        self.extractor_a
306            .extract(solution)
307            .iter()
308            .filter(|a| {
309                if !self.filter_a.test(solution, a) {
310                    return false;
311                }
312                let key = (self.key_a)(a);
313                self.matches_count(key_state.b_count(&key))
314            })
315            .count()
316    }
317}
318
319impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc> IncrementalConstraint<S, Sc>
320    for IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
321where
322    S: Send + Sync + 'static,
323    A: Clone + Send + Sync + 'static,
324    P: Clone + Send + Sync + 'static,
325    B: Clone + Send + Sync + 'static,
326    K: Eq + Hash + Clone + Send + Sync + 'static,
327    EA: CollectionExtract<S, Item = A> + Send + Sync,
328    EP: CollectionExtract<S, Item = P> + Send + Sync,
329    KA: Fn(&A) -> K + Send + Sync,
330    KB: Fn(&B) -> K + Send + Sync,
331    FA: UniFilter<S, A> + Send + Sync,
332    FP: UniFilter<S, P> + Send + Sync,
333    Flatten: FlattenExtract<P, Item = B> + Send + Sync,
334    W: Fn(&A) -> Sc + Send + Sync,
335    Sc: Score,
336{
337    fn evaluate(&self, solution: &S) -> Sc {
338        let key_state = self.build_b_counts(solution);
339
340        let mut total = Sc::zero();
341        for a in self.extractor_a.extract(solution) {
342            if !self.filter_a.test(solution, a) {
343                continue;
344            }
345            let key = (self.key_a)(a);
346            if self.matches_count(key_state.b_count(&key)) {
347                total = total + self.compute_score(a);
348            }
349        }
350        total
351    }
352
353    fn match_count(&self, solution: &S) -> usize {
354        self.full_match_count(solution)
355    }
356
357    fn initialize(&mut self, solution: &S) -> Sc {
358        self.reset();
359        self.rebuild_b_counts(solution);
360        self.initialize_a_state(solution)
361    }
362
363    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
364        let a_changed = self
365            .a_source
366            .assert_localizes(descriptor_index, &self.constraint_ref.name);
367        let parent_changed = self
368            .parent_source
369            .assert_localizes(descriptor_index, &self.constraint_ref.name);
370        let same_source =
371            self.a_source.same_index_domain(self.parent_source) && a_changed && parent_changed;
372
373        let mut total = Sc::zero();
374        if same_source {
375            let keys = self.parent_key_counts(solution, entity_index);
376            total = total + self.update_key_counts(&keys, true);
377            total = total + self.insert_a(solution, entity_index);
378            return total;
379        }
380
381        if parent_changed {
382            let keys = self.parent_key_counts(solution, entity_index);
383            total = total + self.update_key_counts(&keys, true);
384        }
385        if a_changed {
386            total = total + self.insert_a(solution, entity_index);
387        }
388        total
389    }
390
391    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
392        let a_changed = self
393            .a_source
394            .assert_localizes(descriptor_index, &self.constraint_ref.name);
395        let parent_changed = self
396            .parent_source
397            .assert_localizes(descriptor_index, &self.constraint_ref.name);
398        let same_source =
399            self.a_source.same_index_domain(self.parent_source) && a_changed && parent_changed;
400
401        let mut total = Sc::zero();
402        if same_source {
403            let keys = self.parent_key_counts(solution, entity_index);
404            total = total + self.retract_a(entity_index);
405            total = total + self.update_key_counts(&keys, false);
406            return total;
407        }
408
409        if a_changed {
410            total = total + self.retract_a(entity_index);
411        }
412        if parent_changed {
413            let keys = self.parent_key_counts(solution, entity_index);
414            total = total + self.update_key_counts(&keys, false);
415        }
416        total
417    }
418
419    fn reset(&mut self) {
420        self.a_slots.clear();
421        self.key_state.clear_a_buckets();
422        self.key_state.clear_b_counts();
423    }
424
425    fn name(&self) -> &str {
426        &self.constraint_ref.name
427    }
428
429    fn is_hard(&self) -> bool {
430        self.is_hard
431    }
432
433    fn constraint_ref(&self) -> ConstraintRef {
434        self.constraint_ref.clone()
435    }
436}
437
438#[derive(Debug, Clone, Copy, Default)]
439pub struct SelfFlatten;
440
441impl<T> FlattenExtract<T> for SelfFlatten
442where
443    T: Send + Sync,
444{
445    type Item = T;
446
447    fn extract<'a>(&self, parent: &'a T) -> &'a [T] {
448        slice::from_ref(parent)
449    }
450}