Skip to main content

solverforge_scoring/constraint/
balance.rs

1/* Zero-erasure balance constraint for load distribution scoring.
2
3Provides a constraint that penalizes uneven distribution across groups
4using standard deviation. Unlike grouped constraints which score per-group,
5the balance constraint computes a GLOBAL statistic across all groups.
6
7All type information is preserved at compile time - no Arc, no dyn.
8*/
9
10use std::collections::HashMap;
11use std::hash::Hash;
12use std::marker::PhantomData;
13
14use solverforge_core::score::Score;
15use solverforge_core::{ConstraintRef, ImpactType};
16
17use crate::api::constraint_set::IncrementalConstraint;
18use crate::stream::collection_extract::ChangeSource;
19use crate::stream::filter::UniFilter;
20
21/* Zero-erasure balance constraint that penalizes uneven load distribution.
22
23This constraint:
241. Groups entities by key (e.g., employee_id)
252. Counts how many entities belong to each group
263. Computes population standard deviation across all group counts
274. Multiplies the base score by std_dev to produce the final score
28
29The key difference from `GroupedUniConstraint` is that balance computes
30a GLOBAL statistic, not per-group scores.
31
32# Type Parameters
33
34- `S` - Solution type
35- `A` - Entity type
36- `K` - Group key type
37- `E` - Extractor function for entities
38- `F` - Filter type
39- `KF` - Key function
40- `Sc` - Score type
41
42# Example
43
44```
45use solverforge_scoring::constraint::balance::BalanceConstraint;
46use solverforge_scoring::api::constraint_set::IncrementalConstraint;
47use solverforge_scoring::stream::filter::TrueFilter;
48use solverforge_core::{ConstraintRef, ImpactType, HardSoftDecimalScore};
49
50#[derive(Clone)]
51struct Shift { employee_id: Option<usize> }
52
53#[derive(Clone)]
54struct Solution { shifts: Vec<Shift> }
55
56// Base score of 1000 soft per unit of std_dev
57let constraint = BalanceConstraint::new(
58ConstraintRef::new("", "Balance workload"),
59ImpactType::Penalty,
60|s: &Solution| &s.shifts,
61TrueFilter,
62|shift: &Shift| shift.employee_id,
63HardSoftDecimalScore::of_soft(1),  // 1 soft per unit std_dev (scaled internally)
64false,
65);
66
67let solution = Solution {
68shifts: vec![
69Shift { employee_id: Some(0) },
70Shift { employee_id: Some(0) },
71Shift { employee_id: Some(0) },
72Shift { employee_id: Some(1) },
73Shift { employee_id: None },  // Unassigned, filtered out
74],
75};
76
77// Employee 0: 3 shifts, Employee 1: 1 shift
78// Mean = 2, Variance = ((3-2)² + (1-2)²) / 2 = 1
79// StdDev = 1.0, Score = -1 soft (base_score * std_dev, negated for penalty)
80let score = constraint.evaluate(&solution);
81assert_eq!(score, HardSoftDecimalScore::of_soft(-1));
82```
83*/
84pub struct BalanceConstraint<S, A, K, E, F, KF, Sc>
85where
86    Sc: Score,
87{
88    constraint_ref: ConstraintRef,
89    impact_type: ImpactType,
90    extractor: E,
91    filter: F,
92    key_fn: KF,
93    change_source: ChangeSource,
94    // Base score representing 1 unit of standard deviation
95    base_score: Sc,
96    is_hard: bool,
97    // Group key → count of entities in that group
98    counts: HashMap<K, i64>,
99    // Entity index → group key (for tracking assignments)
100    entity_keys: HashMap<usize, K>,
101    // Cached statistics for incremental updates
102    // Number of groups (employees with at least one shift)
103    group_count: i64,
104    // Sum of all counts (total assignments)
105    total_count: i64,
106    // Sum of squared counts (for variance calculation)
107    sum_squared: i64,
108    _phantom: PhantomData<(fn() -> S, fn() -> A)>,
109}
110
111impl<S, A, K, E, F, KF, Sc> BalanceConstraint<S, A, K, E, F, KF, Sc>
112where
113    S: Send + Sync + 'static,
114    A: Clone + Send + Sync + 'static,
115    K: Clone + Eq + Hash + Send + Sync + 'static,
116    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
117    F: UniFilter<S, A>,
118    KF: Fn(&A) -> Option<K> + Send + Sync,
119    Sc: Score + 'static,
120{
121    /* Creates a new zero-erasure balance constraint.
122
123    # Arguments
124
125    * `constraint_ref` - Identifier for this constraint
126    * `impact_type` - Whether to penalize or reward
127    * `extractor` - Function to get entity slice from solution
128    * `filter` - Filter to select which entities to consider
129    * `key_fn` - Function to extract group key (returns None to skip entity)
130    * `base_score` - Score per unit of standard deviation
131    * `is_hard` - Whether this is a hard constraint
132    */
133    pub fn new(
134        constraint_ref: ConstraintRef,
135        impact_type: ImpactType,
136        extractor: E,
137        filter: F,
138        key_fn: KF,
139        base_score: Sc,
140        is_hard: bool,
141    ) -> Self {
142        let change_source = extractor.change_source();
143        Self {
144            constraint_ref,
145            impact_type,
146            extractor,
147            filter,
148            key_fn,
149            change_source,
150            base_score,
151            is_hard,
152            counts: HashMap::new(),
153            entity_keys: HashMap::new(),
154            group_count: 0,
155            total_count: 0,
156            sum_squared: 0,
157            _phantom: PhantomData,
158        }
159    }
160
161    // Computes standard deviation from cached statistics.
162    fn compute_std_dev(&self) -> f64 {
163        if self.group_count == 0 {
164            return 0.0;
165        }
166        let n = self.group_count as f64;
167        let mean = self.total_count as f64 / n;
168        let variance = (self.sum_squared as f64 / n) - (mean * mean);
169        if variance <= 0.0 {
170            return 0.0;
171        }
172        variance.sqrt()
173    }
174
175    // Computes the score from standard deviation.
176    fn compute_score(&self) -> Sc {
177        let std_dev = self.compute_std_dev();
178        let base = self.base_score.multiply(std_dev);
179        match self.impact_type {
180            ImpactType::Penalty => -base,
181            ImpactType::Reward => base,
182        }
183    }
184
185    // Computes std_dev from raw counts (for stateless evaluate).
186    fn compute_std_dev_from_counts(counts: &HashMap<K, i64>) -> f64 {
187        if counts.is_empty() {
188            return 0.0;
189        }
190        let n = counts.len() as f64;
191        let total: i64 = counts.values().sum();
192        let sum_sq: i64 = counts.values().map(|&c| c * c).sum();
193        let mean = total as f64 / n;
194        let variance = (sum_sq as f64 / n) - (mean * mean);
195        if variance > 0.0 {
196            variance.sqrt()
197        } else {
198            0.0
199        }
200    }
201}
202
203impl<S, A, K, E, F, KF, Sc> IncrementalConstraint<S, Sc>
204    for BalanceConstraint<S, A, K, E, F, KF, Sc>
205where
206    S: Send + Sync + 'static,
207    A: Clone + Send + Sync + 'static,
208    K: Clone + Eq + Hash + Send + Sync + 'static,
209    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
210    F: UniFilter<S, A>,
211    KF: Fn(&A) -> Option<K> + Send + Sync,
212    Sc: Score + 'static,
213{
214    fn evaluate(&self, solution: &S) -> Sc {
215        let entities = self.extractor.extract(solution);
216
217        // Build counts from scratch
218        let mut counts: HashMap<K, i64> = HashMap::new();
219        for entity in entities {
220            if !self.filter.test(solution, entity) {
221                continue;
222            }
223            if let Some(key) = (self.key_fn)(entity) {
224                *counts.entry(key).or_insert(0) += 1;
225            }
226        }
227
228        if counts.is_empty() {
229            return Sc::zero();
230        }
231
232        let std_dev = Self::compute_std_dev_from_counts(&counts);
233        let base = self.base_score.multiply(std_dev);
234        match self.impact_type {
235            ImpactType::Penalty => -base,
236            ImpactType::Reward => base,
237        }
238    }
239
240    fn match_count(&self, solution: &S) -> usize {
241        let entities = self.extractor.extract(solution);
242
243        // Count groups that deviate from mean
244        let mut counts: HashMap<K, i64> = HashMap::new();
245        for entity in entities {
246            if !self.filter.test(solution, entity) {
247                continue;
248            }
249            if let Some(key) = (self.key_fn)(entity) {
250                *counts.entry(key).or_insert(0) += 1;
251            }
252        }
253
254        if counts.is_empty() {
255            return 0;
256        }
257
258        let total: i64 = counts.values().sum();
259        let mean = total as f64 / counts.len() as f64;
260
261        // Count groups that deviate significantly from mean
262        counts
263            .values()
264            .filter(|&&c| (c as f64 - mean).abs() > 0.5)
265            .count()
266    }
267
268    fn initialize(&mut self, solution: &S) -> Sc {
269        self.reset();
270
271        let entities = self.extractor.extract(solution);
272
273        for (idx, entity) in entities.iter().enumerate() {
274            if !self.filter.test(solution, entity) {
275                continue;
276            }
277            if let Some(key) = (self.key_fn)(entity) {
278                let old_count = *self.counts.get(&key).unwrap_or(&0);
279                let new_count = old_count + 1;
280                self.counts.insert(key.clone(), new_count);
281                self.entity_keys.insert(idx, key);
282
283                if old_count == 0 {
284                    self.group_count += 1;
285                }
286                self.total_count += 1;
287                self.sum_squared += new_count * new_count - old_count * old_count;
288            }
289        }
290
291        self.compute_score()
292    }
293
294    fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
295        if !self
296            .change_source
297            .assert_localizes(_descriptor_index, &self.constraint_ref.name)
298        {
299            return Sc::zero();
300        }
301        let entities = self.extractor.extract(solution);
302        if entity_index >= entities.len() {
303            return Sc::zero();
304        }
305
306        let entity = &entities[entity_index];
307        if !self.filter.test(solution, entity) {
308            return Sc::zero();
309        }
310
311        let Some(key) = (self.key_fn)(entity) else {
312            return Sc::zero();
313        };
314
315        let old_score = self.compute_score();
316
317        let old_count = *self.counts.get(&key).unwrap_or(&0);
318        let new_count = old_count + 1;
319        self.counts.insert(key.clone(), new_count);
320        self.entity_keys.insert(entity_index, key);
321
322        if old_count == 0 {
323            self.group_count += 1;
324        }
325        self.total_count += 1;
326        self.sum_squared += new_count * new_count - old_count * old_count;
327
328        let new_score = self.compute_score();
329        new_score - old_score
330    }
331
332    fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
333        if !self
334            .change_source
335            .assert_localizes(_descriptor_index, &self.constraint_ref.name)
336        {
337            return Sc::zero();
338        }
339        let entities = self.extractor.extract(solution);
340        if entity_index >= entities.len() {
341            return Sc::zero();
342        }
343
344        // Check if this entity was tracked
345        let Some(key) = self.entity_keys.remove(&entity_index) else {
346            return Sc::zero();
347        };
348
349        let old_score = self.compute_score();
350
351        let old_count = *self.counts.get(&key).unwrap_or(&0);
352        if old_count == 0 {
353            return Sc::zero();
354        }
355
356        let new_count = old_count - 1;
357        if new_count == 0 {
358            self.counts.remove(&key);
359            self.group_count -= 1;
360        } else {
361            self.counts.insert(key, new_count);
362        }
363        self.total_count -= 1;
364        self.sum_squared += new_count * new_count - old_count * old_count;
365
366        let new_score = self.compute_score();
367        new_score - old_score
368    }
369
370    fn reset(&mut self) {
371        self.counts.clear();
372        self.entity_keys.clear();
373        self.group_count = 0;
374        self.total_count = 0;
375        self.sum_squared = 0;
376    }
377
378    fn name(&self) -> &str {
379        &self.constraint_ref.name
380    }
381
382    fn is_hard(&self) -> bool {
383        self.is_hard
384    }
385
386    fn constraint_ref(&self) -> &ConstraintRef {
387        &self.constraint_ref
388    }
389}
390
391impl<S, A, K, E, F, KF, Sc> std::fmt::Debug for BalanceConstraint<S, A, K, E, F, KF, Sc>
392where
393    Sc: Score,
394{
395    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396        f.debug_struct("BalanceConstraint")
397            .field("name", &self.constraint_ref.name)
398            .field("impact_type", &self.impact_type)
399            .field("groups", &self.counts.len())
400            .finish()
401    }
402}