solverforge_scoring/constraint/complemented/
incremental.rs1use std::hash::Hash;
2
3use crate::api::constraint_set::IncrementalConstraint;
4use crate::stream::collector::UniCollector;
5use solverforge_core::score::Score;
6use solverforge_core::ConstraintRef;
7
8use super::ComplementedGroupConstraint;
9
10impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> IncrementalConstraint<S, Sc>
11 for ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
12where
13 S: Send + Sync + 'static,
14 A: Clone + Send + Sync + 'static,
15 B: Clone + Send + Sync + 'static,
16 K: Clone + Eq + Hash + Send + Sync,
17 EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
18 EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
19 KA: Fn(&A) -> Option<K> + Send + Sync,
20 KB: Fn(&B) -> K + Send + Sync,
21 C: UniCollector<A> + Send + Sync,
22 C::Accumulator: Send + Sync,
23 C::Result: Clone + Send + Sync,
24 C::Value: Send + Sync,
25 D: Fn(&B) -> C::Result + Send + Sync,
26 W: Fn(&C::Result) -> Sc + Send + Sync,
27 Sc: Score,
28{
29 fn evaluate(&self, solution: &S) -> Sc {
30 let entities_a = self.extractor_a.extract(solution);
31 let entities_b = self.extractor_b.extract(solution);
32
33 let groups = self.build_groups(entities_a);
34
35 let mut total = Sc::zero();
36 for b in entities_b {
37 let key = (self.key_b)(b);
38 let result = groups
39 .get(&key)
40 .cloned()
41 .unwrap_or_else(|| (self.default_fn)(b));
42 total = total + self.compute_score(&result);
43 }
44
45 total
46 }
47
48 fn match_count(&self, solution: &S) -> usize {
49 let entities_b = self.extractor_b.extract(solution);
50 entities_b.len()
51 }
52
53 fn initialize(&mut self, solution: &S) -> Sc {
54 self.reset();
55
56 let entities_a = self.extractor_a.extract(solution);
57 let entities_b = self.extractor_b.extract(solution);
58
59 for (idx, b) in entities_b.iter().enumerate() {
61 let key = (self.key_b)(b);
62 self.b_by_key.insert(key.clone(), idx);
63 self.b_index_to_key.insert(idx, key);
64 }
65
66 let mut total = Sc::zero();
68 for b in entities_b {
69 let default_result = (self.default_fn)(b);
70 total = total + self.compute_score(&default_result);
71 }
72
73 for (idx, a) in entities_a.iter().enumerate() {
75 total = total + self.insert_entity(entities_b, idx, a);
76 }
77
78 total
79 }
80
81 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
82 let a_changed = self
83 .a_source
84 .assert_localizes(descriptor_index, &self.constraint_ref.name);
85 let b_changed = self
86 .b_source
87 .assert_localizes(descriptor_index, &self.constraint_ref.name);
88 let entities_a = self.extractor_a.extract(solution);
89 let entities_b = self.extractor_b.extract(solution);
90
91 let mut total = Sc::zero();
92 if a_changed && entity_index < entities_a.len() {
93 let entity = &entities_a[entity_index];
94 total = total + self.insert_entity(entities_b, entity_index, entity);
95 }
96 if b_changed {
97 total = total + self.insert_b(entities_b, entity_index);
98 }
99 total
100 }
101
102 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
103 let a_changed = self
104 .a_source
105 .assert_localizes(descriptor_index, &self.constraint_ref.name);
106 let b_changed = self
107 .b_source
108 .assert_localizes(descriptor_index, &self.constraint_ref.name);
109 let entities_a = self.extractor_a.extract(solution);
110 let entities_b = self.extractor_b.extract(solution);
111
112 let mut total = Sc::zero();
113 if a_changed {
114 total = total + self.retract_entity(entities_a, entities_b, entity_index);
115 }
116 if b_changed {
117 total = total + self.retract_b(entities_b, entity_index);
118 }
119 total
120 }
121
122 fn reset(&mut self) {
123 self.groups.clear();
124 self.entity_groups.clear();
125 self.entity_values.clear();
126 self.b_by_key.clear();
127 self.b_index_to_key.clear();
128 }
129
130 fn name(&self) -> &str {
131 &self.constraint_ref.name
132 }
133
134 fn is_hard(&self) -> bool {
135 self.is_hard
136 }
137
138 fn constraint_ref(&self) -> &ConstraintRef {
139 &self.constraint_ref
140 }
141}