1use std::collections::{hash_map::Entry, HashMap, HashSet};
2use std::hash::Hash;
3use std::marker::PhantomData;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::collector::{Accumulator, UniCollector};
10use crate::stream::filter::UniFilter;
11use crate::stream::{ProjectedRowCoordinate, ProjectedRowOwner, ProjectedSource};
12
13struct GroupState<Acc> {
14 accumulator: Acc,
15 count: usize,
16}
17
18pub struct ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
19where
20 Src: ProjectedSource<S, Out>,
21 C: UniCollector<Out>,
22 Sc: Score,
23{
24 constraint_ref: ConstraintRef,
25 impact_type: ImpactType,
26 source: Src,
27 filter: F,
28 key_fn: KF,
29 collector: C,
30 weight_fn: W,
31 is_hard: bool,
32 source_state: Option<Src::State>,
33 groups: HashMap<K, GroupState<C::Accumulator>>,
34 row_outputs: HashMap<ProjectedRowCoordinate, Out>,
35 rows_by_owner: HashMap<ProjectedRowOwner, Vec<ProjectedRowCoordinate>>,
36 _phantom: PhantomData<(fn() -> S, fn() -> Out, fn() -> Sc)>,
37}
38
39impl<S, Out, K, Src, F, KF, C, W, Sc> ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
40where
41 S: Send + Sync + 'static,
42 Out: Send + Sync + 'static,
43 K: Eq + Hash + Send + Sync + 'static,
44 Src: ProjectedSource<S, Out>,
45 F: UniFilter<S, Out>,
46 KF: Fn(&Out) -> K + Send + Sync,
47 C: UniCollector<Out> + Send + Sync + 'static,
48 C::Accumulator: Send + Sync,
49 C::Result: Send + Sync,
50 C::Value: Send + Sync,
51 W: Fn(&K, &C::Result) -> Sc + Send + Sync,
52 Sc: Score + 'static,
53{
54 #[allow(clippy::too_many_arguments)]
55 pub fn new(
56 constraint_ref: ConstraintRef,
57 impact_type: ImpactType,
58 source: Src,
59 filter: F,
60 key_fn: KF,
61 collector: C,
62 weight_fn: W,
63 is_hard: bool,
64 ) -> Self {
65 Self {
66 constraint_ref,
67 impact_type,
68 source,
69 filter,
70 key_fn,
71 collector,
72 weight_fn,
73 is_hard,
74 source_state: None,
75 groups: HashMap::new(),
76 row_outputs: HashMap::new(),
77 rows_by_owner: HashMap::new(),
78 _phantom: PhantomData,
79 }
80 }
81
82 fn compute_score(&self, key: &K, result: &C::Result) -> Sc {
83 let base = (self.weight_fn)(key, result);
84 match self.impact_type {
85 ImpactType::Penalty => -base,
86 ImpactType::Reward => base,
87 }
88 }
89
90 fn ensure_source_state(&mut self, solution: &S) {
91 if self.source_state.is_none() {
92 self.source_state = Some(self.source.build_state(solution));
93 }
94 }
95
96 fn index_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
97 coordinate.for_each_owner(|owner| {
98 self.rows_by_owner
99 .entry(owner)
100 .or_default()
101 .push(coordinate);
102 });
103 }
104
105 fn unindex_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
106 coordinate.for_each_owner(|owner| {
107 let mut remove_bucket = false;
108 if let Some(rows) = self.rows_by_owner.get_mut(&owner) {
109 rows.retain(|candidate| *candidate != coordinate);
110 remove_bucket = rows.is_empty();
111 }
112 if remove_bucket {
113 self.rows_by_owner.remove(&owner);
114 }
115 });
116 }
117
118 fn insert_value(&mut self, key: K, value: &C::Value) -> Sc {
119 let impact = self.impact_type;
120 let weight_fn = &self.weight_fn;
121 match self.groups.entry(key) {
122 Entry::Occupied(mut entry) => {
123 let old_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
124 let old = match impact {
125 ImpactType::Penalty => -old_base,
126 ImpactType::Reward => old_base,
127 };
128 let group = entry.get_mut();
129 group.accumulator.accumulate(value);
130 group.count += 1;
131 let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
132 let new_score = match impact {
133 ImpactType::Penalty => -new_base,
134 ImpactType::Reward => new_base,
135 };
136 new_score - old
137 }
138 Entry::Vacant(entry) => {
139 let mut entry = entry.insert_entry(GroupState {
140 accumulator: self.collector.create_accumulator(),
141 count: 0,
142 });
143 let group = entry.get_mut();
144 group.accumulator.accumulate(value);
145 group.count += 1;
146 let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
147 match impact {
148 ImpactType::Penalty => -new_base,
149 ImpactType::Reward => new_base,
150 }
151 }
152 }
153 }
154
155 fn retract_value(&mut self, key: K, value: &C::Value) -> Sc {
156 let impact = self.impact_type;
157 let weight_fn = &self.weight_fn;
158 let Entry::Occupied(mut entry) = self.groups.entry(key) else {
159 return Sc::zero();
160 };
161 let old_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
162 let old = match impact {
163 ImpactType::Penalty => -old_base,
164 ImpactType::Reward => old_base,
165 };
166 let group = entry.get_mut();
167 group.accumulator.retract(value);
168 group.count = group.count.saturating_sub(1);
169 let new_score = if group.count == 0 {
170 entry.remove();
171 Sc::zero()
172 } else {
173 let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
174 match impact {
175 ImpactType::Penalty => -new_base,
176 ImpactType::Reward => new_base,
177 }
178 };
179
180 new_score - old
181 }
182
183 fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
184 if self.row_outputs.contains_key(&coordinate) || !self.filter.test(solution, &output) {
185 return Sc::zero();
186 }
187 let key = (self.key_fn)(&output);
188 let value = self.collector.extract(&output);
189 let delta = self.insert_value(key, &value);
190 self.row_outputs.insert(coordinate, output);
191 self.index_coordinate(coordinate);
192 delta
193 }
194
195 fn retract_row(&mut self, coordinate: ProjectedRowCoordinate) -> Sc {
196 let Some(output) = self.row_outputs.remove(&coordinate) else {
197 return Sc::zero();
198 };
199 self.unindex_coordinate(coordinate);
200 let key = (self.key_fn)(&output);
201 let value = self.collector.extract(&output);
202 self.retract_value(key, &value)
203 }
204
205 fn localized_owners(
206 &self,
207 descriptor_index: usize,
208 entity_index: usize,
209 ) -> Vec<ProjectedRowOwner> {
210 let mut owners = Vec::new();
211 for slot in 0..self.source.source_count() {
212 if self
213 .source
214 .change_source(slot)
215 .assert_localizes(descriptor_index, &self.constraint_ref.name)
216 {
217 owners.push(ProjectedRowOwner {
218 source_slot: slot,
219 entity_index,
220 });
221 }
222 }
223 owners
224 }
225
226 fn coordinates_for_owners(&self, owners: &[ProjectedRowOwner]) -> Vec<ProjectedRowCoordinate> {
227 let mut seen = HashSet::new();
228 let mut coordinates = Vec::new();
229 for owner in owners {
230 let Some(rows) = self.rows_by_owner.get(owner) else {
231 continue;
232 };
233 for &coordinate in rows {
234 if seen.insert(coordinate) {
235 coordinates.push(coordinate);
236 }
237 }
238 }
239 coordinates
240 }
241}
242
243impl<S, Out, K, Src, F, KF, C, W, Sc> IncrementalConstraint<S, Sc>
244 for ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
245where
246 S: Send + Sync + 'static,
247 Out: Send + Sync + 'static,
248 K: Eq + Hash + Send + Sync + 'static,
249 Src: ProjectedSource<S, Out>,
250 F: UniFilter<S, Out>,
251 KF: Fn(&Out) -> K + Send + Sync,
252 C: UniCollector<Out> + Send + Sync + 'static,
253 C::Accumulator: Send + Sync,
254 C::Result: Send + Sync,
255 C::Value: Send + Sync,
256 W: Fn(&K, &C::Result) -> Sc + Send + Sync,
257 Sc: Score + 'static,
258{
259 fn evaluate(&self, solution: &S) -> Sc {
260 let state = self.source.build_state(solution);
261 let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
262 self.source.collect_all(solution, &state, |_, output| {
263 if !self.filter.test(solution, &output) {
264 return;
265 }
266 let key = (self.key_fn)(&output);
267 let value = self.collector.extract(&output);
268 groups
269 .entry(key)
270 .or_insert_with(|| self.collector.create_accumulator())
271 .accumulate(&value);
272 });
273 groups.iter().fold(Sc::zero(), |total, (key, acc)| {
274 total + self.compute_score(key, &acc.finish())
275 })
276 }
277
278 fn match_count(&self, solution: &S) -> usize {
279 let state = self.source.build_state(solution);
280 let mut keys = HashMap::<K, ()>::new();
281 self.source.collect_all(solution, &state, |_, output| {
282 if self.filter.test(solution, &output) {
283 keys.insert((self.key_fn)(&output), ());
284 }
285 });
286 keys.len()
287 }
288
289 fn initialize(&mut self, solution: &S) -> Sc {
290 self.reset();
291 let state = self.source.build_state(solution);
292 let mut total = Sc::zero();
293 let mut rows = Vec::new();
294 self.source
295 .collect_all(solution, &state, |coordinate, output| {
296 rows.push((coordinate, output));
297 });
298 self.source_state = Some(state);
299 for (coordinate, output) in rows {
300 total = total + self.insert_row(solution, coordinate, output);
301 }
302 total
303 }
304
305 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
306 let owners = self.localized_owners(descriptor_index, entity_index);
307 self.ensure_source_state(solution);
308 {
309 let state = self.source_state.as_mut().expect("projected source state");
310 for owner in &owners {
311 self.source.insert_entity_state(
312 solution,
313 state,
314 owner.source_slot,
315 owner.entity_index,
316 );
317 }
318 }
319 let mut rows = Vec::new();
320 let state = self.source_state.as_ref().expect("projected source state");
321 for owner in &owners {
322 self.source.collect_entity(
323 solution,
324 state,
325 owner.source_slot,
326 owner.entity_index,
327 |coordinate, output| rows.push((coordinate, output)),
328 );
329 }
330 let mut total = Sc::zero();
331 for (coordinate, output) in rows {
332 total = total + self.insert_row(solution, coordinate, output);
333 }
334 total
335 }
336
337 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
338 let owners = self.localized_owners(descriptor_index, entity_index);
339 let mut total = Sc::zero();
340 for coordinate in self.coordinates_for_owners(&owners) {
341 total = total + self.retract_row(coordinate);
342 }
343 if let Some(state) = self.source_state.as_mut() {
344 for owner in &owners {
345 self.source.retract_entity_state(
346 solution,
347 state,
348 owner.source_slot,
349 owner.entity_index,
350 );
351 }
352 }
353 total
354 }
355
356 fn reset(&mut self) {
357 self.source_state = None;
358 self.groups.clear();
359 self.row_outputs.clear();
360 self.rows_by_owner.clear();
361 }
362
363 fn name(&self) -> &str {
364 &self.constraint_ref.name
365 }
366
367 fn constraint_ref(&self) -> &ConstraintRef {
368 &self.constraint_ref
369 }
370
371 fn is_hard(&self) -> bool {
372 self.is_hard
373 }
374}