solverforge_scoring/stream/
grouped_stream.rs

1//! Zero-erasure grouped constraint stream for group-by constraint patterns.
2//!
3//! A `GroupedConstraintStream` operates on groups of entities and supports
4//! filtering, weighting, and constraint finalization.
5//! All type information is preserved at compile time - no Arc, no dyn.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collector::UniCollector;
14use super::complemented_stream::ComplementedConstraintStream;
15use crate::constraint::grouped::GroupedUniConstraint;
16
17/// Zero-erasure constraint stream over grouped entities.
18///
19/// `GroupedConstraintStream` is created by `UniConstraintStream::group_by()`
20/// and operates on (key, collector_result) tuples.
21///
22/// All type parameters are concrete - no trait objects, no Arc allocations.
23///
24/// # Type Parameters
25///
26/// - `S` - Solution type
27/// - `A` - Entity type
28/// - `K` - Group key type
29/// - `E` - Extractor function for entities
30/// - `KF` - Key function
31/// - `C` - Collector type
32/// - `Sc` - Score type
33///
34/// # Example
35///
36/// ```
37/// use solverforge_scoring::stream::ConstraintFactory;
38/// use solverforge_scoring::stream::collector::count;
39/// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
40/// use solverforge_core::score::SimpleScore;
41///
42/// #[derive(Clone, Hash, PartialEq, Eq)]
43/// struct Shift { employee_id: usize }
44///
45/// #[derive(Clone)]
46/// struct Solution { shifts: Vec<Shift> }
47///
48/// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
49///     .for_each(|s: &Solution| &s.shifts)
50///     .group_by(|shift: &Shift| shift.employee_id, count())
51///     .penalize_with(|count: &usize| SimpleScore::of((*count * *count) as i64))
52///     .as_constraint("Balanced workload");
53///
54/// let solution = Solution {
55///     shifts: vec![
56///         Shift { employee_id: 1 },
57///         Shift { employee_id: 1 },
58///         Shift { employee_id: 1 },
59///         Shift { employee_id: 2 },
60///     ],
61/// };
62///
63/// // Employee 1: 3² = 9, Employee 2: 1² = 1, Total: -10
64/// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-10));
65/// ```
66pub struct GroupedConstraintStream<S, A, K, E, KF, C, Sc>
67where
68    Sc: Score,
69{
70    extractor: E,
71    key_fn: KF,
72    collector: C,
73    _phantom: PhantomData<(S, A, K, Sc)>,
74}
75
76impl<S, A, K, E, KF, C, Sc> GroupedConstraintStream<S, A, K, E, KF, C, Sc>
77where
78    S: Send + Sync + 'static,
79    A: Clone + Send + Sync + 'static,
80    K: Clone + Eq + Hash + Send + Sync + 'static,
81    E: Fn(&S) -> &[A] + Send + Sync,
82    KF: Fn(&A) -> K + Send + Sync,
83    C: UniCollector<A> + Send + Sync + 'static,
84    C::Accumulator: Send + Sync,
85    C::Result: Clone + Send + Sync,
86    Sc: Score + 'static,
87{
88    /// Creates a new zero-erasure grouped constraint stream.
89    pub(crate) fn new(extractor: E, key_fn: KF, collector: C) -> Self {
90        Self {
91            extractor,
92            key_fn,
93            collector,
94            _phantom: PhantomData,
95        }
96    }
97
98    /// Penalizes each group with a weight based on the collector result.
99    ///
100    /// # Example
101    ///
102    /// ```
103    /// use solverforge_scoring::stream::ConstraintFactory;
104    /// use solverforge_scoring::stream::collector::count;
105    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
106    /// use solverforge_core::score::SimpleScore;
107    ///
108    /// #[derive(Clone, Hash, PartialEq, Eq)]
109    /// struct Task { priority: u32 }
110    ///
111    /// #[derive(Clone)]
112    /// struct Solution { tasks: Vec<Task> }
113    ///
114    /// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
115    ///     .for_each(|s: &Solution| &s.tasks)
116    ///     .group_by(|t: &Task| t.priority, count())
117    ///     .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
118    ///     .as_constraint("Priority distribution");
119    ///
120    /// let solution = Solution {
121    ///     tasks: vec![
122    ///         Task { priority: 1 },
123    ///         Task { priority: 1 },
124    ///         Task { priority: 2 },
125    ///     ],
126    /// };
127    ///
128    /// // Priority 1: 2 tasks, Priority 2: 1 task, Total: -3
129    /// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-3));
130    /// ```
131    pub fn penalize_with<W>(
132        self,
133        weight_fn: W,
134    ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
135    where
136        W: Fn(&C::Result) -> Sc + Send + Sync,
137    {
138        GroupedConstraintBuilder {
139            extractor: self.extractor,
140            key_fn: self.key_fn,
141            collector: self.collector,
142            impact_type: ImpactType::Penalty,
143            weight_fn,
144            is_hard: false,
145            _phantom: PhantomData,
146        }
147    }
148
149    /// Penalizes each group with a weight, explicitly marked as hard constraint.
150    pub fn penalize_hard_with<W>(
151        self,
152        weight_fn: W,
153    ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
154    where
155        W: Fn(&C::Result) -> Sc + Send + Sync,
156    {
157        GroupedConstraintBuilder {
158            extractor: self.extractor,
159            key_fn: self.key_fn,
160            collector: self.collector,
161            impact_type: ImpactType::Penalty,
162            weight_fn,
163            is_hard: true,
164            _phantom: PhantomData,
165        }
166    }
167
168    /// Rewards each group with a weight based on the collector result.
169    pub fn reward_with<W>(self, weight_fn: W) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
170    where
171        W: Fn(&C::Result) -> Sc + Send + Sync,
172    {
173        GroupedConstraintBuilder {
174            extractor: self.extractor,
175            key_fn: self.key_fn,
176            collector: self.collector,
177            impact_type: ImpactType::Reward,
178            weight_fn,
179            is_hard: false,
180            _phantom: PhantomData,
181        }
182    }
183
184    /// Rewards each group with a weight, explicitly marked as hard constraint.
185    pub fn reward_hard_with<W>(
186        self,
187        weight_fn: W,
188    ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
189    where
190        W: Fn(&C::Result) -> Sc + Send + Sync,
191    {
192        GroupedConstraintBuilder {
193            extractor: self.extractor,
194            key_fn: self.key_fn,
195            collector: self.collector,
196            impact_type: ImpactType::Reward,
197            weight_fn,
198            is_hard: true,
199            _phantom: PhantomData,
200        }
201    }
202
203    /// Adds complement entities with default values for missing keys.
204    ///
205    /// This ensures all keys from the complement source are represented,
206    /// using the grouped value if present, or the default value otherwise.
207    ///
208    /// **Note:** The key function for A entities wraps the original key to
209    /// return `Some(K)`. For filtering (skipping entities without valid keys),
210    /// use `complement_filtered` instead.
211    ///
212    /// # Example
213    ///
214    /// ```
215    /// use solverforge_scoring::stream::ConstraintFactory;
216    /// use solverforge_scoring::stream::collector::count;
217    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
218    /// use solverforge_core::score::SimpleScore;
219    ///
220    /// #[derive(Clone, Hash, PartialEq, Eq)]
221    /// struct Employee { id: usize }
222    ///
223    /// #[derive(Clone, Hash, PartialEq, Eq)]
224    /// struct Shift { employee_id: usize }
225    ///
226    /// #[derive(Clone)]
227    /// struct Schedule {
228    ///     employees: Vec<Employee>,
229    ///     shifts: Vec<Shift>,
230    /// }
231    ///
232    /// // Count shifts per employee, including employees with 0 shifts
233    /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
234    ///     .for_each(|s: &Schedule| &s.shifts)
235    ///     .group_by(|shift: &Shift| shift.employee_id, count())
236    ///     .complement(
237    ///         |s: &Schedule| s.employees.as_slice(),
238    ///         |emp: &Employee| emp.id,
239    ///         |_emp: &Employee| 0usize,
240    ///     )
241    ///     .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
242    ///     .as_constraint("Shift count");
243    ///
244    /// let schedule = Schedule {
245    ///     employees: vec![Employee { id: 0 }, Employee { id: 1 }],
246    ///     shifts: vec![
247    ///         Shift { employee_id: 0 },
248    ///         Shift { employee_id: 0 },
249    ///     ],
250    /// };
251    ///
252    /// // Employee 0: 2, Employee 1: 0 → Total: -2
253    /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
254    /// ```
255    pub fn complement<B, EB, KB, D>(
256        self,
257        extractor_b: EB,
258        key_b: KB,
259        default_fn: D,
260    ) -> ComplementedConstraintStream<
261        S,
262        A,
263        B,
264        K,
265        E,
266        EB,
267        impl Fn(&A) -> Option<K> + Send + Sync,
268        KB,
269        C,
270        D,
271        Sc,
272    >
273    where
274        B: Clone + Send + Sync + 'static,
275        EB: Fn(&S) -> &[B] + Send + Sync,
276        KB: Fn(&B) -> K + Send + Sync,
277        D: Fn(&B) -> C::Result + Send + Sync,
278    {
279        let key_fn = self.key_fn;
280        let wrapped_key_fn = move |a: &A| Some((key_fn)(a));
281        ComplementedConstraintStream::new(
282            self.extractor,
283            extractor_b,
284            wrapped_key_fn,
285            key_b,
286            self.collector,
287            default_fn,
288        )
289    }
290
291    /// Adds complement entities with a custom key function for filtering.
292    ///
293    /// Like `complement`, but allows providing a custom key function for A entities
294    /// that returns `Option<K>`. Entities returning `None` are skipped.
295    ///
296    /// # Example
297    ///
298    /// ```
299    /// use solverforge_scoring::stream::ConstraintFactory;
300    /// use solverforge_scoring::stream::collector::count;
301    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
302    /// use solverforge_core::score::SimpleScore;
303    ///
304    /// #[derive(Clone, Hash, PartialEq, Eq)]
305    /// struct Employee { id: usize }
306    ///
307    /// #[derive(Clone, Hash, PartialEq, Eq)]
308    /// struct Shift { employee_id: Option<usize> }
309    ///
310    /// #[derive(Clone)]
311    /// struct Schedule {
312    ///     employees: Vec<Employee>,
313    ///     shifts: Vec<Shift>,
314    /// }
315    ///
316    /// // Count shifts per employee, skipping unassigned shifts
317    /// // The group_by key is ignored; complement_with_key provides its own
318    /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
319    ///     .for_each(|s: &Schedule| &s.shifts)
320    ///     .group_by(|_shift: &Shift| 0usize, count())  // Placeholder key, will be overridden
321    ///     .complement_with_key(
322    ///         |s: &Schedule| s.employees.as_slice(),
323    ///         |shift: &Shift| shift.employee_id,  // Option<usize>
324    ///         |emp: &Employee| emp.id,            // usize
325    ///         |_emp: &Employee| 0usize,
326    ///     )
327    ///     .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
328    ///     .as_constraint("Shift count");
329    ///
330    /// let schedule = Schedule {
331    ///     employees: vec![Employee { id: 0 }, Employee { id: 1 }],
332    ///     shifts: vec![
333    ///         Shift { employee_id: Some(0) },
334    ///         Shift { employee_id: Some(0) },
335    ///         Shift { employee_id: None },  // Skipped
336    ///     ],
337    /// };
338    ///
339    /// // Employee 0: 2, Employee 1: 0 → Total: -2
340    /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
341    /// ```
342    pub fn complement_with_key<B, EB, KA2, KB, D>(
343        self,
344        extractor_b: EB,
345        key_a: KA2,
346        key_b: KB,
347        default_fn: D,
348    ) -> ComplementedConstraintStream<S, A, B, K, E, EB, KA2, KB, C, D, Sc>
349    where
350        B: Clone + Send + Sync + 'static,
351        EB: Fn(&S) -> &[B] + Send + Sync,
352        KA2: Fn(&A) -> Option<K> + Send + Sync,
353        KB: Fn(&B) -> K + Send + Sync,
354        D: Fn(&B) -> C::Result + Send + Sync,
355    {
356        ComplementedConstraintStream::new(
357            self.extractor,
358            extractor_b,
359            key_a,
360            key_b,
361            self.collector,
362            default_fn,
363        )
364    }
365}
366
367impl<S, A, K, E, KF, C, Sc: Score> std::fmt::Debug
368    for GroupedConstraintStream<S, A, K, E, KF, C, Sc>
369{
370    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371        f.debug_struct("GroupedConstraintStream").finish()
372    }
373}
374
375/// Zero-erasure builder for finalizing a grouped constraint.
376pub struct GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
377where
378    Sc: Score,
379{
380    extractor: E,
381    key_fn: KF,
382    collector: C,
383    impact_type: ImpactType,
384    weight_fn: W,
385    is_hard: bool,
386    _phantom: PhantomData<(S, A, K, Sc)>,
387}
388
389impl<S, A, K, E, KF, C, W, Sc> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
390where
391    S: Send + Sync + 'static,
392    A: Clone + Send + Sync + 'static,
393    K: Clone + Eq + Hash + Send + Sync + 'static,
394    E: Fn(&S) -> &[A] + Send + Sync,
395    KF: Fn(&A) -> K + Send + Sync,
396    C: UniCollector<A> + Send + Sync + 'static,
397    C::Accumulator: Send + Sync,
398    C::Result: Clone + Send + Sync,
399    W: Fn(&C::Result) -> Sc + Send + Sync,
400    Sc: Score + 'static,
401{
402    /// Finalizes the builder into a zero-erasure `GroupedUniConstraint`.
403    ///
404    /// # Example
405    ///
406    /// ```
407    /// use solverforge_scoring::stream::ConstraintFactory;
408    /// use solverforge_scoring::stream::collector::count;
409    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
410    /// use solverforge_core::score::SimpleScore;
411    ///
412    /// #[derive(Clone, Hash, PartialEq, Eq)]
413    /// struct Item { category: u32 }
414    ///
415    /// #[derive(Clone)]
416    /// struct Solution { items: Vec<Item> }
417    ///
418    /// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
419    ///     .for_each(|s: &Solution| &s.items)
420    ///     .group_by(|i: &Item| i.category, count())
421    ///     .penalize_with(|n: &usize| SimpleScore::of(*n as i64))
422    ///     .as_constraint("Category penalty");
423    ///
424    /// assert_eq!(constraint.name(), "Category penalty");
425    /// ```
426    pub fn as_constraint(self, name: &str) -> GroupedUniConstraint<S, A, K, E, KF, C, W, Sc> {
427        GroupedUniConstraint::new(
428            ConstraintRef::new("", name),
429            self.impact_type,
430            self.extractor,
431            self.key_fn,
432            self.collector,
433            self.weight_fn,
434            self.is_hard,
435        )
436    }
437}
438
439impl<S, A, K, E, KF, C, W, Sc: Score> std::fmt::Debug
440    for GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
441{
442    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443        f.debug_struct("GroupedConstraintBuilder")
444            .field("impact_type", &self.impact_type)
445            .finish()
446    }
447}