Skip to main content

solverforge_scoring/constraint/complemented/
incremental.rs

1use std::hash::Hash;
2
3use crate::api::constraint_set::IncrementalConstraint;
4use crate::stream::collector::UniCollector;
5use solverforge_core::score::Score;
6use solverforge_core::ConstraintRef;
7
8use super::ComplementedGroupConstraint;
9
10impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> IncrementalConstraint<S, Sc>
11    for ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
12where
13    S: Send + Sync + 'static,
14    A: Clone + Send + Sync + 'static,
15    B: Clone + Send + Sync + 'static,
16    K: Clone + Eq + Hash + Send + Sync,
17    EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
18    EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
19    KA: Fn(&A) -> Option<K> + Send + Sync,
20    KB: Fn(&B) -> K + Send + Sync,
21    C: UniCollector<A> + Send + Sync,
22    C::Accumulator: Send + Sync,
23    C::Result: Clone + Send + Sync,
24    C::Value: Send + Sync,
25    D: Fn(&B) -> C::Result + Send + Sync,
26    W: Fn(&K, &C::Result) -> Sc + Send + Sync,
27    Sc: Score,
28{
29    fn evaluate(&self, solution: &S) -> Sc {
30        let entities_a = self.extractor_a.extract(solution);
31        let entities_b = self.extractor_b.extract(solution);
32
33        let groups = self.build_groups(entities_a);
34
35        let mut total = Sc::zero();
36        for b in entities_b {
37            let key = (self.key_b)(b);
38            total = total
39                + match groups.get(&key) {
40                    Some(result) => self.compute_score(&key, result),
41                    None => {
42                        let default_result = (self.default_fn)(b);
43                        self.compute_score(&key, &default_result)
44                    }
45                };
46        }
47
48        total
49    }
50
51    fn match_count(&self, solution: &S) -> usize {
52        let entities_b = self.extractor_b.extract(solution);
53        entities_b.len()
54    }
55
56    fn initialize(&mut self, solution: &S) -> Sc {
57        self.reset();
58
59        let entities_a = self.extractor_a.extract(solution);
60        let entities_b = self.extractor_b.extract(solution);
61
62        // Build B key -> index mapping
63        for (idx, b) in entities_b.iter().enumerate() {
64            let key = (self.key_b)(b);
65            self.b_by_key.insert(key.clone(), idx);
66            self.b_index_to_key.insert(idx, key);
67        }
68
69        // Initialize all B entities with default scores
70        let mut total = Sc::zero();
71        for b in entities_b {
72            let key = (self.key_b)(b);
73            let default_result = (self.default_fn)(b);
74            total = total + self.compute_score(&key, &default_result);
75        }
76
77        // Now insert all A entities incrementally
78        for (idx, a) in entities_a.iter().enumerate() {
79            total = total + self.insert_entity(entities_b, idx, a);
80        }
81
82        total
83    }
84
85    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
86        let a_changed = self
87            .a_source
88            .assert_localizes(descriptor_index, &self.constraint_ref.name);
89        let b_changed = self
90            .b_source
91            .assert_localizes(descriptor_index, &self.constraint_ref.name);
92        let entities_a = self.extractor_a.extract(solution);
93        let entities_b = self.extractor_b.extract(solution);
94
95        let mut total = Sc::zero();
96        if a_changed && entity_index < entities_a.len() {
97            let entity = &entities_a[entity_index];
98            total = total + self.insert_entity(entities_b, entity_index, entity);
99        }
100        if b_changed {
101            total = total + self.insert_b(entities_b, entity_index);
102        }
103        total
104    }
105
106    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
107        let a_changed = self
108            .a_source
109            .assert_localizes(descriptor_index, &self.constraint_ref.name);
110        let b_changed = self
111            .b_source
112            .assert_localizes(descriptor_index, &self.constraint_ref.name);
113        let entities_a = self.extractor_a.extract(solution);
114        let entities_b = self.extractor_b.extract(solution);
115
116        let mut total = Sc::zero();
117        if a_changed {
118            total = total + self.retract_entity(entities_a, entities_b, entity_index);
119        }
120        if b_changed {
121            total = total + self.retract_b(entities_b, entity_index);
122        }
123        total
124    }
125
126    fn reset(&mut self) {
127        self.groups.clear();
128        self.entity_groups.clear();
129        self.entity_values.clear();
130        self.b_by_key.clear();
131        self.b_index_to_key.clear();
132    }
133
134    fn name(&self) -> &str {
135        &self.constraint_ref.name
136    }
137
138    fn is_hard(&self) -> bool {
139        self.is_hard
140    }
141
142    fn constraint_ref(&self) -> &ConstraintRef {
143        &self.constraint_ref
144    }
145}