1use std::collections::HashMap;
9use std::hash::Hash;
10use std::marker::PhantomData;
11
12use solverforge_core::score::Score;
13use solverforge_core::{ConstraintRef, ImpactType};
14
15use crate::api::constraint_set::IncrementalConstraint;
16use crate::stream::collector::{Accumulator, UniCollector};
17use crate::stream::filter::UniFilter;
18
19struct GroupState<Acc> {
20 accumulator: Acc,
21 count: usize,
22}
23
24type CollectorRetraction<C, A> = <<C as UniCollector<A>>::Accumulator as Accumulator<
25 <C as UniCollector<A>>::Value,
26 <C as UniCollector<A>>::Result,
27>>::Retraction;
28
29pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
94where
95 C: UniCollector<A>,
96 Sc: Score,
97{
98 constraint_ref: ConstraintRef,
99 impact_type: ImpactType,
100 extractor: E,
101 filter: Fi,
102 key_fn: KF,
103 collector: C,
104 weight_fn: W,
105 is_hard: bool,
106 change_source: crate::stream::collection_extract::ChangeSource,
107 groups: HashMap<K, GroupState<C::Accumulator>>,
109 entity_groups: HashMap<usize, K>,
111 entity_retractions: HashMap<usize, CollectorRetraction<C, A>>,
113 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
114}
115
116impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
117where
118 S: Send + Sync + 'static,
119 A: Clone + Send + Sync + 'static,
120 K: Clone + Eq + Hash + Send + Sync + 'static,
121 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
122 Fi: UniFilter<S, A>,
123 KF: Fn(&A) -> K + Send + Sync,
124 C: UniCollector<A> + Send + Sync + 'static,
125 C::Accumulator: Send + Sync,
126 C::Result: Send + Sync,
127 W: Fn(&K, &C::Result) -> Sc + Send + Sync,
128 Sc: Score + 'static,
129{
130 #[allow(clippy::too_many_arguments)]
144 pub fn new(
145 constraint_ref: ConstraintRef,
146 impact_type: ImpactType,
147 extractor: E,
148 filter: Fi,
149 key_fn: KF,
150 collector: C,
151 weight_fn: W,
152 is_hard: bool,
153 ) -> Self {
154 let change_source = extractor.change_source();
155 Self {
156 constraint_ref,
157 impact_type,
158 extractor,
159 filter,
160 key_fn,
161 collector,
162 weight_fn,
163 is_hard,
164 change_source,
165 groups: HashMap::new(),
166 entity_groups: HashMap::new(),
167 entity_retractions: HashMap::new(),
168 _phantom: PhantomData,
169 }
170 }
171
172 fn compute_score(&self, key: &K, result: &C::Result) -> Sc {
174 let base = (self.weight_fn)(key, result);
175 match self.impact_type {
176 ImpactType::Penalty => -base,
177 ImpactType::Reward => base,
178 }
179 }
180}
181
182impl<S, A, K, E, Fi, KF, C, W, Sc> IncrementalConstraint<S, Sc>
183 for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
184where
185 S: Send + Sync + 'static,
186 A: Clone + Send + Sync + 'static,
187 K: Clone + Eq + Hash + Send + Sync + 'static,
188 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
189 Fi: UniFilter<S, A>,
190 KF: Fn(&A) -> K + Send + Sync,
191 C: UniCollector<A> + Send + Sync + 'static,
192 C::Accumulator: Send + Sync,
193 C::Result: Send + Sync,
194 C::Value: Send + Sync,
195 W: Fn(&K, &C::Result) -> Sc + Send + Sync,
196 Sc: Score + 'static,
197{
198 fn evaluate(&self, solution: &S) -> Sc {
199 let entities = self.extractor.extract(solution);
200
201 let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
203
204 for entity in entities {
205 if !self.filter.test(solution, entity) {
206 continue;
207 }
208 let key = (self.key_fn)(entity);
209 let value = self.collector.extract(entity);
210 let acc = groups
211 .entry(key)
212 .or_insert_with(|| self.collector.create_accumulator());
213 acc.accumulate(value);
214 }
215
216 let mut total = Sc::zero();
218 for (key, acc) in &groups {
219 total = total + acc.with_result(|result| self.compute_score(key, result));
220 }
221
222 total
223 }
224
225 fn match_count(&self, solution: &S) -> usize {
226 let entities = self.extractor.extract(solution);
227
228 let mut groups: HashMap<K, ()> = HashMap::new();
230 for entity in entities {
231 if !self.filter.test(solution, entity) {
232 continue;
233 }
234 let key = (self.key_fn)(entity);
235 groups.insert(key, ());
236 }
237
238 groups.len()
239 }
240
241 fn initialize(&mut self, solution: &S) -> Sc {
242 self.reset();
243
244 let entities = self.extractor.extract(solution);
245 let mut total = Sc::zero();
246
247 for (idx, entity) in entities.iter().enumerate() {
248 if !self.filter.test(solution, entity) {
249 continue;
250 }
251 total = total + self.insert_entity(entities, idx, entity);
252 }
253
254 total
255 }
256
257 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
258 if !self
259 .change_source
260 .assert_localizes(descriptor_index, &self.constraint_ref.name)
261 {
262 return Sc::zero();
263 }
264 let entities = self.extractor.extract(solution);
265 if entity_index >= entities.len() {
266 return Sc::zero();
267 }
268
269 let entity = &entities[entity_index];
270 if !self.filter.test(solution, entity) {
271 return Sc::zero();
272 }
273 self.insert_entity(entities, entity_index, entity)
274 }
275
276 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
277 if !self
278 .change_source
279 .assert_localizes(descriptor_index, &self.constraint_ref.name)
280 {
281 return Sc::zero();
282 }
283 let entities = self.extractor.extract(solution);
284 self.retract_entity(entities, entity_index)
285 }
286
287 fn reset(&mut self) {
288 self.groups.clear();
289 self.entity_groups.clear();
290 self.entity_retractions.clear();
291 }
292
293 fn name(&self) -> &str {
294 &self.constraint_ref.name
295 }
296
297 fn is_hard(&self) -> bool {
298 self.is_hard
299 }
300
301 fn constraint_ref(&self) -> &ConstraintRef {
302 &self.constraint_ref
303 }
304}
305
306impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
307where
308 S: Send + Sync + 'static,
309 A: Clone + Send + Sync + 'static,
310 K: Clone + Eq + Hash + Send + Sync + 'static,
311 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
312 Fi: UniFilter<S, A>,
313 KF: Fn(&A) -> K + Send + Sync,
314 C: UniCollector<A> + Send + Sync + 'static,
315 C::Accumulator: Send + Sync,
316 C::Result: Send + Sync,
317 C::Value: Send + Sync,
318 W: Fn(&K, &C::Result) -> Sc + Send + Sync,
319 Sc: Score + 'static,
320{
321 fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
322 let key = (self.key_fn)(entity);
323 let entity_key = key.clone();
324 let value = self.collector.extract(entity);
325 let impact = self.impact_type;
326
327 let weight_fn = &self.weight_fn;
328 let (old, new_score) = match self.groups.entry(key) {
329 std::collections::hash_map::Entry::Occupied(mut entry) => {
330 let old_base = entry
331 .get()
332 .accumulator
333 .with_result(|result| weight_fn(entry.key(), result));
334 let old = match impact {
335 ImpactType::Penalty => -old_base,
336 ImpactType::Reward => old_base,
337 };
338 let group = entry.get_mut();
339 let retraction = group.accumulator.accumulate(value);
340 group.count += 1;
341 let new_base = entry
342 .get()
343 .accumulator
344 .with_result(|result| weight_fn(entry.key(), result));
345 let new_score = match impact {
346 ImpactType::Penalty => -new_base,
347 ImpactType::Reward => new_base,
348 };
349 self.entity_retractions.insert(entity_index, retraction);
350 (old, new_score)
351 }
352 std::collections::hash_map::Entry::Vacant(entry) => {
353 let mut entry = entry.insert_entry(GroupState {
354 accumulator: self.collector.create_accumulator(),
355 count: 0,
356 });
357 let group = entry.get_mut();
358 let retraction = group.accumulator.accumulate(value);
359 group.count += 1;
360 let new_base = entry
361 .get()
362 .accumulator
363 .with_result(|result| weight_fn(entry.key(), result));
364 let new_score = match impact {
365 ImpactType::Penalty => -new_base,
366 ImpactType::Reward => new_base,
367 };
368 self.entity_retractions.insert(entity_index, retraction);
369 (Sc::zero(), new_score)
370 }
371 };
372
373 self.entity_groups.insert(entity_index, entity_key);
375
376 new_score - old
378 }
379
380 fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
381 let Some(key) = self.entity_groups.remove(&entity_index) else {
383 return Sc::zero();
384 };
385
386 let Some(retraction) = self.entity_retractions.remove(&entity_index) else {
388 return Sc::zero();
389 };
390 let impact = self.impact_type;
391
392 let weight_fn = &self.weight_fn;
393 let std::collections::hash_map::Entry::Occupied(mut entry) = self.groups.entry(key) else {
394 return Sc::zero();
395 };
396
397 let old_base = entry
398 .get()
399 .accumulator
400 .with_result(|result| weight_fn(entry.key(), result));
401 let old = match impact {
402 ImpactType::Penalty => -old_base,
403 ImpactType::Reward => old_base,
404 };
405
406 let group = entry.get_mut();
407 group.accumulator.retract(retraction);
408 group.count = group.count.saturating_sub(1);
409 let is_empty = group.count == 0;
410 let new_score = if is_empty {
411 entry.remove();
412 Sc::zero()
413 } else {
414 let new_base = entry
415 .get()
416 .accumulator
417 .with_result(|result| weight_fn(entry.key(), result));
418 match impact {
419 ImpactType::Penalty => -new_base,
420 ImpactType::Reward => new_base,
421 }
422 };
423
424 new_score - old
426 }
427}
428
429impl<S, A, K, E, Fi, KF, C, W, Sc> std::fmt::Debug
430 for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
431where
432 C: UniCollector<A>,
433 Sc: Score,
434{
435 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
436 f.debug_struct("GroupedUniConstraint")
437 .field("name", &self.constraint_ref.name)
438 .field("impact_type", &self.impact_type)
439 .field("groups", &self.groups.len())
440 .finish()
441 }
442}