Skip to main content

solverforge_scoring/stream/
complemented_stream.rs

1/* Zero-erasure complemented constraint stream.
2
3A `ComplementedConstraintStream` adds entities from a complement source
4that are not present in grouped results, with default values.
5*/
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collection_extract::CollectionExtract;
14use super::collector::{Accumulator, Collector};
15use super::weighting_support::ConstraintWeight;
16use crate::constraint::complemented::ComplementedGroupConstraint;
17
18/* Zero-erasure constraint stream with complemented groups.
19
20`ComplementedConstraintStream` results from calling `complement` on a
21`GroupedConstraintStream`. It ensures all keys from a complement source
22are represented, using default values for missing keys.
23
24The key function for A entities returns `Option<K>` to allow skipping
25entities without valid keys (e.g., unassigned shifts).
26
27# Type Parameters
28
29- `S` - Solution type
30- `A` - Original entity type (e.g., Shift)
31- `B` - Complement entity type (e.g., Employee)
32- `K` - Group key type
33- `EA` - Extractor for A entities
34- `EB` - Extractor for B entities (complement source)
35- `KA` - Key function for A (returns `Option<K>` to allow filtering)
36- `KB` - Key function for B
37- `C` - Collector type
38- `D` - Default value function
39- `Sc` - Score type
40
41# Example
42
43```
44use solverforge_scoring::stream::ConstraintFactory;
45use solverforge_scoring::stream::collector::count;
46use solverforge_scoring::api::constraint_set::IncrementalConstraint;
47use solverforge_core::score::SoftScore;
48
49#[derive(Clone, Hash, PartialEq, Eq)]
50struct Employee { id: usize }
51
52#[derive(Clone, Hash, PartialEq, Eq)]
53struct Shift { employee_id: usize }
54
55#[derive(Clone)]
56struct Schedule {
57employees: Vec<Employee>,
58shifts: Vec<Shift>,
59}
60
61// Count shifts per employee, including employees with 0 shifts
62let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
63.for_each(|s: &Schedule| &s.shifts)
64.group_by(|shift: &Shift| shift.employee_id, count())
65.complement(
66|s: &Schedule| s.employees.as_slice(),
67|emp: &Employee| emp.id,
68|_emp: &Employee| 0usize,
69)
70.penalize(|_employee_id: &usize, count: &usize| SoftScore::of(*count as i64))
71.named("Shift count");
72
73let schedule = Schedule {
74employees: vec![Employee { id: 0 }, Employee { id: 1 }, Employee { id: 2 }],
75shifts: vec![
76Shift { employee_id: 0 },
77Shift { employee_id: 0 },
78// Employee 1 has 0 shifts, Employee 2 has 0 shifts
79],
80};
81
82// Employee 0: 2, Employee 1: 0, Employee 2: 0 → Total: -2
83assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
84```
85*/
86pub struct ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, Sc>
87where
88    Sc: Score,
89{
90    extractor_a: EA,
91    extractor_b: EB,
92    key_a: KA,
93    key_b: KB,
94    collector: C,
95    default_fn: D,
96    _phantom: PhantomData<(
97        fn() -> S,
98        fn() -> A,
99        fn() -> B,
100        fn() -> K,
101        fn() -> V,
102        fn() -> R,
103        fn() -> Acc,
104        fn() -> Sc,
105    )>,
106}
107
108impl<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, Sc>
109    ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, Sc>
110where
111    S: Send + Sync + 'static,
112    A: Clone + Send + Sync + 'static,
113    B: Clone + Send + Sync + 'static,
114    K: Clone + Eq + Hash + Send + Sync + 'static,
115    EA: CollectionExtract<S, Item = A>,
116    EB: CollectionExtract<S, Item = B>,
117    KA: Fn(&A) -> Option<K> + Send + Sync,
118    KB: Fn(&B) -> K + Send + Sync,
119    C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
120    V: Send + Sync + 'static,
121    R: Send + Sync + 'static,
122    Acc: Accumulator<V, R> + Send + Sync + 'static,
123    D: Fn(&B) -> R + Send + Sync,
124    Sc: Score + 'static,
125{
126    fn into_weighted_builder<W>(
127        self,
128        impact_type: ImpactType,
129        weight_fn: W,
130        is_hard: bool,
131    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc>
132    where
133        W: Fn(&K, &R) -> Sc + Send + Sync,
134    {
135        ComplementedConstraintBuilder {
136            extractor_a: self.extractor_a,
137            extractor_b: self.extractor_b,
138            key_a: self.key_a,
139            key_b: self.key_b,
140            collector: self.collector,
141            default_fn: self.default_fn,
142            impact_type,
143            weight_fn,
144            is_hard,
145            _phantom: PhantomData,
146        }
147    }
148
149    // Creates a new complemented constraint stream.
150    pub(crate) fn new(
151        extractor_a: EA,
152        extractor_b: EB,
153        key_a: KA,
154        key_b: KB,
155        collector: C,
156        default_fn: D,
157    ) -> Self {
158        Self {
159            extractor_a,
160            extractor_b,
161            key_a,
162            key_b,
163            collector,
164            default_fn,
165            _phantom: PhantomData,
166        }
167    }
168
169    pub fn penalize<W>(
170        self,
171        weight: W,
172    ) -> ComplementedConstraintBuilder<
173        S,
174        A,
175        B,
176        K,
177        EA,
178        EB,
179        KA,
180        KB,
181        C,
182        V,
183        R,
184        Acc,
185        D,
186        impl Fn(&K, &R) -> Sc + Send + Sync,
187        Sc,
188    >
189    where
190        W: for<'w> ConstraintWeight<(&'w K, &'w R), Sc> + Send + Sync,
191    {
192        let is_hard = weight.is_hard();
193        self.into_weighted_builder(
194            ImpactType::Penalty,
195            move |key: &K, result: &R| weight.score((key, result)),
196            is_hard,
197        )
198    }
199
200    pub fn reward<W>(
201        self,
202        weight: W,
203    ) -> ComplementedConstraintBuilder<
204        S,
205        A,
206        B,
207        K,
208        EA,
209        EB,
210        KA,
211        KB,
212        C,
213        V,
214        R,
215        Acc,
216        D,
217        impl Fn(&K, &R) -> Sc + Send + Sync,
218        Sc,
219    >
220    where
221        W: for<'w> ConstraintWeight<(&'w K, &'w R), Sc> + Send + Sync,
222    {
223        let is_hard = weight.is_hard();
224        self.into_weighted_builder(
225            ImpactType::Reward,
226            move |key: &K, result: &R| weight.score((key, result)),
227            is_hard,
228        )
229    }
230}
231
232impl<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, Sc: Score> std::fmt::Debug
233    for ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, Sc>
234{
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        f.debug_struct("ComplementedConstraintStream").finish()
237    }
238}
239
240// Zero-erasure builder for finalizing a complemented constraint.
241pub struct ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc>
242where
243    Sc: Score,
244{
245    extractor_a: EA,
246    extractor_b: EB,
247    key_a: KA,
248    key_b: KB,
249    collector: C,
250    default_fn: D,
251    impact_type: ImpactType,
252    weight_fn: W,
253    is_hard: bool,
254    _phantom: PhantomData<(
255        fn() -> S,
256        fn() -> A,
257        fn() -> B,
258        fn() -> K,
259        fn() -> V,
260        fn() -> R,
261        fn() -> Acc,
262        fn() -> Sc,
263    )>,
264}
265
266impl<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc>
267    ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc>
268where
269    S: Send + Sync + 'static,
270    A: Clone + Send + Sync + 'static,
271    B: Clone + Send + Sync + 'static,
272    K: Clone + Eq + Hash + Send + Sync + 'static,
273    EA: CollectionExtract<S, Item = A>,
274    EB: CollectionExtract<S, Item = B>,
275    KA: Fn(&A) -> Option<K> + Send + Sync,
276    KB: Fn(&B) -> K + Send + Sync,
277    C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
278    V: Send + Sync + 'static,
279    R: Send + Sync + 'static,
280    Acc: Accumulator<V, R> + Send + Sync + 'static,
281    D: Fn(&B) -> R + Send + Sync,
282    W: Fn(&K, &R) -> Sc + Send + Sync,
283    Sc: Score + 'static,
284{
285    pub fn named(
286        self,
287        name: &str,
288    ) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc> {
289        ComplementedGroupConstraint::new(
290            ConstraintRef::new("", name),
291            self.impact_type,
292            self.extractor_a,
293            self.extractor_b,
294            self.key_a,
295            self.key_b,
296            self.collector,
297            self.default_fn,
298            self.weight_fn,
299            self.is_hard,
300        )
301    }
302
303    // Finalizes the builder into a `ComplementedGroupConstraint`.
304}
305
306impl<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc: Score> std::fmt::Debug
307    for ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc>
308{
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        f.debug_struct("ComplementedConstraintBuilder")
311            .field("impact_type", &self.impact_type)
312            .finish()
313    }
314}