Skip to main content

solverforge_scoring/stream/
balance_stream.rs

1/* Zero-erasure balance constraint stream for load distribution patterns.
2
3A `BalanceConstraintStream` is created from `UniConstraintStream::balance()`
4and provides fluent finalization into a `BalanceConstraint`.
5*/
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collection_extract::CollectionExtract;
14use super::filter::UniFilter;
15use super::weighting_support::fixed_weight_is_hard;
16use crate::constraint::balance::BalanceConstraint;
17
18/* Zero-erasure stream for building balance constraints.
19
20Created by `UniConstraintStream::balance()`. Provides `penalize()` and
21`reward()` methods to finalize the constraint.
22
23# Type Parameters
24
25- `S` - Solution type
26- `A` - Entity type
27- `K` - Group key type
28- `E` - Extractor function for entities
29- `F` - Filter type
30- `KF` - Key function (returns Option<K> to skip unassigned entities)
31- `Sc` - Score type
32
33# Example
34
35```
36use solverforge_scoring::stream::ConstraintFactory;
37use solverforge_scoring::api::constraint_set::IncrementalConstraint;
38use solverforge_core::score::SoftScore;
39
40#[derive(Clone)]
41struct Shift { employee_id: Option<usize> }
42
43#[derive(Clone)]
44struct Solution { shifts: Vec<Shift> }
45
46let constraint = ConstraintFactory::<Solution, SoftScore>::new()
47.for_each(|s: &Solution| &s.shifts)
48.balance(|shift: &Shift| shift.employee_id)
49.penalize(SoftScore::of(1000))
50.named("Balance workload");
51
52let solution = Solution {
53shifts: vec![
54Shift { employee_id: Some(0) },
55Shift { employee_id: Some(0) },
56Shift { employee_id: Some(0) },
57Shift { employee_id: Some(1) },
58],
59};
60
61// std_dev = 1.0, penalty = -1000
62assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1000));
63```
64*/
65pub struct BalanceConstraintStream<S, A, K, E, F, KF, Sc>
66where
67    Sc: Score,
68{
69    extractor: E,
70    filter: F,
71    key_fn: KF,
72    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
73}
74
75impl<S, A, K, E, F, KF, Sc> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
76where
77    S: Send + Sync + 'static,
78    A: Clone + Send + Sync + 'static,
79    K: Clone + Eq + Hash + Send + Sync + 'static,
80    E: CollectionExtract<S, Item = A>,
81    F: UniFilter<S, A>,
82    KF: Fn(&A) -> Option<K> + Send + Sync,
83    Sc: Score + 'static,
84{
85    fn into_builder(
86        self,
87        impact_type: ImpactType,
88        base_score: Sc,
89    ) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
90        BalanceConstraintBuilder {
91            extractor: self.extractor,
92            filter: self.filter,
93            key_fn: self.key_fn,
94            impact_type,
95            base_score,
96            is_hard: fixed_weight_is_hard(base_score),
97            _phantom: PhantomData,
98        }
99    }
100
101    // Creates a new balance constraint stream.
102    pub(crate) fn new(extractor: E, filter: F, key_fn: KF) -> Self {
103        Self {
104            extractor,
105            filter,
106            key_fn,
107            _phantom: PhantomData,
108        }
109    }
110
111    /* Penalizes imbalanced distribution with the given base score per unit std_dev.
112
113    The final score is `base_score.multiply(std_dev)`, negated for penalty.
114    */
115    pub fn penalize(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
116        self.into_builder(ImpactType::Penalty, base_score)
117    }
118
119    // Penalizes imbalanced distribution with one hard score unit per unit std_dev.
120    pub fn penalize_hard(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
121    where
122        Sc: Copy,
123    {
124        self.penalize(Sc::one_hard())
125    }
126
127    // Penalizes imbalanced distribution with one soft score unit per unit std_dev.
128    pub fn penalize_soft(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
129    where
130        Sc: Copy,
131    {
132        self.penalize(Sc::one_soft())
133    }
134
135    // Rewards imbalanced distribution with one hard score unit per unit std_dev.
136    pub fn reward_hard(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
137    where
138        Sc: Copy,
139    {
140        self.reward(Sc::one_hard())
141    }
142
143    // Rewards imbalanced distribution with one soft score unit per unit std_dev.
144    pub fn reward_soft(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
145    where
146        Sc: Copy,
147    {
148        self.reward(Sc::one_soft())
149    }
150
151    /* Rewards imbalanced distribution with the given base score per unit std_dev.
152
153    The final score is `base_score.multiply(std_dev)`.
154    */
155    pub fn reward(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
156        self.into_builder(ImpactType::Reward, base_score)
157    }
158}
159
160impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
161    for BalanceConstraintStream<S, A, K, E, F, KF, Sc>
162{
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        f.debug_struct("BalanceConstraintStream").finish()
165    }
166}
167
168// Zero-erasure builder for finalizing a balance constraint.
169pub struct BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
170where
171    Sc: Score,
172{
173    extractor: E,
174    filter: F,
175    key_fn: KF,
176    impact_type: ImpactType,
177    base_score: Sc,
178    is_hard: bool,
179    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K)>,
180}
181
182impl<S, A, K, E, F, KF, Sc> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
183where
184    S: Send + Sync + 'static,
185    A: Clone + Send + Sync + 'static,
186    K: Clone + Eq + Hash + Send + Sync + 'static,
187    E: CollectionExtract<S, Item = A>,
188    F: UniFilter<S, A>,
189    KF: Fn(&A) -> Option<K> + Send + Sync,
190    Sc: Score + 'static,
191{
192    pub fn named(self, name: &str) -> BalanceConstraint<S, A, K, E, F, KF, Sc> {
193        BalanceConstraint::new(
194            ConstraintRef::new("", name),
195            self.impact_type,
196            self.extractor,
197            self.filter,
198            self.key_fn,
199            self.base_score,
200            self.is_hard,
201        )
202    }
203
204    // Finalizes the builder into a zero-erasure `BalanceConstraint`.
205}
206
207impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
208    for BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
209{
210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        f.debug_struct("BalanceConstraintBuilder")
212            .field("impact_type", &self.impact_type)
213            .finish()
214    }
215}