1use std::collections::HashMap;
10use std::hash::Hash;
11use std::marker::PhantomData;
12
13use solverforge_core::score::Score;
14use solverforge_core::{ConstraintRef, ImpactType};
15
16use crate::api::constraint_set::IncrementalConstraint;
17use crate::stream::filter::UniFilter;
18
19pub struct BalanceConstraint<S, A, K, E, F, KF, Sc>
82where
83 Sc: Score,
84{
85 constraint_ref: ConstraintRef,
86 impact_type: ImpactType,
87 extractor: E,
88 filter: F,
89 key_fn: KF,
90 base_score: Sc,
92 is_hard: bool,
93 counts: HashMap<K, i64>,
95 entity_keys: HashMap<usize, K>,
97 group_count: i64,
100 total_count: i64,
102 sum_squared: i64,
104 _phantom: PhantomData<(S, A)>,
105}
106
107impl<S, A, K, E, F, KF, Sc> BalanceConstraint<S, A, K, E, F, KF, Sc>
108where
109 S: Send + Sync + 'static,
110 A: Clone + Send + Sync + 'static,
111 K: Clone + Eq + Hash + Send + Sync + 'static,
112 E: Fn(&S) -> &[A] + Send + Sync,
113 F: UniFilter<A>,
114 KF: Fn(&A) -> Option<K> + Send + Sync,
115 Sc: Score + 'static,
116{
117 pub fn new(
129 constraint_ref: ConstraintRef,
130 impact_type: ImpactType,
131 extractor: E,
132 filter: F,
133 key_fn: KF,
134 base_score: Sc,
135 is_hard: bool,
136 ) -> Self {
137 Self {
138 constraint_ref,
139 impact_type,
140 extractor,
141 filter,
142 key_fn,
143 base_score,
144 is_hard,
145 counts: HashMap::new(),
146 entity_keys: HashMap::new(),
147 group_count: 0,
148 total_count: 0,
149 sum_squared: 0,
150 _phantom: PhantomData,
151 }
152 }
153
154 fn compute_std_dev(&self) -> f64 {
156 if self.group_count == 0 {
157 return 0.0;
158 }
159 let n = self.group_count as f64;
160 let mean = self.total_count as f64 / n;
161 let variance = (self.sum_squared as f64 / n) - (mean * mean);
162 if variance <= 0.0 {
163 return 0.0;
164 }
165 variance.sqrt()
166 }
167
168 fn compute_score(&self) -> Sc {
170 let std_dev = self.compute_std_dev();
171 let base = self.base_score.multiply(std_dev);
172 match self.impact_type {
173 ImpactType::Penalty => -base,
174 ImpactType::Reward => base,
175 }
176 }
177
178 fn compute_std_dev_from_counts(counts: &HashMap<K, i64>) -> f64 {
180 if counts.is_empty() {
181 return 0.0;
182 }
183 let n = counts.len() as f64;
184 let total: i64 = counts.values().sum();
185 let sum_sq: i64 = counts.values().map(|&c| c * c).sum();
186 let mean = total as f64 / n;
187 let variance = (sum_sq as f64 / n) - (mean * mean);
188 if variance > 0.0 {
189 variance.sqrt()
190 } else {
191 0.0
192 }
193 }
194}
195
196impl<S, A, K, E, F, KF, Sc> IncrementalConstraint<S, Sc>
197 for BalanceConstraint<S, A, K, E, F, KF, Sc>
198where
199 S: Send + Sync + 'static,
200 A: Clone + Send + Sync + 'static,
201 K: Clone + Eq + Hash + Send + Sync + 'static,
202 E: Fn(&S) -> &[A] + Send + Sync,
203 F: UniFilter<A>,
204 KF: Fn(&A) -> Option<K> + Send + Sync,
205 Sc: Score + 'static,
206{
207 fn evaluate(&self, solution: &S) -> Sc {
208 let entities = (self.extractor)(solution);
209
210 let mut counts: HashMap<K, i64> = HashMap::new();
212 for entity in entities {
213 if !self.filter.test(entity) {
214 continue;
215 }
216 if let Some(key) = (self.key_fn)(entity) {
217 *counts.entry(key).or_insert(0) += 1;
218 }
219 }
220
221 if counts.is_empty() {
222 return Sc::zero();
223 }
224
225 let std_dev = Self::compute_std_dev_from_counts(&counts);
226 let base = self.base_score.multiply(std_dev);
227 match self.impact_type {
228 ImpactType::Penalty => -base,
229 ImpactType::Reward => base,
230 }
231 }
232
233 fn match_count(&self, solution: &S) -> usize {
234 let entities = (self.extractor)(solution);
235
236 let mut counts: HashMap<K, i64> = HashMap::new();
238 for entity in entities {
239 if !self.filter.test(entity) {
240 continue;
241 }
242 if let Some(key) = (self.key_fn)(entity) {
243 *counts.entry(key).or_insert(0) += 1;
244 }
245 }
246
247 if counts.is_empty() {
248 return 0;
249 }
250
251 let total: i64 = counts.values().sum();
252 let mean = total as f64 / counts.len() as f64;
253
254 counts
256 .values()
257 .filter(|&&c| (c as f64 - mean).abs() > 0.5)
258 .count()
259 }
260
261 fn initialize(&mut self, solution: &S) -> Sc {
262 self.reset();
263
264 let entities = (self.extractor)(solution);
265
266 for (idx, entity) in entities.iter().enumerate() {
267 if !self.filter.test(entity) {
268 continue;
269 }
270 if let Some(key) = (self.key_fn)(entity) {
271 let old_count = *self.counts.get(&key).unwrap_or(&0);
272 let new_count = old_count + 1;
273 self.counts.insert(key.clone(), new_count);
274 self.entity_keys.insert(idx, key);
275
276 if old_count == 0 {
277 self.group_count += 1;
278 }
279 self.total_count += 1;
280 self.sum_squared += new_count * new_count - old_count * old_count;
281 }
282 }
283
284 self.compute_score()
285 }
286
287 fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
288 let entities = (self.extractor)(solution);
289 if entity_index >= entities.len() {
290 return Sc::zero();
291 }
292
293 let entity = &entities[entity_index];
294 if !self.filter.test(entity) {
295 return Sc::zero();
296 }
297
298 let Some(key) = (self.key_fn)(entity) else {
299 return Sc::zero();
300 };
301
302 let old_score = self.compute_score();
303
304 let old_count = *self.counts.get(&key).unwrap_or(&0);
305 let new_count = old_count + 1;
306 self.counts.insert(key.clone(), new_count);
307 self.entity_keys.insert(entity_index, key);
308
309 if old_count == 0 {
310 self.group_count += 1;
311 }
312 self.total_count += 1;
313 self.sum_squared += new_count * new_count - old_count * old_count;
314
315 let new_score = self.compute_score();
316 new_score - old_score
317 }
318
319 fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
320 let entities = (self.extractor)(solution);
321 if entity_index >= entities.len() {
322 return Sc::zero();
323 }
324
325 let Some(key) = self.entity_keys.remove(&entity_index) else {
327 return Sc::zero();
328 };
329
330 let old_score = self.compute_score();
331
332 let old_count = *self.counts.get(&key).unwrap_or(&0);
333 if old_count == 0 {
334 return Sc::zero();
335 }
336
337 let new_count = old_count - 1;
338 if new_count == 0 {
339 self.counts.remove(&key);
340 self.group_count -= 1;
341 } else {
342 self.counts.insert(key, new_count);
343 }
344 self.total_count -= 1;
345 self.sum_squared += new_count * new_count - old_count * old_count;
346
347 let new_score = self.compute_score();
348 new_score - old_score
349 }
350
351 fn reset(&mut self) {
352 self.counts.clear();
353 self.entity_keys.clear();
354 self.group_count = 0;
355 self.total_count = 0;
356 self.sum_squared = 0;
357 }
358
359 fn name(&self) -> &str {
360 &self.constraint_ref.name
361 }
362
363 fn is_hard(&self) -> bool {
364 self.is_hard
365 }
366
367 fn constraint_ref(&self) -> ConstraintRef {
368 self.constraint_ref.clone()
369 }
370}
371
372impl<S, A, K, E, F, KF, Sc> std::fmt::Debug for BalanceConstraint<S, A, K, E, F, KF, Sc>
373where
374 Sc: Score,
375{
376 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377 f.debug_struct("BalanceConstraint")
378 .field("name", &self.constraint_ref.name)
379 .field("impact_type", &self.impact_type)
380 .field("groups", &self.counts.len())
381 .finish()
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use crate::stream::filter::TrueFilter;
389 use solverforge_core::score::SimpleScore;
390
391 #[derive(Clone)]
392 struct Shift {
393 employee_id: Option<usize>,
394 }
395
396 #[derive(Clone)]
397 struct Solution {
398 shifts: Vec<Shift>,
399 }
400
401 #[test]
402 fn test_balance_evaluate_equal_distribution() {
403 let constraint = BalanceConstraint::new(
404 ConstraintRef::new("", "Balance"),
405 ImpactType::Penalty,
406 |s: &Solution| &s.shifts,
407 TrueFilter,
408 |shift: &Shift| shift.employee_id,
409 SimpleScore::of(1000), false,
411 );
412
413 let solution = Solution {
415 shifts: vec![
416 Shift {
417 employee_id: Some(0),
418 },
419 Shift {
420 employee_id: Some(0),
421 },
422 Shift {
423 employee_id: Some(1),
424 },
425 Shift {
426 employee_id: Some(1),
427 },
428 ],
429 };
430
431 assert_eq!(constraint.evaluate(&solution), SimpleScore::of(0));
433 }
434
435 #[test]
436 fn test_balance_evaluate_unequal_distribution() {
437 let constraint = BalanceConstraint::new(
438 ConstraintRef::new("", "Balance"),
439 ImpactType::Penalty,
440 |s: &Solution| &s.shifts,
441 TrueFilter,
442 |shift: &Shift| shift.employee_id,
443 SimpleScore::of(1000), false,
445 );
446
447 let solution = Solution {
449 shifts: vec![
450 Shift {
451 employee_id: Some(0),
452 },
453 Shift {
454 employee_id: Some(0),
455 },
456 Shift {
457 employee_id: Some(0),
458 },
459 Shift {
460 employee_id: Some(1),
461 },
462 ],
463 };
464
465 assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1000));
468 }
469
470 #[test]
471 fn test_balance_filters_unassigned() {
472 let constraint = BalanceConstraint::new(
473 ConstraintRef::new("", "Balance"),
474 ImpactType::Penalty,
475 |s: &Solution| &s.shifts,
476 TrueFilter,
477 |shift: &Shift| shift.employee_id,
478 SimpleScore::of(1000),
479 false,
480 );
481
482 let solution = Solution {
484 shifts: vec![
485 Shift {
486 employee_id: Some(0),
487 },
488 Shift {
489 employee_id: Some(0),
490 },
491 Shift {
492 employee_id: Some(1),
493 },
494 Shift {
495 employee_id: Some(1),
496 },
497 Shift { employee_id: None },
498 ],
499 };
500
501 assert_eq!(constraint.evaluate(&solution), SimpleScore::of(0));
503 }
504
505 #[test]
506 fn test_balance_incremental() {
507 let mut constraint = BalanceConstraint::new(
508 ConstraintRef::new("", "Balance"),
509 ImpactType::Penalty,
510 |s: &Solution| &s.shifts,
511 TrueFilter,
512 |shift: &Shift| shift.employee_id,
513 SimpleScore::of(1000),
514 false,
515 );
516
517 let solution = Solution {
518 shifts: vec![
519 Shift {
520 employee_id: Some(0),
521 },
522 Shift {
523 employee_id: Some(0),
524 },
525 Shift {
526 employee_id: Some(1),
527 },
528 Shift {
529 employee_id: Some(1),
530 },
531 ],
532 };
533
534 let initial = constraint.initialize(&solution);
536 assert_eq!(initial, SimpleScore::of(0));
537
538 let delta = constraint.on_retract(&solution, 0);
540 assert_eq!(delta, SimpleScore::of(-500));
544
545 let delta = constraint.on_insert(&solution, 0);
547 assert_eq!(delta, SimpleScore::of(500));
549 }
550
551 #[test]
552 fn test_balance_empty_solution() {
553 let constraint = BalanceConstraint::new(
554 ConstraintRef::new("", "Balance"),
555 ImpactType::Penalty,
556 |s: &Solution| &s.shifts,
557 TrueFilter,
558 |shift: &Shift| shift.employee_id,
559 SimpleScore::of(1000),
560 false,
561 );
562
563 let solution = Solution { shifts: vec![] };
564 assert_eq!(constraint.evaluate(&solution), SimpleScore::of(0));
565 }
566
567 #[test]
568 fn test_balance_single_employee() {
569 let constraint = BalanceConstraint::new(
570 ConstraintRef::new("", "Balance"),
571 ImpactType::Penalty,
572 |s: &Solution| &s.shifts,
573 TrueFilter,
574 |shift: &Shift| shift.employee_id,
575 SimpleScore::of(1000),
576 false,
577 );
578
579 let solution = Solution {
581 shifts: vec![
582 Shift {
583 employee_id: Some(0),
584 },
585 Shift {
586 employee_id: Some(0),
587 },
588 Shift {
589 employee_id: Some(0),
590 },
591 Shift {
592 employee_id: Some(0),
593 },
594 Shift {
595 employee_id: Some(0),
596 },
597 ],
598 };
599
600 assert_eq!(constraint.evaluate(&solution), SimpleScore::of(0));
602 }
603
604 #[test]
605 fn test_balance_reward() {
606 let constraint = BalanceConstraint::new(
607 ConstraintRef::new("", "Balance reward"),
608 ImpactType::Reward,
609 |s: &Solution| &s.shifts,
610 TrueFilter,
611 |shift: &Shift| shift.employee_id,
612 SimpleScore::of(1000),
613 false,
614 );
615
616 let solution = Solution {
617 shifts: vec![
618 Shift {
619 employee_id: Some(0),
620 },
621 Shift {
622 employee_id: Some(0),
623 },
624 Shift {
625 employee_id: Some(0),
626 },
627 Shift {
628 employee_id: Some(1),
629 },
630 ],
631 };
632
633 assert_eq!(constraint.evaluate(&solution), SimpleScore::of(1000));
635 }
636}