solverforge_scoring/constraint/
balance.rs

1//! Zero-erasure balance constraint for load distribution scoring.
2//!
3//! Provides a constraint that penalizes uneven distribution across groups
4//! using standard deviation. Unlike grouped constraints which score per-group,
5//! the balance constraint computes a GLOBAL statistic across all groups.
6//!
7//! All type information is preserved at compile time - no Arc, no dyn.
8
9use std::collections::HashMap;
10use std::hash::Hash;
11use std::marker::PhantomData;
12
13use solverforge_core::score::Score;
14use solverforge_core::{ConstraintRef, ImpactType};
15
16use crate::api::constraint_set::IncrementalConstraint;
17use crate::stream::filter::UniFilter;
18
19/// Zero-erasure balance constraint that penalizes uneven load distribution.
20///
21/// This constraint:
22/// 1. Groups entities by key (e.g., employee_id)
23/// 2. Counts how many entities belong to each group
24/// 3. Computes population standard deviation across all group counts
25/// 4. Multiplies the base score by std_dev to produce the final score
26///
27/// The key difference from `GroupedUniConstraint` is that balance computes
28/// a GLOBAL statistic, not per-group scores.
29///
30/// # Type Parameters
31///
32/// - `S` - Solution type
33/// - `A` - Entity type
34/// - `K` - Group key type
35/// - `E` - Extractor function for entities
36/// - `F` - Filter type
37/// - `KF` - Key function
38/// - `Sc` - Score type
39///
40/// # Example
41///
42/// ```
43/// use solverforge_scoring::constraint::balance::BalanceConstraint;
44/// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
45/// use solverforge_scoring::stream::filter::TrueFilter;
46/// use solverforge_core::{ConstraintRef, ImpactType, HardSoftDecimalScore};
47///
48/// #[derive(Clone)]
49/// struct Shift { employee_id: Option<usize> }
50///
51/// #[derive(Clone)]
52/// struct Solution { shifts: Vec<Shift> }
53///
54/// // Base score of 1000 soft per unit of std_dev
55/// let constraint = BalanceConstraint::new(
56///     ConstraintRef::new("", "Balance workload"),
57///     ImpactType::Penalty,
58///     |s: &Solution| &s.shifts,
59///     TrueFilter,
60///     |shift: &Shift| shift.employee_id,
61///     HardSoftDecimalScore::of_soft(1),  // 1 soft per unit std_dev (scaled internally)
62///     false,
63/// );
64///
65/// let solution = Solution {
66///     shifts: vec![
67///         Shift { employee_id: Some(0) },
68///         Shift { employee_id: Some(0) },
69///         Shift { employee_id: Some(0) },
70///         Shift { employee_id: Some(1) },
71///         Shift { employee_id: None },  // Unassigned, filtered out
72///     ],
73/// };
74///
75/// // Employee 0: 3 shifts, Employee 1: 1 shift
76/// // Mean = 2, Variance = ((3-2)² + (1-2)²) / 2 = 1
77/// // StdDev = 1.0, Score = -1 soft (base_score * std_dev, negated for penalty)
78/// let score = constraint.evaluate(&solution);
79/// assert_eq!(score, HardSoftDecimalScore::of_soft(-1));
80/// ```
81pub struct BalanceConstraint<S, A, K, E, F, KF, Sc>
82where
83    Sc: Score,
84{
85    constraint_ref: ConstraintRef,
86    impact_type: ImpactType,
87    extractor: E,
88    filter: F,
89    key_fn: KF,
90    /// Base score representing 1 unit of standard deviation
91    base_score: Sc,
92    is_hard: bool,
93    /// Group key → count of entities in that group
94    counts: HashMap<K, i64>,
95    /// Entity index → group key (for tracking assignments)
96    entity_keys: HashMap<usize, K>,
97    /// Cached statistics for incremental updates
98    /// Number of groups (employees with at least one shift)
99    group_count: i64,
100    /// Sum of all counts (total assignments)
101    total_count: i64,
102    /// Sum of squared counts (for variance calculation)
103    sum_squared: i64,
104    _phantom: PhantomData<(S, A)>,
105}
106
107impl<S, A, K, E, F, KF, Sc> BalanceConstraint<S, A, K, E, F, KF, Sc>
108where
109    S: Send + Sync + 'static,
110    A: Clone + Send + Sync + 'static,
111    K: Clone + Eq + Hash + Send + Sync + 'static,
112    E: Fn(&S) -> &[A] + Send + Sync,
113    F: UniFilter<S, A>,
114    KF: Fn(&A) -> Option<K> + Send + Sync,
115    Sc: Score + 'static,
116{
117    /// Creates a new zero-erasure balance constraint.
118    ///
119    /// # Arguments
120    ///
121    /// * `constraint_ref` - Identifier for this constraint
122    /// * `impact_type` - Whether to penalize or reward
123    /// * `extractor` - Function to get entity slice from solution
124    /// * `filter` - Filter to select which entities to consider
125    /// * `key_fn` - Function to extract group key (returns None to skip entity)
126    /// * `base_score` - Score per unit of standard deviation
127    /// * `is_hard` - Whether this is a hard constraint
128    pub fn new(
129        constraint_ref: ConstraintRef,
130        impact_type: ImpactType,
131        extractor: E,
132        filter: F,
133        key_fn: KF,
134        base_score: Sc,
135        is_hard: bool,
136    ) -> Self {
137        Self {
138            constraint_ref,
139            impact_type,
140            extractor,
141            filter,
142            key_fn,
143            base_score,
144            is_hard,
145            counts: HashMap::new(),
146            entity_keys: HashMap::new(),
147            group_count: 0,
148            total_count: 0,
149            sum_squared: 0,
150            _phantom: PhantomData,
151        }
152    }
153
154    /// Computes standard deviation from cached statistics.
155    fn compute_std_dev(&self) -> f64 {
156        if self.group_count == 0 {
157            return 0.0;
158        }
159        let n = self.group_count as f64;
160        let mean = self.total_count as f64 / n;
161        let variance = (self.sum_squared as f64 / n) - (mean * mean);
162        if variance <= 0.0 {
163            return 0.0;
164        }
165        variance.sqrt()
166    }
167
168    /// Computes the score from standard deviation.
169    fn compute_score(&self) -> Sc {
170        let std_dev = self.compute_std_dev();
171        let base = self.base_score.multiply(std_dev);
172        match self.impact_type {
173            ImpactType::Penalty => -base,
174            ImpactType::Reward => base,
175        }
176    }
177
178    /// Computes std_dev from raw counts (for stateless evaluate).
179    fn compute_std_dev_from_counts(counts: &HashMap<K, i64>) -> f64 {
180        if counts.is_empty() {
181            return 0.0;
182        }
183        let n = counts.len() as f64;
184        let total: i64 = counts.values().sum();
185        let sum_sq: i64 = counts.values().map(|&c| c * c).sum();
186        let mean = total as f64 / n;
187        let variance = (sum_sq as f64 / n) - (mean * mean);
188        if variance > 0.0 {
189            variance.sqrt()
190        } else {
191            0.0
192        }
193    }
194}
195
196impl<S, A, K, E, F, KF, Sc> IncrementalConstraint<S, Sc>
197    for BalanceConstraint<S, A, K, E, F, KF, Sc>
198where
199    S: Send + Sync + 'static,
200    A: Clone + Send + Sync + 'static,
201    K: Clone + Eq + Hash + Send + Sync + 'static,
202    E: Fn(&S) -> &[A] + Send + Sync,
203    F: UniFilter<S, A>,
204    KF: Fn(&A) -> Option<K> + Send + Sync,
205    Sc: Score + 'static,
206{
207    fn evaluate(&self, solution: &S) -> Sc {
208        let entities = (self.extractor)(solution);
209
210        // Build counts from scratch
211        let mut counts: HashMap<K, i64> = HashMap::new();
212        for entity in entities {
213            if !self.filter.test(solution, entity) {
214                continue;
215            }
216            if let Some(key) = (self.key_fn)(entity) {
217                *counts.entry(key).or_insert(0) += 1;
218            }
219        }
220
221        if counts.is_empty() {
222            return Sc::zero();
223        }
224
225        let std_dev = Self::compute_std_dev_from_counts(&counts);
226        let base = self.base_score.multiply(std_dev);
227        match self.impact_type {
228            ImpactType::Penalty => -base,
229            ImpactType::Reward => base,
230        }
231    }
232
233    fn match_count(&self, solution: &S) -> usize {
234        let entities = (self.extractor)(solution);
235
236        // Count groups that deviate from mean
237        let mut counts: HashMap<K, i64> = HashMap::new();
238        for entity in entities {
239            if !self.filter.test(solution, entity) {
240                continue;
241            }
242            if let Some(key) = (self.key_fn)(entity) {
243                *counts.entry(key).or_insert(0) += 1;
244            }
245        }
246
247        if counts.is_empty() {
248            return 0;
249        }
250
251        let total: i64 = counts.values().sum();
252        let mean = total as f64 / counts.len() as f64;
253
254        // Count groups that deviate significantly from mean
255        counts
256            .values()
257            .filter(|&&c| (c as f64 - mean).abs() > 0.5)
258            .count()
259    }
260
261    fn initialize(&mut self, solution: &S) -> Sc {
262        self.reset();
263
264        let entities = (self.extractor)(solution);
265
266        for (idx, entity) in entities.iter().enumerate() {
267            if !self.filter.test(solution, entity) {
268                continue;
269            }
270            if let Some(key) = (self.key_fn)(entity) {
271                let old_count = *self.counts.get(&key).unwrap_or(&0);
272                let new_count = old_count + 1;
273                self.counts.insert(key.clone(), new_count);
274                self.entity_keys.insert(idx, key);
275
276                if old_count == 0 {
277                    self.group_count += 1;
278                }
279                self.total_count += 1;
280                self.sum_squared += new_count * new_count - old_count * old_count;
281            }
282        }
283
284        self.compute_score()
285    }
286
287    fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
288        let entities = (self.extractor)(solution);
289        if entity_index >= entities.len() {
290            return Sc::zero();
291        }
292
293        let entity = &entities[entity_index];
294        if !self.filter.test(solution, entity) {
295            return Sc::zero();
296        }
297
298        let Some(key) = (self.key_fn)(entity) else {
299            return Sc::zero();
300        };
301
302        let old_score = self.compute_score();
303
304        let old_count = *self.counts.get(&key).unwrap_or(&0);
305        let new_count = old_count + 1;
306        self.counts.insert(key.clone(), new_count);
307        self.entity_keys.insert(entity_index, key);
308
309        if old_count == 0 {
310            self.group_count += 1;
311        }
312        self.total_count += 1;
313        self.sum_squared += new_count * new_count - old_count * old_count;
314
315        let new_score = self.compute_score();
316        new_score - old_score
317    }
318
319    fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
320        let entities = (self.extractor)(solution);
321        if entity_index >= entities.len() {
322            return Sc::zero();
323        }
324
325        // Check if this entity was tracked
326        let Some(key) = self.entity_keys.remove(&entity_index) else {
327            return Sc::zero();
328        };
329
330        let old_score = self.compute_score();
331
332        let old_count = *self.counts.get(&key).unwrap_or(&0);
333        if old_count == 0 {
334            return Sc::zero();
335        }
336
337        let new_count = old_count - 1;
338        if new_count == 0 {
339            self.counts.remove(&key);
340            self.group_count -= 1;
341        } else {
342            self.counts.insert(key, new_count);
343        }
344        self.total_count -= 1;
345        self.sum_squared += new_count * new_count - old_count * old_count;
346
347        let new_score = self.compute_score();
348        new_score - old_score
349    }
350
351    fn reset(&mut self) {
352        self.counts.clear();
353        self.entity_keys.clear();
354        self.group_count = 0;
355        self.total_count = 0;
356        self.sum_squared = 0;
357    }
358
359    fn name(&self) -> &str {
360        &self.constraint_ref.name
361    }
362
363    fn is_hard(&self) -> bool {
364        self.is_hard
365    }
366
367    fn constraint_ref(&self) -> ConstraintRef {
368        self.constraint_ref.clone()
369    }
370}
371
372impl<S, A, K, E, F, KF, Sc> std::fmt::Debug for BalanceConstraint<S, A, K, E, F, KF, Sc>
373where
374    Sc: Score,
375{
376    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377        f.debug_struct("BalanceConstraint")
378            .field("name", &self.constraint_ref.name)
379            .field("impact_type", &self.impact_type)
380            .field("groups", &self.counts.len())
381            .finish()
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    use crate::stream::filter::TrueFilter;
389    use solverforge_core::score::SimpleScore;
390
391    #[derive(Clone)]
392    struct Shift {
393        employee_id: Option<usize>,
394    }
395
396    #[derive(Clone)]
397    struct Solution {
398        shifts: Vec<Shift>,
399    }
400
401    #[test]
402    fn test_balance_evaluate_equal_distribution() {
403        let constraint = BalanceConstraint::new(
404            ConstraintRef::new("", "Balance"),
405            ImpactType::Penalty,
406            |s: &Solution| &s.shifts,
407            TrueFilter,
408            |shift: &Shift| shift.employee_id,
409            SimpleScore::of(1000), // 1000 per unit std_dev
410            false,
411        );
412
413        // Equal distribution: 2 shifts each
414        let solution = Solution {
415            shifts: vec![
416                Shift {
417                    employee_id: Some(0),
418                },
419                Shift {
420                    employee_id: Some(0),
421                },
422                Shift {
423                    employee_id: Some(1),
424                },
425                Shift {
426                    employee_id: Some(1),
427                },
428            ],
429        };
430
431        // Mean = 2, all counts = 2, variance = 0, std_dev = 0
432        assert_eq!(constraint.evaluate(&solution), SimpleScore::of(0));
433    }
434
435    #[test]
436    fn test_balance_evaluate_unequal_distribution() {
437        let constraint = BalanceConstraint::new(
438            ConstraintRef::new("", "Balance"),
439            ImpactType::Penalty,
440            |s: &Solution| &s.shifts,
441            TrueFilter,
442            |shift: &Shift| shift.employee_id,
443            SimpleScore::of(1000), // 1000 per unit std_dev
444            false,
445        );
446
447        // Unequal: employee 0 has 3, employee 1 has 1
448        let solution = Solution {
449            shifts: vec![
450                Shift {
451                    employee_id: Some(0),
452                },
453                Shift {
454                    employee_id: Some(0),
455                },
456                Shift {
457                    employee_id: Some(0),
458                },
459                Shift {
460                    employee_id: Some(1),
461                },
462            ],
463        };
464
465        // Mean = 2, variance = ((3-2)² + (1-2)²) / 2 = 1, std_dev = 1.0
466        // base_score * std_dev = 1000 * 1.0 = 1000, negated = -1000
467        assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1000));
468    }
469
470    #[test]
471    fn test_balance_filters_unassigned() {
472        let constraint = BalanceConstraint::new(
473            ConstraintRef::new("", "Balance"),
474            ImpactType::Penalty,
475            |s: &Solution| &s.shifts,
476            TrueFilter,
477            |shift: &Shift| shift.employee_id,
478            SimpleScore::of(1000),
479            false,
480        );
481
482        // Employee 0: 2, Employee 1: 2, plus unassigned (ignored)
483        let solution = Solution {
484            shifts: vec![
485                Shift {
486                    employee_id: Some(0),
487                },
488                Shift {
489                    employee_id: Some(0),
490                },
491                Shift {
492                    employee_id: Some(1),
493                },
494                Shift {
495                    employee_id: Some(1),
496                },
497                Shift { employee_id: None },
498            ],
499        };
500
501        // Balanced, std_dev = 0
502        assert_eq!(constraint.evaluate(&solution), SimpleScore::of(0));
503    }
504
505    #[test]
506    fn test_balance_incremental() {
507        let mut constraint = BalanceConstraint::new(
508            ConstraintRef::new("", "Balance"),
509            ImpactType::Penalty,
510            |s: &Solution| &s.shifts,
511            TrueFilter,
512            |shift: &Shift| shift.employee_id,
513            SimpleScore::of(1000),
514            false,
515        );
516
517        let solution = Solution {
518            shifts: vec![
519                Shift {
520                    employee_id: Some(0),
521                },
522                Shift {
523                    employee_id: Some(0),
524                },
525                Shift {
526                    employee_id: Some(1),
527                },
528                Shift {
529                    employee_id: Some(1),
530                },
531            ],
532        };
533
534        // Initialize with balanced state (std_dev = 0)
535        let initial = constraint.initialize(&solution);
536        assert_eq!(initial, SimpleScore::of(0));
537
538        // Retract one shift from employee 0
539        let delta = constraint.on_retract(&solution, 0);
540        // Now: employee 0 has 1, employee 1 has 2
541        // Mean = 1.5, variance = (0.25 + 0.25) / 2 = 0.25, std_dev = 0.5
542        // Score = -1000 * 0.5 = -500
543        assert_eq!(delta, SimpleScore::of(-500));
544
545        // Insert it back
546        let delta = constraint.on_insert(&solution, 0);
547        // Back to balanced: delta = +500
548        assert_eq!(delta, SimpleScore::of(500));
549    }
550
551    #[test]
552    fn test_balance_empty_solution() {
553        let constraint = BalanceConstraint::new(
554            ConstraintRef::new("", "Balance"),
555            ImpactType::Penalty,
556            |s: &Solution| &s.shifts,
557            TrueFilter,
558            |shift: &Shift| shift.employee_id,
559            SimpleScore::of(1000),
560            false,
561        );
562
563        let solution = Solution { shifts: vec![] };
564        assert_eq!(constraint.evaluate(&solution), SimpleScore::of(0));
565    }
566
567    #[test]
568    fn test_balance_single_employee() {
569        let constraint = BalanceConstraint::new(
570            ConstraintRef::new("", "Balance"),
571            ImpactType::Penalty,
572            |s: &Solution| &s.shifts,
573            TrueFilter,
574            |shift: &Shift| shift.employee_id,
575            SimpleScore::of(1000),
576            false,
577        );
578
579        // Single employee with 5 shifts - no variance possible
580        let solution = Solution {
581            shifts: vec![
582                Shift {
583                    employee_id: Some(0),
584                },
585                Shift {
586                    employee_id: Some(0),
587                },
588                Shift {
589                    employee_id: Some(0),
590                },
591                Shift {
592                    employee_id: Some(0),
593                },
594                Shift {
595                    employee_id: Some(0),
596                },
597            ],
598        };
599
600        // With only one group, variance = 0
601        assert_eq!(constraint.evaluate(&solution), SimpleScore::of(0));
602    }
603
604    #[test]
605    fn test_balance_reward() {
606        let constraint = BalanceConstraint::new(
607            ConstraintRef::new("", "Balance reward"),
608            ImpactType::Reward,
609            |s: &Solution| &s.shifts,
610            TrueFilter,
611            |shift: &Shift| shift.employee_id,
612            SimpleScore::of(1000),
613            false,
614        );
615
616        let solution = Solution {
617            shifts: vec![
618                Shift {
619                    employee_id: Some(0),
620                },
621                Shift {
622                    employee_id: Some(0),
623                },
624                Shift {
625                    employee_id: Some(0),
626                },
627                Shift {
628                    employee_id: Some(1),
629                },
630            ],
631        };
632
633        // std_dev = 1.0, reward = +1000
634        assert_eq!(constraint.evaluate(&solution), SimpleScore::of(1000));
635    }
636}