Skip to main content

solverforge_scoring/constraint/
balance.rs

1/* Zero-erasure balance constraint for load distribution scoring.
2
3Provides a constraint that penalizes uneven distribution across groups
4using standard deviation. Unlike grouped constraints which score per-group,
5the balance constraint computes a GLOBAL statistic across all groups.
6
7All type information is preserved at compile time - no Arc, no dyn.
8*/
9
10use std::collections::HashMap;
11use std::hash::Hash;
12use std::marker::PhantomData;
13
14use solverforge_core::score::Score;
15use solverforge_core::{ConstraintRef, ImpactType};
16
17use crate::api::constraint_set::IncrementalConstraint;
18use crate::stream::filter::UniFilter;
19
20/* Zero-erasure balance constraint that penalizes uneven load distribution.
21
22This constraint:
231. Groups entities by key (e.g., employee_id)
242. Counts how many entities belong to each group
253. Computes population standard deviation across all group counts
264. Multiplies the base score by std_dev to produce the final score
27
28The key difference from `GroupedUniConstraint` is that balance computes
29a GLOBAL statistic, not per-group scores.
30
31# Type Parameters
32
33- `S` - Solution type
34- `A` - Entity type
35- `K` - Group key type
36- `E` - Extractor function for entities
37- `F` - Filter type
38- `KF` - Key function
39- `Sc` - Score type
40
41# Example
42
43```
44use solverforge_scoring::constraint::balance::BalanceConstraint;
45use solverforge_scoring::api::constraint_set::IncrementalConstraint;
46use solverforge_scoring::stream::filter::TrueFilter;
47use solverforge_core::{ConstraintRef, ImpactType, HardSoftDecimalScore};
48
49#[derive(Clone)]
50struct Shift { employee_id: Option<usize> }
51
52#[derive(Clone)]
53struct Solution { shifts: Vec<Shift> }
54
55// Base score of 1000 soft per unit of std_dev
56let constraint = BalanceConstraint::new(
57ConstraintRef::new("", "Balance workload"),
58ImpactType::Penalty,
59|s: &Solution| &s.shifts,
60TrueFilter,
61|shift: &Shift| shift.employee_id,
62HardSoftDecimalScore::of_soft(1),  // 1 soft per unit std_dev (scaled internally)
63false,
64);
65
66let solution = Solution {
67shifts: vec![
68Shift { employee_id: Some(0) },
69Shift { employee_id: Some(0) },
70Shift { employee_id: Some(0) },
71Shift { employee_id: Some(1) },
72Shift { employee_id: None },  // Unassigned, filtered out
73],
74};
75
76// Employee 0: 3 shifts, Employee 1: 1 shift
77// Mean = 2, Variance = ((3-2)² + (1-2)²) / 2 = 1
78// StdDev = 1.0, Score = -1 soft (base_score * std_dev, negated for penalty)
79let score = constraint.evaluate(&solution);
80assert_eq!(score, HardSoftDecimalScore::of_soft(-1));
81```
82*/
83pub struct BalanceConstraint<S, A, K, E, F, KF, Sc>
84where
85    Sc: Score,
86{
87    constraint_ref: ConstraintRef,
88    impact_type: ImpactType,
89    extractor: E,
90    filter: F,
91    key_fn: KF,
92    // Base score representing 1 unit of standard deviation
93    base_score: Sc,
94    is_hard: bool,
95    // Group key → count of entities in that group
96    counts: HashMap<K, i64>,
97    // Entity index → group key (for tracking assignments)
98    entity_keys: HashMap<usize, K>,
99    // Cached statistics for incremental updates
100    // Number of groups (employees with at least one shift)
101    group_count: i64,
102    // Sum of all counts (total assignments)
103    total_count: i64,
104    // Sum of squared counts (for variance calculation)
105    sum_squared: i64,
106    _phantom: PhantomData<(fn() -> S, fn() -> A)>,
107}
108
109impl<S, A, K, E, F, KF, Sc> BalanceConstraint<S, A, K, E, F, KF, Sc>
110where
111    S: Send + Sync + 'static,
112    A: Clone + Send + Sync + 'static,
113    K: Clone + Eq + Hash + Send + Sync + 'static,
114    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
115    F: UniFilter<S, A>,
116    KF: Fn(&A) -> Option<K> + Send + Sync,
117    Sc: Score + 'static,
118{
119    /* Creates a new zero-erasure balance constraint.
120
121    # Arguments
122
123    * `constraint_ref` - Identifier for this constraint
124    * `impact_type` - Whether to penalize or reward
125    * `extractor` - Function to get entity slice from solution
126    * `filter` - Filter to select which entities to consider
127    * `key_fn` - Function to extract group key (returns None to skip entity)
128    * `base_score` - Score per unit of standard deviation
129    * `is_hard` - Whether this is a hard constraint
130    */
131    pub fn new(
132        constraint_ref: ConstraintRef,
133        impact_type: ImpactType,
134        extractor: E,
135        filter: F,
136        key_fn: KF,
137        base_score: Sc,
138        is_hard: bool,
139    ) -> Self {
140        Self {
141            constraint_ref,
142            impact_type,
143            extractor,
144            filter,
145            key_fn,
146            base_score,
147            is_hard,
148            counts: HashMap::new(),
149            entity_keys: HashMap::new(),
150            group_count: 0,
151            total_count: 0,
152            sum_squared: 0,
153            _phantom: PhantomData,
154        }
155    }
156
157    // Computes standard deviation from cached statistics.
158    fn compute_std_dev(&self) -> f64 {
159        if self.group_count == 0 {
160            return 0.0;
161        }
162        let n = self.group_count as f64;
163        let mean = self.total_count as f64 / n;
164        let variance = (self.sum_squared as f64 / n) - (mean * mean);
165        if variance <= 0.0 {
166            return 0.0;
167        }
168        variance.sqrt()
169    }
170
171    // Computes the score from standard deviation.
172    fn compute_score(&self) -> Sc {
173        let std_dev = self.compute_std_dev();
174        let base = self.base_score.multiply(std_dev);
175        match self.impact_type {
176            ImpactType::Penalty => -base,
177            ImpactType::Reward => base,
178        }
179    }
180
181    // Computes std_dev from raw counts (for stateless evaluate).
182    fn compute_std_dev_from_counts(counts: &HashMap<K, i64>) -> f64 {
183        if counts.is_empty() {
184            return 0.0;
185        }
186        let n = counts.len() as f64;
187        let total: i64 = counts.values().sum();
188        let sum_sq: i64 = counts.values().map(|&c| c * c).sum();
189        let mean = total as f64 / n;
190        let variance = (sum_sq as f64 / n) - (mean * mean);
191        if variance > 0.0 {
192            variance.sqrt()
193        } else {
194            0.0
195        }
196    }
197}
198
199impl<S, A, K, E, F, KF, Sc> IncrementalConstraint<S, Sc>
200    for BalanceConstraint<S, A, K, E, F, KF, Sc>
201where
202    S: Send + Sync + 'static,
203    A: Clone + Send + Sync + 'static,
204    K: Clone + Eq + Hash + Send + Sync + 'static,
205    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
206    F: UniFilter<S, A>,
207    KF: Fn(&A) -> Option<K> + Send + Sync,
208    Sc: Score + 'static,
209{
210    fn evaluate(&self, solution: &S) -> Sc {
211        let entities = self.extractor.extract(solution);
212
213        // Build counts from scratch
214        let mut counts: HashMap<K, i64> = HashMap::new();
215        for entity in entities {
216            if !self.filter.test(solution, entity) {
217                continue;
218            }
219            if let Some(key) = (self.key_fn)(entity) {
220                *counts.entry(key).or_insert(0) += 1;
221            }
222        }
223
224        if counts.is_empty() {
225            return Sc::zero();
226        }
227
228        let std_dev = Self::compute_std_dev_from_counts(&counts);
229        let base = self.base_score.multiply(std_dev);
230        match self.impact_type {
231            ImpactType::Penalty => -base,
232            ImpactType::Reward => base,
233        }
234    }
235
236    fn match_count(&self, solution: &S) -> usize {
237        let entities = self.extractor.extract(solution);
238
239        // Count groups that deviate from mean
240        let mut counts: HashMap<K, i64> = HashMap::new();
241        for entity in entities {
242            if !self.filter.test(solution, entity) {
243                continue;
244            }
245            if let Some(key) = (self.key_fn)(entity) {
246                *counts.entry(key).or_insert(0) += 1;
247            }
248        }
249
250        if counts.is_empty() {
251            return 0;
252        }
253
254        let total: i64 = counts.values().sum();
255        let mean = total as f64 / counts.len() as f64;
256
257        // Count groups that deviate significantly from mean
258        counts
259            .values()
260            .filter(|&&c| (c as f64 - mean).abs() > 0.5)
261            .count()
262    }
263
264    fn initialize(&mut self, solution: &S) -> Sc {
265        self.reset();
266
267        let entities = self.extractor.extract(solution);
268
269        for (idx, entity) in entities.iter().enumerate() {
270            if !self.filter.test(solution, entity) {
271                continue;
272            }
273            if let Some(key) = (self.key_fn)(entity) {
274                let old_count = *self.counts.get(&key).unwrap_or(&0);
275                let new_count = old_count + 1;
276                self.counts.insert(key.clone(), new_count);
277                self.entity_keys.insert(idx, key);
278
279                if old_count == 0 {
280                    self.group_count += 1;
281                }
282                self.total_count += 1;
283                self.sum_squared += new_count * new_count - old_count * old_count;
284            }
285        }
286
287        self.compute_score()
288    }
289
290    fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
291        let entities = self.extractor.extract(solution);
292        if entity_index >= entities.len() {
293            return Sc::zero();
294        }
295
296        let entity = &entities[entity_index];
297        if !self.filter.test(solution, entity) {
298            return Sc::zero();
299        }
300
301        let Some(key) = (self.key_fn)(entity) else {
302            return Sc::zero();
303        };
304
305        let old_score = self.compute_score();
306
307        let old_count = *self.counts.get(&key).unwrap_or(&0);
308        let new_count = old_count + 1;
309        self.counts.insert(key.clone(), new_count);
310        self.entity_keys.insert(entity_index, key);
311
312        if old_count == 0 {
313            self.group_count += 1;
314        }
315        self.total_count += 1;
316        self.sum_squared += new_count * new_count - old_count * old_count;
317
318        let new_score = self.compute_score();
319        new_score - old_score
320    }
321
322    fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
323        let entities = self.extractor.extract(solution);
324        if entity_index >= entities.len() {
325            return Sc::zero();
326        }
327
328        // Check if this entity was tracked
329        let Some(key) = self.entity_keys.remove(&entity_index) else {
330            return Sc::zero();
331        };
332
333        let old_score = self.compute_score();
334
335        let old_count = *self.counts.get(&key).unwrap_or(&0);
336        if old_count == 0 {
337            return Sc::zero();
338        }
339
340        let new_count = old_count - 1;
341        if new_count == 0 {
342            self.counts.remove(&key);
343            self.group_count -= 1;
344        } else {
345            self.counts.insert(key, new_count);
346        }
347        self.total_count -= 1;
348        self.sum_squared += new_count * new_count - old_count * old_count;
349
350        let new_score = self.compute_score();
351        new_score - old_score
352    }
353
354    fn reset(&mut self) {
355        self.counts.clear();
356        self.entity_keys.clear();
357        self.group_count = 0;
358        self.total_count = 0;
359        self.sum_squared = 0;
360    }
361
362    fn name(&self) -> &str {
363        &self.constraint_ref.name
364    }
365
366    fn is_hard(&self) -> bool {
367        self.is_hard
368    }
369
370    fn constraint_ref(&self) -> ConstraintRef {
371        self.constraint_ref.clone()
372    }
373}
374
375impl<S, A, K, E, F, KF, Sc> std::fmt::Debug for BalanceConstraint<S, A, K, E, F, KF, Sc>
376where
377    Sc: Score,
378{
379    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380        f.debug_struct("BalanceConstraint")
381            .field("name", &self.constraint_ref.name)
382            .field("impact_type", &self.impact_type)
383            .field("groups", &self.counts.len())
384            .finish()
385    }
386}