Skip to main content

solverforge_scoring/stream/
complemented_stream.rs

1/* Zero-erasure complemented constraint stream.
2
3A `ComplementedConstraintStream` adds entities from a complement source
4that are not present in grouped results, with default values.
5*/
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collection_extract::CollectionExtract;
14use super::collector::UniCollector;
15use crate::constraint::complemented::ComplementedGroupConstraint;
16
17/* Zero-erasure constraint stream with complemented groups.
18
19`ComplementedConstraintStream` results from calling `complement` on a
20`GroupedConstraintStream`. It ensures all keys from a complement source
21are represented, using default values for missing keys.
22
23The key function for A entities returns `Option<K>` to allow skipping
24entities without valid keys (e.g., unassigned shifts).
25
26# Type Parameters
27
28- `S` - Solution type
29- `A` - Original entity type (e.g., Shift)
30- `B` - Complement entity type (e.g., Employee)
31- `K` - Group key type
32- `EA` - Extractor for A entities
33- `EB` - Extractor for B entities (complement source)
34- `KA` - Key function for A (returns `Option<K>` to allow filtering)
35- `KB` - Key function for B
36- `C` - Collector type
37- `D` - Default value function
38- `Sc` - Score type
39
40# Example
41
42```
43use solverforge_scoring::stream::ConstraintFactory;
44use solverforge_scoring::stream::collector::count;
45use solverforge_scoring::api::constraint_set::IncrementalConstraint;
46use solverforge_core::score::SoftScore;
47
48#[derive(Clone, Hash, PartialEq, Eq)]
49struct Employee { id: usize }
50
51#[derive(Clone, Hash, PartialEq, Eq)]
52struct Shift { employee_id: usize }
53
54#[derive(Clone)]
55struct Schedule {
56employees: Vec<Employee>,
57shifts: Vec<Shift>,
58}
59
60// Count shifts per employee, including employees with 0 shifts
61let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
62.for_each(|s: &Schedule| &s.shifts)
63.group_by(|shift: &Shift| shift.employee_id, count())
64.complement(
65|s: &Schedule| s.employees.as_slice(),
66|emp: &Employee| emp.id,
67|_emp: &Employee| 0usize,
68)
69.penalize_with(|count: &usize| SoftScore::of(*count as i64))
70.named("Shift count");
71
72let schedule = Schedule {
73employees: vec![Employee { id: 0 }, Employee { id: 1 }, Employee { id: 2 }],
74shifts: vec![
75Shift { employee_id: 0 },
76Shift { employee_id: 0 },
77// Employee 1 has 0 shifts, Employee 2 has 0 shifts
78],
79};
80
81// Employee 0: 2, Employee 1: 0, Employee 2: 0 → Total: -2
82assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
83```
84*/
85pub struct ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
86where
87    Sc: Score,
88{
89    extractor_a: EA,
90    extractor_b: EB,
91    key_a: KA,
92    key_b: KB,
93    collector: C,
94    default_fn: D,
95    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
96}
97
98impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
99    ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
100where
101    S: Send + Sync + 'static,
102    A: Clone + Send + Sync + 'static,
103    B: Clone + Send + Sync + 'static,
104    K: Clone + Eq + Hash + Send + Sync + 'static,
105    EA: CollectionExtract<S, Item = A>,
106    EB: CollectionExtract<S, Item = B>,
107    KA: Fn(&A) -> Option<K> + Send + Sync,
108    KB: Fn(&B) -> K + Send + Sync,
109    C: UniCollector<A> + Send + Sync + 'static,
110    C::Accumulator: Send + Sync,
111    C::Result: Clone + Send + Sync,
112    D: Fn(&B) -> C::Result + Send + Sync,
113    Sc: Score + 'static,
114{
115    // Creates a new complemented constraint stream.
116    pub(crate) fn new(
117        extractor_a: EA,
118        extractor_b: EB,
119        key_a: KA,
120        key_b: KB,
121        collector: C,
122        default_fn: D,
123    ) -> Self {
124        Self {
125            extractor_a,
126            extractor_b,
127            key_a,
128            key_b,
129            collector,
130            default_fn,
131            _phantom: PhantomData,
132        }
133    }
134
135    // Penalizes each complemented group with a weight based on the result.
136    pub fn penalize_with<W>(
137        self,
138        weight_fn: W,
139    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
140    where
141        W: Fn(&C::Result) -> Sc + Send + Sync,
142    {
143        ComplementedConstraintBuilder {
144            extractor_a: self.extractor_a,
145            extractor_b: self.extractor_b,
146            key_a: self.key_a,
147            key_b: self.key_b,
148            collector: self.collector,
149            default_fn: self.default_fn,
150            impact_type: ImpactType::Penalty,
151            weight_fn,
152            is_hard: false,
153            _phantom: PhantomData,
154        }
155    }
156
157    // Penalizes each complemented group, explicitly marked as hard constraint.
158    pub fn penalize_hard_with<W>(
159        self,
160        weight_fn: W,
161    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
162    where
163        W: Fn(&C::Result) -> Sc + Send + Sync,
164    {
165        ComplementedConstraintBuilder {
166            extractor_a: self.extractor_a,
167            extractor_b: self.extractor_b,
168            key_a: self.key_a,
169            key_b: self.key_b,
170            collector: self.collector,
171            default_fn: self.default_fn,
172            impact_type: ImpactType::Penalty,
173            weight_fn,
174            is_hard: true,
175            _phantom: PhantomData,
176        }
177    }
178
179    // Rewards each complemented group with a weight based on the result.
180    pub fn reward_with<W>(
181        self,
182        weight_fn: W,
183    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
184    where
185        W: Fn(&C::Result) -> Sc + Send + Sync,
186    {
187        ComplementedConstraintBuilder {
188            extractor_a: self.extractor_a,
189            extractor_b: self.extractor_b,
190            key_a: self.key_a,
191            key_b: self.key_b,
192            collector: self.collector,
193            default_fn: self.default_fn,
194            impact_type: ImpactType::Reward,
195            weight_fn,
196            is_hard: false,
197            _phantom: PhantomData,
198        }
199    }
200
201    // Rewards each complemented group, explicitly marked as hard constraint.
202    pub fn reward_hard_with<W>(
203        self,
204        weight_fn: W,
205    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
206    where
207        W: Fn(&C::Result) -> Sc + Send + Sync,
208    {
209        ComplementedConstraintBuilder {
210            extractor_a: self.extractor_a,
211            extractor_b: self.extractor_b,
212            key_a: self.key_a,
213            key_b: self.key_b,
214            collector: self.collector,
215            default_fn: self.default_fn,
216            impact_type: ImpactType::Reward,
217            weight_fn,
218            is_hard: true,
219            _phantom: PhantomData,
220        }
221    }
222
223    // Penalizes each complemented group with one hard score unit.
224    pub fn penalize_hard(
225        self,
226    ) -> ComplementedConstraintBuilder<
227        S,
228        A,
229        B,
230        K,
231        EA,
232        EB,
233        KA,
234        KB,
235        C,
236        D,
237        impl Fn(&C::Result) -> Sc + Send + Sync,
238        Sc,
239    >
240    where
241        Sc: Copy,
242    {
243        let w = Sc::one_hard();
244        self.penalize_hard_with(move |_: &C::Result| w)
245    }
246
247    // Penalizes each complemented group with one soft score unit.
248    pub fn penalize_soft(
249        self,
250    ) -> ComplementedConstraintBuilder<
251        S,
252        A,
253        B,
254        K,
255        EA,
256        EB,
257        KA,
258        KB,
259        C,
260        D,
261        impl Fn(&C::Result) -> Sc + Send + Sync,
262        Sc,
263    >
264    where
265        Sc: Copy,
266    {
267        let w = Sc::one_soft();
268        self.penalize_with(move |_: &C::Result| w)
269    }
270
271    // Rewards each complemented group with one hard score unit.
272    pub fn reward_hard(
273        self,
274    ) -> ComplementedConstraintBuilder<
275        S,
276        A,
277        B,
278        K,
279        EA,
280        EB,
281        KA,
282        KB,
283        C,
284        D,
285        impl Fn(&C::Result) -> Sc + Send + Sync,
286        Sc,
287    >
288    where
289        Sc: Copy,
290    {
291        let w = Sc::one_hard();
292        self.reward_hard_with(move |_: &C::Result| w)
293    }
294
295    // Rewards each complemented group with one soft score unit.
296    pub fn reward_soft(
297        self,
298    ) -> ComplementedConstraintBuilder<
299        S,
300        A,
301        B,
302        K,
303        EA,
304        EB,
305        KA,
306        KB,
307        C,
308        D,
309        impl Fn(&C::Result) -> Sc + Send + Sync,
310        Sc,
311    >
312    where
313        Sc: Copy,
314    {
315        let w = Sc::one_soft();
316        self.reward_with(move |_: &C::Result| w)
317    }
318}
319
320impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc: Score> std::fmt::Debug
321    for ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
322{
323    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324        f.debug_struct("ComplementedConstraintStream").finish()
325    }
326}
327
328// Zero-erasure builder for finalizing a complemented constraint.
329pub struct ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
330where
331    Sc: Score,
332{
333    extractor_a: EA,
334    extractor_b: EB,
335    key_a: KA,
336    key_b: KB,
337    collector: C,
338    default_fn: D,
339    impact_type: ImpactType,
340    weight_fn: W,
341    is_hard: bool,
342    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
343}
344
345impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
346    ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
347where
348    S: Send + Sync + 'static,
349    A: Clone + Send + Sync + 'static,
350    B: Clone + Send + Sync + 'static,
351    K: Clone + Eq + Hash + Send + Sync + 'static,
352    EA: CollectionExtract<S, Item = A>,
353    EB: CollectionExtract<S, Item = B>,
354    KA: Fn(&A) -> Option<K> + Send + Sync,
355    KB: Fn(&B) -> K + Send + Sync,
356    C: UniCollector<A> + Send + Sync + 'static,
357    C::Accumulator: Send + Sync,
358    C::Result: Clone + Send + Sync,
359    D: Fn(&B) -> C::Result + Send + Sync,
360    W: Fn(&C::Result) -> Sc + Send + Sync,
361    Sc: Score + 'static,
362{
363    pub fn named(
364        self,
365        name: &str,
366    ) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> {
367        ComplementedGroupConstraint::new(
368            ConstraintRef::new("", name),
369            self.impact_type,
370            self.extractor_a,
371            self.extractor_b,
372            self.key_a,
373            self.key_b,
374            self.collector,
375            self.default_fn,
376            self.weight_fn,
377            self.is_hard,
378        )
379    }
380
381    // Finalizes the builder into a `ComplementedGroupConstraint`.
382}
383
384impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc: Score> std::fmt::Debug
385    for ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
386{
387    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388        f.debug_struct("ComplementedConstraintBuilder")
389            .field("impact_type", &self.impact_type)
390            .finish()
391    }
392}