Skip to main content

solverforge_scoring/stream/
complemented_stream.rs

1/* Zero-erasure complemented constraint stream.
2
3A `ComplementedConstraintStream` adds entities from a complement source
4that are not present in grouped results, with default values.
5*/
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collection_extract::CollectionExtract;
14use super::collector::UniCollector;
15use super::weighting_support::ConstraintWeight;
16use crate::constraint::complemented::ComplementedGroupConstraint;
17
18/* Zero-erasure constraint stream with complemented groups.
19
20`ComplementedConstraintStream` results from calling `complement` on a
21`GroupedConstraintStream`. It ensures all keys from a complement source
22are represented, using default values for missing keys.
23
24The key function for A entities returns `Option<K>` to allow skipping
25entities without valid keys (e.g., unassigned shifts).
26
27# Type Parameters
28
29- `S` - Solution type
30- `A` - Original entity type (e.g., Shift)
31- `B` - Complement entity type (e.g., Employee)
32- `K` - Group key type
33- `EA` - Extractor for A entities
34- `EB` - Extractor for B entities (complement source)
35- `KA` - Key function for A (returns `Option<K>` to allow filtering)
36- `KB` - Key function for B
37- `C` - Collector type
38- `D` - Default value function
39- `Sc` - Score type
40
41# Example
42
43```
44use solverforge_scoring::stream::ConstraintFactory;
45use solverforge_scoring::stream::collector::count;
46use solverforge_scoring::api::constraint_set::IncrementalConstraint;
47use solverforge_core::score::SoftScore;
48
49#[derive(Clone, Hash, PartialEq, Eq)]
50struct Employee { id: usize }
51
52#[derive(Clone, Hash, PartialEq, Eq)]
53struct Shift { employee_id: usize }
54
55#[derive(Clone)]
56struct Schedule {
57employees: Vec<Employee>,
58shifts: Vec<Shift>,
59}
60
61// Count shifts per employee, including employees with 0 shifts
62let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
63.for_each(|s: &Schedule| &s.shifts)
64.group_by(|shift: &Shift| shift.employee_id, count())
65.complement(
66|s: &Schedule| s.employees.as_slice(),
67|emp: &Employee| emp.id,
68|_emp: &Employee| 0usize,
69)
70.penalize(|_employee_id: &usize, count: &usize| SoftScore::of(*count as i64))
71.named("Shift count");
72
73let schedule = Schedule {
74employees: vec![Employee { id: 0 }, Employee { id: 1 }, Employee { id: 2 }],
75shifts: vec![
76Shift { employee_id: 0 },
77Shift { employee_id: 0 },
78// Employee 1 has 0 shifts, Employee 2 has 0 shifts
79],
80};
81
82// Employee 0: 2, Employee 1: 0, Employee 2: 0 → Total: -2
83assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
84```
85*/
86pub struct ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
87where
88    Sc: Score,
89{
90    extractor_a: EA,
91    extractor_b: EB,
92    key_a: KA,
93    key_b: KB,
94    collector: C,
95    default_fn: D,
96    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
97}
98
99impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
100    ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
101where
102    S: Send + Sync + 'static,
103    A: Clone + Send + Sync + 'static,
104    B: Clone + Send + Sync + 'static,
105    K: Clone + Eq + Hash + Send + Sync + 'static,
106    EA: CollectionExtract<S, Item = A>,
107    EB: CollectionExtract<S, Item = B>,
108    KA: Fn(&A) -> Option<K> + Send + Sync,
109    KB: Fn(&B) -> K + Send + Sync,
110    C: UniCollector<A> + Send + Sync + 'static,
111    C::Accumulator: Send + Sync,
112    C::Result: Send + Sync,
113    D: Fn(&B) -> C::Result + Send + Sync,
114    Sc: Score + 'static,
115{
116    fn into_weighted_builder<W>(
117        self,
118        impact_type: ImpactType,
119        weight_fn: W,
120        is_hard: bool,
121    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
122    where
123        W: Fn(&K, &C::Result) -> Sc + Send + Sync,
124    {
125        ComplementedConstraintBuilder {
126            extractor_a: self.extractor_a,
127            extractor_b: self.extractor_b,
128            key_a: self.key_a,
129            key_b: self.key_b,
130            collector: self.collector,
131            default_fn: self.default_fn,
132            impact_type,
133            weight_fn,
134            is_hard,
135            _phantom: PhantomData,
136        }
137    }
138
139    // Creates a new complemented constraint stream.
140    pub(crate) fn new(
141        extractor_a: EA,
142        extractor_b: EB,
143        key_a: KA,
144        key_b: KB,
145        collector: C,
146        default_fn: D,
147    ) -> Self {
148        Self {
149            extractor_a,
150            extractor_b,
151            key_a,
152            key_b,
153            collector,
154            default_fn,
155            _phantom: PhantomData,
156        }
157    }
158
159    pub fn penalize<W>(
160        self,
161        weight: W,
162    ) -> ComplementedConstraintBuilder<
163        S,
164        A,
165        B,
166        K,
167        EA,
168        EB,
169        KA,
170        KB,
171        C,
172        D,
173        impl Fn(&K, &C::Result) -> Sc + Send + Sync,
174        Sc,
175    >
176    where
177        W: for<'w> ConstraintWeight<(&'w K, &'w C::Result), Sc> + Send + Sync,
178    {
179        let is_hard = weight.is_hard();
180        self.into_weighted_builder(
181            ImpactType::Penalty,
182            move |key: &K, result: &C::Result| weight.score((key, result)),
183            is_hard,
184        )
185    }
186
187    pub fn reward<W>(
188        self,
189        weight: W,
190    ) -> ComplementedConstraintBuilder<
191        S,
192        A,
193        B,
194        K,
195        EA,
196        EB,
197        KA,
198        KB,
199        C,
200        D,
201        impl Fn(&K, &C::Result) -> Sc + Send + Sync,
202        Sc,
203    >
204    where
205        W: for<'w> ConstraintWeight<(&'w K, &'w C::Result), Sc> + Send + Sync,
206    {
207        let is_hard = weight.is_hard();
208        self.into_weighted_builder(
209            ImpactType::Reward,
210            move |key: &K, result: &C::Result| weight.score((key, result)),
211            is_hard,
212        )
213    }
214}
215
216impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc: Score> std::fmt::Debug
217    for ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
218{
219    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220        f.debug_struct("ComplementedConstraintStream").finish()
221    }
222}
223
224// Zero-erasure builder for finalizing a complemented constraint.
225pub struct ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
226where
227    Sc: Score,
228{
229    extractor_a: EA,
230    extractor_b: EB,
231    key_a: KA,
232    key_b: KB,
233    collector: C,
234    default_fn: D,
235    impact_type: ImpactType,
236    weight_fn: W,
237    is_hard: bool,
238    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
239}
240
241impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
242    ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
243where
244    S: Send + Sync + 'static,
245    A: Clone + Send + Sync + 'static,
246    B: Clone + Send + Sync + 'static,
247    K: Clone + Eq + Hash + Send + Sync + 'static,
248    EA: CollectionExtract<S, Item = A>,
249    EB: CollectionExtract<S, Item = B>,
250    KA: Fn(&A) -> Option<K> + Send + Sync,
251    KB: Fn(&B) -> K + Send + Sync,
252    C: UniCollector<A> + Send + Sync + 'static,
253    C::Accumulator: Send + Sync,
254    C::Result: Send + Sync,
255    D: Fn(&B) -> C::Result + Send + Sync,
256    W: Fn(&K, &C::Result) -> Sc + Send + Sync,
257    Sc: Score + 'static,
258{
259    pub fn named(
260        self,
261        name: &str,
262    ) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> {
263        ComplementedGroupConstraint::new(
264            ConstraintRef::new("", name),
265            self.impact_type,
266            self.extractor_a,
267            self.extractor_b,
268            self.key_a,
269            self.key_b,
270            self.collector,
271            self.default_fn,
272            self.weight_fn,
273            self.is_hard,
274        )
275    }
276
277    // Finalizes the builder into a `ComplementedGroupConstraint`.
278}
279
280impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc: Score> std::fmt::Debug
281    for ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
282{
283    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284        f.debug_struct("ComplementedConstraintBuilder")
285            .field("impact_type", &self.impact_type)
286            .finish()
287    }
288}