solverforge_scoring/stream/
complemented_stream.rs

1//! Zero-erasure complemented constraint stream.
2//!
3//! A `ComplementedConstraintStream` adds entities from a complement source
4//! that are not present in grouped results, with default values.
5
6use std::hash::Hash;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use super::collector::UniCollector;
13use crate::constraint::complemented::ComplementedGroupConstraint;
14
15/// Zero-erasure constraint stream with complemented groups.
16///
17/// `ComplementedConstraintStream` results from calling `complement` on a
18/// `GroupedConstraintStream`. It ensures all keys from a complement source
19/// are represented, using default values for missing keys.
20///
21/// The key function for A entities returns `Option<K>` to allow skipping
22/// entities without valid keys (e.g., unassigned shifts).
23///
24/// # Type Parameters
25///
26/// - `S` - Solution type
27/// - `A` - Original entity type (e.g., Shift)
28/// - `B` - Complement entity type (e.g., Employee)
29/// - `K` - Group key type
30/// - `EA` - Extractor for A entities
31/// - `EB` - Extractor for B entities (complement source)
32/// - `KA` - Key function for A (returns `Option<K>` to allow filtering)
33/// - `KB` - Key function for B
34/// - `C` - Collector type
35/// - `D` - Default value function
36/// - `Sc` - Score type
37///
38/// # Example
39///
40/// ```
41/// use solverforge_scoring::stream::ConstraintFactory;
42/// use solverforge_scoring::stream::collector::count;
43/// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
44/// use solverforge_core::score::SimpleScore;
45///
46/// #[derive(Clone, Hash, PartialEq, Eq)]
47/// struct Employee { id: usize }
48///
49/// #[derive(Clone, Hash, PartialEq, Eq)]
50/// struct Shift { employee_id: usize }
51///
52/// #[derive(Clone)]
53/// struct Schedule {
54///     employees: Vec<Employee>,
55///     shifts: Vec<Shift>,
56/// }
57///
58/// // Count shifts per employee, including employees with 0 shifts
59/// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
60///     .for_each(|s: &Schedule| &s.shifts)
61///     .group_by(|shift: &Shift| shift.employee_id, count())
62///     .complement(
63///         |s: &Schedule| s.employees.as_slice(),
64///         |emp: &Employee| emp.id,
65///         |_emp: &Employee| 0usize,
66///     )
67///     .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
68///     .as_constraint("Shift count");
69///
70/// let schedule = Schedule {
71///     employees: vec![Employee { id: 0 }, Employee { id: 1 }, Employee { id: 2 }],
72///     shifts: vec![
73///         Shift { employee_id: 0 },
74///         Shift { employee_id: 0 },
75///         // Employee 1 has 0 shifts, Employee 2 has 0 shifts
76///     ],
77/// };
78///
79/// // Employee 0: 2, Employee 1: 0, Employee 2: 0 → Total: -2
80/// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
81/// ```
82pub struct ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
83where
84    Sc: Score,
85{
86    extractor_a: EA,
87    extractor_b: EB,
88    key_a: KA,
89    key_b: KB,
90    collector: C,
91    default_fn: D,
92    _phantom: PhantomData<(S, A, B, K, Sc)>,
93}
94
95impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
96    ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
97where
98    S: Send + Sync + 'static,
99    A: Clone + Send + Sync + 'static,
100    B: Clone + Send + Sync + 'static,
101    K: Clone + Eq + Hash + Send + Sync + 'static,
102    EA: Fn(&S) -> &[A] + Send + Sync,
103    EB: Fn(&S) -> &[B] + Send + Sync,
104    KA: Fn(&A) -> Option<K> + Send + Sync,
105    KB: Fn(&B) -> K + Send + Sync,
106    C: UniCollector<A> + Send + Sync + 'static,
107    C::Accumulator: Send + Sync,
108    C::Result: Clone + Send + Sync,
109    D: Fn(&B) -> C::Result + Send + Sync,
110    Sc: Score + 'static,
111{
112    /// Creates a new complemented constraint stream.
113    pub(crate) fn new(
114        extractor_a: EA,
115        extractor_b: EB,
116        key_a: KA,
117        key_b: KB,
118        collector: C,
119        default_fn: D,
120    ) -> Self {
121        Self {
122            extractor_a,
123            extractor_b,
124            key_a,
125            key_b,
126            collector,
127            default_fn,
128            _phantom: PhantomData,
129        }
130    }
131
132    /// Penalizes each complemented group with a weight based on the result.
133    pub fn penalize_with<W>(
134        self,
135        weight_fn: W,
136    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
137    where
138        W: Fn(&C::Result) -> Sc + Send + Sync,
139    {
140        ComplementedConstraintBuilder {
141            extractor_a: self.extractor_a,
142            extractor_b: self.extractor_b,
143            key_a: self.key_a,
144            key_b: self.key_b,
145            collector: self.collector,
146            default_fn: self.default_fn,
147            impact_type: ImpactType::Penalty,
148            weight_fn,
149            is_hard: false,
150            _phantom: PhantomData,
151        }
152    }
153
154    /// Penalizes each complemented group, explicitly marked as hard constraint.
155    pub fn penalize_hard_with<W>(
156        self,
157        weight_fn: W,
158    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
159    where
160        W: Fn(&C::Result) -> Sc + Send + Sync,
161    {
162        ComplementedConstraintBuilder {
163            extractor_a: self.extractor_a,
164            extractor_b: self.extractor_b,
165            key_a: self.key_a,
166            key_b: self.key_b,
167            collector: self.collector,
168            default_fn: self.default_fn,
169            impact_type: ImpactType::Penalty,
170            weight_fn,
171            is_hard: true,
172            _phantom: PhantomData,
173        }
174    }
175
176    /// Rewards each complemented group with a weight based on the result.
177    pub fn reward_with<W>(
178        self,
179        weight_fn: W,
180    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
181    where
182        W: Fn(&C::Result) -> Sc + Send + Sync,
183    {
184        ComplementedConstraintBuilder {
185            extractor_a: self.extractor_a,
186            extractor_b: self.extractor_b,
187            key_a: self.key_a,
188            key_b: self.key_b,
189            collector: self.collector,
190            default_fn: self.default_fn,
191            impact_type: ImpactType::Reward,
192            weight_fn,
193            is_hard: false,
194            _phantom: PhantomData,
195        }
196    }
197
198    /// Rewards each complemented group, explicitly marked as hard constraint.
199    pub fn reward_hard_with<W>(
200        self,
201        weight_fn: W,
202    ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
203    where
204        W: Fn(&C::Result) -> Sc + Send + Sync,
205    {
206        ComplementedConstraintBuilder {
207            extractor_a: self.extractor_a,
208            extractor_b: self.extractor_b,
209            key_a: self.key_a,
210            key_b: self.key_b,
211            collector: self.collector,
212            default_fn: self.default_fn,
213            impact_type: ImpactType::Reward,
214            weight_fn,
215            is_hard: true,
216            _phantom: PhantomData,
217        }
218    }
219}
220
221impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc: Score> std::fmt::Debug
222    for ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
223{
224    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225        f.debug_struct("ComplementedConstraintStream").finish()
226    }
227}
228
229/// Zero-erasure builder for finalizing a complemented constraint.
230pub struct ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
231where
232    Sc: Score,
233{
234    extractor_a: EA,
235    extractor_b: EB,
236    key_a: KA,
237    key_b: KB,
238    collector: C,
239    default_fn: D,
240    impact_type: ImpactType,
241    weight_fn: W,
242    is_hard: bool,
243    _phantom: PhantomData<(S, A, B, K, Sc)>,
244}
245
246impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
247    ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
248where
249    S: Send + Sync + 'static,
250    A: Clone + Send + Sync + 'static,
251    B: Clone + Send + Sync + 'static,
252    K: Clone + Eq + Hash + Send + Sync + 'static,
253    EA: Fn(&S) -> &[A] + Send + Sync,
254    EB: Fn(&S) -> &[B] + Send + Sync,
255    KA: Fn(&A) -> Option<K> + Send + Sync,
256    KB: Fn(&B) -> K + Send + Sync,
257    C: UniCollector<A> + Send + Sync + 'static,
258    C::Accumulator: Send + Sync,
259    C::Result: Clone + Send + Sync,
260    D: Fn(&B) -> C::Result + Send + Sync,
261    W: Fn(&C::Result) -> Sc + Send + Sync,
262    Sc: Score + 'static,
263{
264    /// Finalizes the builder into a `ComplementedGroupConstraint`.
265    pub fn as_constraint(
266        self,
267        name: &str,
268    ) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> {
269        ComplementedGroupConstraint::new(
270            ConstraintRef::new("", name),
271            self.impact_type,
272            self.extractor_a,
273            self.extractor_b,
274            self.key_a,
275            self.key_b,
276            self.collector,
277            self.default_fn,
278            self.weight_fn,
279            self.is_hard,
280        )
281    }
282}
283
284impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc: Score> std::fmt::Debug
285    for ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
286{
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        f.debug_struct("ComplementedConstraintBuilder")
289            .field("impact_type", &self.impact_type)
290            .finish()
291    }
292}