Skip to main content

solverforge_scoring/constraint/
exists.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3use std::marker::PhantomData;
4use std::slice;
5
6use solverforge_core::score::Score;
7use solverforge_core::{ConstraintRef, ImpactType};
8
9use crate::api::constraint_set::IncrementalConstraint;
10use crate::stream::collection_extract::{ChangeSource, TrackedCollectionExtract};
11use crate::stream::filter::UniFilter;
12use crate::stream::{ExistenceMode, FlattenExtract};
13
14#[derive(Debug, Clone)]
15struct ASlot<K, Sc>
16where
17    Sc: Score,
18{
19    key: Option<K>,
20    bucket_pos: usize,
21    contribution: Sc,
22}
23
24impl<K, Sc> Default for ASlot<K, Sc>
25where
26    Sc: Score,
27{
28    fn default() -> Self {
29        Self {
30            key: None,
31            bucket_pos: 0,
32            contribution: Sc::zero(),
33        }
34    }
35}
36
37pub struct IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
38where
39    Sc: Score,
40{
41    constraint_ref: ConstraintRef,
42    impact_type: ImpactType,
43    mode: ExistenceMode,
44    extractor_a: EA,
45    extractor_parent: EP,
46    key_a: KA,
47    key_b: KB,
48    filter_a: FA,
49    filter_parent: FP,
50    flatten: Flatten,
51    weight: W,
52    is_hard: bool,
53    a_source: ChangeSource,
54    parent_source: ChangeSource,
55    a_slots: Vec<ASlot<K, Sc>>,
56    a_indices_by_key: HashMap<K, Vec<usize>>,
57    b_key_counts: HashMap<K, usize>,
58    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> P, fn() -> B)>,
59}
60
61impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
62    IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
63where
64    S: 'static,
65    A: Clone + 'static,
66    P: Clone + 'static,
67    B: Clone + 'static,
68    K: Eq + Hash + Clone,
69    EA: TrackedCollectionExtract<S, Item = A>,
70    EP: TrackedCollectionExtract<S, Item = P>,
71    KA: Fn(&A) -> K,
72    KB: Fn(&B) -> K,
73    FA: UniFilter<S, A>,
74    FP: UniFilter<S, P>,
75    Flatten: FlattenExtract<P, Item = B>,
76    W: Fn(&A) -> Sc,
77    Sc: Score,
78{
79    #[allow(clippy::too_many_arguments)]
80    pub fn new(
81        constraint_ref: ConstraintRef,
82        impact_type: ImpactType,
83        mode: ExistenceMode,
84        extractor_a: EA,
85        extractor_parent: EP,
86        key_a: KA,
87        key_b: KB,
88        filter_a: FA,
89        filter_parent: FP,
90        flatten: Flatten,
91        weight: W,
92        is_hard: bool,
93    ) -> Self {
94        let a_source = extractor_a.change_source();
95        let parent_source = extractor_parent.change_source();
96        Self {
97            constraint_ref,
98            impact_type,
99            mode,
100            extractor_a,
101            extractor_parent,
102            key_a,
103            key_b,
104            filter_a,
105            filter_parent,
106            flatten,
107            weight,
108            is_hard,
109            a_source,
110            parent_source,
111            a_slots: Vec::new(),
112            a_indices_by_key: HashMap::new(),
113            b_key_counts: HashMap::new(),
114            _phantom: PhantomData,
115        }
116    }
117
118    #[inline]
119    fn compute_score(&self, a: &A) -> Sc {
120        let base = (self.weight)(a);
121        match self.impact_type {
122            ImpactType::Penalty => -base,
123            ImpactType::Reward => base,
124        }
125    }
126
127    #[inline]
128    fn matches_existence(&self, key: &K) -> bool {
129        let count = self.b_key_counts.get(key).copied().unwrap_or(0);
130        match self.mode {
131            ExistenceMode::Exists => count > 0,
132            ExistenceMode::NotExists => count == 0,
133        }
134    }
135
136    fn rebuild_b_counts(&mut self, solution: &S) {
137        self.b_key_counts.clear();
138        for parent in self.extractor_parent.extract(solution) {
139            if !self.filter_parent.test(solution, parent) {
140                continue;
141            }
142            for item in self.flatten.extract(parent) {
143                *self.b_key_counts.entry((self.key_b)(item)).or_insert(0) += 1;
144            }
145        }
146    }
147
148    fn remove_a_from_bucket(&mut self, idx: usize, key: &K, bucket_pos: usize) {
149        let mut remove_key = false;
150        if let Some(bucket) = self.a_indices_by_key.get_mut(key) {
151            let removed = bucket.swap_remove(bucket_pos);
152            debug_assert_eq!(removed, idx);
153            if bucket_pos < bucket.len() {
154                let moved_idx = bucket[bucket_pos];
155                self.a_slots[moved_idx].bucket_pos = bucket_pos;
156            }
157            remove_key = bucket.is_empty();
158        }
159        if remove_key {
160            self.a_indices_by_key.remove(key);
161        }
162    }
163
164    fn retract_a(&mut self, idx: usize) -> Sc {
165        if idx >= self.a_slots.len() {
166            return Sc::zero();
167        }
168        let slot = self.a_slots[idx].clone();
169        let Some(key) = slot.key.clone() else {
170            return Sc::zero();
171        };
172        self.remove_a_from_bucket(idx, &key, slot.bucket_pos);
173        self.a_slots[idx] = ASlot::default();
174        -slot.contribution
175    }
176
177    fn insert_a(&mut self, solution: &S, idx: usize) -> Sc {
178        let entities_a = self.extractor_a.extract(solution);
179        if idx >= entities_a.len() {
180            return Sc::zero();
181        }
182        if self.a_slots.len() < entities_a.len() {
183            self.a_slots.resize(entities_a.len(), ASlot::default());
184        }
185
186        let a = &entities_a[idx];
187        if !self.filter_a.test(solution, a) {
188            self.a_slots[idx] = ASlot::default();
189            return Sc::zero();
190        }
191
192        let key = (self.key_a)(a);
193        let bucket = self.a_indices_by_key.entry(key.clone()).or_default();
194        let bucket_pos = bucket.len();
195        bucket.push(idx);
196
197        let contribution = if self.matches_existence(&key) {
198            self.compute_score(a)
199        } else {
200            Sc::zero()
201        };
202
203        self.a_slots[idx] = ASlot {
204            key: Some(key),
205            bucket_pos,
206            contribution,
207        };
208        contribution
209    }
210
211    fn reevaluate_key(&mut self, solution: &S, key: &K) -> Sc {
212        let Some(indices) = self.a_indices_by_key.get(key).cloned() else {
213            return Sc::zero();
214        };
215        let entities_a = self.extractor_a.extract(solution);
216        let mut total = Sc::zero();
217        let exists = self.matches_existence(key);
218
219        for idx in indices {
220            let a = &entities_a[idx];
221            let new_contribution = if exists {
222                self.compute_score(a)
223            } else {
224                Sc::zero()
225            };
226            let old_contribution = self.a_slots[idx].contribution;
227            self.a_slots[idx].contribution = new_contribution;
228            total = total + (new_contribution - old_contribution);
229        }
230
231        total
232    }
233
234    fn update_key_counts(
235        &mut self,
236        solution: &S,
237        key_multiset: &HashMap<K, usize>,
238        insert: bool,
239    ) -> Sc {
240        let mut total = Sc::zero();
241
242        for (key, count) in key_multiset {
243            if insert {
244                *self.b_key_counts.entry(key.clone()).or_insert(0) += *count;
245            } else {
246                let mut remove_key = false;
247                if let Some(entry) = self.b_key_counts.get_mut(key) {
248                    *entry = entry.saturating_sub(*count);
249                    remove_key = *entry == 0;
250                }
251                if remove_key {
252                    self.b_key_counts.remove(key);
253                }
254            }
255        }
256
257        for key in key_multiset.keys() {
258            total = total + self.reevaluate_key(solution, key);
259        }
260
261        total
262    }
263
264    fn parent_key_multiset(&self, solution: &S, idx: usize) -> HashMap<K, usize> {
265        let parents = self.extractor_parent.extract(solution);
266        if idx >= parents.len() {
267            return HashMap::new();
268        }
269        let parent = &parents[idx];
270        if !self.filter_parent.test(solution, parent) {
271            return HashMap::new();
272        }
273
274        let mut multiset = HashMap::new();
275        for item in self.flatten.extract(parent) {
276            *multiset.entry((self.key_b)(item)).or_insert(0) += 1;
277        }
278        multiset
279    }
280
281    fn initialize_a_state(&mut self, solution: &S) -> Sc {
282        self.a_slots.clear();
283        self.a_indices_by_key.clear();
284
285        let len = self.extractor_a.extract(solution).len();
286        self.a_slots.resize(len, ASlot::default());
287
288        let mut total = Sc::zero();
289        for idx in 0..len {
290            total = total + self.insert_a(solution, idx);
291        }
292        total
293    }
294
295    fn full_match_count(&self, solution: &S) -> usize {
296        let mut key_counts = HashMap::<K, usize>::new();
297        for parent in self.extractor_parent.extract(solution) {
298            if !self.filter_parent.test(solution, parent) {
299                continue;
300            }
301            for item in self.flatten.extract(parent) {
302                *key_counts.entry((self.key_b)(item)).or_insert(0) += 1;
303            }
304        }
305
306        self.extractor_a
307            .extract(solution)
308            .iter()
309            .filter(|a| {
310                self.filter_a.test(solution, a)
311                    && match self.mode {
312                        ExistenceMode::Exists => {
313                            key_counts.get(&(self.key_a)(a)).copied().unwrap_or(0) > 0
314                        }
315                        ExistenceMode::NotExists => {
316                            key_counts.get(&(self.key_a)(a)).copied().unwrap_or(0) == 0
317                        }
318                    }
319            })
320            .count()
321    }
322}
323
324impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc> IncrementalConstraint<S, Sc>
325    for IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
326where
327    S: Send + Sync + 'static,
328    A: Clone + Send + Sync + 'static,
329    P: Clone + Send + Sync + 'static,
330    B: Clone + Send + Sync + 'static,
331    K: Eq + Hash + Clone + Send + Sync,
332    EA: TrackedCollectionExtract<S, Item = A> + Send + Sync,
333    EP: TrackedCollectionExtract<S, Item = P> + Send + Sync,
334    KA: Fn(&A) -> K + Send + Sync,
335    KB: Fn(&B) -> K + Send + Sync,
336    FA: UniFilter<S, A> + Send + Sync,
337    FP: UniFilter<S, P> + Send + Sync,
338    Flatten: FlattenExtract<P, Item = B> + Send + Sync,
339    W: Fn(&A) -> Sc + Send + Sync,
340    Sc: Score,
341{
342    fn evaluate(&self, solution: &S) -> Sc {
343        let mut counts = HashMap::<K, usize>::new();
344        for parent in self.extractor_parent.extract(solution) {
345            if !self.filter_parent.test(solution, parent) {
346                continue;
347            }
348            for item in self.flatten.extract(parent) {
349                *counts.entry((self.key_b)(item)).or_insert(0) += 1;
350            }
351        }
352
353        let mut total = Sc::zero();
354        for a in self.extractor_a.extract(solution) {
355            if !self.filter_a.test(solution, a) {
356                continue;
357            }
358            let key = (self.key_a)(a);
359            let matches = match self.mode {
360                ExistenceMode::Exists => counts.get(&key).copied().unwrap_or(0) > 0,
361                ExistenceMode::NotExists => counts.get(&key).copied().unwrap_or(0) == 0,
362            };
363            if matches {
364                total = total + self.compute_score(a);
365            }
366        }
367        total
368    }
369
370    fn match_count(&self, solution: &S) -> usize {
371        self.full_match_count(solution)
372    }
373
374    fn initialize(&mut self, solution: &S) -> Sc {
375        self.reset();
376        self.rebuild_b_counts(solution);
377        self.initialize_a_state(solution)
378    }
379
380    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
381        let a_changed =
382            matches!(self.a_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
383        let parent_changed =
384            matches!(self.parent_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
385        let same_source = self.a_source == self.parent_source && a_changed && parent_changed;
386
387        let mut total = Sc::zero();
388        if same_source {
389            let keys = self.parent_key_multiset(solution, entity_index);
390            total = total + self.update_key_counts(solution, &keys, true);
391            total = total + self.insert_a(solution, entity_index);
392            return total;
393        }
394
395        if parent_changed {
396            let keys = self.parent_key_multiset(solution, entity_index);
397            total = total + self.update_key_counts(solution, &keys, true);
398        }
399        if a_changed {
400            total = total + self.insert_a(solution, entity_index);
401        }
402        total
403    }
404
405    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
406        let a_changed =
407            matches!(self.a_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
408        let parent_changed =
409            matches!(self.parent_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
410        let same_source = self.a_source == self.parent_source && a_changed && parent_changed;
411
412        let mut total = Sc::zero();
413        if same_source {
414            let keys = self.parent_key_multiset(solution, entity_index);
415            total = total + self.retract_a(entity_index);
416            total = total + self.update_key_counts(solution, &keys, false);
417            return total;
418        }
419
420        if a_changed {
421            total = total + self.retract_a(entity_index);
422        }
423        if parent_changed {
424            let keys = self.parent_key_multiset(solution, entity_index);
425            total = total + self.update_key_counts(solution, &keys, false);
426        }
427        total
428    }
429
430    fn reset(&mut self) {
431        self.a_slots.clear();
432        self.a_indices_by_key.clear();
433        self.b_key_counts.clear();
434    }
435
436    fn name(&self) -> &str {
437        &self.constraint_ref.name
438    }
439
440    fn is_hard(&self) -> bool {
441        self.is_hard
442    }
443
444    fn constraint_ref(&self) -> ConstraintRef {
445        self.constraint_ref.clone()
446    }
447}
448
449#[derive(Debug, Clone, Copy, Default)]
450pub struct SelfFlatten;
451
452impl<T> FlattenExtract<T> for SelfFlatten
453where
454    T: Send + Sync,
455{
456    type Item = T;
457
458    fn extract<'a>(&self, parent: &'a T) -> &'a [T] {
459        slice::from_ref(parent)
460    }
461}