Skip to main content

solverforge_scoring/stream/
grouped_stream.rs

1// Zero-erasure grouped constraint stream for group-by constraint patterns.
2//
3// A `GroupedConstraintStream` operates on groups of entities and supports
4// filtering, weighting, and constraint finalization.
5// All type information is preserved at compile time - no Arc, no dyn.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collector::UniCollector;
14use super::complemented_stream::ComplementedConstraintStream;
15use super::filter::UniFilter;
16use crate::constraint::grouped::GroupedUniConstraint;
17
18// Zero-erasure constraint stream over grouped entities.
19//
20// `GroupedConstraintStream` is created by `UniConstraintStream::group_by()`
21// and operates on (key, collector_result) tuples.
22//
23// All type parameters are concrete - no trait objects, no Arc allocations.
24//
25// # Type Parameters
26//
27// - `S` - Solution type
28// - `A` - Entity type
29// - `K` - Group key type
30// - `E` - Extractor function for entities
31// - `Fi` - Filter type (preserved from upstream stream)
32// - `KF` - Key function
33// - `C` - Collector type
34// - `Sc` - Score type
35//
36// # Example
37//
38// ```
39// use solverforge_scoring::stream::ConstraintFactory;
40// use solverforge_scoring::stream::collector::count;
41// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
42// use solverforge_core::score::SoftScore;
43//
44// #[derive(Clone, Hash, PartialEq, Eq)]
45// struct Shift { employee_id: usize }
46//
47// #[derive(Clone)]
48// struct Solution { shifts: Vec<Shift> }
49//
50// let constraint = ConstraintFactory::<Solution, SoftScore>::new()
51//     .for_each(|s: &Solution| &s.shifts)
52//     .group_by(|shift: &Shift| shift.employee_id, count())
53//     .penalize_with(|count: &usize| SoftScore::of((*count * *count) as i64))
54//     .as_constraint("Balanced workload");
55//
56// let solution = Solution {
57//     shifts: vec![
58//         Shift { employee_id: 1 },
59//         Shift { employee_id: 1 },
60//         Shift { employee_id: 1 },
61//         Shift { employee_id: 2 },
62//     ],
63// };
64//
65// // Employee 1: 3² = 9, Employee 2: 1² = 1, Total: -10
66// assert_eq!(constraint.evaluate(&solution), SoftScore::of(-10));
67// ```
68pub struct GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
69where
70    Sc: Score,
71{
72    extractor: E,
73    filter: Fi,
74    key_fn: KF,
75    collector: C,
76    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
77}
78
79impl<S, A, K, E, Fi, KF, C, Sc> GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
80where
81    S: Send + Sync + 'static,
82    A: Clone + Send + Sync + 'static,
83    K: Clone + Eq + Hash + Send + Sync + 'static,
84    E: Fn(&S) -> &[A] + Send + Sync,
85    Fi: UniFilter<S, A>,
86    KF: Fn(&A) -> K + Send + Sync,
87    C: UniCollector<A> + Send + Sync + 'static,
88    C::Accumulator: Send + Sync,
89    C::Result: Clone + Send + Sync,
90    Sc: Score + 'static,
91{
92    // Creates a new zero-erasure grouped constraint stream.
93    pub(crate) fn new(extractor: E, filter: Fi, key_fn: KF, collector: C) -> Self {
94        Self {
95            extractor,
96            filter,
97            key_fn,
98            collector,
99            _phantom: PhantomData,
100        }
101    }
102
103    // Penalizes each group with a weight based on the collector result.
104    //
105    // # Example
106    //
107    // ```
108    // use solverforge_scoring::stream::ConstraintFactory;
109    // use solverforge_scoring::stream::collector::count;
110    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
111    // use solverforge_core::score::SoftScore;
112    //
113    // #[derive(Clone, Hash, PartialEq, Eq)]
114    // struct Task { priority: u32 }
115    //
116    // #[derive(Clone)]
117    // struct Solution { tasks: Vec<Task> }
118    //
119    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
120    //     .for_each(|s: &Solution| &s.tasks)
121    //     .group_by(|t: &Task| t.priority, count())
122    //     .penalize_with(|count: &usize| SoftScore::of(*count as i64))
123    //     .as_constraint("Priority distribution");
124    //
125    // let solution = Solution {
126    //     tasks: vec![
127    //         Task { priority: 1 },
128    //         Task { priority: 1 },
129    //         Task { priority: 2 },
130    //     ],
131    // };
132    //
133    // // Priority 1: 2 tasks, Priority 2: 1 task, Total: -3
134    // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-3));
135    // ```
136    pub fn penalize_with<W>(
137        self,
138        weight_fn: W,
139    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
140    where
141        W: Fn(&C::Result) -> Sc + Send + Sync,
142    {
143        GroupedConstraintBuilder {
144            extractor: self.extractor,
145            filter: self.filter,
146            key_fn: self.key_fn,
147            collector: self.collector,
148            impact_type: ImpactType::Penalty,
149            weight_fn,
150            is_hard: false,
151            expected_descriptor: None,
152            _phantom: PhantomData,
153        }
154    }
155
156    // Penalizes each group with a weight, explicitly marked as hard constraint.
157    pub fn penalize_hard_with<W>(
158        self,
159        weight_fn: W,
160    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
161    where
162        W: Fn(&C::Result) -> Sc + Send + Sync,
163    {
164        GroupedConstraintBuilder {
165            extractor: self.extractor,
166            filter: self.filter,
167            key_fn: self.key_fn,
168            collector: self.collector,
169            impact_type: ImpactType::Penalty,
170            weight_fn,
171            is_hard: true,
172            expected_descriptor: None,
173            _phantom: PhantomData,
174        }
175    }
176
177    // Rewards each group with a weight based on the collector result.
178    pub fn reward_with<W>(
179        self,
180        weight_fn: W,
181    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
182    where
183        W: Fn(&C::Result) -> Sc + Send + Sync,
184    {
185        GroupedConstraintBuilder {
186            extractor: self.extractor,
187            filter: self.filter,
188            key_fn: self.key_fn,
189            collector: self.collector,
190            impact_type: ImpactType::Reward,
191            weight_fn,
192            is_hard: false,
193            expected_descriptor: None,
194            _phantom: PhantomData,
195        }
196    }
197
198    // Rewards each group with a weight, explicitly marked as hard constraint.
199    pub fn reward_hard_with<W>(
200        self,
201        weight_fn: W,
202    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
203    where
204        W: Fn(&C::Result) -> Sc + Send + Sync,
205    {
206        GroupedConstraintBuilder {
207            extractor: self.extractor,
208            filter: self.filter,
209            key_fn: self.key_fn,
210            collector: self.collector,
211            impact_type: ImpactType::Reward,
212            weight_fn,
213            is_hard: true,
214            expected_descriptor: None,
215            _phantom: PhantomData,
216        }
217    }
218
219    // Penalizes each group with one hard score unit.
220    pub fn penalize_hard(
221        self,
222    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
223    where
224        Sc: Copy,
225    {
226        let w = Sc::one_hard();
227        self.penalize_hard_with(move |_: &C::Result| w)
228    }
229
230    // Penalizes each group with one soft score unit.
231    pub fn penalize_soft(
232        self,
233    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
234    where
235        Sc: Copy,
236    {
237        let w = Sc::one_soft();
238        self.penalize_with(move |_: &C::Result| w)
239    }
240
241    // Rewards each group with one hard score unit.
242    pub fn reward_hard(
243        self,
244    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
245    where
246        Sc: Copy,
247    {
248        let w = Sc::one_hard();
249        self.reward_hard_with(move |_: &C::Result| w)
250    }
251
252    // Rewards each group with one soft score unit.
253    pub fn reward_soft(
254        self,
255    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
256    where
257        Sc: Copy,
258    {
259        let w = Sc::one_soft();
260        self.reward_with(move |_: &C::Result| w)
261    }
262
263    // Adds complement entities with default values for missing keys.
264    //
265    // This ensures all keys from the complement source are represented,
266    // using the grouped value if present, or the default value otherwise.
267    //
268    // **Note:** The key function for A entities wraps the original key to
269    // return `Some(K)`. For filtering (skipping entities without valid keys),
270    // use `complement_filtered` instead.
271    //
272    // # Example
273    //
274    // ```
275    // use solverforge_scoring::stream::ConstraintFactory;
276    // use solverforge_scoring::stream::collector::count;
277    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
278    // use solverforge_core::score::SoftScore;
279    //
280    // #[derive(Clone, Hash, PartialEq, Eq)]
281    // struct Employee { id: usize }
282    //
283    // #[derive(Clone, Hash, PartialEq, Eq)]
284    // struct Shift { employee_id: usize }
285    //
286    // #[derive(Clone)]
287    // struct Schedule {
288    //     employees: Vec<Employee>,
289    //     shifts: Vec<Shift>,
290    // }
291    //
292    // // Count shifts per employee, including employees with 0 shifts
293    // let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
294    //     .for_each(|s: &Schedule| &s.shifts)
295    //     .group_by(|shift: &Shift| shift.employee_id, count())
296    //     .complement(
297    //         |s: &Schedule| s.employees.as_slice(),
298    //         |emp: &Employee| emp.id,
299    //         |_emp: &Employee| 0usize,
300    //     )
301    //     .penalize_with(|count: &usize| SoftScore::of(*count as i64))
302    //     .as_constraint("Shift count");
303    //
304    // let schedule = Schedule {
305    //     employees: vec![Employee { id: 0 }, Employee { id: 1 }],
306    //     shifts: vec![
307    //         Shift { employee_id: 0 },
308    //         Shift { employee_id: 0 },
309    //     ],
310    // };
311    //
312    // // Employee 0: 2, Employee 1: 0 → Total: -2
313    // assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
314    // ```
315    pub fn complement<B, EB, KB, D>(
316        self,
317        extractor_b: EB,
318        key_b: KB,
319        default_fn: D,
320    ) -> ComplementedConstraintStream<
321        S,
322        A,
323        B,
324        K,
325        E,
326        EB,
327        impl Fn(&A) -> Option<K> + Send + Sync,
328        KB,
329        C,
330        D,
331        Sc,
332    >
333    where
334        B: Clone + Send + Sync + 'static,
335        EB: Fn(&S) -> &[B] + Send + Sync,
336        KB: Fn(&B) -> K + Send + Sync,
337        D: Fn(&B) -> C::Result + Send + Sync,
338    {
339        let key_fn = self.key_fn;
340        let wrapped_key_fn = move |a: &A| Some((key_fn)(a));
341        ComplementedConstraintStream::new(
342            self.extractor,
343            extractor_b,
344            wrapped_key_fn,
345            key_b,
346            self.collector,
347            default_fn,
348        )
349    }
350
351    // Adds complement entities with a custom key function for filtering.
352    //
353    // Like `complement`, but allows providing a custom key function for A entities
354    // that returns `Option<K>`. Entities returning `None` are skipped.
355    //
356    // # Example
357    //
358    // ```
359    // use solverforge_scoring::stream::ConstraintFactory;
360    // use solverforge_scoring::stream::collector::count;
361    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
362    // use solverforge_core::score::SoftScore;
363    //
364    // #[derive(Clone, Hash, PartialEq, Eq)]
365    // struct Employee { id: usize }
366    //
367    // #[derive(Clone, Hash, PartialEq, Eq)]
368    // struct Shift { employee_id: Option<usize> }
369    //
370    // #[derive(Clone)]
371    // struct Schedule {
372    //     employees: Vec<Employee>,
373    //     shifts: Vec<Shift>,
374    // }
375    //
376    // // Count shifts per employee, skipping unassigned shifts
377    // // The group_by key is ignored; complement_with_key provides its own
378    // let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
379    //     .for_each(|s: &Schedule| &s.shifts)
380    //     .group_by(|_shift: &Shift| 0usize, count())  // Placeholder key, will be overridden
381    //     .complement_with_key(
382    //         |s: &Schedule| s.employees.as_slice(),
383    //         |shift: &Shift| shift.employee_id,  // Option<usize>
384    //         |emp: &Employee| emp.id,            // usize
385    //         |_emp: &Employee| 0usize,
386    //     )
387    //     .penalize_with(|count: &usize| SoftScore::of(*count as i64))
388    //     .as_constraint("Shift count");
389    //
390    // let schedule = Schedule {
391    //     employees: vec![Employee { id: 0 }, Employee { id: 1 }],
392    //     shifts: vec![
393    //         Shift { employee_id: Some(0) },
394    //         Shift { employee_id: Some(0) },
395    //         Shift { employee_id: None },  // Skipped
396    //     ],
397    // };
398    //
399    // // Employee 0: 2, Employee 1: 0 → Total: -2
400    // assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
401    // ```
402    pub fn complement_with_key<B, EB, KA2, KB, D>(
403        self,
404        extractor_b: EB,
405        key_a: KA2,
406        key_b: KB,
407        default_fn: D,
408    ) -> ComplementedConstraintStream<S, A, B, K, E, EB, KA2, KB, C, D, Sc>
409    where
410        B: Clone + Send + Sync + 'static,
411        EB: Fn(&S) -> &[B] + Send + Sync,
412        KA2: Fn(&A) -> Option<K> + Send + Sync,
413        KB: Fn(&B) -> K + Send + Sync,
414        D: Fn(&B) -> C::Result + Send + Sync,
415    {
416        ComplementedConstraintStream::new(
417            self.extractor,
418            extractor_b,
419            key_a,
420            key_b,
421            self.collector,
422            default_fn,
423        )
424    }
425}
426
427impl<S, A, K, E, Fi, KF, C, Sc: Score> std::fmt::Debug
428    for GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
429{
430    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431        f.debug_struct("GroupedConstraintStream").finish()
432    }
433}
434
435// Zero-erasure builder for finalizing a grouped constraint.
436pub struct GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
437where
438    Sc: Score,
439{
440    extractor: E,
441    filter: Fi,
442    key_fn: KF,
443    collector: C,
444    impact_type: ImpactType,
445    weight_fn: W,
446    is_hard: bool,
447    expected_descriptor: Option<usize>,
448    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
449}
450
451impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
452where
453    S: Send + Sync + 'static,
454    A: Clone + Send + Sync + 'static,
455    K: Clone + Eq + Hash + Send + Sync + 'static,
456    E: Fn(&S) -> &[A] + Send + Sync,
457    Fi: UniFilter<S, A>,
458    KF: Fn(&A) -> K + Send + Sync,
459    C: UniCollector<A> + Send + Sync + 'static,
460    C::Accumulator: Send + Sync,
461    C::Result: Clone + Send + Sync,
462    W: Fn(&C::Result) -> Sc + Send + Sync,
463    Sc: Score + 'static,
464{
465    // Finalizes the builder into a zero-erasure `GroupedUniConstraint`.
466    //
467    // # Example
468    //
469    // ```
470    // use solverforge_scoring::stream::ConstraintFactory;
471    // use solverforge_scoring::stream::collector::count;
472    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
473    // use solverforge_core::score::SoftScore;
474    //
475    // #[derive(Clone, Hash, PartialEq, Eq)]
476    // struct Item { category: u32 }
477    //
478    // #[derive(Clone)]
479    // struct Solution { items: Vec<Item> }
480    //
481    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
482    //     .for_each(|s: &Solution| &s.items)
483    //     .group_by(|i: &Item| i.category, count())
484    //     .penalize_with(|n: &usize| SoftScore::of(*n as i64))
485    //     .as_constraint("Category penalty");
486    //
487    // assert_eq!(constraint.name(), "Category penalty");
488    // ```
489    // Alias for `as_constraint`.
490    pub fn named(self, name: &str) -> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc> {
491        self.as_constraint(name)
492    }
493
494    pub fn for_descriptor(mut self, descriptor_index: usize) -> Self {
495        self.expected_descriptor = Some(descriptor_index);
496        self
497    }
498
499    pub fn as_constraint(self, name: &str) -> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc> {
500        let mut constraint = GroupedUniConstraint::new(
501            ConstraintRef::new("", name),
502            self.impact_type,
503            self.extractor,
504            self.filter,
505            self.key_fn,
506            self.collector,
507            self.weight_fn,
508            self.is_hard,
509        );
510        if let Some(d) = self.expected_descriptor {
511            constraint = constraint.with_descriptor(d);
512        }
513        constraint
514    }
515}
516
517impl<S, A, K, E, Fi, KF, C, W, Sc: Score> std::fmt::Debug
518    for GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
519{
520    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521        f.debug_struct("GroupedConstraintBuilder")
522            .field("impact_type", &self.impact_type)
523            .finish()
524    }
525}