Skip to main content

solverforge_scoring/stream/
complemented_stream.rs

1// Zero-erasure complemented constraint stream.
2//
3// A `ComplementedConstraintStream` adds entities from a complement source
4// that are not present in grouped results, with default values.
5
6use std::hash::Hash;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use super::collector::UniCollector;
13use crate::constraint::complemented::ComplementedGroupConstraint;
14
15// Zero-erasure constraint stream with complemented groups.
16//
17// `ComplementedConstraintStream` results from calling `complement` on a
18// `GroupedConstraintStream`. It ensures all keys from a complement source
19// are represented, using default values for missing keys.
20//
21// The key function for A entities returns `Option<K>` to allow skipping
22// entities without valid keys (e.g., unassigned shifts).
23//
24// # Type Parameters
25//
26// - `S` - Solution type
27// - `A` - Original entity type (e.g., Shift)
28// - `B` - Complement entity type (e.g., Employee)
29// - `K` - Group key type
30// - `EA` - Extractor for A entities
31// - `EB` - Extractor for B entities (complement source)
32// - `KA` - Key function for A (returns `Option<K>` to allow filtering)
33// - `KB` - Key function for B
34// - `C` - Collector type
35// - `D` - Default value function
36// - `Sc` - Score type
37//
38// # Example
39//
40// ```
41// use solverforge_scoring::stream::ConstraintFactory;
42// use solverforge_scoring::stream::collector::count;
43// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
44// use solverforge_core::score::SimpleScore;
45//
46// #[derive(Clone, Hash, PartialEq, Eq)]
47// struct Employee { id: usize }
48//
49// #[derive(Clone, Hash, PartialEq, Eq)]
50// struct Shift { employee_id: usize }
51//
52// #[derive(Clone)]
53// struct Schedule {
54//     employees: Vec<Employee>,
55//     shifts: Vec<Shift>,
56// }
57//
58// // Count shifts per employee, including employees with 0 shifts
59// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
60//     .for_each(|s: &Schedule| &s.shifts)
61//     .group_by(|shift: &Shift| shift.employee_id, count())
62//     .complement(
63//         |s: &Schedule| s.employees.as_slice(),
64//         |emp: &Employee| emp.id,
65//         |_emp: &Employee| 0usize,
66//     )
67//     .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
68//     .as_constraint("Shift count");
69//
70// let schedule = Schedule {
71//     employees: vec![Employee { id: 0 }, Employee { id: 1 }, Employee { id: 2 }],
72//     shifts: vec![
73//         Shift { employee_id: 0 },
74//         Shift { employee_id: 0 },
75//         // Employee 1 has 0 shifts, Employee 2 has 0 shifts
76//     ],
77// };
78//
79// // Employee 0: 2, Employee 1: 0, Employee 2: 0 → Total: -2
80// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
81// ```
82pub struct ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
83where
84    Sc: Score,
85{
86    extractor_a: EA,
87    extractor_b: EB,
88    key_a: KA,
89    key_b: KB,
90    collector: C,
91    default_fn: D,
92    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
93}
94
95impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
96    ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
97where
98    S: Send + Sync + 'static,
99    A: Clone + Send + Sync + 'static,
100    B: Clone + Send + Sync + 'static,
101    K: Clone + Eq + Hash + Send + Sync + 'static,
102    EA: Fn(&S) -> &[A] + Send + Sync,
103    EB: Fn(&S) -> &[B] + Send + Sync,
104    KA: Fn(&A) -> Option<K> + Send + Sync,
105    KB: Fn(&B) -> K + Send + Sync,
106    C: UniCollector<A> + Send + Sync + 'static,
107    C::Accumulator: Send + Sync,
108    C::Result: Clone + Send + Sync,
109    D: Fn(&B) -> C::Result + Send + Sync,
110    Sc: Score + 'static,
111{
112    // Creates a new complemented constraint stream.
113    pub(crate) fn new(
114        extractor_a: EA,
115        extractor_b: EB,
116        key_a: KA,
117        key_b: KB,
118        collector: C,
119        default_fn: D,
120    ) -> Self {
121        Self {
122            extractor_a,
123            extractor_b,
124            key_a,
125            key_b,
126            collector,
127            default_fn,
128            _phantom: PhantomData,
129        }
130    }
131
132    // Penalizes each complemented group with a weight based on the result.
133    pub fn penalize_with<W>(
134        self,
135        weight_fn: W,
136    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
137    where
138        W: Fn(&C::Result) -> Sc + Send + Sync,
139    {
140        ComplementedConstraintBuilder {
141            extractor_a: self.extractor_a,
142            extractor_b: self.extractor_b,
143            key_a: self.key_a,
144            key_b: self.key_b,
145            collector: self.collector,
146            default_fn: self.default_fn,
147            impact_type: ImpactType::Penalty,
148            weight_fn,
149            is_hard: false,
150            _phantom: PhantomData,
151        }
152    }
153
154    // Penalizes each complemented group, explicitly marked as hard constraint.
155    pub fn penalize_hard_with<W>(
156        self,
157        weight_fn: W,
158    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
159    where
160        W: Fn(&C::Result) -> Sc + Send + Sync,
161    {
162        ComplementedConstraintBuilder {
163            extractor_a: self.extractor_a,
164            extractor_b: self.extractor_b,
165            key_a: self.key_a,
166            key_b: self.key_b,
167            collector: self.collector,
168            default_fn: self.default_fn,
169            impact_type: ImpactType::Penalty,
170            weight_fn,
171            is_hard: true,
172            _phantom: PhantomData,
173        }
174    }
175
176    // Rewards each complemented group with a weight based on the result.
177    pub fn reward_with<W>(
178        self,
179        weight_fn: W,
180    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
181    where
182        W: Fn(&C::Result) -> Sc + Send + Sync,
183    {
184        ComplementedConstraintBuilder {
185            extractor_a: self.extractor_a,
186            extractor_b: self.extractor_b,
187            key_a: self.key_a,
188            key_b: self.key_b,
189            collector: self.collector,
190            default_fn: self.default_fn,
191            impact_type: ImpactType::Reward,
192            weight_fn,
193            is_hard: false,
194            _phantom: PhantomData,
195        }
196    }
197
198    // Rewards each complemented group, explicitly marked as hard constraint.
199    pub fn reward_hard_with<W>(
200        self,
201        weight_fn: W,
202    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
203    where
204        W: Fn(&C::Result) -> Sc + Send + Sync,
205    {
206        ComplementedConstraintBuilder {
207            extractor_a: self.extractor_a,
208            extractor_b: self.extractor_b,
209            key_a: self.key_a,
210            key_b: self.key_b,
211            collector: self.collector,
212            default_fn: self.default_fn,
213            impact_type: ImpactType::Reward,
214            weight_fn,
215            is_hard: true,
216            _phantom: PhantomData,
217        }
218    }
219}
220
221impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc: Score> std::fmt::Debug
222    for ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
223{
224    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225        f.debug_struct("ComplementedConstraintStream").finish()
226    }
227}
228
229// Zero-erasure builder for finalizing a complemented constraint.
230pub struct ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
231where
232    Sc: Score,
233{
234    extractor_a: EA,
235    extractor_b: EB,
236    key_a: KA,
237    key_b: KB,
238    collector: C,
239    default_fn: D,
240    impact_type: ImpactType,
241    weight_fn: W,
242    is_hard: bool,
243    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
244}
245
246impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
247    ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
248where
249    S: Send + Sync + 'static,
250    A: Clone + Send + Sync + 'static,
251    B: Clone + Send + Sync + 'static,
252    K: Clone + Eq + Hash + Send + Sync + 'static,
253    EA: Fn(&S) -> &[A] + Send + Sync,
254    EB: Fn(&S) -> &[B] + Send + Sync,
255    KA: Fn(&A) -> Option<K> + Send + Sync,
256    KB: Fn(&B) -> K + Send + Sync,
257    C: UniCollector<A> + Send + Sync + 'static,
258    C::Accumulator: Send + Sync,
259    C::Result: Clone + Send + Sync,
260    D: Fn(&B) -> C::Result + Send + Sync,
261    W: Fn(&C::Result) -> Sc + Send + Sync,
262    Sc: Score + 'static,
263{
264    // Finalizes the builder into a `ComplementedGroupConstraint`.
265    pub fn as_constraint(
266        self,
267        name: &str,
268    ) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> {
269        ComplementedGroupConstraint::new(
270            ConstraintRef::new("", name),
271            self.impact_type,
272            self.extractor_a,
273            self.extractor_b,
274            self.key_a,
275            self.key_b,
276            self.collector,
277            self.default_fn,
278            self.weight_fn,
279            self.is_hard,
280        )
281    }
282}
283
284impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc: Score> std::fmt::Debug
285    for ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
286{
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        f.debug_struct("ComplementedConstraintBuilder")
289            .field("impact_type", &self.impact_type)
290            .finish()
291    }
292}