1use std::collections::HashMap;
9use std::hash::Hash;
10use std::marker::PhantomData;
11
12use solverforge_core::score::Score;
13use solverforge_core::{ConstraintRef, ImpactType};
14
15use crate::api::constraint_set::IncrementalConstraint;
16use crate::stream::collector::{Accumulator, Collector};
17use crate::stream::filter::UniFilter;
18
19struct GroupState<Acc> {
20 accumulator: Acc,
21 count: usize,
22}
23
24type CollectorRetraction<Acc, V, R> = <Acc as Accumulator<V, R>>::Retraction;
25
26pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
91where
92 Acc: Accumulator<V, R>,
93 Sc: Score,
94{
95 constraint_ref: ConstraintRef,
96 impact_type: ImpactType,
97 extractor: E,
98 filter: Fi,
99 key_fn: KF,
100 collector: C,
101 weight_fn: W,
102 is_hard: bool,
103 change_source: crate::stream::collection_extract::ChangeSource,
104 groups: HashMap<K, GroupState<Acc>>,
106 entity_groups: HashMap<usize, K>,
108 entity_retractions: HashMap<usize, CollectorRetraction<Acc, V, R>>,
110 _phantom: PhantomData<(
111 fn() -> S,
112 fn() -> A,
113 fn() -> V,
114 fn() -> R,
115 fn() -> Acc,
116 fn() -> Sc,
117 )>,
118}
119
120impl<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
121 GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
122where
123 S: Send + Sync + 'static,
124 A: Clone + Send + Sync + 'static,
125 K: Clone + Eq + Hash + Send + Sync + 'static,
126 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
127 Fi: UniFilter<S, A>,
128 KF: Fn(&A) -> K + Send + Sync,
129 C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
130 V: Send + Sync + 'static,
131 R: Send + Sync + 'static,
132 Acc: Accumulator<V, R> + Send + Sync + 'static,
133 W: Fn(&K, &R) -> Sc + Send + Sync,
134 Sc: Score + 'static,
135{
136 #[allow(clippy::too_many_arguments)]
150 pub fn new(
151 constraint_ref: ConstraintRef,
152 impact_type: ImpactType,
153 extractor: E,
154 filter: Fi,
155 key_fn: KF,
156 collector: C,
157 weight_fn: W,
158 is_hard: bool,
159 ) -> Self {
160 let change_source = extractor.change_source();
161 Self {
162 constraint_ref,
163 impact_type,
164 extractor,
165 filter,
166 key_fn,
167 collector,
168 weight_fn,
169 is_hard,
170 change_source,
171 groups: HashMap::new(),
172 entity_groups: HashMap::new(),
173 entity_retractions: HashMap::new(),
174 _phantom: PhantomData,
175 }
176 }
177
178 fn compute_score(&self, key: &K, result: &R) -> Sc {
180 let base = (self.weight_fn)(key, result);
181 match self.impact_type {
182 ImpactType::Penalty => -base,
183 ImpactType::Reward => base,
184 }
185 }
186}
187
188impl<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc> IncrementalConstraint<S, Sc>
189 for GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
190where
191 S: Send + Sync + 'static,
192 A: Clone + Send + Sync + 'static,
193 K: Clone + Eq + Hash + Send + Sync + 'static,
194 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
195 Fi: UniFilter<S, A>,
196 KF: Fn(&A) -> K + Send + Sync,
197 C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
198 V: Send + Sync + 'static,
199 R: Send + Sync + 'static,
200 Acc: Accumulator<V, R> + Send + Sync + 'static,
201 W: Fn(&K, &R) -> Sc + Send + Sync,
202 Sc: Score + 'static,
203{
204 fn evaluate(&self, solution: &S) -> Sc {
205 let entities = self.extractor.extract(solution);
206
207 let mut groups: HashMap<K, Acc> = HashMap::new();
209
210 for entity in entities {
211 if !self.filter.test(solution, entity) {
212 continue;
213 }
214 let key = (self.key_fn)(entity);
215 let value = self.collector.extract(entity);
216 let acc = groups
217 .entry(key)
218 .or_insert_with(|| self.collector.create_accumulator());
219 acc.accumulate(value);
220 }
221
222 let mut total = Sc::zero();
224 for (key, acc) in &groups {
225 total = total + acc.with_result(|result| self.compute_score(key, result));
226 }
227
228 total
229 }
230
231 fn match_count(&self, solution: &S) -> usize {
232 let entities = self.extractor.extract(solution);
233
234 let mut groups: HashMap<K, ()> = HashMap::new();
236 for entity in entities {
237 if !self.filter.test(solution, entity) {
238 continue;
239 }
240 let key = (self.key_fn)(entity);
241 groups.insert(key, ());
242 }
243
244 groups.len()
245 }
246
247 fn initialize(&mut self, solution: &S) -> Sc {
248 self.reset();
249
250 let entities = self.extractor.extract(solution);
251 let mut total = Sc::zero();
252
253 for (idx, entity) in entities.iter().enumerate() {
254 if !self.filter.test(solution, entity) {
255 continue;
256 }
257 total = total + self.insert_entity(entities, idx, entity);
258 }
259
260 total
261 }
262
263 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
264 if !self
265 .change_source
266 .assert_localizes(descriptor_index, &self.constraint_ref.name)
267 {
268 return Sc::zero();
269 }
270 let entities = self.extractor.extract(solution);
271 if entity_index >= entities.len() {
272 return Sc::zero();
273 }
274
275 let entity = &entities[entity_index];
276 if !self.filter.test(solution, entity) {
277 return Sc::zero();
278 }
279 self.insert_entity(entities, entity_index, entity)
280 }
281
282 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
283 if !self
284 .change_source
285 .assert_localizes(descriptor_index, &self.constraint_ref.name)
286 {
287 return Sc::zero();
288 }
289 let entities = self.extractor.extract(solution);
290 self.retract_entity(entities, entity_index)
291 }
292
293 fn reset(&mut self) {
294 self.groups.clear();
295 self.entity_groups.clear();
296 self.entity_retractions.clear();
297 }
298
299 fn name(&self) -> &str {
300 &self.constraint_ref.name
301 }
302
303 fn is_hard(&self) -> bool {
304 self.is_hard
305 }
306
307 fn constraint_ref(&self) -> &ConstraintRef {
308 &self.constraint_ref
309 }
310}
311
312impl<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
313 GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
314where
315 S: Send + Sync + 'static,
316 A: Clone + Send + Sync + 'static,
317 K: Clone + Eq + Hash + Send + Sync + 'static,
318 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
319 Fi: UniFilter<S, A>,
320 KF: Fn(&A) -> K + Send + Sync,
321 C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
322 V: Send + Sync + 'static,
323 R: Send + Sync + 'static,
324 Acc: Accumulator<V, R> + Send + Sync + 'static,
325 W: Fn(&K, &R) -> Sc + Send + Sync,
326 Sc: Score + 'static,
327{
328 fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
329 let key = (self.key_fn)(entity);
330 let entity_key = key.clone();
331 let value = self.collector.extract(entity);
332 let impact = self.impact_type;
333
334 let weight_fn = &self.weight_fn;
335 let (old, new_score) = match self.groups.entry(key) {
336 std::collections::hash_map::Entry::Occupied(mut entry) => {
337 let old_base = entry
338 .get()
339 .accumulator
340 .with_result(|result| weight_fn(entry.key(), result));
341 let old = match impact {
342 ImpactType::Penalty => -old_base,
343 ImpactType::Reward => old_base,
344 };
345 let group = entry.get_mut();
346 let retraction = group.accumulator.accumulate(value);
347 group.count += 1;
348 let new_base = entry
349 .get()
350 .accumulator
351 .with_result(|result| weight_fn(entry.key(), result));
352 let new_score = match impact {
353 ImpactType::Penalty => -new_base,
354 ImpactType::Reward => new_base,
355 };
356 self.entity_retractions.insert(entity_index, retraction);
357 (old, new_score)
358 }
359 std::collections::hash_map::Entry::Vacant(entry) => {
360 let mut entry = entry.insert_entry(GroupState {
361 accumulator: self.collector.create_accumulator(),
362 count: 0,
363 });
364 let group = entry.get_mut();
365 let retraction = group.accumulator.accumulate(value);
366 group.count += 1;
367 let new_base = entry
368 .get()
369 .accumulator
370 .with_result(|result| weight_fn(entry.key(), result));
371 let new_score = match impact {
372 ImpactType::Penalty => -new_base,
373 ImpactType::Reward => new_base,
374 };
375 self.entity_retractions.insert(entity_index, retraction);
376 (Sc::zero(), new_score)
377 }
378 };
379
380 self.entity_groups.insert(entity_index, entity_key);
382
383 new_score - old
385 }
386
387 fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
388 let Some(key) = self.entity_groups.remove(&entity_index) else {
390 return Sc::zero();
391 };
392
393 let Some(retraction) = self.entity_retractions.remove(&entity_index) else {
395 return Sc::zero();
396 };
397 let impact = self.impact_type;
398
399 let weight_fn = &self.weight_fn;
400 let std::collections::hash_map::Entry::Occupied(mut entry) = self.groups.entry(key) else {
401 return Sc::zero();
402 };
403
404 let old_base = entry
405 .get()
406 .accumulator
407 .with_result(|result| weight_fn(entry.key(), result));
408 let old = match impact {
409 ImpactType::Penalty => -old_base,
410 ImpactType::Reward => old_base,
411 };
412
413 let group = entry.get_mut();
414 group.accumulator.retract(retraction);
415 group.count = group.count.saturating_sub(1);
416 let is_empty = group.count == 0;
417 let new_score = if is_empty {
418 entry.remove();
419 Sc::zero()
420 } else {
421 let new_base = entry
422 .get()
423 .accumulator
424 .with_result(|result| weight_fn(entry.key(), result));
425 match impact {
426 ImpactType::Penalty => -new_base,
427 ImpactType::Reward => new_base,
428 }
429 };
430
431 new_score - old
433 }
434}
435
436impl<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc> std::fmt::Debug
437 for GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
438where
439 Acc: Accumulator<V, R>,
440 Sc: Score,
441{
442 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443 f.debug_struct("GroupedUniConstraint")
444 .field("name", &self.constraint_ref.name)
445 .field("impact_type", &self.impact_type)
446 .field("groups", &self.groups.len())
447 .finish()
448 }
449}