1use std::collections::{hash_map::Entry, HashMap, HashSet};
2use std::hash::Hash;
3use std::marker::PhantomData;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::collector::{Accumulator, Collector};
10use crate::stream::filter::UniFilter;
11use crate::stream::{ProjectedRowCoordinate, ProjectedRowOwner, ProjectedSource};
12
13struct GroupState<Acc> {
14 accumulator: Acc,
15 count: usize,
16}
17
18type CollectorRetraction<Acc, V, R> = <Acc as Accumulator<V, R>>::Retraction;
19
20pub struct ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, V, R, Acc, W, Sc>
21where
22 Src: ProjectedSource<S, Out>,
23 Acc: Accumulator<V, R>,
24 Sc: Score,
25{
26 constraint_ref: ConstraintRef,
27 impact_type: ImpactType,
28 source: Src,
29 filter: F,
30 key_fn: KF,
31 collector: C,
32 weight_fn: W,
33 is_hard: bool,
34 source_state: Option<Src::State>,
35 groups: HashMap<K, GroupState<Acc>>,
36 row_outputs: HashMap<ProjectedRowCoordinate, Out>,
37 row_retractions: HashMap<ProjectedRowCoordinate, CollectorRetraction<Acc, V, R>>,
38 rows_by_owner: HashMap<ProjectedRowOwner, Vec<ProjectedRowCoordinate>>,
39 _phantom: PhantomData<(
40 fn() -> S,
41 fn() -> Out,
42 fn() -> V,
43 fn() -> R,
44 fn() -> Acc,
45 fn() -> Sc,
46 )>,
47}
48
49impl<S, Out, K, Src, F, KF, C, V, R, Acc, W, Sc>
50 ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, V, R, Acc, W, Sc>
51where
52 S: Send + Sync + 'static,
53 Out: Send + Sync + 'static,
54 K: Eq + Hash + Send + Sync + 'static,
55 Src: ProjectedSource<S, Out>,
56 F: UniFilter<S, Out>,
57 KF: Fn(&Out) -> K + Send + Sync,
58 C: for<'i> Collector<&'i Out, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
59 V: Send + Sync,
60 R: Send + Sync,
61 Acc: Accumulator<V, R> + Send + Sync,
62 W: Fn(&K, &R) -> Sc + Send + Sync,
63 Sc: Score + 'static,
64{
65 #[allow(clippy::too_many_arguments)]
66 pub fn new(
67 constraint_ref: ConstraintRef,
68 impact_type: ImpactType,
69 source: Src,
70 filter: F,
71 key_fn: KF,
72 collector: C,
73 weight_fn: W,
74 is_hard: bool,
75 ) -> Self {
76 Self {
77 constraint_ref,
78 impact_type,
79 source,
80 filter,
81 key_fn,
82 collector,
83 weight_fn,
84 is_hard,
85 source_state: None,
86 groups: HashMap::new(),
87 row_outputs: HashMap::new(),
88 row_retractions: HashMap::new(),
89 rows_by_owner: HashMap::new(),
90 _phantom: PhantomData,
91 }
92 }
93
94 fn compute_score(&self, key: &K, result: &R) -> Sc {
95 let base = (self.weight_fn)(key, result);
96 match self.impact_type {
97 ImpactType::Penalty => -base,
98 ImpactType::Reward => base,
99 }
100 }
101
102 fn ensure_source_state(&mut self, solution: &S) {
103 if self.source_state.is_none() {
104 self.source_state = Some(self.source.build_state(solution));
105 }
106 }
107
108 fn index_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
109 coordinate.for_each_owner(|owner| {
110 self.rows_by_owner
111 .entry(owner)
112 .or_default()
113 .push(coordinate);
114 });
115 }
116
117 fn unindex_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
118 coordinate.for_each_owner(|owner| {
119 let mut remove_bucket = false;
120 if let Some(rows) = self.rows_by_owner.get_mut(&owner) {
121 rows.retain(|candidate| *candidate != coordinate);
122 remove_bucket = rows.is_empty();
123 }
124 if remove_bucket {
125 self.rows_by_owner.remove(&owner);
126 }
127 });
128 }
129
130 fn insert_value(&mut self, key: K, value: V) -> (Sc, CollectorRetraction<Acc, V, R>) {
131 let impact = self.impact_type;
132 let weight_fn = &self.weight_fn;
133 match self.groups.entry(key) {
134 Entry::Occupied(mut entry) => {
135 let old_base = entry
136 .get()
137 .accumulator
138 .with_result(|result| weight_fn(entry.key(), result));
139 let old = match impact {
140 ImpactType::Penalty => -old_base,
141 ImpactType::Reward => old_base,
142 };
143 let group = entry.get_mut();
144 let retraction = group.accumulator.accumulate(value);
145 group.count += 1;
146 let new_base = entry
147 .get()
148 .accumulator
149 .with_result(|result| weight_fn(entry.key(), result));
150 let new_score = match impact {
151 ImpactType::Penalty => -new_base,
152 ImpactType::Reward => new_base,
153 };
154 (new_score - old, retraction)
155 }
156 Entry::Vacant(entry) => {
157 let mut entry = entry.insert_entry(GroupState {
158 accumulator: self.collector.create_accumulator(),
159 count: 0,
160 });
161 let group = entry.get_mut();
162 let retraction = group.accumulator.accumulate(value);
163 group.count += 1;
164 let new_base = entry
165 .get()
166 .accumulator
167 .with_result(|result| weight_fn(entry.key(), result));
168 let score = match impact {
169 ImpactType::Penalty => -new_base,
170 ImpactType::Reward => new_base,
171 };
172 (score, retraction)
173 }
174 }
175 }
176
177 fn retract_value(&mut self, key: K, retraction: CollectorRetraction<Acc, V, R>) -> Sc {
178 let impact = self.impact_type;
179 let weight_fn = &self.weight_fn;
180 let Entry::Occupied(mut entry) = self.groups.entry(key) else {
181 return Sc::zero();
182 };
183 let old_base = entry
184 .get()
185 .accumulator
186 .with_result(|result| weight_fn(entry.key(), result));
187 let old = match impact {
188 ImpactType::Penalty => -old_base,
189 ImpactType::Reward => old_base,
190 };
191 let group = entry.get_mut();
192 group.accumulator.retract(retraction);
193 group.count = group.count.saturating_sub(1);
194 let new_score = if group.count == 0 {
195 entry.remove();
196 Sc::zero()
197 } else {
198 let new_base = entry
199 .get()
200 .accumulator
201 .with_result(|result| weight_fn(entry.key(), result));
202 match impact {
203 ImpactType::Penalty => -new_base,
204 ImpactType::Reward => new_base,
205 }
206 };
207
208 new_score - old
209 }
210
211 fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
212 if self.row_outputs.contains_key(&coordinate) || !self.filter.test(solution, &output) {
213 return Sc::zero();
214 }
215 let key = (self.key_fn)(&output);
216 let value = self.collector.extract(&output);
217 let (delta, retraction) = self.insert_value(key, value);
218 self.row_outputs.insert(coordinate, output);
219 self.row_retractions.insert(coordinate, retraction);
220 self.index_coordinate(coordinate);
221 delta
222 }
223
224 fn retract_row(&mut self, coordinate: ProjectedRowCoordinate) -> Sc {
225 let Some(output) = self.row_outputs.remove(&coordinate) else {
226 return Sc::zero();
227 };
228 self.unindex_coordinate(coordinate);
229 let key = (self.key_fn)(&output);
230 let Some(retraction) = self.row_retractions.remove(&coordinate) else {
231 return Sc::zero();
232 };
233 self.retract_value(key, retraction)
234 }
235
236 fn localized_owners(
237 &self,
238 descriptor_index: usize,
239 entity_index: usize,
240 ) -> Vec<ProjectedRowOwner> {
241 let mut owners = Vec::new();
242 for slot in 0..self.source.source_count() {
243 if self
244 .source
245 .change_source(slot)
246 .assert_localizes(descriptor_index, &self.constraint_ref.name)
247 {
248 owners.push(ProjectedRowOwner {
249 source_slot: slot,
250 entity_index,
251 });
252 }
253 }
254 owners
255 }
256
257 fn coordinates_for_owners(&self, owners: &[ProjectedRowOwner]) -> Vec<ProjectedRowCoordinate> {
258 let mut seen = HashSet::new();
259 let mut coordinates = Vec::new();
260 for owner in owners {
261 let Some(rows) = self.rows_by_owner.get(owner) else {
262 continue;
263 };
264 for &coordinate in rows {
265 if seen.insert(coordinate) {
266 coordinates.push(coordinate);
267 }
268 }
269 }
270 coordinates
271 }
272}
273
274impl<S, Out, K, Src, F, KF, C, V, R, Acc, W, Sc> IncrementalConstraint<S, Sc>
275 for ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, V, R, Acc, W, Sc>
276where
277 S: Send + Sync + 'static,
278 Out: Send + Sync + 'static,
279 K: Eq + Hash + Send + Sync + 'static,
280 Src: ProjectedSource<S, Out>,
281 F: UniFilter<S, Out>,
282 KF: Fn(&Out) -> K + Send + Sync,
283 C: for<'i> Collector<&'i Out, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
284 V: Send + Sync,
285 R: Send + Sync,
286 Acc: Accumulator<V, R> + Send + Sync,
287 W: Fn(&K, &R) -> Sc + Send + Sync,
288 Sc: Score + 'static,
289{
290 fn evaluate(&self, solution: &S) -> Sc {
291 let state = self.source.build_state(solution);
292 let mut groups: HashMap<K, Acc> = HashMap::new();
293 self.source.collect_all(solution, &state, |_, output| {
294 if !self.filter.test(solution, &output) {
295 return;
296 }
297 let key = (self.key_fn)(&output);
298 let value = self.collector.extract(&output);
299 groups
300 .entry(key)
301 .or_insert_with(|| self.collector.create_accumulator())
302 .accumulate(value);
303 });
304 groups.iter().fold(Sc::zero(), |total, (key, acc)| {
305 total + acc.with_result(|result| self.compute_score(key, result))
306 })
307 }
308
309 fn match_count(&self, solution: &S) -> usize {
310 let state = self.source.build_state(solution);
311 let mut keys = HashMap::<K, ()>::new();
312 self.source.collect_all(solution, &state, |_, output| {
313 if self.filter.test(solution, &output) {
314 keys.insert((self.key_fn)(&output), ());
315 }
316 });
317 keys.len()
318 }
319
320 fn initialize(&mut self, solution: &S) -> Sc {
321 self.reset();
322 let state = self.source.build_state(solution);
323 let mut total = Sc::zero();
324 let mut rows = Vec::new();
325 self.source
326 .collect_all(solution, &state, |coordinate, output| {
327 rows.push((coordinate, output));
328 });
329 self.source_state = Some(state);
330 for (coordinate, output) in rows {
331 total = total + self.insert_row(solution, coordinate, output);
332 }
333 total
334 }
335
336 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
337 let owners = self.localized_owners(descriptor_index, entity_index);
338 self.ensure_source_state(solution);
339 {
340 let state = self.source_state.as_mut().expect("projected source state");
341 for owner in &owners {
342 self.source.insert_entity_state(
343 solution,
344 state,
345 owner.source_slot,
346 owner.entity_index,
347 );
348 }
349 }
350 let mut rows = Vec::new();
351 let state = self.source_state.as_ref().expect("projected source state");
352 for owner in &owners {
353 self.source.collect_entity(
354 solution,
355 state,
356 owner.source_slot,
357 owner.entity_index,
358 |coordinate, output| rows.push((coordinate, output)),
359 );
360 }
361 let mut total = Sc::zero();
362 for (coordinate, output) in rows {
363 total = total + self.insert_row(solution, coordinate, output);
364 }
365 total
366 }
367
368 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
369 let owners = self.localized_owners(descriptor_index, entity_index);
370 let mut total = Sc::zero();
371 for coordinate in self.coordinates_for_owners(&owners) {
372 total = total + self.retract_row(coordinate);
373 }
374 if let Some(state) = self.source_state.as_mut() {
375 for owner in &owners {
376 self.source.retract_entity_state(
377 solution,
378 state,
379 owner.source_slot,
380 owner.entity_index,
381 );
382 }
383 }
384 total
385 }
386
387 fn reset(&mut self) {
388 self.source_state = None;
389 self.groups.clear();
390 self.row_outputs.clear();
391 self.row_retractions.clear();
392 self.rows_by_owner.clear();
393 }
394
395 fn name(&self) -> &str {
396 &self.constraint_ref.name
397 }
398
399 fn constraint_ref(&self) -> &ConstraintRef {
400 &self.constraint_ref
401 }
402
403 fn is_hard(&self) -> bool {
404 self.is_hard
405 }
406}