Skip to main content

solverforge_scoring/stream/
grouped_stream.rs

1// Zero-erasure grouped constraint stream for group-by constraint patterns.
2//
3// A `GroupedConstraintStream` operates on groups of entities and supports
4// filtering, weighting, and constraint finalization.
5// All type information is preserved at compile time - no Arc, no dyn.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collector::UniCollector;
14use super::complemented_stream::ComplementedConstraintStream;
15use super::filter::UniFilter;
16use crate::constraint::grouped::GroupedUniConstraint;
17
18// Zero-erasure constraint stream over grouped entities.
19//
20// `GroupedConstraintStream` is created by `UniConstraintStream::group_by()`
21// and operates on (key, collector_result) tuples.
22//
23// All type parameters are concrete - no trait objects, no Arc allocations.
24//
25// # Type Parameters
26//
27// - `S` - Solution type
28// - `A` - Entity type
29// - `K` - Group key type
30// - `E` - Extractor function for entities
31// - `Fi` - Filter type (preserved from upstream stream)
32// - `KF` - Key function
33// - `C` - Collector type
34// - `Sc` - Score type
35//
36// # Example
37//
38// ```
39// use solverforge_scoring::stream::ConstraintFactory;
40// use solverforge_scoring::stream::collector::count;
41// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
42// use solverforge_core::score::SoftScore;
43//
44// #[derive(Clone, Hash, PartialEq, Eq)]
45// struct Shift { employee_id: usize }
46//
47// #[derive(Clone)]
48// struct Solution { shifts: Vec<Shift> }
49//
50// let constraint = ConstraintFactory::<Solution, SoftScore>::new()
51//     .for_each(|s: &Solution| &s.shifts)
52//     .group_by(|shift: &Shift| shift.employee_id, count())
53//     .penalize_with(|count: &usize| SoftScore::of((*count * *count) as i64))
54//     .as_constraint("Balanced workload");
55//
56// let solution = Solution {
57//     shifts: vec![
58//         Shift { employee_id: 1 },
59//         Shift { employee_id: 1 },
60//         Shift { employee_id: 1 },
61//         Shift { employee_id: 2 },
62//     ],
63// };
64//
65// // Employee 1: 3² = 9, Employee 2: 1² = 1, Total: -10
66// assert_eq!(constraint.evaluate(&solution), SoftScore::of(-10));
67// ```
68pub struct GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
69where
70    Sc: Score,
71{
72    extractor: E,
73    filter: Fi,
74    key_fn: KF,
75    collector: C,
76    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
77}
78
79impl<S, A, K, E, Fi, KF, C, Sc> GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
80where
81    S: Send + Sync + 'static,
82    A: Clone + Send + Sync + 'static,
83    K: Clone + Eq + Hash + Send + Sync + 'static,
84    E: Fn(&S) -> &[A] + Send + Sync,
85    Fi: UniFilter<S, A>,
86    KF: Fn(&A) -> K + Send + Sync,
87    C: UniCollector<A> + Send + Sync + 'static,
88    C::Accumulator: Send + Sync,
89    C::Result: Clone + Send + Sync,
90    Sc: Score + 'static,
91{
92    // Creates a new zero-erasure grouped constraint stream.
93    pub(crate) fn new(extractor: E, filter: Fi, key_fn: KF, collector: C) -> Self {
94        Self {
95            extractor,
96            filter,
97            key_fn,
98            collector,
99            _phantom: PhantomData,
100        }
101    }
102
103    // Penalizes each group with a weight based on the collector result.
104    //
105    // # Example
106    //
107    // ```
108    // use solverforge_scoring::stream::ConstraintFactory;
109    // use solverforge_scoring::stream::collector::count;
110    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
111    // use solverforge_core::score::SoftScore;
112    //
113    // #[derive(Clone, Hash, PartialEq, Eq)]
114    // struct Task { priority: u32 }
115    //
116    // #[derive(Clone)]
117    // struct Solution { tasks: Vec<Task> }
118    //
119    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
120    //     .for_each(|s: &Solution| &s.tasks)
121    //     .group_by(|t: &Task| t.priority, count())
122    //     .penalize_with(|count: &usize| SoftScore::of(*count as i64))
123    //     .as_constraint("Priority distribution");
124    //
125    // let solution = Solution {
126    //     tasks: vec![
127    //         Task { priority: 1 },
128    //         Task { priority: 1 },
129    //         Task { priority: 2 },
130    //     ],
131    // };
132    //
133    // // Priority 1: 2 tasks, Priority 2: 1 task, Total: -3
134    // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-3));
135    // ```
136    pub fn penalize_with<W>(
137        self,
138        weight_fn: W,
139    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
140    where
141        W: Fn(&C::Result) -> Sc + Send + Sync,
142    {
143        GroupedConstraintBuilder {
144            extractor: self.extractor,
145            filter: self.filter,
146            key_fn: self.key_fn,
147            collector: self.collector,
148            impact_type: ImpactType::Penalty,
149            weight_fn,
150            is_hard: false,
151            expected_descriptor: None,
152            _phantom: PhantomData,
153        }
154    }
155
156    // Penalizes each group with a weight, explicitly marked as hard constraint.
157    pub fn penalize_hard_with<W>(
158        self,
159        weight_fn: W,
160    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
161    where
162        W: Fn(&C::Result) -> Sc + Send + Sync,
163    {
164        GroupedConstraintBuilder {
165            extractor: self.extractor,
166            filter: self.filter,
167            key_fn: self.key_fn,
168            collector: self.collector,
169            impact_type: ImpactType::Penalty,
170            weight_fn,
171            is_hard: true,
172            expected_descriptor: None,
173            _phantom: PhantomData,
174        }
175    }
176
177    // Rewards each group with a weight based on the collector result.
178    pub fn reward_with<W>(
179        self,
180        weight_fn: W,
181    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
182    where
183        W: Fn(&C::Result) -> Sc + Send + Sync,
184    {
185        GroupedConstraintBuilder {
186            extractor: self.extractor,
187            filter: self.filter,
188            key_fn: self.key_fn,
189            collector: self.collector,
190            impact_type: ImpactType::Reward,
191            weight_fn,
192            is_hard: false,
193            expected_descriptor: None,
194            _phantom: PhantomData,
195        }
196    }
197
198    // Rewards each group with a weight, explicitly marked as hard constraint.
199    pub fn reward_hard_with<W>(
200        self,
201        weight_fn: W,
202    ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
203    where
204        W: Fn(&C::Result) -> Sc + Send + Sync,
205    {
206        GroupedConstraintBuilder {
207            extractor: self.extractor,
208            filter: self.filter,
209            key_fn: self.key_fn,
210            collector: self.collector,
211            impact_type: ImpactType::Reward,
212            weight_fn,
213            is_hard: true,
214            expected_descriptor: None,
215            _phantom: PhantomData,
216        }
217    }
218
219    // Adds complement entities with default values for missing keys.
220    //
221    // This ensures all keys from the complement source are represented,
222    // using the grouped value if present, or the default value otherwise.
223    //
224    // **Note:** The key function for A entities wraps the original key to
225    // return `Some(K)`. For filtering (skipping entities without valid keys),
226    // use `complement_filtered` instead.
227    //
228    // # Example
229    //
230    // ```
231    // use solverforge_scoring::stream::ConstraintFactory;
232    // use solverforge_scoring::stream::collector::count;
233    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
234    // use solverforge_core::score::SoftScore;
235    //
236    // #[derive(Clone, Hash, PartialEq, Eq)]
237    // struct Employee { id: usize }
238    //
239    // #[derive(Clone, Hash, PartialEq, Eq)]
240    // struct Shift { employee_id: usize }
241    //
242    // #[derive(Clone)]
243    // struct Schedule {
244    //     employees: Vec<Employee>,
245    //     shifts: Vec<Shift>,
246    // }
247    //
248    // // Count shifts per employee, including employees with 0 shifts
249    // let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
250    //     .for_each(|s: &Schedule| &s.shifts)
251    //     .group_by(|shift: &Shift| shift.employee_id, count())
252    //     .complement(
253    //         |s: &Schedule| s.employees.as_slice(),
254    //         |emp: &Employee| emp.id,
255    //         |_emp: &Employee| 0usize,
256    //     )
257    //     .penalize_with(|count: &usize| SoftScore::of(*count as i64))
258    //     .as_constraint("Shift count");
259    //
260    // let schedule = Schedule {
261    //     employees: vec![Employee { id: 0 }, Employee { id: 1 }],
262    //     shifts: vec![
263    //         Shift { employee_id: 0 },
264    //         Shift { employee_id: 0 },
265    //     ],
266    // };
267    //
268    // // Employee 0: 2, Employee 1: 0 → Total: -2
269    // assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
270    // ```
271    pub fn complement<B, EB, KB, D>(
272        self,
273        extractor_b: EB,
274        key_b: KB,
275        default_fn: D,
276    ) -> ComplementedConstraintStream<
277        S,
278        A,
279        B,
280        K,
281        E,
282        EB,
283        impl Fn(&A) -> Option<K> + Send + Sync,
284        KB,
285        C,
286        D,
287        Sc,
288    >
289    where
290        B: Clone + Send + Sync + 'static,
291        EB: Fn(&S) -> &[B] + Send + Sync,
292        KB: Fn(&B) -> K + Send + Sync,
293        D: Fn(&B) -> C::Result + Send + Sync,
294    {
295        let key_fn = self.key_fn;
296        let wrapped_key_fn = move |a: &A| Some((key_fn)(a));
297        ComplementedConstraintStream::new(
298            self.extractor,
299            extractor_b,
300            wrapped_key_fn,
301            key_b,
302            self.collector,
303            default_fn,
304        )
305    }
306
307    // Adds complement entities with a custom key function for filtering.
308    //
309    // Like `complement`, but allows providing a custom key function for A entities
310    // that returns `Option<K>`. Entities returning `None` are skipped.
311    //
312    // # Example
313    //
314    // ```
315    // use solverforge_scoring::stream::ConstraintFactory;
316    // use solverforge_scoring::stream::collector::count;
317    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
318    // use solverforge_core::score::SoftScore;
319    //
320    // #[derive(Clone, Hash, PartialEq, Eq)]
321    // struct Employee { id: usize }
322    //
323    // #[derive(Clone, Hash, PartialEq, Eq)]
324    // struct Shift { employee_id: Option<usize> }
325    //
326    // #[derive(Clone)]
327    // struct Schedule {
328    //     employees: Vec<Employee>,
329    //     shifts: Vec<Shift>,
330    // }
331    //
332    // // Count shifts per employee, skipping unassigned shifts
333    // // The group_by key is ignored; complement_with_key provides its own
334    // let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
335    //     .for_each(|s: &Schedule| &s.shifts)
336    //     .group_by(|_shift: &Shift| 0usize, count())  // Placeholder key, will be overridden
337    //     .complement_with_key(
338    //         |s: &Schedule| s.employees.as_slice(),
339    //         |shift: &Shift| shift.employee_id,  // Option<usize>
340    //         |emp: &Employee| emp.id,            // usize
341    //         |_emp: &Employee| 0usize,
342    //     )
343    //     .penalize_with(|count: &usize| SoftScore::of(*count as i64))
344    //     .as_constraint("Shift count");
345    //
346    // let schedule = Schedule {
347    //     employees: vec![Employee { id: 0 }, Employee { id: 1 }],
348    //     shifts: vec![
349    //         Shift { employee_id: Some(0) },
350    //         Shift { employee_id: Some(0) },
351    //         Shift { employee_id: None },  // Skipped
352    //     ],
353    // };
354    //
355    // // Employee 0: 2, Employee 1: 0 → Total: -2
356    // assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
357    // ```
358    pub fn complement_with_key<B, EB, KA2, KB, D>(
359        self,
360        extractor_b: EB,
361        key_a: KA2,
362        key_b: KB,
363        default_fn: D,
364    ) -> ComplementedConstraintStream<S, A, B, K, E, EB, KA2, KB, C, D, Sc>
365    where
366        B: Clone + Send + Sync + 'static,
367        EB: Fn(&S) -> &[B] + Send + Sync,
368        KA2: Fn(&A) -> Option<K> + Send + Sync,
369        KB: Fn(&B) -> K + Send + Sync,
370        D: Fn(&B) -> C::Result + Send + Sync,
371    {
372        ComplementedConstraintStream::new(
373            self.extractor,
374            extractor_b,
375            key_a,
376            key_b,
377            self.collector,
378            default_fn,
379        )
380    }
381}
382
383impl<S, A, K, E, Fi, KF, C, Sc: Score> std::fmt::Debug
384    for GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
385{
386    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387        f.debug_struct("GroupedConstraintStream").finish()
388    }
389}
390
391// Zero-erasure builder for finalizing a grouped constraint.
392pub struct GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
393where
394    Sc: Score,
395{
396    extractor: E,
397    filter: Fi,
398    key_fn: KF,
399    collector: C,
400    impact_type: ImpactType,
401    weight_fn: W,
402    is_hard: bool,
403    expected_descriptor: Option<usize>,
404    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
405}
406
407impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
408where
409    S: Send + Sync + 'static,
410    A: Clone + Send + Sync + 'static,
411    K: Clone + Eq + Hash + Send + Sync + 'static,
412    E: Fn(&S) -> &[A] + Send + Sync,
413    Fi: UniFilter<S, A>,
414    KF: Fn(&A) -> K + Send + Sync,
415    C: UniCollector<A> + Send + Sync + 'static,
416    C::Accumulator: Send + Sync,
417    C::Result: Clone + Send + Sync,
418    W: Fn(&C::Result) -> Sc + Send + Sync,
419    Sc: Score + 'static,
420{
421    // Finalizes the builder into a zero-erasure `GroupedUniConstraint`.
422    //
423    // # Example
424    //
425    // ```
426    // use solverforge_scoring::stream::ConstraintFactory;
427    // use solverforge_scoring::stream::collector::count;
428    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
429    // use solverforge_core::score::SoftScore;
430    //
431    // #[derive(Clone, Hash, PartialEq, Eq)]
432    // struct Item { category: u32 }
433    //
434    // #[derive(Clone)]
435    // struct Solution { items: Vec<Item> }
436    //
437    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
438    //     .for_each(|s: &Solution| &s.items)
439    //     .group_by(|i: &Item| i.category, count())
440    //     .penalize_with(|n: &usize| SoftScore::of(*n as i64))
441    //     .as_constraint("Category penalty");
442    //
443    // assert_eq!(constraint.name(), "Category penalty");
444    // ```
445    pub fn for_descriptor(mut self, descriptor_index: usize) -> Self {
446        self.expected_descriptor = Some(descriptor_index);
447        self
448    }
449
450    pub fn as_constraint(self, name: &str) -> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc> {
451        let mut constraint = GroupedUniConstraint::new(
452            ConstraintRef::new("", name),
453            self.impact_type,
454            self.extractor,
455            self.filter,
456            self.key_fn,
457            self.collector,
458            self.weight_fn,
459            self.is_hard,
460        );
461        if let Some(d) = self.expected_descriptor {
462            constraint = constraint.with_descriptor(d);
463        }
464        constraint
465    }
466}
467
468impl<S, A, K, E, Fi, KF, C, W, Sc: Score> std::fmt::Debug
469    for GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
470{
471    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472        f.debug_struct("GroupedConstraintBuilder")
473            .field("impact_type", &self.impact_type)
474            .finish()
475    }
476}