Skip to main content

solverforge_scoring/stream/
grouped_stream.rs

1// Zero-erasure grouped constraint stream for group-by constraint patterns.
2//
3// A `GroupedConstraintStream` operates on groups of entities and supports
4// filtering, weighting, and constraint finalization.
5// All type information is preserved at compile time - no Arc, no dyn.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collector::UniCollector;
14use super::complemented_stream::ComplementedConstraintStream;
15use crate::constraint::grouped::GroupedUniConstraint;
16
17// Zero-erasure constraint stream over grouped entities.
18//
19// `GroupedConstraintStream` is created by `UniConstraintStream::group_by()`
20// and operates on (key, collector_result) tuples.
21//
22// All type parameters are concrete - no trait objects, no Arc allocations.
23//
24// # Type Parameters
25//
26// - `S` - Solution type
27// - `A` - Entity type
28// - `K` - Group key type
29// - `E` - Extractor function for entities
30// - `KF` - Key function
31// - `C` - Collector type
32// - `Sc` - Score type
33//
34// # Example
35//
36// ```
37// use solverforge_scoring::stream::ConstraintFactory;
38// use solverforge_scoring::stream::collector::count;
39// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
40// use solverforge_core::score::SimpleScore;
41//
42// #[derive(Clone, Hash, PartialEq, Eq)]
43// struct Shift { employee_id: usize }
44//
45// #[derive(Clone)]
46// struct Solution { shifts: Vec<Shift> }
47//
48// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
49//     .for_each(|s: &Solution| &s.shifts)
50//     .group_by(|shift: &Shift| shift.employee_id, count())
51//     .penalize_with(|count: &usize| SimpleScore::of((*count * *count) as i64))
52//     .as_constraint("Balanced workload");
53//
54// let solution = Solution {
55//     shifts: vec![
56//         Shift { employee_id: 1 },
57//         Shift { employee_id: 1 },
58//         Shift { employee_id: 1 },
59//         Shift { employee_id: 2 },
60//     ],
61// };
62//
63// // Employee 1: 3² = 9, Employee 2: 1² = 1, Total: -10
64// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-10));
65// ```
66pub struct GroupedConstraintStream<S, A, K, E, KF, C, Sc>
67where
68    Sc: Score,
69{
70    extractor: E,
71    key_fn: KF,
72    collector: C,
73    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
74}
75
76impl<S, A, K, E, KF, C, Sc> GroupedConstraintStream<S, A, K, E, KF, C, Sc>
77where
78    S: Send + Sync + 'static,
79    A: Clone + Send + Sync + 'static,
80    K: Clone + Eq + Hash + Send + Sync + 'static,
81    E: Fn(&S) -> &[A] + Send + Sync,
82    KF: Fn(&A) -> K + Send + Sync,
83    C: UniCollector<A> + Send + Sync + 'static,
84    C::Accumulator: Send + Sync,
85    C::Result: Clone + Send + Sync,
86    Sc: Score + 'static,
87{
88    // Creates a new zero-erasure grouped constraint stream.
89    pub(crate) fn new(extractor: E, key_fn: KF, collector: C) -> Self {
90        Self {
91            extractor,
92            key_fn,
93            collector,
94            _phantom: PhantomData,
95        }
96    }
97
98    // Penalizes each group with a weight based on the collector result.
99    //
100    // # Example
101    //
102    // ```
103    // use solverforge_scoring::stream::ConstraintFactory;
104    // use solverforge_scoring::stream::collector::count;
105    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
106    // use solverforge_core::score::SimpleScore;
107    //
108    // #[derive(Clone, Hash, PartialEq, Eq)]
109    // struct Task { priority: u32 }
110    //
111    // #[derive(Clone)]
112    // struct Solution { tasks: Vec<Task> }
113    //
114    // let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
115    //     .for_each(|s: &Solution| &s.tasks)
116    //     .group_by(|t: &Task| t.priority, count())
117    //     .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
118    //     .as_constraint("Priority distribution");
119    //
120    // let solution = Solution {
121    //     tasks: vec![
122    //         Task { priority: 1 },
123    //         Task { priority: 1 },
124    //         Task { priority: 2 },
125    //     ],
126    // };
127    //
128    // // Priority 1: 2 tasks, Priority 2: 1 task, Total: -3
129    // assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-3));
130    // ```
131    pub fn penalize_with<W>(
132        self,
133        weight_fn: W,
134    ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
135    where
136        W: Fn(&C::Result) -> Sc + Send + Sync,
137    {
138        GroupedConstraintBuilder {
139            extractor: self.extractor,
140            key_fn: self.key_fn,
141            collector: self.collector,
142            impact_type: ImpactType::Penalty,
143            weight_fn,
144            is_hard: false,
145            _phantom: PhantomData,
146        }
147    }
148
149    // Penalizes each group with a weight, explicitly marked as hard constraint.
150    pub fn penalize_hard_with<W>(
151        self,
152        weight_fn: W,
153    ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
154    where
155        W: Fn(&C::Result) -> Sc + Send + Sync,
156    {
157        GroupedConstraintBuilder {
158            extractor: self.extractor,
159            key_fn: self.key_fn,
160            collector: self.collector,
161            impact_type: ImpactType::Penalty,
162            weight_fn,
163            is_hard: true,
164            _phantom: PhantomData,
165        }
166    }
167
168    // Rewards each group with a weight based on the collector result.
169    pub fn reward_with<W>(self, weight_fn: W) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
170    where
171        W: Fn(&C::Result) -> Sc + Send + Sync,
172    {
173        GroupedConstraintBuilder {
174            extractor: self.extractor,
175            key_fn: self.key_fn,
176            collector: self.collector,
177            impact_type: ImpactType::Reward,
178            weight_fn,
179            is_hard: false,
180            _phantom: PhantomData,
181        }
182    }
183
184    // Rewards each group with a weight, explicitly marked as hard constraint.
185    pub fn reward_hard_with<W>(
186        self,
187        weight_fn: W,
188    ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
189    where
190        W: Fn(&C::Result) -> Sc + Send + Sync,
191    {
192        GroupedConstraintBuilder {
193            extractor: self.extractor,
194            key_fn: self.key_fn,
195            collector: self.collector,
196            impact_type: ImpactType::Reward,
197            weight_fn,
198            is_hard: true,
199            _phantom: PhantomData,
200        }
201    }
202
203    // Adds complement entities with default values for missing keys.
204    //
205    // This ensures all keys from the complement source are represented,
206    // using the grouped value if present, or the default value otherwise.
207    //
208    // **Note:** The key function for A entities wraps the original key to
209    // return `Some(K)`. For filtering (skipping entities without valid keys),
210    // use `complement_filtered` instead.
211    //
212    // # Example
213    //
214    // ```
215    // use solverforge_scoring::stream::ConstraintFactory;
216    // use solverforge_scoring::stream::collector::count;
217    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
218    // use solverforge_core::score::SimpleScore;
219    //
220    // #[derive(Clone, Hash, PartialEq, Eq)]
221    // struct Employee { id: usize }
222    //
223    // #[derive(Clone, Hash, PartialEq, Eq)]
224    // struct Shift { employee_id: usize }
225    //
226    // #[derive(Clone)]
227    // struct Schedule {
228    //     employees: Vec<Employee>,
229    //     shifts: Vec<Shift>,
230    // }
231    //
232    // // Count shifts per employee, including employees with 0 shifts
233    // let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
234    //     .for_each(|s: &Schedule| &s.shifts)
235    //     .group_by(|shift: &Shift| shift.employee_id, count())
236    //     .complement(
237    //         |s: &Schedule| s.employees.as_slice(),
238    //         |emp: &Employee| emp.id,
239    //         |_emp: &Employee| 0usize,
240    //     )
241    //     .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
242    //     .as_constraint("Shift count");
243    //
244    // let schedule = Schedule {
245    //     employees: vec![Employee { id: 0 }, Employee { id: 1 }],
246    //     shifts: vec![
247    //         Shift { employee_id: 0 },
248    //         Shift { employee_id: 0 },
249    //     ],
250    // };
251    //
252    // // Employee 0: 2, Employee 1: 0 → Total: -2
253    // assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
254    // ```
255    pub fn complement<B, EB, KB, D>(
256        self,
257        extractor_b: EB,
258        key_b: KB,
259        default_fn: D,
260    ) -> ComplementedConstraintStream<
261        S,
262        A,
263        B,
264        K,
265        E,
266        EB,
267        impl Fn(&A) -> Option<K> + Send + Sync,
268        KB,
269        C,
270        D,
271        Sc,
272    >
273    where
274        B: Clone + Send + Sync + 'static,
275        EB: Fn(&S) -> &[B] + Send + Sync,
276        KB: Fn(&B) -> K + Send + Sync,
277        D: Fn(&B) -> C::Result + Send + Sync,
278    {
279        let key_fn = self.key_fn;
280        let wrapped_key_fn = move |a: &A| Some((key_fn)(a));
281        ComplementedConstraintStream::new(
282            self.extractor,
283            extractor_b,
284            wrapped_key_fn,
285            key_b,
286            self.collector,
287            default_fn,
288        )
289    }
290
291    // Adds complement entities with a custom key function for filtering.
292    //
293    // Like `complement`, but allows providing a custom key function for A entities
294    // that returns `Option<K>`. Entities returning `None` are skipped.
295    //
296    // # Example
297    //
298    // ```
299    // use solverforge_scoring::stream::ConstraintFactory;
300    // use solverforge_scoring::stream::collector::count;
301    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
302    // use solverforge_core::score::SimpleScore;
303    //
304    // #[derive(Clone, Hash, PartialEq, Eq)]
305    // struct Employee { id: usize }
306    //
307    // #[derive(Clone, Hash, PartialEq, Eq)]
308    // struct Shift { employee_id: Option<usize> }
309    //
310    // #[derive(Clone)]
311    // struct Schedule {
312    //     employees: Vec<Employee>,
313    //     shifts: Vec<Shift>,
314    // }
315    //
316    // // Count shifts per employee, skipping unassigned shifts
317    // // The group_by key is ignored; complement_with_key provides its own
318    // let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
319    //     .for_each(|s: &Schedule| &s.shifts)
320    //     .group_by(|_shift: &Shift| 0usize, count())  // Placeholder key, will be overridden
321    //     .complement_with_key(
322    //         |s: &Schedule| s.employees.as_slice(),
323    //         |shift: &Shift| shift.employee_id,  // Option<usize>
324    //         |emp: &Employee| emp.id,            // usize
325    //         |_emp: &Employee| 0usize,
326    //     )
327    //     .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
328    //     .as_constraint("Shift count");
329    //
330    // let schedule = Schedule {
331    //     employees: vec![Employee { id: 0 }, Employee { id: 1 }],
332    //     shifts: vec![
333    //         Shift { employee_id: Some(0) },
334    //         Shift { employee_id: Some(0) },
335    //         Shift { employee_id: None },  // Skipped
336    //     ],
337    // };
338    //
339    // // Employee 0: 2, Employee 1: 0 → Total: -2
340    // assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
341    // ```
342    pub fn complement_with_key<B, EB, KA2, KB, D>(
343        self,
344        extractor_b: EB,
345        key_a: KA2,
346        key_b: KB,
347        default_fn: D,
348    ) -> ComplementedConstraintStream<S, A, B, K, E, EB, KA2, KB, C, D, Sc>
349    where
350        B: Clone + Send + Sync + 'static,
351        EB: Fn(&S) -> &[B] + Send + Sync,
352        KA2: Fn(&A) -> Option<K> + Send + Sync,
353        KB: Fn(&B) -> K + Send + Sync,
354        D: Fn(&B) -> C::Result + Send + Sync,
355    {
356        ComplementedConstraintStream::new(
357            self.extractor,
358            extractor_b,
359            key_a,
360            key_b,
361            self.collector,
362            default_fn,
363        )
364    }
365}
366
367impl<S, A, K, E, KF, C, Sc: Score> std::fmt::Debug
368    for GroupedConstraintStream<S, A, K, E, KF, C, Sc>
369{
370    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371        f.debug_struct("GroupedConstraintStream").finish()
372    }
373}
374
375// Zero-erasure builder for finalizing a grouped constraint.
376pub struct GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
377where
378    Sc: Score,
379{
380    extractor: E,
381    key_fn: KF,
382    collector: C,
383    impact_type: ImpactType,
384    weight_fn: W,
385    is_hard: bool,
386    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
387}
388
389impl<S, A, K, E, KF, C, W, Sc> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
390where
391    S: Send + Sync + 'static,
392    A: Clone + Send + Sync + 'static,
393    K: Clone + Eq + Hash + Send + Sync + 'static,
394    E: Fn(&S) -> &[A] + Send + Sync,
395    KF: Fn(&A) -> K + Send + Sync,
396    C: UniCollector<A> + Send + Sync + 'static,
397    C::Accumulator: Send + Sync,
398    C::Result: Clone + Send + Sync,
399    W: Fn(&C::Result) -> Sc + Send + Sync,
400    Sc: Score + 'static,
401{
402    // Finalizes the builder into a zero-erasure `GroupedUniConstraint`.
403    //
404    // # Example
405    //
406    // ```
407    // use solverforge_scoring::stream::ConstraintFactory;
408    // use solverforge_scoring::stream::collector::count;
409    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
410    // use solverforge_core::score::SimpleScore;
411    //
412    // #[derive(Clone, Hash, PartialEq, Eq)]
413    // struct Item { category: u32 }
414    //
415    // #[derive(Clone)]
416    // struct Solution { items: Vec<Item> }
417    //
418    // let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
419    //     .for_each(|s: &Solution| &s.items)
420    //     .group_by(|i: &Item| i.category, count())
421    //     .penalize_with(|n: &usize| SimpleScore::of(*n as i64))
422    //     .as_constraint("Category penalty");
423    //
424    // assert_eq!(constraint.name(), "Category penalty");
425    // ```
426    pub fn as_constraint(self, name: &str) -> GroupedUniConstraint<S, A, K, E, KF, C, W, Sc> {
427        GroupedUniConstraint::new(
428            ConstraintRef::new("", name),
429            self.impact_type,
430            self.extractor,
431            self.key_fn,
432            self.collector,
433            self.weight_fn,
434            self.is_hard,
435        )
436    }
437}
438
439impl<S, A, K, E, KF, C, W, Sc: Score> std::fmt::Debug
440    for GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
441{
442    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443        f.debug_struct("GroupedConstraintBuilder")
444            .field("impact_type", &self.impact_type)
445            .finish()
446    }
447}