Skip to main content

solverforge_scoring/stream/
balance_stream.rs

1// Zero-erasure balance constraint stream for load distribution patterns.
2//
3// A `BalanceConstraintStream` is created from `UniConstraintStream::balance()`
4// and provides fluent finalization into a `BalanceConstraint`.
5
6use std::hash::Hash;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use super::filter::UniFilter;
13use crate::constraint::balance::BalanceConstraint;
14
15// Zero-erasure stream for building balance constraints.
16//
17// Created by `UniConstraintStream::balance()`. Provides `penalize()` and
18// `reward()` methods to finalize the constraint.
19//
20// # Type Parameters
21//
22// - `S` - Solution type
23// - `A` - Entity type
24// - `K` - Group key type
25// - `E` - Extractor function for entities
26// - `F` - Filter type
27// - `KF` - Key function (returns Option<K> to skip unassigned entities)
28// - `Sc` - Score type
29//
30// # Example
31//
32// ```
33// use solverforge_scoring::stream::ConstraintFactory;
34// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
35// use solverforge_core::score::SimpleScore;
36//
37// #[derive(Clone)]
38// struct Shift { employee_id: Option<usize> }
39//
40// #[derive(Clone)]
41// struct Solution { shifts: Vec<Shift> }
42//
43// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
44//     .for_each(|s: &Solution| &s.shifts)
45//     .balance(|shift: &Shift| shift.employee_id)
46//     .penalize(SimpleScore::of(1000))
47//     .as_constraint("Balance workload");
48//
49// let solution = Solution {
50//     shifts: vec![
51//         Shift { employee_id: Some(0) },
52//         Shift { employee_id: Some(0) },
53//         Shift { employee_id: Some(0) },
54//         Shift { employee_id: Some(1) },
55//     ],
56// };
57//
58// // std_dev = 1.0, penalty = -1000
59// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1000));
60// ```
61pub struct BalanceConstraintStream<S, A, K, E, F, KF, Sc>
62where
63    Sc: Score,
64{
65    extractor: E,
66    filter: F,
67    key_fn: KF,
68    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
69}
70
71impl<S, A, K, E, F, KF, Sc> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
72where
73    S: Send + Sync + 'static,
74    A: Clone + Send + Sync + 'static,
75    K: Clone + Eq + Hash + Send + Sync + 'static,
76    E: Fn(&S) -> &[A] + Send + Sync,
77    F: UniFilter<S, A>,
78    KF: Fn(&A) -> Option<K> + Send + Sync,
79    Sc: Score + 'static,
80{
81    // Creates a new balance constraint stream.
82    pub(crate) fn new(extractor: E, filter: F, key_fn: KF) -> Self {
83        Self {
84            extractor,
85            filter,
86            key_fn,
87            _phantom: PhantomData,
88        }
89    }
90
91    // Penalizes imbalanced distribution with the given base score per unit std_dev.
92    //
93    // The final score is `base_score.multiply(std_dev)`, negated for penalty.
94    pub fn penalize(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
95        let is_hard = base_score
96            .to_level_numbers()
97            .first()
98            .map(|&h| h != 0)
99            .unwrap_or(false);
100        BalanceConstraintBuilder {
101            extractor: self.extractor,
102            filter: self.filter,
103            key_fn: self.key_fn,
104            impact_type: ImpactType::Penalty,
105            base_score,
106            is_hard,
107            _phantom: PhantomData,
108        }
109    }
110
111    // Rewards imbalanced distribution with the given base score per unit std_dev.
112    //
113    // The final score is `base_score.multiply(std_dev)`.
114    pub fn reward(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
115        let is_hard = base_score
116            .to_level_numbers()
117            .first()
118            .map(|&h| h != 0)
119            .unwrap_or(false);
120        BalanceConstraintBuilder {
121            extractor: self.extractor,
122            filter: self.filter,
123            key_fn: self.key_fn,
124            impact_type: ImpactType::Reward,
125            base_score,
126            is_hard,
127            _phantom: PhantomData,
128        }
129    }
130}
131
132impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
133    for BalanceConstraintStream<S, A, K, E, F, KF, Sc>
134{
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        f.debug_struct("BalanceConstraintStream").finish()
137    }
138}
139
140// Zero-erasure builder for finalizing a balance constraint.
141pub struct BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
142where
143    Sc: Score,
144{
145    extractor: E,
146    filter: F,
147    key_fn: KF,
148    impact_type: ImpactType,
149    base_score: Sc,
150    is_hard: bool,
151    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K)>,
152}
153
154impl<S, A, K, E, F, KF, Sc> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
155where
156    S: Send + Sync + 'static,
157    A: Clone + Send + Sync + 'static,
158    K: Clone + Eq + Hash + Send + Sync + 'static,
159    E: Fn(&S) -> &[A] + Send + Sync,
160    F: UniFilter<S, A>,
161    KF: Fn(&A) -> Option<K> + Send + Sync,
162    Sc: Score + 'static,
163{
164    // Finalizes the builder into a zero-erasure `BalanceConstraint`.
165    pub fn as_constraint(self, name: &str) -> BalanceConstraint<S, A, K, E, F, KF, Sc> {
166        BalanceConstraint::new(
167            ConstraintRef::new("", name),
168            self.impact_type,
169            self.extractor,
170            self.filter,
171            self.key_fn,
172            self.base_score,
173            self.is_hard,
174        )
175    }
176}
177
178impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
179    for BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
180{
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        f.debug_struct("BalanceConstraintBuilder")
183            .field("impact_type", &self.impact_type)
184            .finish()
185    }
186}