Skip to main content

solverforge_scoring/stream/
balance_stream.rs

1/* Zero-erasure balance constraint stream for load distribution patterns.
2
3A `BalanceConstraintStream` is created from `UniConstraintStream::balance()`
4and provides fluent finalization into a `BalanceConstraint`.
5*/
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collection_extract::CollectionExtract;
14use super::filter::UniFilter;
15use crate::constraint::balance::BalanceConstraint;
16
17/* Zero-erasure stream for building balance constraints.
18
19Created by `UniConstraintStream::balance()`. Provides `penalize()` and
20`reward()` methods to finalize the constraint.
21
22# Type Parameters
23
24- `S` - Solution type
25- `A` - Entity type
26- `K` - Group key type
27- `E` - Extractor function for entities
28- `F` - Filter type
29- `KF` - Key function (returns Option<K> to skip unassigned entities)
30- `Sc` - Score type
31
32# Example
33
34```
35use solverforge_scoring::stream::ConstraintFactory;
36use solverforge_scoring::api::constraint_set::IncrementalConstraint;
37use solverforge_core::score::SoftScore;
38
39#[derive(Clone)]
40struct Shift { employee_id: Option<usize> }
41
42#[derive(Clone)]
43struct Solution { shifts: Vec<Shift> }
44
45let constraint = ConstraintFactory::<Solution, SoftScore>::new()
46.for_each(|s: &Solution| &s.shifts)
47.balance(|shift: &Shift| shift.employee_id)
48.penalize(SoftScore::of(1000))
49.named("Balance workload");
50
51let solution = Solution {
52shifts: vec![
53Shift { employee_id: Some(0) },
54Shift { employee_id: Some(0) },
55Shift { employee_id: Some(0) },
56Shift { employee_id: Some(1) },
57],
58};
59
60// std_dev = 1.0, penalty = -1000
61assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1000));
62```
63*/
64pub struct BalanceConstraintStream<S, A, K, E, F, KF, Sc>
65where
66    Sc: Score,
67{
68    extractor: E,
69    filter: F,
70    key_fn: KF,
71    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
72}
73
74impl<S, A, K, E, F, KF, Sc> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
75where
76    S: Send + Sync + 'static,
77    A: Clone + Send + Sync + 'static,
78    K: Clone + Eq + Hash + Send + Sync + 'static,
79    E: CollectionExtract<S, Item = A>,
80    F: UniFilter<S, A>,
81    KF: Fn(&A) -> Option<K> + Send + Sync,
82    Sc: Score + 'static,
83{
84    // Creates a new balance constraint stream.
85    pub(crate) fn new(extractor: E, filter: F, key_fn: KF) -> Self {
86        Self {
87            extractor,
88            filter,
89            key_fn,
90            _phantom: PhantomData,
91        }
92    }
93
94    /* Penalizes imbalanced distribution with the given base score per unit std_dev.
95
96    The final score is `base_score.multiply(std_dev)`, negated for penalty.
97    */
98    pub fn penalize(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
99        let is_hard = base_score
100            .to_level_numbers()
101            .first()
102            .map(|&h| h != 0)
103            .unwrap_or(false);
104        BalanceConstraintBuilder {
105            extractor: self.extractor,
106            filter: self.filter,
107            key_fn: self.key_fn,
108            impact_type: ImpactType::Penalty,
109            base_score,
110            is_hard,
111            _phantom: PhantomData,
112        }
113    }
114
115    // Penalizes imbalanced distribution with one hard score unit per unit std_dev.
116    pub fn penalize_hard(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
117    where
118        Sc: Copy,
119    {
120        self.penalize(Sc::one_hard())
121    }
122
123    // Penalizes imbalanced distribution with one soft score unit per unit std_dev.
124    pub fn penalize_soft(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
125    where
126        Sc: Copy,
127    {
128        self.penalize(Sc::one_soft())
129    }
130
131    // Rewards imbalanced distribution with one hard score unit per unit std_dev.
132    pub fn reward_hard(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
133    where
134        Sc: Copy,
135    {
136        self.reward(Sc::one_hard())
137    }
138
139    // Rewards imbalanced distribution with one soft score unit per unit std_dev.
140    pub fn reward_soft(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
141    where
142        Sc: Copy,
143    {
144        self.reward(Sc::one_soft())
145    }
146
147    /* Rewards imbalanced distribution with the given base score per unit std_dev.
148
149    The final score is `base_score.multiply(std_dev)`.
150    */
151    pub fn reward(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
152        let is_hard = base_score
153            .to_level_numbers()
154            .first()
155            .map(|&h| h != 0)
156            .unwrap_or(false);
157        BalanceConstraintBuilder {
158            extractor: self.extractor,
159            filter: self.filter,
160            key_fn: self.key_fn,
161            impact_type: ImpactType::Reward,
162            base_score,
163            is_hard,
164            _phantom: PhantomData,
165        }
166    }
167}
168
169impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
170    for BalanceConstraintStream<S, A, K, E, F, KF, Sc>
171{
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        f.debug_struct("BalanceConstraintStream").finish()
174    }
175}
176
177// Zero-erasure builder for finalizing a balance constraint.
178pub struct BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
179where
180    Sc: Score,
181{
182    extractor: E,
183    filter: F,
184    key_fn: KF,
185    impact_type: ImpactType,
186    base_score: Sc,
187    is_hard: bool,
188    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K)>,
189}
190
191impl<S, A, K, E, F, KF, Sc> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
192where
193    S: Send + Sync + 'static,
194    A: Clone + Send + Sync + 'static,
195    K: Clone + Eq + Hash + Send + Sync + 'static,
196    E: CollectionExtract<S, Item = A>,
197    F: UniFilter<S, A>,
198    KF: Fn(&A) -> Option<K> + Send + Sync,
199    Sc: Score + 'static,
200{
201    pub fn named(self, name: &str) -> BalanceConstraint<S, A, K, E, F, KF, Sc> {
202        BalanceConstraint::new(
203            ConstraintRef::new("", name),
204            self.impact_type,
205            self.extractor,
206            self.filter,
207            self.key_fn,
208            self.base_score,
209            self.is_hard,
210        )
211    }
212
213    // Finalizes the builder into a zero-erasure `BalanceConstraint`.
214}
215
216impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
217    for BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
218{
219    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220        f.debug_struct("BalanceConstraintBuilder")
221            .field("impact_type", &self.impact_type)
222            .finish()
223    }
224}