Skip to main content

solverforge_scoring/stream/
balance_stream.rs

1// Zero-erasure balance constraint stream for load distribution patterns.
2//
3// A `BalanceConstraintStream` is created from `UniConstraintStream::balance()`
4// and provides fluent finalization into a `BalanceConstraint`.
5
6use std::hash::Hash;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use super::filter::UniFilter;
13use crate::constraint::balance::BalanceConstraint;
14
15// Zero-erasure stream for building balance constraints.
16//
17// Created by `UniConstraintStream::balance()`. Provides `penalize()` and
18// `reward()` methods to finalize the constraint.
19//
20// # Type Parameters
21//
22// - `S` - Solution type
23// - `A` - Entity type
24// - `K` - Group key type
25// - `E` - Extractor function for entities
26// - `F` - Filter type
27// - `KF` - Key function (returns Option<K> to skip unassigned entities)
28// - `Sc` - Score type
29//
30// # Example
31//
32// ```
33// use solverforge_scoring::stream::ConstraintFactory;
34// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
35// use solverforge_core::score::SoftScore;
36//
37// #[derive(Clone)]
38// struct Shift { employee_id: Option<usize> }
39//
40// #[derive(Clone)]
41// struct Solution { shifts: Vec<Shift> }
42//
43// let constraint = ConstraintFactory::<Solution, SoftScore>::new()
44//     .for_each(|s: &Solution| &s.shifts)
45//     .balance(|shift: &Shift| shift.employee_id)
46//     .penalize(SoftScore::of(1000))
47//     .as_constraint("Balance workload");
48//
49// let solution = Solution {
50//     shifts: vec![
51//         Shift { employee_id: Some(0) },
52//         Shift { employee_id: Some(0) },
53//         Shift { employee_id: Some(0) },
54//         Shift { employee_id: Some(1) },
55//     ],
56// };
57//
58// // std_dev = 1.0, penalty = -1000
59// assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1000));
60// ```
61pub struct BalanceConstraintStream<S, A, K, E, F, KF, Sc>
62where
63    Sc: Score,
64{
65    extractor: E,
66    filter: F,
67    key_fn: KF,
68    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
69}
70
71impl<S, A, K, E, F, KF, Sc> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
72where
73    S: Send + Sync + 'static,
74    A: Clone + Send + Sync + 'static,
75    K: Clone + Eq + Hash + Send + Sync + 'static,
76    E: Fn(&S) -> &[A] + Send + Sync,
77    F: UniFilter<S, A>,
78    KF: Fn(&A) -> Option<K> + Send + Sync,
79    Sc: Score + 'static,
80{
81    // Creates a new balance constraint stream.
82    pub(crate) fn new(extractor: E, filter: F, key_fn: KF) -> Self {
83        Self {
84            extractor,
85            filter,
86            key_fn,
87            _phantom: PhantomData,
88        }
89    }
90
91    // Penalizes imbalanced distribution with the given base score per unit std_dev.
92    //
93    // The final score is `base_score.multiply(std_dev)`, negated for penalty.
94    pub fn penalize(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
95        let is_hard = base_score
96            .to_level_numbers()
97            .first()
98            .map(|&h| h != 0)
99            .unwrap_or(false);
100        BalanceConstraintBuilder {
101            extractor: self.extractor,
102            filter: self.filter,
103            key_fn: self.key_fn,
104            impact_type: ImpactType::Penalty,
105            base_score,
106            is_hard,
107            _phantom: PhantomData,
108        }
109    }
110
111    // Penalizes imbalanced distribution with one hard score unit per unit std_dev.
112    pub fn penalize_hard(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
113    where
114        Sc: Copy,
115    {
116        self.penalize(Sc::one_hard())
117    }
118
119    // Penalizes imbalanced distribution with one soft score unit per unit std_dev.
120    pub fn penalize_soft(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
121    where
122        Sc: Copy,
123    {
124        self.penalize(Sc::one_soft())
125    }
126
127    // Rewards imbalanced distribution with one hard score unit per unit std_dev.
128    pub fn reward_hard(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
129    where
130        Sc: Copy,
131    {
132        self.reward(Sc::one_hard())
133    }
134
135    // Rewards imbalanced distribution with one soft score unit per unit std_dev.
136    pub fn reward_soft(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
137    where
138        Sc: Copy,
139    {
140        self.reward(Sc::one_soft())
141    }
142
143    // Rewards imbalanced distribution with the given base score per unit std_dev.
144    //
145    // The final score is `base_score.multiply(std_dev)`.
146    pub fn reward(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
147        let is_hard = base_score
148            .to_level_numbers()
149            .first()
150            .map(|&h| h != 0)
151            .unwrap_or(false);
152        BalanceConstraintBuilder {
153            extractor: self.extractor,
154            filter: self.filter,
155            key_fn: self.key_fn,
156            impact_type: ImpactType::Reward,
157            base_score,
158            is_hard,
159            _phantom: PhantomData,
160        }
161    }
162}
163
164impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
165    for BalanceConstraintStream<S, A, K, E, F, KF, Sc>
166{
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        f.debug_struct("BalanceConstraintStream").finish()
169    }
170}
171
172// Zero-erasure builder for finalizing a balance constraint.
173pub struct BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
174where
175    Sc: Score,
176{
177    extractor: E,
178    filter: F,
179    key_fn: KF,
180    impact_type: ImpactType,
181    base_score: Sc,
182    is_hard: bool,
183    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K)>,
184}
185
186impl<S, A, K, E, F, KF, Sc> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
187where
188    S: Send + Sync + 'static,
189    A: Clone + Send + Sync + 'static,
190    K: Clone + Eq + Hash + Send + Sync + 'static,
191    E: Fn(&S) -> &[A] + Send + Sync,
192    F: UniFilter<S, A>,
193    KF: Fn(&A) -> Option<K> + Send + Sync,
194    Sc: Score + 'static,
195{
196    // Alias for `as_constraint`.
197    pub fn named(self, name: &str) -> BalanceConstraint<S, A, K, E, F, KF, Sc> {
198        self.as_constraint(name)
199    }
200
201    // Finalizes the builder into a zero-erasure `BalanceConstraint`.
202    pub fn as_constraint(self, name: &str) -> BalanceConstraint<S, A, K, E, F, KF, Sc> {
203        BalanceConstraint::new(
204            ConstraintRef::new("", name),
205            self.impact_type,
206            self.extractor,
207            self.filter,
208            self.key_fn,
209            self.base_score,
210            self.is_hard,
211        )
212    }
213}
214
215impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
216    for BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
217{
218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        f.debug_struct("BalanceConstraintBuilder")
220            .field("impact_type", &self.impact_type)
221            .finish()
222    }
223}