Skip to main content

solverforge_scoring/constraint/
balance.rs

1// Zero-erasure balance constraint for load distribution scoring.
2//
3// Provides a constraint that penalizes uneven distribution across groups
4// using standard deviation. Unlike grouped constraints which score per-group,
5// the balance constraint computes a GLOBAL statistic across all groups.
6//
7// All type information is preserved at compile time - no Arc, no dyn.
8
9use std::collections::HashMap;
10use std::hash::Hash;
11use std::marker::PhantomData;
12
13use solverforge_core::score::Score;
14use solverforge_core::{ConstraintRef, ImpactType};
15
16use crate::api::constraint_set::IncrementalConstraint;
17use crate::stream::filter::UniFilter;
18
19// Zero-erasure balance constraint that penalizes uneven load distribution.
20//
21// This constraint:
22// 1. Groups entities by key (e.g., employee_id)
23// 2. Counts how many entities belong to each group
24// 3. Computes population standard deviation across all group counts
25// 4. Multiplies the base score by std_dev to produce the final score
26//
27// The key difference from `GroupedUniConstraint` is that balance computes
28// a GLOBAL statistic, not per-group scores.
29//
30// # Type Parameters
31//
32// - `S` - Solution type
33// - `A` - Entity type
34// - `K` - Group key type
35// - `E` - Extractor function for entities
36// - `F` - Filter type
37// - `KF` - Key function
38// - `Sc` - Score type
39//
40// # Example
41//
42// ```
43// use solverforge_scoring::constraint::balance::BalanceConstraint;
44// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
45// use solverforge_scoring::stream::filter::TrueFilter;
46// use solverforge_core::{ConstraintRef, ImpactType, HardSoftDecimalScore};
47//
48// #[derive(Clone)]
49// struct Shift { employee_id: Option<usize> }
50//
51// #[derive(Clone)]
52// struct Solution { shifts: Vec<Shift> }
53//
54// // Base score of 1000 soft per unit of std_dev
55// let constraint = BalanceConstraint::new(
56//     ConstraintRef::new("", "Balance workload"),
57//     ImpactType::Penalty,
58//     |s: &Solution| &s.shifts,
59//     TrueFilter,
60//     |shift: &Shift| shift.employee_id,
61//     HardSoftDecimalScore::of_soft(1),  // 1 soft per unit std_dev (scaled internally)
62//     false,
63// );
64//
65// let solution = Solution {
66//     shifts: vec![
67//         Shift { employee_id: Some(0) },
68//         Shift { employee_id: Some(0) },
69//         Shift { employee_id: Some(0) },
70//         Shift { employee_id: Some(1) },
71//         Shift { employee_id: None },  // Unassigned, filtered out
72//     ],
73// };
74//
75// // Employee 0: 3 shifts, Employee 1: 1 shift
76// // Mean = 2, Variance = ((3-2)² + (1-2)²) / 2 = 1
77// // StdDev = 1.0, Score = -1 soft (base_score * std_dev, negated for penalty)
78// let score = constraint.evaluate(&solution);
79// assert_eq!(score, HardSoftDecimalScore::of_soft(-1));
80// ```
81pub struct BalanceConstraint<S, A, K, E, F, KF, Sc>
82where
83    Sc: Score,
84{
85    constraint_ref: ConstraintRef,
86    impact_type: ImpactType,
87    extractor: E,
88    filter: F,
89    key_fn: KF,
90    // Base score representing 1 unit of standard deviation
91    base_score: Sc,
92    is_hard: bool,
93    // Group key → count of entities in that group
94    counts: HashMap<K, i64>,
95    // Entity index → group key (for tracking assignments)
96    entity_keys: HashMap<usize, K>,
97    // Cached statistics for incremental updates
98    // Number of groups (employees with at least one shift)
99    group_count: i64,
100    // Sum of all counts (total assignments)
101    total_count: i64,
102    // Sum of squared counts (for variance calculation)
103    sum_squared: i64,
104    _phantom: PhantomData<(fn() -> S, fn() -> A)>,
105}
106
107impl<S, A, K, E, F, KF, Sc> BalanceConstraint<S, A, K, E, F, KF, Sc>
108where
109    S: Send + Sync + 'static,
110    A: Clone + Send + Sync + 'static,
111    K: Clone + Eq + Hash + Send + Sync + 'static,
112    E: Fn(&S) -> &[A] + Send + Sync,
113    F: UniFilter<S, A>,
114    KF: Fn(&A) -> Option<K> + Send + Sync,
115    Sc: Score + 'static,
116{
117    // Creates a new zero-erasure balance constraint.
118    //
119    // # Arguments
120    //
121    // * `constraint_ref` - Identifier for this constraint
122    // * `impact_type` - Whether to penalize or reward
123    // * `extractor` - Function to get entity slice from solution
124    // * `filter` - Filter to select which entities to consider
125    // * `key_fn` - Function to extract group key (returns None to skip entity)
126    // * `base_score` - Score per unit of standard deviation
127    // * `is_hard` - Whether this is a hard constraint
128    pub fn new(
129        constraint_ref: ConstraintRef,
130        impact_type: ImpactType,
131        extractor: E,
132        filter: F,
133        key_fn: KF,
134        base_score: Sc,
135        is_hard: bool,
136    ) -> Self {
137        Self {
138            constraint_ref,
139            impact_type,
140            extractor,
141            filter,
142            key_fn,
143            base_score,
144            is_hard,
145            counts: HashMap::new(),
146            entity_keys: HashMap::new(),
147            group_count: 0,
148            total_count: 0,
149            sum_squared: 0,
150            _phantom: PhantomData,
151        }
152    }
153
154    // Computes standard deviation from cached statistics.
155    fn compute_std_dev(&self) -> f64 {
156        if self.group_count == 0 {
157            return 0.0;
158        }
159        let n = self.group_count as f64;
160        let mean = self.total_count as f64 / n;
161        let variance = (self.sum_squared as f64 / n) - (mean * mean);
162        if variance <= 0.0 {
163            return 0.0;
164        }
165        variance.sqrt()
166    }
167
168    // Computes the score from standard deviation.
169    fn compute_score(&self) -> Sc {
170        let std_dev = self.compute_std_dev();
171        let base = self.base_score.multiply(std_dev);
172        match self.impact_type {
173            ImpactType::Penalty => -base,
174            ImpactType::Reward => base,
175        }
176    }
177
178    // Computes std_dev from raw counts (for stateless evaluate).
179    fn compute_std_dev_from_counts(counts: &HashMap<K, i64>) -> f64 {
180        if counts.is_empty() {
181            return 0.0;
182        }
183        let n = counts.len() as f64;
184        let total: i64 = counts.values().sum();
185        let sum_sq: i64 = counts.values().map(|&c| c * c).sum();
186        let mean = total as f64 / n;
187        let variance = (sum_sq as f64 / n) - (mean * mean);
188        if variance > 0.0 {
189            variance.sqrt()
190        } else {
191            0.0
192        }
193    }
194}
195
196impl<S, A, K, E, F, KF, Sc> IncrementalConstraint<S, Sc>
197    for BalanceConstraint<S, A, K, E, F, KF, Sc>
198where
199    S: Send + Sync + 'static,
200    A: Clone + Send + Sync + 'static,
201    K: Clone + Eq + Hash + Send + Sync + 'static,
202    E: Fn(&S) -> &[A] + Send + Sync,
203    F: UniFilter<S, A>,
204    KF: Fn(&A) -> Option<K> + Send + Sync,
205    Sc: Score + 'static,
206{
207    fn evaluate(&self, solution: &S) -> Sc {
208        let entities = (self.extractor)(solution);
209
210        // Build counts from scratch
211        let mut counts: HashMap<K, i64> = HashMap::new();
212        for entity in entities {
213            if !self.filter.test(solution, entity) {
214                continue;
215            }
216            if let Some(key) = (self.key_fn)(entity) {
217                *counts.entry(key).or_insert(0) += 1;
218            }
219        }
220
221        if counts.is_empty() {
222            return Sc::zero();
223        }
224
225        let std_dev = Self::compute_std_dev_from_counts(&counts);
226        let base = self.base_score.multiply(std_dev);
227        match self.impact_type {
228            ImpactType::Penalty => -base,
229            ImpactType::Reward => base,
230        }
231    }
232
233    fn match_count(&self, solution: &S) -> usize {
234        let entities = (self.extractor)(solution);
235
236        // Count groups that deviate from mean
237        let mut counts: HashMap<K, i64> = HashMap::new();
238        for entity in entities {
239            if !self.filter.test(solution, entity) {
240                continue;
241            }
242            if let Some(key) = (self.key_fn)(entity) {
243                *counts.entry(key).or_insert(0) += 1;
244            }
245        }
246
247        if counts.is_empty() {
248            return 0;
249        }
250
251        let total: i64 = counts.values().sum();
252        let mean = total as f64 / counts.len() as f64;
253
254        // Count groups that deviate significantly from mean
255        counts
256            .values()
257            .filter(|&&c| (c as f64 - mean).abs() > 0.5)
258            .count()
259    }
260
261    fn initialize(&mut self, solution: &S) -> Sc {
262        self.reset();
263
264        let entities = (self.extractor)(solution);
265
266        for (idx, entity) in entities.iter().enumerate() {
267            if !self.filter.test(solution, entity) {
268                continue;
269            }
270            if let Some(key) = (self.key_fn)(entity) {
271                let old_count = *self.counts.get(&key).unwrap_or(&0);
272                let new_count = old_count + 1;
273                self.counts.insert(key.clone(), new_count);
274                self.entity_keys.insert(idx, key);
275
276                if old_count == 0 {
277                    self.group_count += 1;
278                }
279                self.total_count += 1;
280                self.sum_squared += new_count * new_count - old_count * old_count;
281            }
282        }
283
284        self.compute_score()
285    }
286
287    fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
288        let entities = (self.extractor)(solution);
289        if entity_index >= entities.len() {
290            return Sc::zero();
291        }
292
293        let entity = &entities[entity_index];
294        if !self.filter.test(solution, entity) {
295            return Sc::zero();
296        }
297
298        let Some(key) = (self.key_fn)(entity) else {
299            return Sc::zero();
300        };
301
302        let old_score = self.compute_score();
303
304        let old_count = *self.counts.get(&key).unwrap_or(&0);
305        let new_count = old_count + 1;
306        self.counts.insert(key.clone(), new_count);
307        self.entity_keys.insert(entity_index, key);
308
309        if old_count == 0 {
310            self.group_count += 1;
311        }
312        self.total_count += 1;
313        self.sum_squared += new_count * new_count - old_count * old_count;
314
315        let new_score = self.compute_score();
316        new_score - old_score
317    }
318
319    fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
320        let entities = (self.extractor)(solution);
321        if entity_index >= entities.len() {
322            return Sc::zero();
323        }
324
325        // Check if this entity was tracked
326        let Some(key) = self.entity_keys.remove(&entity_index) else {
327            return Sc::zero();
328        };
329
330        let old_score = self.compute_score();
331
332        let old_count = *self.counts.get(&key).unwrap_or(&0);
333        if old_count == 0 {
334            return Sc::zero();
335        }
336
337        let new_count = old_count - 1;
338        if new_count == 0 {
339            self.counts.remove(&key);
340            self.group_count -= 1;
341        } else {
342            self.counts.insert(key, new_count);
343        }
344        self.total_count -= 1;
345        self.sum_squared += new_count * new_count - old_count * old_count;
346
347        let new_score = self.compute_score();
348        new_score - old_score
349    }
350
351    fn reset(&mut self) {
352        self.counts.clear();
353        self.entity_keys.clear();
354        self.group_count = 0;
355        self.total_count = 0;
356        self.sum_squared = 0;
357    }
358
359    fn name(&self) -> &str {
360        &self.constraint_ref.name
361    }
362
363    fn is_hard(&self) -> bool {
364        self.is_hard
365    }
366
367    fn constraint_ref(&self) -> ConstraintRef {
368        self.constraint_ref.clone()
369    }
370}
371
372impl<S, A, K, E, F, KF, Sc> std::fmt::Debug for BalanceConstraint<S, A, K, E, F, KF, Sc>
373where
374    Sc: Score,
375{
376    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377        f.debug_struct("BalanceConstraint")
378            .field("name", &self.constraint_ref.name)
379            .field("impact_type", &self.impact_type)
380            .field("groups", &self.counts.len())
381            .finish()
382    }
383}