Skip to main content

solverforge_scoring/stream/
complemented_stream.rs

1// Zero-erasure complemented constraint stream.
2//
3// A `ComplementedConstraintStream` adds entities from a complement source
4// that are not present in grouped results, with default values.
5
6use std::hash::Hash;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use super::collector::UniCollector;
13use crate::constraint::complemented::ComplementedGroupConstraint;
14
15// Zero-erasure constraint stream with complemented groups.
16//
17// `ComplementedConstraintStream` results from calling `complement` on a
18// `GroupedConstraintStream`. It ensures all keys from a complement source
19// are represented, using default values for missing keys.
20//
21// The key function for A entities returns `Option<K>` to allow skipping
22// entities without valid keys (e.g., unassigned shifts).
23//
24// # Type Parameters
25//
26// - `S` - Solution type
27// - `A` - Original entity type (e.g., Shift)
28// - `B` - Complement entity type (e.g., Employee)
29// - `K` - Group key type
30// - `EA` - Extractor for A entities
31// - `EB` - Extractor for B entities (complement source)
32// - `KA` - Key function for A (returns `Option<K>` to allow filtering)
33// - `KB` - Key function for B
34// - `C` - Collector type
35// - `D` - Default value function
36// - `Sc` - Score type
37//
38// # Example
39//
40// ```
41// use solverforge_scoring::stream::ConstraintFactory;
42// use solverforge_scoring::stream::collector::count;
43// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
44// use solverforge_core::score::SoftScore;
45//
46// #[derive(Clone, Hash, PartialEq, Eq)]
47// struct Employee { id: usize }
48//
49// #[derive(Clone, Hash, PartialEq, Eq)]
50// struct Shift { employee_id: usize }
51//
52// #[derive(Clone)]
53// struct Schedule {
54//     employees: Vec<Employee>,
55//     shifts: Vec<Shift>,
56// }
57//
58// // Count shifts per employee, including employees with 0 shifts
59// let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
60//     .for_each(|s: &Schedule| &s.shifts)
61//     .group_by(|shift: &Shift| shift.employee_id, count())
62//     .complement(
63//         |s: &Schedule| s.employees.as_slice(),
64//         |emp: &Employee| emp.id,
65//         |_emp: &Employee| 0usize,
66//     )
67//     .penalize_with(|count: &usize| SoftScore::of(*count as i64))
68//     .as_constraint("Shift count");
69//
70// let schedule = Schedule {
71//     employees: vec![Employee { id: 0 }, Employee { id: 1 }, Employee { id: 2 }],
72//     shifts: vec![
73//         Shift { employee_id: 0 },
74//         Shift { employee_id: 0 },
75//         // Employee 1 has 0 shifts, Employee 2 has 0 shifts
76//     ],
77// };
78//
79// // Employee 0: 2, Employee 1: 0, Employee 2: 0 → Total: -2
80// assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
81// ```
82pub struct ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
83where
84    Sc: Score,
85{
86    extractor_a: EA,
87    extractor_b: EB,
88    key_a: KA,
89    key_b: KB,
90    collector: C,
91    default_fn: D,
92    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
93}
94
95impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
96    ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
97where
98    S: Send + Sync + 'static,
99    A: Clone + Send + Sync + 'static,
100    B: Clone + Send + Sync + 'static,
101    K: Clone + Eq + Hash + Send + Sync + 'static,
102    EA: Fn(&S) -> &[A] + Send + Sync,
103    EB: Fn(&S) -> &[B] + Send + Sync,
104    KA: Fn(&A) -> Option<K> + Send + Sync,
105    KB: Fn(&B) -> K + Send + Sync,
106    C: UniCollector<A> + Send + Sync + 'static,
107    C::Accumulator: Send + Sync,
108    C::Result: Clone + Send + Sync,
109    D: Fn(&B) -> C::Result + Send + Sync,
110    Sc: Score + 'static,
111{
112    // Creates a new complemented constraint stream.
113    pub(crate) fn new(
114        extractor_a: EA,
115        extractor_b: EB,
116        key_a: KA,
117        key_b: KB,
118        collector: C,
119        default_fn: D,
120    ) -> Self {
121        Self {
122            extractor_a,
123            extractor_b,
124            key_a,
125            key_b,
126            collector,
127            default_fn,
128            _phantom: PhantomData,
129        }
130    }
131
132    // Penalizes each complemented group with a weight based on the result.
133    pub fn penalize_with<W>(
134        self,
135        weight_fn: W,
136    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
137    where
138        W: Fn(&C::Result) -> Sc + Send + Sync,
139    {
140        ComplementedConstraintBuilder {
141            extractor_a: self.extractor_a,
142            extractor_b: self.extractor_b,
143            key_a: self.key_a,
144            key_b: self.key_b,
145            collector: self.collector,
146            default_fn: self.default_fn,
147            impact_type: ImpactType::Penalty,
148            weight_fn,
149            is_hard: false,
150            _phantom: PhantomData,
151        }
152    }
153
154    // Penalizes each complemented group, explicitly marked as hard constraint.
155    pub fn penalize_hard_with<W>(
156        self,
157        weight_fn: W,
158    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
159    where
160        W: Fn(&C::Result) -> Sc + Send + Sync,
161    {
162        ComplementedConstraintBuilder {
163            extractor_a: self.extractor_a,
164            extractor_b: self.extractor_b,
165            key_a: self.key_a,
166            key_b: self.key_b,
167            collector: self.collector,
168            default_fn: self.default_fn,
169            impact_type: ImpactType::Penalty,
170            weight_fn,
171            is_hard: true,
172            _phantom: PhantomData,
173        }
174    }
175
176    // Rewards each complemented group with a weight based on the result.
177    pub fn reward_with<W>(
178        self,
179        weight_fn: W,
180    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
181    where
182        W: Fn(&C::Result) -> Sc + Send + Sync,
183    {
184        ComplementedConstraintBuilder {
185            extractor_a: self.extractor_a,
186            extractor_b: self.extractor_b,
187            key_a: self.key_a,
188            key_b: self.key_b,
189            collector: self.collector,
190            default_fn: self.default_fn,
191            impact_type: ImpactType::Reward,
192            weight_fn,
193            is_hard: false,
194            _phantom: PhantomData,
195        }
196    }
197
198    // Rewards each complemented group, explicitly marked as hard constraint.
199    pub fn reward_hard_with<W>(
200        self,
201        weight_fn: W,
202    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
203    where
204        W: Fn(&C::Result) -> Sc + Send + Sync,
205    {
206        ComplementedConstraintBuilder {
207            extractor_a: self.extractor_a,
208            extractor_b: self.extractor_b,
209            key_a: self.key_a,
210            key_b: self.key_b,
211            collector: self.collector,
212            default_fn: self.default_fn,
213            impact_type: ImpactType::Reward,
214            weight_fn,
215            is_hard: true,
216            _phantom: PhantomData,
217        }
218    }
219
220    // Penalizes each complemented group with one hard score unit.
221    pub fn penalize_hard(
222        self,
223    ) -> ComplementedConstraintBuilder<
224        S,
225        A,
226        B,
227        K,
228        EA,
229        EB,
230        KA,
231        KB,
232        C,
233        D,
234        impl Fn(&C::Result) -> Sc + Send + Sync,
235        Sc,
236    >
237    where
238        Sc: Copy,
239    {
240        let w = Sc::one_hard();
241        self.penalize_hard_with(move |_: &C::Result| w)
242    }
243
244    // Penalizes each complemented group with one soft score unit.
245    pub fn penalize_soft(
246        self,
247    ) -> ComplementedConstraintBuilder<
248        S,
249        A,
250        B,
251        K,
252        EA,
253        EB,
254        KA,
255        KB,
256        C,
257        D,
258        impl Fn(&C::Result) -> Sc + Send + Sync,
259        Sc,
260    >
261    where
262        Sc: Copy,
263    {
264        let w = Sc::one_soft();
265        self.penalize_with(move |_: &C::Result| w)
266    }
267
268    // Rewards each complemented group with one hard score unit.
269    pub fn reward_hard(
270        self,
271    ) -> ComplementedConstraintBuilder<
272        S,
273        A,
274        B,
275        K,
276        EA,
277        EB,
278        KA,
279        KB,
280        C,
281        D,
282        impl Fn(&C::Result) -> Sc + Send + Sync,
283        Sc,
284    >
285    where
286        Sc: Copy,
287    {
288        let w = Sc::one_hard();
289        self.reward_hard_with(move |_: &C::Result| w)
290    }
291
292    // Rewards each complemented group with one soft score unit.
293    pub fn reward_soft(
294        self,
295    ) -> ComplementedConstraintBuilder<
296        S,
297        A,
298        B,
299        K,
300        EA,
301        EB,
302        KA,
303        KB,
304        C,
305        D,
306        impl Fn(&C::Result) -> Sc + Send + Sync,
307        Sc,
308    >
309    where
310        Sc: Copy,
311    {
312        let w = Sc::one_soft();
313        self.reward_with(move |_: &C::Result| w)
314    }
315}
316
317impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc: Score> std::fmt::Debug
318    for ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
319{
320    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321        f.debug_struct("ComplementedConstraintStream").finish()
322    }
323}
324
325// Zero-erasure builder for finalizing a complemented constraint.
326pub struct ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
327where
328    Sc: Score,
329{
330    extractor_a: EA,
331    extractor_b: EB,
332    key_a: KA,
333    key_b: KB,
334    collector: C,
335    default_fn: D,
336    impact_type: ImpactType,
337    weight_fn: W,
338    is_hard: bool,
339    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
340}
341
342impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
343    ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
344where
345    S: Send + Sync + 'static,
346    A: Clone + Send + Sync + 'static,
347    B: Clone + Send + Sync + 'static,
348    K: Clone + Eq + Hash + Send + Sync + 'static,
349    EA: Fn(&S) -> &[A] + Send + Sync,
350    EB: Fn(&S) -> &[B] + Send + Sync,
351    KA: Fn(&A) -> Option<K> + Send + Sync,
352    KB: Fn(&B) -> K + Send + Sync,
353    C: UniCollector<A> + Send + Sync + 'static,
354    C::Accumulator: Send + Sync,
355    C::Result: Clone + Send + Sync,
356    D: Fn(&B) -> C::Result + Send + Sync,
357    W: Fn(&C::Result) -> Sc + Send + Sync,
358    Sc: Score + 'static,
359{
360    // Alias for `as_constraint`.
361    pub fn named(
362        self,
363        name: &str,
364    ) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> {
365        self.as_constraint(name)
366    }
367
368    // Finalizes the builder into a `ComplementedGroupConstraint`.
369    pub fn as_constraint(
370        self,
371        name: &str,
372    ) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> {
373        ComplementedGroupConstraint::new(
374            ConstraintRef::new("", name),
375            self.impact_type,
376            self.extractor_a,
377            self.extractor_b,
378            self.key_a,
379            self.key_b,
380            self.collector,
381            self.default_fn,
382            self.weight_fn,
383            self.is_hard,
384        )
385    }
386}
387
388impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc: Score> std::fmt::Debug
389    for ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
390{
391    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392        f.debug_struct("ComplementedConstraintBuilder")
393            .field("impact_type", &self.impact_type)
394            .finish()
395    }
396}