1use std::collections::HashMap;
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::api::constraint_set::IncrementalConstraint;
14use crate::stream::collector::{Accumulator, UniCollector};
15
16pub struct ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
87where
88 C: UniCollector<A>,
89 Sc: Score,
90{
91 constraint_ref: ConstraintRef,
92 impact_type: ImpactType,
93 extractor_a: EA,
94 extractor_b: EB,
95 key_a: KA,
96 key_b: KB,
97 collector: C,
98 default_fn: D,
99 weight_fn: W,
100 is_hard: bool,
101 groups: HashMap<K, C::Accumulator>,
103 entity_groups: HashMap<usize, K>,
105 entity_values: HashMap<usize, C::Value>,
107 b_by_key: HashMap<K, usize>,
109 _phantom: PhantomData<(S, A, B, Sc)>,
110}
111
112impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
113 ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
114where
115 S: 'static,
116 A: Clone + 'static,
117 B: Clone + 'static,
118 K: Clone + Eq + Hash,
119 EA: Fn(&S) -> &[A],
120 EB: Fn(&S) -> &[B],
121 KA: Fn(&A) -> Option<K>,
122 KB: Fn(&B) -> K,
123 C: UniCollector<A>,
124 C::Result: Clone,
125 D: Fn(&B) -> C::Result,
126 W: Fn(&C::Result) -> Sc,
127 Sc: Score,
128{
129 #[allow(clippy::too_many_arguments)]
131 pub fn new(
132 constraint_ref: ConstraintRef,
133 impact_type: ImpactType,
134 extractor_a: EA,
135 extractor_b: EB,
136 key_a: KA,
137 key_b: KB,
138 collector: C,
139 default_fn: D,
140 weight_fn: W,
141 is_hard: bool,
142 ) -> Self {
143 Self {
144 constraint_ref,
145 impact_type,
146 extractor_a,
147 extractor_b,
148 key_a,
149 key_b,
150 collector,
151 default_fn,
152 weight_fn,
153 is_hard,
154 groups: HashMap::new(),
155 entity_groups: HashMap::new(),
156 entity_values: HashMap::new(),
157 b_by_key: HashMap::new(),
158 _phantom: PhantomData,
159 }
160 }
161
162 #[inline]
163 fn compute_score(&self, result: &C::Result) -> Sc {
164 let base = (self.weight_fn)(result);
165 match self.impact_type {
166 ImpactType::Penalty => -base,
167 ImpactType::Reward => base,
168 }
169 }
170
171 fn build_groups(&self, entities_a: &[A]) -> HashMap<K, C::Result> {
173 let mut accumulators: HashMap<K, C::Accumulator> = HashMap::new();
174
175 for a in entities_a {
176 let Some(key) = (self.key_a)(a) else {
178 continue;
179 };
180 let value = self.collector.extract(a);
181 accumulators
182 .entry(key)
183 .or_insert_with(|| self.collector.create_accumulator())
184 .accumulate(&value);
185 }
186
187 accumulators
188 .into_iter()
189 .map(|(k, acc)| (k, acc.finish()))
190 .collect()
191 }
192}
193
194impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> IncrementalConstraint<S, Sc>
195 for ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
196where
197 S: Send + Sync + 'static,
198 A: Clone + Send + Sync + 'static,
199 B: Clone + Send + Sync + 'static,
200 K: Clone + Eq + Hash + Send + Sync,
201 EA: Fn(&S) -> &[A] + Send + Sync,
202 EB: Fn(&S) -> &[B] + Send + Sync,
203 KA: Fn(&A) -> Option<K> + Send + Sync,
204 KB: Fn(&B) -> K + Send + Sync,
205 C: UniCollector<A> + Send + Sync,
206 C::Accumulator: Send + Sync,
207 C::Result: Clone + Send + Sync,
208 C::Value: Send + Sync,
209 D: Fn(&B) -> C::Result + Send + Sync,
210 W: Fn(&C::Result) -> Sc + Send + Sync,
211 Sc: Score,
212{
213 fn evaluate(&self, solution: &S) -> Sc {
214 let entities_a = (self.extractor_a)(solution);
215 let entities_b = (self.extractor_b)(solution);
216
217 let groups = self.build_groups(entities_a);
218
219 let mut total = Sc::zero();
220 for b in entities_b {
221 let key = (self.key_b)(b);
222 let result = groups
223 .get(&key)
224 .cloned()
225 .unwrap_or_else(|| (self.default_fn)(b));
226 total = total + self.compute_score(&result);
227 }
228
229 total
230 }
231
232 fn match_count(&self, solution: &S) -> usize {
233 let entities_b = (self.extractor_b)(solution);
234 entities_b.len()
235 }
236
237 fn initialize(&mut self, solution: &S) -> Sc {
238 self.reset();
239
240 let entities_a = (self.extractor_a)(solution);
241 let entities_b = (self.extractor_b)(solution);
242
243 for (idx, b) in entities_b.iter().enumerate() {
245 let key = (self.key_b)(b);
246 self.b_by_key.insert(key, idx);
247 }
248
249 let mut total = Sc::zero();
251 for b in entities_b {
252 let default_result = (self.default_fn)(b);
253 total = total + self.compute_score(&default_result);
254 }
255
256 for (idx, a) in entities_a.iter().enumerate() {
258 total = total + self.insert_entity(entities_b, idx, a);
259 }
260
261 total
262 }
263
264 fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
265 let entities_a = (self.extractor_a)(solution);
266 let entities_b = (self.extractor_b)(solution);
267
268 if entity_index >= entities_a.len() {
269 return Sc::zero();
270 }
271
272 let entity = &entities_a[entity_index];
273 self.insert_entity(entities_b, entity_index, entity)
274 }
275
276 fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
277 let entities_a = (self.extractor_a)(solution);
278 let entities_b = (self.extractor_b)(solution);
279
280 self.retract_entity(entities_a, entities_b, entity_index)
281 }
282
283 fn reset(&mut self) {
284 self.groups.clear();
285 self.entity_groups.clear();
286 self.entity_values.clear();
287 self.b_by_key.clear();
288 }
289
290 fn name(&self) -> &str {
291 &self.constraint_ref.name
292 }
293
294 fn is_hard(&self) -> bool {
295 self.is_hard
296 }
297
298 fn constraint_ref(&self) -> ConstraintRef {
299 self.constraint_ref.clone()
300 }
301}
302
303impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
304 ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
305where
306 S: Send + Sync + 'static,
307 A: Clone + Send + Sync + 'static,
308 B: Clone + Send + Sync + 'static,
309 K: Clone + Eq + Hash + Send + Sync,
310 EA: Fn(&S) -> &[A] + Send + Sync,
311 EB: Fn(&S) -> &[B] + Send + Sync,
312 KA: Fn(&A) -> Option<K> + Send + Sync,
313 KB: Fn(&B) -> K + Send + Sync,
314 C: UniCollector<A> + Send + Sync,
315 C::Accumulator: Send + Sync,
316 C::Result: Clone + Send + Sync,
317 C::Value: Send + Sync,
318 D: Fn(&B) -> C::Result + Send + Sync,
319 W: Fn(&C::Result) -> Sc + Send + Sync,
320 Sc: Score,
321{
322 fn insert_entity(&mut self, entities_b: &[B], entity_index: usize, entity: &A) -> Sc {
324 let Some(key) = (self.key_a)(entity) else {
326 return Sc::zero();
327 };
328 let value = self.collector.extract(entity);
329 let impact = self.impact_type;
330
331 let b_idx = self.b_by_key.get(&key).copied();
333 let Some(b_idx) = b_idx else {
334 let acc = self
337 .groups
338 .entry(key.clone())
339 .or_insert_with(|| self.collector.create_accumulator());
340 acc.accumulate(&value);
341 self.entity_groups.insert(entity_index, key);
342 self.entity_values.insert(entity_index, value);
343 return Sc::zero();
344 };
345
346 let b = &entities_b[b_idx];
347
348 let old_result = self
350 .groups
351 .get(&key)
352 .map(|acc| acc.finish())
353 .unwrap_or_else(|| (self.default_fn)(b));
354 let old_base = (self.weight_fn)(&old_result);
355 let old = match impact {
356 ImpactType::Penalty => -old_base,
357 ImpactType::Reward => old_base,
358 };
359
360 let acc = self
362 .groups
363 .entry(key.clone())
364 .or_insert_with(|| self.collector.create_accumulator());
365 acc.accumulate(&value);
366
367 let new_result = acc.finish();
369 let new_base = (self.weight_fn)(&new_result);
370 let new_score = match impact {
371 ImpactType::Penalty => -new_base,
372 ImpactType::Reward => new_base,
373 };
374
375 self.entity_groups.insert(entity_index, key);
377 self.entity_values.insert(entity_index, value);
378
379 new_score - old
381 }
382
383 fn retract_entity(&mut self, _entities_a: &[A], entities_b: &[B], entity_index: usize) -> Sc {
385 let Some(key) = self.entity_groups.remove(&entity_index) else {
387 return Sc::zero();
388 };
389
390 let Some(value) = self.entity_values.remove(&entity_index) else {
392 return Sc::zero();
393 };
394 let impact = self.impact_type;
395
396 let b_idx = self.b_by_key.get(&key).copied();
398 let Some(b_idx) = b_idx else {
399 if let Some(acc) = self.groups.get_mut(&key) {
401 acc.retract(&value);
402 }
403 return Sc::zero();
404 };
405
406 let b = &entities_b[b_idx];
407
408 let Some(acc) = self.groups.get_mut(&key) else {
410 return Sc::zero();
411 };
412
413 let old_result = acc.finish();
415 let old_base = (self.weight_fn)(&old_result);
416 let old = match impact {
417 ImpactType::Penalty => -old_base,
418 ImpactType::Reward => old_base,
419 };
420
421 acc.retract(&value);
423
424 let new_result = acc.finish();
426 let default_result = (self.default_fn)(b);
428 let new_base = (self.weight_fn)(&new_result);
429 let new_score = match impact {
430 ImpactType::Penalty => -new_base,
431 ImpactType::Reward => new_base,
432 };
433
434 let _ = default_result; new_score - old
440 }
441}
442
443impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> std::fmt::Debug
444 for ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
445where
446 C: UniCollector<A>,
447 Sc: Score,
448{
449 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
450 f.debug_struct("ComplementedGroupConstraint")
451 .field("name", &self.constraint_ref.name)
452 .field("impact_type", &self.impact_type)
453 .field("groups", &self.groups.len())
454 .finish()
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use crate::stream::collector::count;
462 use solverforge_core::score::SimpleScore;
463
464 #[derive(Clone, Hash, PartialEq, Eq)]
465 struct Employee {
466 id: usize,
467 }
468
469 #[derive(Clone)]
470 struct Shift {
471 employee_id: Option<usize>,
472 }
473
474 #[derive(Clone)]
475 struct Schedule {
476 employees: Vec<Employee>,
477 shifts: Vec<Shift>,
478 }
479
480 #[test]
481 fn test_complemented_evaluate() {
482 let constraint = ComplementedGroupConstraint::new(
483 ConstraintRef::new("", "Shift count"),
484 ImpactType::Penalty,
485 |s: &Schedule| s.shifts.as_slice(),
486 |s: &Schedule| s.employees.as_slice(),
487 |shift: &Shift| shift.employee_id,
488 |emp: &Employee| emp.id,
489 count::<Shift>(),
490 |_emp: &Employee| 0usize,
491 |count: &usize| SimpleScore::of(*count as i64),
492 false,
493 );
494
495 let schedule = Schedule {
496 employees: vec![Employee { id: 0 }, Employee { id: 1 }],
497 shifts: vec![
498 Shift {
499 employee_id: Some(0),
500 },
501 Shift {
502 employee_id: Some(0),
503 },
504 ],
505 };
506
507 assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
510 }
511
512 #[test]
513 fn test_complemented_skips_none_keys() {
514 let constraint = ComplementedGroupConstraint::new(
515 ConstraintRef::new("", "Shift count"),
516 ImpactType::Penalty,
517 |s: &Schedule| s.shifts.as_slice(),
518 |s: &Schedule| s.employees.as_slice(),
519 |shift: &Shift| shift.employee_id,
520 |emp: &Employee| emp.id,
521 count::<Shift>(),
522 |_emp: &Employee| 0usize,
523 |count: &usize| SimpleScore::of(*count as i64),
524 false,
525 );
526
527 let schedule = Schedule {
528 employees: vec![Employee { id: 0 }, Employee { id: 1 }],
529 shifts: vec![
530 Shift {
531 employee_id: Some(0),
532 },
533 Shift {
534 employee_id: Some(0),
535 },
536 Shift { employee_id: None }, Shift { employee_id: None }, ],
539 };
540
541 assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
545 }
546
547 #[test]
548 fn test_complemented_incremental() {
549 let mut constraint = ComplementedGroupConstraint::new(
550 ConstraintRef::new("", "Shift count"),
551 ImpactType::Penalty,
552 |s: &Schedule| s.shifts.as_slice(),
553 |s: &Schedule| s.employees.as_slice(),
554 |shift: &Shift| shift.employee_id,
555 |emp: &Employee| emp.id,
556 count::<Shift>(),
557 |_emp: &Employee| 0usize,
558 |count: &usize| SimpleScore::of(*count as i64),
559 false,
560 );
561
562 let schedule = Schedule {
563 employees: vec![Employee { id: 0 }, Employee { id: 1 }, Employee { id: 2 }],
564 shifts: vec![
565 Shift {
566 employee_id: Some(0),
567 },
568 Shift {
569 employee_id: Some(0),
570 },
571 Shift {
572 employee_id: Some(1),
573 },
574 ],
575 };
576
577 let total = constraint.initialize(&schedule);
579 assert_eq!(total, SimpleScore::of(-3));
584
585 let delta = constraint.on_retract(&schedule, 0);
587 assert_eq!(delta, SimpleScore::of(1));
589
590 let delta = constraint.on_insert(&schedule, 0);
592 assert_eq!(delta, SimpleScore::of(-1));
594 }
595
596 #[test]
597 fn test_complemented_incremental_with_none_keys() {
598 let mut constraint = ComplementedGroupConstraint::new(
599 ConstraintRef::new("", "Shift count"),
600 ImpactType::Penalty,
601 |s: &Schedule| s.shifts.as_slice(),
602 |s: &Schedule| s.employees.as_slice(),
603 |shift: &Shift| shift.employee_id,
604 |emp: &Employee| emp.id,
605 count::<Shift>(),
606 |_emp: &Employee| 0usize,
607 |count: &usize| SimpleScore::of(*count as i64),
608 false,
609 );
610
611 let schedule = Schedule {
612 employees: vec![Employee { id: 0 }, Employee { id: 1 }],
613 shifts: vec![
614 Shift {
615 employee_id: Some(0),
616 },
617 Shift { employee_id: None }, Shift {
619 employee_id: Some(0),
620 },
621 ],
622 };
623
624 let total = constraint.initialize(&schedule);
626 assert_eq!(total, SimpleScore::of(-2));
629
630 let delta = constraint.on_retract(&schedule, 1);
632 assert_eq!(delta, SimpleScore::of(0));
633
634 let delta = constraint.on_insert(&schedule, 1);
636 assert_eq!(delta, SimpleScore::of(0));
637 }
638
639 #[test]
640 fn test_complemented_with_default() {
641 let constraint = ComplementedGroupConstraint::new(
642 ConstraintRef::new("", "Workload balance"),
643 ImpactType::Penalty,
644 |s: &Schedule| s.shifts.as_slice(),
645 |s: &Schedule| s.employees.as_slice(),
646 |shift: &Shift| shift.employee_id,
647 |emp: &Employee| emp.id,
648 count::<Shift>(),
649 |_emp: &Employee| 0usize,
650 |count: &usize| SimpleScore::of((*count as i64).pow(2)),
651 false,
652 );
653
654 let schedule = Schedule {
655 employees: vec![Employee { id: 0 }, Employee { id: 1 }, Employee { id: 2 }],
656 shifts: vec![
657 Shift {
658 employee_id: Some(0),
659 },
660 Shift {
661 employee_id: Some(0),
662 },
663 Shift {
664 employee_id: Some(0),
665 },
666 ],
667 };
668
669 assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-9));
674 }
675
676 #[test]
677 fn test_complemented_incremental_matches_evaluate() {
678 let mut constraint = ComplementedGroupConstraint::new(
679 ConstraintRef::new("", "Shift count"),
680 ImpactType::Penalty,
681 |s: &Schedule| s.shifts.as_slice(),
682 |s: &Schedule| s.employees.as_slice(),
683 |shift: &Shift| shift.employee_id,
684 |emp: &Employee| emp.id,
685 count::<Shift>(),
686 |_emp: &Employee| 0usize,
687 |count: &usize| SimpleScore::of((*count as i64).pow(2)),
688 false,
689 );
690
691 let schedule = Schedule {
692 employees: vec![Employee { id: 0 }, Employee { id: 1 }],
693 shifts: vec![
694 Shift {
695 employee_id: Some(0),
696 },
697 Shift {
698 employee_id: Some(0),
699 },
700 Shift {
701 employee_id: Some(1),
702 },
703 ],
704 };
705
706 let init_total = constraint.initialize(&schedule);
708 let eval_total = constraint.evaluate(&schedule);
709 assert_eq!(init_total, eval_total);
710
711 assert_eq!(init_total, SimpleScore::of(-5));
714
715 let mut running_total = init_total;
717
718 running_total = running_total + constraint.on_retract(&schedule, 2);
720 assert_eq!(running_total, SimpleScore::of(-4));
722
723 running_total = running_total + constraint.on_insert(&schedule, 2);
725 assert_eq!(running_total, SimpleScore::of(-5));
727 }
728}