Skip to main content

solverforge_scoring/stream/
grouped_stream.rs

1/* Zero-erasure grouped constraint stream for group-by constraint patterns.
2
3A `GroupedConstraintStream` operates on groups of entities and supports
4filtering, weighting, and constraint finalization.
5All type information is preserved at compile time - no Arc, no dyn.
6*/
7
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use super::collection_extract::CollectionExtract;
15use super::collector::UniCollector;
16use super::complemented_stream::ComplementedConstraintStream;
17use super::filter::UniFilter;
18use crate::constraint::grouped::GroupedUniConstraint;
19
20/* Zero-erasure constraint stream over grouped entities.
21
22`GroupedConstraintStream` is created by `UniConstraintStream::group_by()`
23and operates on (key, collector_result) tuples.
24
25All type parameters are concrete - no trait objects, no Arc allocations.
26
27# Type Parameters
28
29- `S` - Solution type
30- `A` - Entity type
31- `K` - Group key type
32- `E` - Extractor function for entities
33- `Fi` - Filter type (preserved from upstream stream)
34- `KF` - Key function
35- `C` - Collector type
36- `Sc` - Score type
37
38# Example
39
40```
41use solverforge_scoring::stream::ConstraintFactory;
42use solverforge_scoring::stream::collector::count;
43use solverforge_scoring::api::constraint_set::IncrementalConstraint;
44use solverforge_core::score::SoftScore;
45
46#[derive(Clone, Hash, PartialEq, Eq)]
47struct Shift { employee_id: usize }
48
49#[derive(Clone)]
50struct Solution { shifts: Vec<Shift> }
51
52let constraint = ConstraintFactory::<Solution, SoftScore>::new()
53.for_each(|s: &Solution| &s.shifts)
54.group_by(|shift: &Shift| shift.employee_id, count())
55.penalize_with(|count: &usize| SoftScore::of((*count * *count) as i64))
56.named("Balanced workload");
57
58let solution = Solution {
59shifts: vec![
60Shift { employee_id: 1 },
61Shift { employee_id: 1 },
62Shift { employee_id: 1 },
63Shift { employee_id: 2 },
64],
65};
66
67// Employee 1: 3² = 9, Employee 2: 1² = 1, Total: -10
68assert_eq!(constraint.evaluate(&solution), SoftScore::of(-10));
69```
70*/
71pub struct GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
72where
73    Sc: Score,
74{
75    extractor: E,
76    filter: Fi,
77    key_fn: KF,
78    collector: C,
79    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
80}
81
82impl<S, A, K, E, Fi, KF, C, Sc> GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
83where
84    S: Send + Sync + 'static,
85    A: Clone + Send + Sync + 'static,
86    K: Clone + Eq + Hash + Send + Sync + 'static,
87    E: CollectionExtract<S, Item = A>,
88    Fi: UniFilter<S, A>,
89    KF: Fn(&A) -> K + Send + Sync,
90    C: UniCollector<A> + Send + Sync + 'static,
91    C::Accumulator: Send + Sync,
92    C::Result: Clone + Send + Sync,
93    Sc: Score + 'static,
94{
95    // Creates a new zero-erasure grouped constraint stream.
96    pub(crate) fn new(extractor: E, filter: Fi, key_fn: KF, collector: C) -> Self {
97        Self {
98            extractor,
99            filter,
100            key_fn,
101            collector,
102            _phantom: PhantomData,
103        }
104    }
105
106    /* Penalizes each group with a weight based on the collector result.
107
108    # Example
109
110    ```
111    use solverforge_scoring::stream::ConstraintFactory;
112    use solverforge_scoring::stream::collector::count;
113    use solverforge_scoring::api::constraint_set::IncrementalConstraint;
114    use solverforge_core::score::SoftScore;
115
116    #[derive(Clone, Hash, PartialEq, Eq)]
117    struct Task { priority: u32 }
118
119    #[derive(Clone)]
120    struct Solution { tasks: Vec<Task> }
121
122    let constraint = ConstraintFactory::<Solution, SoftScore>::new()
123    .for_each(|s: &Solution| &s.tasks)
124    .group_by(|t: &Task| t.priority, count())
125    .penalize_with(|count: &usize| SoftScore::of(*count as i64))
126    .named("Priority distribution");
127
128    let solution = Solution {
129    tasks: vec![
130    Task { priority: 1 },
131    Task { priority: 1 },
132    Task { priority: 2 },
133    ],
134    };
135
136    // Priority 1: 2 tasks, Priority 2: 1 task, Total: -3
137    assert_eq!(constraint.evaluate(&solution), SoftScore::of(-3));
138    ```
139    */
140    pub fn penalize_with<W>(
141        self,
142        weight_fn: W,
143    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
144    where
145        W: Fn(&C::Result) -> Sc + Send + Sync,
146    {
147        GroupedConstraintBuilder {
148            extractor: self.extractor,
149            filter: self.filter,
150            key_fn: self.key_fn,
151            collector: self.collector,
152            impact_type: ImpactType::Penalty,
153            weight_fn,
154            is_hard: false,
155            expected_descriptor: None,
156            _phantom: PhantomData,
157        }
158    }
159
160    // Penalizes each group with a weight, explicitly marked as hard constraint.
161    pub fn penalize_hard_with<W>(
162        self,
163        weight_fn: W,
164    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
165    where
166        W: Fn(&C::Result) -> Sc + Send + Sync,
167    {
168        GroupedConstraintBuilder {
169            extractor: self.extractor,
170            filter: self.filter,
171            key_fn: self.key_fn,
172            collector: self.collector,
173            impact_type: ImpactType::Penalty,
174            weight_fn,
175            is_hard: true,
176            expected_descriptor: None,
177            _phantom: PhantomData,
178        }
179    }
180
181    // Rewards each group with a weight based on the collector result.
182    pub fn reward_with<W>(
183        self,
184        weight_fn: W,
185    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
186    where
187        W: Fn(&C::Result) -> Sc + Send + Sync,
188    {
189        GroupedConstraintBuilder {
190            extractor: self.extractor,
191            filter: self.filter,
192            key_fn: self.key_fn,
193            collector: self.collector,
194            impact_type: ImpactType::Reward,
195            weight_fn,
196            is_hard: false,
197            expected_descriptor: None,
198            _phantom: PhantomData,
199        }
200    }
201
202    // Rewards each group with a weight, explicitly marked as hard constraint.
203    pub fn reward_hard_with<W>(
204        self,
205        weight_fn: W,
206    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
207    where
208        W: Fn(&C::Result) -> Sc + Send + Sync,
209    {
210        GroupedConstraintBuilder {
211            extractor: self.extractor,
212            filter: self.filter,
213            key_fn: self.key_fn,
214            collector: self.collector,
215            impact_type: ImpactType::Reward,
216            weight_fn,
217            is_hard: true,
218            expected_descriptor: None,
219            _phantom: PhantomData,
220        }
221    }
222
223    // Penalizes each group with one hard score unit.
224    pub fn penalize_hard(
225        self,
226    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
227    where
228        Sc: Copy,
229    {
230        let w = Sc::one_hard();
231        self.penalize_hard_with(move |_: &C::Result| w)
232    }
233
234    // Penalizes each group with one soft score unit.
235    pub fn penalize_soft(
236        self,
237    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
238    where
239        Sc: Copy,
240    {
241        let w = Sc::one_soft();
242        self.penalize_with(move |_: &C::Result| w)
243    }
244
245    // Rewards each group with one hard score unit.
246    pub fn reward_hard(
247        self,
248    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
249    where
250        Sc: Copy,
251    {
252        let w = Sc::one_hard();
253        self.reward_hard_with(move |_: &C::Result| w)
254    }
255
256    // Rewards each group with one soft score unit.
257    pub fn reward_soft(
258        self,
259    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
260    where
261        Sc: Copy,
262    {
263        let w = Sc::one_soft();
264        self.reward_with(move |_: &C::Result| w)
265    }
266
267    /* Adds complement entities with default values for missing keys.
268
269    This ensures all keys from the complement source are represented,
270    using the grouped value if present, or the default value otherwise.
271
272    **Note:** The key function for A entities wraps the original key to
273    return `Some(K)`. For filtering (skipping entities without valid keys),
274    use `complement_filtered` instead.
275
276    # Example
277
278    ```
279    use solverforge_scoring::stream::ConstraintFactory;
280    use solverforge_scoring::stream::collector::count;
281    use solverforge_scoring::api::constraint_set::IncrementalConstraint;
282    use solverforge_core::score::SoftScore;
283
284    #[derive(Clone, Hash, PartialEq, Eq)]
285    struct Employee { id: usize }
286
287    #[derive(Clone, Hash, PartialEq, Eq)]
288    struct Shift { employee_id: usize }
289
290    #[derive(Clone)]
291    struct Schedule {
292    employees: Vec<Employee>,
293    shifts: Vec<Shift>,
294    }
295
296    // Count shifts per employee, including employees with 0 shifts
297    let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
298    .for_each(|s: &Schedule| &s.shifts)
299    .group_by(|shift: &Shift| shift.employee_id, count())
300    .complement(
301    |s: &Schedule| s.employees.as_slice(),
302    |emp: &Employee| emp.id,
303    |_emp: &Employee| 0usize,
304    )
305    .penalize_with(|count: &usize| SoftScore::of(*count as i64))
306    .named("Shift count");
307
308    let schedule = Schedule {
309    employees: vec![Employee { id: 0 }, Employee { id: 1 }],
310    shifts: vec![
311    Shift { employee_id: 0 },
312    Shift { employee_id: 0 },
313    ],
314    };
315
316    // Employee 0: 2, Employee 1: 0 → Total: -2
317    assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
318    ```
319    */
320    pub fn complement<B, EB, KB, D>(
321        self,
322        extractor_b: EB,
323        key_b: KB,
324        default_fn: D,
325    ) -> ComplementedConstraintStream<
326        S,
327        A,
328        B,
329        K,
330        E,
331        EB,
332        impl Fn(&A) -> Option<K> + Send + Sync,
333        KB,
334        C,
335        D,
336        Sc,
337    >
338    where
339        B: Clone + Send + Sync + 'static,
340        EB: CollectionExtract<S, Item = B>,
341        KB: Fn(&B) -> K + Send + Sync,
342        D: Fn(&B) -> C::Result + Send + Sync,
343    {
344        let key_fn = self.key_fn;
345        let wrapped_key_fn = move |a: &A| Some((key_fn)(a));
346        ComplementedConstraintStream::new(
347            self.extractor,
348            extractor_b,
349            wrapped_key_fn,
350            key_b,
351            self.collector,
352            default_fn,
353        )
354    }
355
356    /* Adds complement entities with a custom key function for filtering.
357
358    Like `complement`, but allows providing a custom key function for A entities
359    that returns `Option<K>`. Entities returning `None` are skipped.
360
361    # Example
362
363    ```
364    use solverforge_scoring::stream::ConstraintFactory;
365    use solverforge_scoring::stream::collector::count;
366    use solverforge_scoring::api::constraint_set::IncrementalConstraint;
367    use solverforge_core::score::SoftScore;
368
369    #[derive(Clone, Hash, PartialEq, Eq)]
370    struct Employee { id: usize }
371
372    #[derive(Clone, Hash, PartialEq, Eq)]
373    struct Shift { employee_id: Option<usize> }
374
375    #[derive(Clone)]
376    struct Schedule {
377    employees: Vec<Employee>,
378    shifts: Vec<Shift>,
379    }
380
381    // Count shifts per employee, skipping unassigned shifts
382    // The group_by key is ignored; complement_with_key provides its own
383    let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
384    .for_each(|s: &Schedule| &s.shifts)
385    .group_by(|_shift: &Shift| 0usize, count())  // Placeholder key, will be overridden
386    .complement_with_key(
387    |s: &Schedule| s.employees.as_slice(),
388    |shift: &Shift| shift.employee_id,  // Option<usize>
389    |emp: &Employee| emp.id,            // usize
390    |_emp: &Employee| 0usize,
391    )
392    .penalize_with(|count: &usize| SoftScore::of(*count as i64))
393    .named("Shift count");
394
395    let schedule = Schedule {
396    employees: vec![Employee { id: 0 }, Employee { id: 1 }],
397    shifts: vec![
398    Shift { employee_id: Some(0) },
399    Shift { employee_id: Some(0) },
400    Shift { employee_id: None },  // Skipped
401    ],
402    };
403
404    // Employee 0: 2, Employee 1: 0 → Total: -2
405    assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
406    ```
407    */
408    pub fn complement_with_key<B, EB, KA2, KB, D>(
409        self,
410        extractor_b: EB,
411        key_a: KA2,
412        key_b: KB,
413        default_fn: D,
414    ) -> ComplementedConstraintStream<S, A, B, K, E, EB, KA2, KB, C, D, Sc>
415    where
416        B: Clone + Send + Sync + 'static,
417        EB: CollectionExtract<S, Item = B>,
418        KA2: Fn(&A) -> Option<K> + Send + Sync,
419        KB: Fn(&B) -> K + Send + Sync,
420        D: Fn(&B) -> C::Result + Send + Sync,
421    {
422        ComplementedConstraintStream::new(
423            self.extractor,
424            extractor_b,
425            key_a,
426            key_b,
427            self.collector,
428            default_fn,
429        )
430    }
431}
432
433impl<S, A, K, E, Fi, KF, C, Sc: Score> std::fmt::Debug
434    for GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
435{
436    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
437        f.debug_struct("GroupedConstraintStream").finish()
438    }
439}
440
441// Zero-erasure builder for finalizing a grouped constraint.
442pub struct GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
443where
444    Sc: Score,
445{
446    extractor: E,
447    filter: Fi,
448    key_fn: KF,
449    collector: C,
450    impact_type: ImpactType,
451    weight_fn: W,
452    is_hard: bool,
453    expected_descriptor: Option<usize>,
454    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
455}
456
457impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
458where
459    S: Send + Sync + 'static,
460    A: Clone + Send + Sync + 'static,
461    K: Clone + Eq + Hash + Send + Sync + 'static,
462    E: CollectionExtract<S, Item = A>,
463    Fi: UniFilter<S, A>,
464    KF: Fn(&A) -> K + Send + Sync,
465    C: UniCollector<A> + Send + Sync + 'static,
466    C::Accumulator: Send + Sync,
467    C::Result: Clone + Send + Sync,
468    W: Fn(&C::Result) -> Sc + Send + Sync,
469    Sc: Score + 'static,
470{
471    /* Finalizes the builder into a zero-erasure `GroupedUniConstraint`.
472
473    # Example
474
475    ```
476    use solverforge_scoring::stream::ConstraintFactory;
477    use solverforge_scoring::stream::collector::count;
478    use solverforge_scoring::api::constraint_set::IncrementalConstraint;
479    use solverforge_core::score::SoftScore;
480
481    #[derive(Clone, Hash, PartialEq, Eq)]
482    struct Item { category: u32 }
483
484    #[derive(Clone)]
485    struct Solution { items: Vec<Item> }
486
487    let constraint = ConstraintFactory::<Solution, SoftScore>::new()
488    .for_each(|s: &Solution| &s.items)
489    .group_by(|i: &Item| i.category, count())
490    .penalize_with(|n: &usize| SoftScore::of(*n as i64))
491    .named("Category penalty");
492
493    assert_eq!(constraint.name(), "Category penalty");
494    ```
495    */
496    pub fn named(self, name: &str) -> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc> {
497        let mut constraint = GroupedUniConstraint::new(
498            ConstraintRef::new("", name),
499            self.impact_type,
500            self.extractor,
501            self.filter,
502            self.key_fn,
503            self.collector,
504            self.weight_fn,
505            self.is_hard,
506        );
507        if let Some(d) = self.expected_descriptor {
508            constraint = constraint.with_descriptor(d);
509        }
510        constraint
511    }
512
513    pub fn for_descriptor(mut self, descriptor_index: usize) -> Self {
514        self.expected_descriptor = Some(descriptor_index);
515        self
516    }
517}
518
519impl<S, A, K, E, Fi, KF, C, W, Sc: Score> std::fmt::Debug
520    for GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
521{
522    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523        f.debug_struct("GroupedConstraintBuilder")
524            .field("impact_type", &self.impact_type)
525            .finish()
526    }
527}