Skip to main content

solverforge_scoring/stream/
complemented_stream.rs

1/* Zero-erasure complemented constraint stream.
2
3A `ComplementedConstraintStream` adds entities from a complement source
4that are not present in grouped results, with default values.
5*/
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collection_extract::CollectionExtract;
14use super::collector::UniCollector;
15use crate::constraint::complemented::ComplementedGroupConstraint;
16
17/* Zero-erasure constraint stream with complemented groups.
18
19`ComplementedConstraintStream` results from calling `complement` on a
20`GroupedConstraintStream`. It ensures all keys from a complement source
21are represented, using default values for missing keys.
22
23The key function for A entities returns `Option<K>` to allow skipping
24entities without valid keys (e.g., unassigned shifts).
25
26# Type Parameters
27
28- `S` - Solution type
29- `A` - Original entity type (e.g., Shift)
30- `B` - Complement entity type (e.g., Employee)
31- `K` - Group key type
32- `EA` - Extractor for A entities
33- `EB` - Extractor for B entities (complement source)
34- `KA` - Key function for A (returns `Option<K>` to allow filtering)
35- `KB` - Key function for B
36- `C` - Collector type
37- `D` - Default value function
38- `Sc` - Score type
39
40# Example
41
42```
43use solverforge_scoring::stream::ConstraintFactory;
44use solverforge_scoring::stream::collector::count;
45use solverforge_scoring::api::constraint_set::IncrementalConstraint;
46use solverforge_core::score::SoftScore;
47
48#[derive(Clone, Hash, PartialEq, Eq)]
49struct Employee { id: usize }
50
51#[derive(Clone, Hash, PartialEq, Eq)]
52struct Shift { employee_id: usize }
53
54#[derive(Clone)]
55struct Schedule {
56employees: Vec<Employee>,
57shifts: Vec<Shift>,
58}
59
60// Count shifts per employee, including employees with 0 shifts
61let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
62.for_each(|s: &Schedule| &s.shifts)
63.group_by(|shift: &Shift| shift.employee_id, count())
64.complement(
65|s: &Schedule| s.employees.as_slice(),
66|emp: &Employee| emp.id,
67|_emp: &Employee| 0usize,
68)
69.penalize_with(|count: &usize| SoftScore::of(*count as i64))
70.named("Shift count");
71
72let schedule = Schedule {
73employees: vec![Employee { id: 0 }, Employee { id: 1 }, Employee { id: 2 }],
74shifts: vec![
75Shift { employee_id: 0 },
76Shift { employee_id: 0 },
77// Employee 1 has 0 shifts, Employee 2 has 0 shifts
78],
79};
80
81// Employee 0: 2, Employee 1: 0, Employee 2: 0 → Total: -2
82assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
83```
84*/
85pub struct ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
86where
87    Sc: Score,
88{
89    extractor_a: EA,
90    extractor_b: EB,
91    key_a: KA,
92    key_b: KB,
93    collector: C,
94    default_fn: D,
95    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
96}
97
98impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
99    ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
100where
101    S: Send + Sync + 'static,
102    A: Clone + Send + Sync + 'static,
103    B: Clone + Send + Sync + 'static,
104    K: Clone + Eq + Hash + Send + Sync + 'static,
105    EA: CollectionExtract<S, Item = A>,
106    EB: CollectionExtract<S, Item = B>,
107    KA: Fn(&A) -> Option<K> + Send + Sync,
108    KB: Fn(&B) -> K + Send + Sync,
109    C: UniCollector<A> + Send + Sync + 'static,
110    C::Accumulator: Send + Sync,
111    C::Result: Clone + Send + Sync,
112    D: Fn(&B) -> C::Result + Send + Sync,
113    Sc: Score + 'static,
114{
115    fn into_weighted_builder<W>(
116        self,
117        impact_type: ImpactType,
118        weight_fn: W,
119        is_hard: bool,
120    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
121    where
122        W: Fn(&C::Result) -> Sc + Send + Sync,
123    {
124        ComplementedConstraintBuilder {
125            extractor_a: self.extractor_a,
126            extractor_b: self.extractor_b,
127            key_a: self.key_a,
128            key_b: self.key_b,
129            collector: self.collector,
130            default_fn: self.default_fn,
131            impact_type,
132            weight_fn,
133            is_hard,
134            _phantom: PhantomData,
135        }
136    }
137
138    // Creates a new complemented constraint stream.
139    pub(crate) fn new(
140        extractor_a: EA,
141        extractor_b: EB,
142        key_a: KA,
143        key_b: KB,
144        collector: C,
145        default_fn: D,
146    ) -> Self {
147        Self {
148            extractor_a,
149            extractor_b,
150            key_a,
151            key_b,
152            collector,
153            default_fn,
154            _phantom: PhantomData,
155        }
156    }
157
158    // Penalizes each complemented group with a weight based on the result.
159    pub fn penalize_with<W>(
160        self,
161        weight_fn: W,
162    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
163    where
164        W: Fn(&C::Result) -> Sc + Send + Sync,
165    {
166        self.into_weighted_builder(ImpactType::Penalty, weight_fn, false)
167    }
168
169    // Penalizes each complemented group, explicitly marked as hard constraint.
170    pub fn penalize_hard_with<W>(
171        self,
172        weight_fn: W,
173    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
174    where
175        W: Fn(&C::Result) -> Sc + Send + Sync,
176    {
177        self.into_weighted_builder(ImpactType::Penalty, weight_fn, true)
178    }
179
180    // Rewards each complemented group with a weight based on the result.
181    pub fn reward_with<W>(
182        self,
183        weight_fn: W,
184    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
185    where
186        W: Fn(&C::Result) -> Sc + Send + Sync,
187    {
188        self.into_weighted_builder(ImpactType::Reward, weight_fn, false)
189    }
190
191    // Rewards each complemented group, explicitly marked as hard constraint.
192    pub fn reward_hard_with<W>(
193        self,
194        weight_fn: W,
195    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
196    where
197        W: Fn(&C::Result) -> Sc + Send + Sync,
198    {
199        self.into_weighted_builder(ImpactType::Reward, weight_fn, true)
200    }
201
202    // Penalizes each complemented group with one hard score unit.
203    pub fn penalize_hard(
204        self,
205    ) -> ComplementedConstraintBuilder<
206        S,
207        A,
208        B,
209        K,
210        EA,
211        EB,
212        KA,
213        KB,
214        C,
215        D,
216        impl Fn(&C::Result) -> Sc + Send + Sync,
217        Sc,
218    >
219    where
220        Sc: Copy,
221    {
222        let w = Sc::one_hard();
223        self.penalize_hard_with(move |_: &C::Result| w)
224    }
225
226    // Penalizes each complemented group with one soft score unit.
227    pub fn penalize_soft(
228        self,
229    ) -> ComplementedConstraintBuilder<
230        S,
231        A,
232        B,
233        K,
234        EA,
235        EB,
236        KA,
237        KB,
238        C,
239        D,
240        impl Fn(&C::Result) -> Sc + Send + Sync,
241        Sc,
242    >
243    where
244        Sc: Copy,
245    {
246        let w = Sc::one_soft();
247        self.penalize_with(move |_: &C::Result| w)
248    }
249
250    // Rewards each complemented group with one hard score unit.
251    pub fn reward_hard(
252        self,
253    ) -> ComplementedConstraintBuilder<
254        S,
255        A,
256        B,
257        K,
258        EA,
259        EB,
260        KA,
261        KB,
262        C,
263        D,
264        impl Fn(&C::Result) -> Sc + Send + Sync,
265        Sc,
266    >
267    where
268        Sc: Copy,
269    {
270        let w = Sc::one_hard();
271        self.reward_hard_with(move |_: &C::Result| w)
272    }
273
274    // Rewards each complemented group with one soft score unit.
275    pub fn reward_soft(
276        self,
277    ) -> ComplementedConstraintBuilder<
278        S,
279        A,
280        B,
281        K,
282        EA,
283        EB,
284        KA,
285        KB,
286        C,
287        D,
288        impl Fn(&C::Result) -> Sc + Send + Sync,
289        Sc,
290    >
291    where
292        Sc: Copy,
293    {
294        let w = Sc::one_soft();
295        self.reward_with(move |_: &C::Result| w)
296    }
297}
298
299impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc: Score> std::fmt::Debug
300    for ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
301{
302    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303        f.debug_struct("ComplementedConstraintStream").finish()
304    }
305}
306
307// Zero-erasure builder for finalizing a complemented constraint.
308pub struct ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
309where
310    Sc: Score,
311{
312    extractor_a: EA,
313    extractor_b: EB,
314    key_a: KA,
315    key_b: KB,
316    collector: C,
317    default_fn: D,
318    impact_type: ImpactType,
319    weight_fn: W,
320    is_hard: bool,
321    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
322}
323
324impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
325    ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
326where
327    S: Send + Sync + 'static,
328    A: Clone + Send + Sync + 'static,
329    B: Clone + Send + Sync + 'static,
330    K: Clone + Eq + Hash + Send + Sync + 'static,
331    EA: CollectionExtract<S, Item = A>,
332    EB: CollectionExtract<S, Item = B>,
333    KA: Fn(&A) -> Option<K> + Send + Sync,
334    KB: Fn(&B) -> K + Send + Sync,
335    C: UniCollector<A> + Send + Sync + 'static,
336    C::Accumulator: Send + Sync,
337    C::Result: Clone + Send + Sync,
338    D: Fn(&B) -> C::Result + Send + Sync,
339    W: Fn(&C::Result) -> Sc + Send + Sync,
340    Sc: Score + 'static,
341{
342    pub fn named(
343        self,
344        name: &str,
345    ) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> {
346        ComplementedGroupConstraint::new(
347            ConstraintRef::new("", name),
348            self.impact_type,
349            self.extractor_a,
350            self.extractor_b,
351            self.key_a,
352            self.key_b,
353            self.collector,
354            self.default_fn,
355            self.weight_fn,
356            self.is_hard,
357        )
358    }
359
360    // Finalizes the builder into a `ComplementedGroupConstraint`.
361}
362
363impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc: Score> std::fmt::Debug
364    for ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
365{
366    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367        f.debug_struct("ComplementedConstraintBuilder")
368            .field("impact_type", &self.impact_type)
369            .finish()
370    }
371}