1use std::collections::{hash_map::Entry, HashMap, HashSet};
2use std::hash::Hash;
3use std::marker::PhantomData;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::collector::{Accumulator, UniCollector};
10use crate::stream::filter::UniFilter;
11use crate::stream::{ProjectedRowCoordinate, ProjectedRowOwner, ProjectedSource};
12
13struct GroupState<Acc> {
14 accumulator: Acc,
15 count: usize,
16}
17
18type CollectorRetraction<C, Out> = <<C as UniCollector<Out>>::Accumulator as Accumulator<
19 <C as UniCollector<Out>>::Value,
20 <C as UniCollector<Out>>::Result,
21>>::Retraction;
22
23pub struct ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
24where
25 Src: ProjectedSource<S, Out>,
26 C: UniCollector<Out>,
27 Sc: Score,
28{
29 constraint_ref: ConstraintRef,
30 impact_type: ImpactType,
31 source: Src,
32 filter: F,
33 key_fn: KF,
34 collector: C,
35 weight_fn: W,
36 is_hard: bool,
37 source_state: Option<Src::State>,
38 groups: HashMap<K, GroupState<C::Accumulator>>,
39 row_outputs: HashMap<ProjectedRowCoordinate, Out>,
40 row_retractions: HashMap<ProjectedRowCoordinate, CollectorRetraction<C, Out>>,
41 rows_by_owner: HashMap<ProjectedRowOwner, Vec<ProjectedRowCoordinate>>,
42 _phantom: PhantomData<(fn() -> S, fn() -> Out, fn() -> Sc)>,
43}
44
45impl<S, Out, K, Src, F, KF, C, W, Sc> ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
46where
47 S: Send + Sync + 'static,
48 Out: Send + Sync + 'static,
49 K: Eq + Hash + Send + Sync + 'static,
50 Src: ProjectedSource<S, Out>,
51 F: UniFilter<S, Out>,
52 KF: Fn(&Out) -> K + Send + Sync,
53 C: UniCollector<Out> + Send + Sync + 'static,
54 C::Accumulator: Send + Sync,
55 C::Result: Send + Sync,
56 C::Value: Send + Sync,
57 W: Fn(&K, &C::Result) -> Sc + Send + Sync,
58 Sc: Score + 'static,
59{
60 #[allow(clippy::too_many_arguments)]
61 pub fn new(
62 constraint_ref: ConstraintRef,
63 impact_type: ImpactType,
64 source: Src,
65 filter: F,
66 key_fn: KF,
67 collector: C,
68 weight_fn: W,
69 is_hard: bool,
70 ) -> Self {
71 Self {
72 constraint_ref,
73 impact_type,
74 source,
75 filter,
76 key_fn,
77 collector,
78 weight_fn,
79 is_hard,
80 source_state: None,
81 groups: HashMap::new(),
82 row_outputs: HashMap::new(),
83 row_retractions: HashMap::new(),
84 rows_by_owner: HashMap::new(),
85 _phantom: PhantomData,
86 }
87 }
88
89 fn compute_score(&self, key: &K, result: &C::Result) -> Sc {
90 let base = (self.weight_fn)(key, result);
91 match self.impact_type {
92 ImpactType::Penalty => -base,
93 ImpactType::Reward => base,
94 }
95 }
96
97 fn ensure_source_state(&mut self, solution: &S) {
98 if self.source_state.is_none() {
99 self.source_state = Some(self.source.build_state(solution));
100 }
101 }
102
103 fn index_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
104 coordinate.for_each_owner(|owner| {
105 self.rows_by_owner
106 .entry(owner)
107 .or_default()
108 .push(coordinate);
109 });
110 }
111
112 fn unindex_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
113 coordinate.for_each_owner(|owner| {
114 let mut remove_bucket = false;
115 if let Some(rows) = self.rows_by_owner.get_mut(&owner) {
116 rows.retain(|candidate| *candidate != coordinate);
117 remove_bucket = rows.is_empty();
118 }
119 if remove_bucket {
120 self.rows_by_owner.remove(&owner);
121 }
122 });
123 }
124
125 fn insert_value(&mut self, key: K, value: C::Value) -> (Sc, CollectorRetraction<C, Out>) {
126 let impact = self.impact_type;
127 let weight_fn = &self.weight_fn;
128 match self.groups.entry(key) {
129 Entry::Occupied(mut entry) => {
130 let old_base = entry
131 .get()
132 .accumulator
133 .with_result(|result| weight_fn(entry.key(), result));
134 let old = match impact {
135 ImpactType::Penalty => -old_base,
136 ImpactType::Reward => old_base,
137 };
138 let group = entry.get_mut();
139 let retraction = group.accumulator.accumulate(value);
140 group.count += 1;
141 let new_base = entry
142 .get()
143 .accumulator
144 .with_result(|result| weight_fn(entry.key(), result));
145 let new_score = match impact {
146 ImpactType::Penalty => -new_base,
147 ImpactType::Reward => new_base,
148 };
149 (new_score - old, retraction)
150 }
151 Entry::Vacant(entry) => {
152 let mut entry = entry.insert_entry(GroupState {
153 accumulator: self.collector.create_accumulator(),
154 count: 0,
155 });
156 let group = entry.get_mut();
157 let retraction = group.accumulator.accumulate(value);
158 group.count += 1;
159 let new_base = entry
160 .get()
161 .accumulator
162 .with_result(|result| weight_fn(entry.key(), result));
163 let score = match impact {
164 ImpactType::Penalty => -new_base,
165 ImpactType::Reward => new_base,
166 };
167 (score, retraction)
168 }
169 }
170 }
171
172 fn retract_value(&mut self, key: K, retraction: CollectorRetraction<C, Out>) -> Sc {
173 let impact = self.impact_type;
174 let weight_fn = &self.weight_fn;
175 let Entry::Occupied(mut entry) = self.groups.entry(key) else {
176 return Sc::zero();
177 };
178 let old_base = entry
179 .get()
180 .accumulator
181 .with_result(|result| weight_fn(entry.key(), result));
182 let old = match impact {
183 ImpactType::Penalty => -old_base,
184 ImpactType::Reward => old_base,
185 };
186 let group = entry.get_mut();
187 group.accumulator.retract(retraction);
188 group.count = group.count.saturating_sub(1);
189 let new_score = if group.count == 0 {
190 entry.remove();
191 Sc::zero()
192 } else {
193 let new_base = entry
194 .get()
195 .accumulator
196 .with_result(|result| weight_fn(entry.key(), result));
197 match impact {
198 ImpactType::Penalty => -new_base,
199 ImpactType::Reward => new_base,
200 }
201 };
202
203 new_score - old
204 }
205
206 fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
207 if self.row_outputs.contains_key(&coordinate) || !self.filter.test(solution, &output) {
208 return Sc::zero();
209 }
210 let key = (self.key_fn)(&output);
211 let value = self.collector.extract(&output);
212 let (delta, retraction) = self.insert_value(key, value);
213 self.row_outputs.insert(coordinate, output);
214 self.row_retractions.insert(coordinate, retraction);
215 self.index_coordinate(coordinate);
216 delta
217 }
218
219 fn retract_row(&mut self, coordinate: ProjectedRowCoordinate) -> Sc {
220 let Some(output) = self.row_outputs.remove(&coordinate) else {
221 return Sc::zero();
222 };
223 self.unindex_coordinate(coordinate);
224 let key = (self.key_fn)(&output);
225 let Some(retraction) = self.row_retractions.remove(&coordinate) else {
226 return Sc::zero();
227 };
228 self.retract_value(key, retraction)
229 }
230
231 fn localized_owners(
232 &self,
233 descriptor_index: usize,
234 entity_index: usize,
235 ) -> Vec<ProjectedRowOwner> {
236 let mut owners = Vec::new();
237 for slot in 0..self.source.source_count() {
238 if self
239 .source
240 .change_source(slot)
241 .assert_localizes(descriptor_index, &self.constraint_ref.name)
242 {
243 owners.push(ProjectedRowOwner {
244 source_slot: slot,
245 entity_index,
246 });
247 }
248 }
249 owners
250 }
251
252 fn coordinates_for_owners(&self, owners: &[ProjectedRowOwner]) -> Vec<ProjectedRowCoordinate> {
253 let mut seen = HashSet::new();
254 let mut coordinates = Vec::new();
255 for owner in owners {
256 let Some(rows) = self.rows_by_owner.get(owner) else {
257 continue;
258 };
259 for &coordinate in rows {
260 if seen.insert(coordinate) {
261 coordinates.push(coordinate);
262 }
263 }
264 }
265 coordinates
266 }
267}
268
269impl<S, Out, K, Src, F, KF, C, W, Sc> IncrementalConstraint<S, Sc>
270 for ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
271where
272 S: Send + Sync + 'static,
273 Out: Send + Sync + 'static,
274 K: Eq + Hash + Send + Sync + 'static,
275 Src: ProjectedSource<S, Out>,
276 F: UniFilter<S, Out>,
277 KF: Fn(&Out) -> K + Send + Sync,
278 C: UniCollector<Out> + Send + Sync + 'static,
279 C::Accumulator: Send + Sync,
280 C::Result: Send + Sync,
281 C::Value: Send + Sync,
282 W: Fn(&K, &C::Result) -> Sc + Send + Sync,
283 Sc: Score + 'static,
284{
285 fn evaluate(&self, solution: &S) -> Sc {
286 let state = self.source.build_state(solution);
287 let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
288 self.source.collect_all(solution, &state, |_, output| {
289 if !self.filter.test(solution, &output) {
290 return;
291 }
292 let key = (self.key_fn)(&output);
293 let value = self.collector.extract(&output);
294 groups
295 .entry(key)
296 .or_insert_with(|| self.collector.create_accumulator())
297 .accumulate(value);
298 });
299 groups.iter().fold(Sc::zero(), |total, (key, acc)| {
300 total + acc.with_result(|result| self.compute_score(key, result))
301 })
302 }
303
304 fn match_count(&self, solution: &S) -> usize {
305 let state = self.source.build_state(solution);
306 let mut keys = HashMap::<K, ()>::new();
307 self.source.collect_all(solution, &state, |_, output| {
308 if self.filter.test(solution, &output) {
309 keys.insert((self.key_fn)(&output), ());
310 }
311 });
312 keys.len()
313 }
314
315 fn initialize(&mut self, solution: &S) -> Sc {
316 self.reset();
317 let state = self.source.build_state(solution);
318 let mut total = Sc::zero();
319 let mut rows = Vec::new();
320 self.source
321 .collect_all(solution, &state, |coordinate, output| {
322 rows.push((coordinate, output));
323 });
324 self.source_state = Some(state);
325 for (coordinate, output) in rows {
326 total = total + self.insert_row(solution, coordinate, output);
327 }
328 total
329 }
330
331 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
332 let owners = self.localized_owners(descriptor_index, entity_index);
333 self.ensure_source_state(solution);
334 {
335 let state = self.source_state.as_mut().expect("projected source state");
336 for owner in &owners {
337 self.source.insert_entity_state(
338 solution,
339 state,
340 owner.source_slot,
341 owner.entity_index,
342 );
343 }
344 }
345 let mut rows = Vec::new();
346 let state = self.source_state.as_ref().expect("projected source state");
347 for owner in &owners {
348 self.source.collect_entity(
349 solution,
350 state,
351 owner.source_slot,
352 owner.entity_index,
353 |coordinate, output| rows.push((coordinate, output)),
354 );
355 }
356 let mut total = Sc::zero();
357 for (coordinate, output) in rows {
358 total = total + self.insert_row(solution, coordinate, output);
359 }
360 total
361 }
362
363 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
364 let owners = self.localized_owners(descriptor_index, entity_index);
365 let mut total = Sc::zero();
366 for coordinate in self.coordinates_for_owners(&owners) {
367 total = total + self.retract_row(coordinate);
368 }
369 if let Some(state) = self.source_state.as_mut() {
370 for owner in &owners {
371 self.source.retract_entity_state(
372 solution,
373 state,
374 owner.source_slot,
375 owner.entity_index,
376 );
377 }
378 }
379 total
380 }
381
382 fn reset(&mut self) {
383 self.source_state = None;
384 self.groups.clear();
385 self.row_outputs.clear();
386 self.row_retractions.clear();
387 self.rows_by_owner.clear();
388 }
389
390 fn name(&self) -> &str {
391 &self.constraint_ref.name
392 }
393
394 fn constraint_ref(&self) -> &ConstraintRef {
395 &self.constraint_ref
396 }
397
398 fn is_hard(&self) -> bool {
399 self.is_hard
400 }
401}