1use std::collections::HashMap;
2use std::hash::Hash;
3use std::marker::PhantomData;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::collector::{Accumulator, UniCollector};
10use crate::stream::filter::UniFilter;
11use crate::stream::ProjectedSource;
12
13pub struct ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
14where
15 C: UniCollector<Out>,
16 Sc: Score,
17{
18 constraint_ref: ConstraintRef,
19 impact_type: ImpactType,
20 source: Src,
21 filter: F,
22 key_fn: KF,
23 collector: C,
24 weight_fn: W,
25 is_hard: bool,
26 groups: HashMap<K, C::Accumulator>,
27 group_counts: HashMap<K, usize>,
28 entity_values: HashMap<(usize, usize), Vec<(K, C::Value)>>,
29 _phantom: PhantomData<(fn() -> S, fn() -> Out, fn() -> Sc)>,
30}
31
32impl<S, Out, K, Src, F, KF, C, W, Sc> ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
33where
34 S: Send + Sync + 'static,
35 Out: Clone + Send + Sync + 'static,
36 K: Clone + Eq + Hash + Send + Sync + 'static,
37 Src: ProjectedSource<S, Out>,
38 F: UniFilter<S, Out>,
39 KF: Fn(&Out) -> K + Send + Sync,
40 C: UniCollector<Out> + Send + Sync + 'static,
41 C::Accumulator: Send + Sync,
42 C::Result: Send + Sync,
43 C::Value: Clone + Send + Sync,
44 W: Fn(&C::Result) -> Sc + Send + Sync,
45 Sc: Score + 'static,
46{
47 #[allow(clippy::too_many_arguments)]
48 pub fn new(
49 constraint_ref: ConstraintRef,
50 impact_type: ImpactType,
51 source: Src,
52 filter: F,
53 key_fn: KF,
54 collector: C,
55 weight_fn: W,
56 is_hard: bool,
57 ) -> Self {
58 Self {
59 constraint_ref,
60 impact_type,
61 source,
62 filter,
63 key_fn,
64 collector,
65 weight_fn,
66 is_hard,
67 groups: HashMap::new(),
68 group_counts: HashMap::new(),
69 entity_values: HashMap::new(),
70 _phantom: PhantomData,
71 }
72 }
73
74 fn compute_score(&self, result: &C::Result) -> Sc {
75 let base = (self.weight_fn)(result);
76 match self.impact_type {
77 ImpactType::Penalty => -base,
78 ImpactType::Reward => base,
79 }
80 }
81
82 fn retract_output(&mut self, key: &K, value: &C::Value) -> Sc {
83 let Some(acc) = self.groups.get_mut(key) else {
84 return Sc::zero();
85 };
86 let impact = self.impact_type;
87 let old_base = (self.weight_fn)(&acc.finish());
88 let old = match impact {
89 ImpactType::Penalty => -old_base,
90 ImpactType::Reward => old_base,
91 };
92
93 let is_empty = {
94 let count = self.group_counts.entry(key.clone()).or_insert(0);
95 *count = count.saturating_sub(1);
96 *count == 0
97 };
98 if is_empty {
99 self.group_counts.remove(key);
100 }
101
102 acc.retract(value);
103 let new_score = if is_empty {
104 self.groups.remove(key);
105 Sc::zero()
106 } else {
107 let new_base = (self.weight_fn)(&acc.finish());
108 match impact {
109 ImpactType::Penalty => -new_base,
110 ImpactType::Reward => new_base,
111 }
112 };
113
114 new_score - old
115 }
116
117 fn insert_entity_outputs(&mut self, solution: &S, slot: usize, entity_index: usize) -> Sc {
118 let mut total = Sc::zero();
119 let mut cached = Vec::new();
120 let source = &self.source;
121 let filter = &self.filter;
122 let key_fn = &self.key_fn;
123 let collector = &self.collector;
124 let weight_fn = &self.weight_fn;
125 let impact = self.impact_type;
126 let groups = &mut self.groups;
127 let group_counts = &mut self.group_counts;
128 source.collect_entity(solution, slot, entity_index, |_, output| {
129 if !filter.test(solution, &output) {
130 return;
131 }
132 let key = key_fn(&output);
133 let value = collector.extract(&output);
134 let is_new = !groups.contains_key(&key);
135 let acc = groups
136 .entry(key.clone())
137 .or_insert_with(|| collector.create_accumulator());
138 let old = if is_new {
139 Sc::zero()
140 } else {
141 let old_base = weight_fn(&acc.finish());
142 match impact {
143 ImpactType::Penalty => -old_base,
144 ImpactType::Reward => old_base,
145 }
146 };
147 acc.accumulate(&value);
148 let new_base = weight_fn(&acc.finish());
149 let new_score = match impact {
150 ImpactType::Penalty => -new_base,
151 ImpactType::Reward => new_base,
152 };
153 *group_counts.entry(key.clone()).or_insert(0) += 1;
154 cached.push((key, value));
155 total = total + (new_score - old);
156 });
157 self.entity_values.insert((slot, entity_index), cached);
158 total
159 }
160
161 fn retract_entity_outputs(&mut self, slot: usize, entity_index: usize) -> Sc {
162 let Some(cached) = self.entity_values.remove(&(slot, entity_index)) else {
163 return Sc::zero();
164 };
165 let mut total = Sc::zero();
166 for (key, value) in cached {
167 total = total + self.retract_output(&key, &value);
168 }
169 total
170 }
171
172 fn localized_slots(&self, descriptor_index: usize) -> Vec<usize> {
173 let mut slots = Vec::new();
174 for slot in 0..self.source.source_count() {
175 if self
176 .source
177 .change_source(slot)
178 .assert_localizes(descriptor_index, &self.constraint_ref.name)
179 {
180 slots.push(slot);
181 }
182 }
183 slots
184 }
185}
186
187impl<S, Out, K, Src, F, KF, C, W, Sc> IncrementalConstraint<S, Sc>
188 for ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
189where
190 S: Send + Sync + 'static,
191 Out: Clone + Send + Sync + 'static,
192 K: Clone + Eq + Hash + Send + Sync + 'static,
193 Src: ProjectedSource<S, Out>,
194 F: UniFilter<S, Out>,
195 KF: Fn(&Out) -> K + Send + Sync,
196 C: UniCollector<Out> + Send + Sync + 'static,
197 C::Accumulator: Send + Sync,
198 C::Result: Send + Sync,
199 C::Value: Clone + Send + Sync,
200 W: Fn(&C::Result) -> Sc + Send + Sync,
201 Sc: Score + 'static,
202{
203 fn evaluate(&self, solution: &S) -> Sc {
204 let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
205 self.source.collect_all(solution, |_, output| {
206 if !self.filter.test(solution, &output) {
207 return;
208 }
209 let key = (self.key_fn)(&output);
210 let value = self.collector.extract(&output);
211 groups
212 .entry(key)
213 .or_insert_with(|| self.collector.create_accumulator())
214 .accumulate(&value);
215 });
216 groups.values().fold(Sc::zero(), |total, acc| {
217 total + self.compute_score(&acc.finish())
218 })
219 }
220
221 fn match_count(&self, solution: &S) -> usize {
222 let mut keys = HashMap::<K, ()>::new();
223 self.source.collect_all(solution, |_, output| {
224 if self.filter.test(solution, &output) {
225 keys.insert((self.key_fn)(&output), ());
226 }
227 });
228 keys.len()
229 }
230
231 fn initialize(&mut self, solution: &S) -> Sc {
232 self.reset();
233 let mut total = Sc::zero();
234 let source = &self.source;
235 let filter = &self.filter;
236 let key_fn = &self.key_fn;
237 let collector = &self.collector;
238 let weight_fn = &self.weight_fn;
239 let impact = self.impact_type;
240 let groups = &mut self.groups;
241 let group_counts = &mut self.group_counts;
242 let entity_values = &mut self.entity_values;
243 source.collect_all(solution, |coordinate, output| {
244 if !filter.test(solution, &output) {
245 return;
246 }
247 let key = key_fn(&output);
248 let value = collector.extract(&output);
249 let is_new = !groups.contains_key(&key);
250 let acc = groups
251 .entry(key.clone())
252 .or_insert_with(|| collector.create_accumulator());
253 let old = if is_new {
254 Sc::zero()
255 } else {
256 let old_base = weight_fn(&acc.finish());
257 match impact {
258 ImpactType::Penalty => -old_base,
259 ImpactType::Reward => old_base,
260 }
261 };
262 acc.accumulate(&value);
263 let new_base = weight_fn(&acc.finish());
264 let new_score = match impact {
265 ImpactType::Penalty => -new_base,
266 ImpactType::Reward => new_base,
267 };
268 *group_counts.entry(key.clone()).or_insert(0) += 1;
269 entity_values
270 .entry((coordinate.source_slot, coordinate.entity_index))
271 .or_default()
272 .push((key, value));
273 total = total + (new_score - old);
274 });
275 total
276 }
277
278 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
279 let mut total = Sc::zero();
280 for slot in self.localized_slots(descriptor_index) {
281 total = total + self.insert_entity_outputs(solution, slot, entity_index);
282 }
283 total
284 }
285
286 fn on_retract(&mut self, _solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
287 let mut total = Sc::zero();
288 for slot in self.localized_slots(descriptor_index) {
289 total = total + self.retract_entity_outputs(slot, entity_index);
290 }
291 total
292 }
293
294 fn reset(&mut self) {
295 self.groups.clear();
296 self.group_counts.clear();
297 self.entity_values.clear();
298 }
299
300 fn name(&self) -> &str {
301 &self.constraint_ref.name
302 }
303
304 fn is_hard(&self) -> bool {
305 self.is_hard
306 }
307
308 fn constraint_ref(&self) -> ConstraintRef {
309 self.constraint_ref.clone()
310 }
311}