Skip to main content

solverforge_scoring/stream/
grouped_stream.rs

1// Zero-erasure grouped constraint stream for group-by constraint patterns.
2//
3// A `GroupedConstraintStream` operates on groups of entities and supports
4// filtering, weighting, and constraint finalization.
5// All type information is preserved at compile time - no Arc, no dyn.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collector::UniCollector;
14use super::complemented_stream::ComplementedConstraintStream;
15use crate::constraint::grouped::GroupedUniConstraint;
16
17// Zero-erasure constraint stream over grouped entities.
18//
19// `GroupedConstraintStream` is created by `UniConstraintStream::group_by()`
20// and operates on (key, collector_result) tuples.
21//
22// All type parameters are concrete - no trait objects, no Arc allocations.
23//
24// # Type Parameters
25//
26// - `S` - Solution type
27// - `A` - Entity type
28// - `K` - Group key type
29// - `E` - Extractor function for entities
30// - `KF` - Key function
31// - `C` - Collector type
32// - `Sc` - Score type
33//
34// # Example
35//
36// ```
37// use solverforge_scoring::stream::ConstraintFactory;
38// use solverforge_scoring::stream::collector::count;
39// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
40// use solverforge_core::score::SimpleScore;
41//
42// #[derive(Clone, Hash, PartialEq, Eq)]
43// struct Shift { employee_id: usize }
44//
45// #[derive(Clone)]
46// struct Solution { shifts: Vec<Shift> }
47//
48// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
49//     .for_each(|s: &Solution| &s.shifts)
50//     .group_by(|shift: &Shift| shift.employee_id, count())
51//     .penalize_with(|count: &usize| SimpleScore::of((*count * *count) as i64))
52//     .as_constraint("Balanced workload");
53//
54// let solution = Solution {
55//     shifts: vec![
56//         Shift { employee_id: 1 },
57//         Shift { employee_id: 1 },
58//         Shift { employee_id: 1 },
59//         Shift { employee_id: 2 },
60//     ],
61// };
62//
63// // Employee 1: 3² = 9, Employee 2: 1² = 1, Total: -10
64// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-10));
65// ```
66pub struct GroupedConstraintStream<S, A, K, E, KF, C, Sc>
67where
68    Sc: Score,
69{
70    extractor: E,
71    key_fn: KF,
72    collector: C,
73    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
74}
75
76impl<S, A, K, E, KF, C, Sc> GroupedConstraintStream<S, A, K, E, KF, C, Sc>
77where
78    S: Send + Sync + 'static,
79    A: Clone + Send + Sync + 'static,
80    K: Clone + Eq + Hash + Send + Sync + 'static,
81    E: Fn(&S) -> &[A] + Send + Sync,
82    KF: Fn(&A) -> K + Send + Sync,
83    C: UniCollector<A> + Send + Sync + 'static,
84    C::Accumulator: Send + Sync,
85    C::Result: Clone + Send + Sync,
86    Sc: Score + 'static,
87{
88    // Creates a new zero-erasure grouped constraint stream.
89    pub(crate) fn new(extractor: E, key_fn: KF, collector: C) -> Self {
90        Self {
91            extractor,
92            key_fn,
93            collector,
94            _phantom: PhantomData,
95        }
96    }
97
98    // Penalizes each group with a weight based on the collector result.
99    //
100    // # Example
101    //
102    // ```
103    // use solverforge_scoring::stream::ConstraintFactory;
104    // use solverforge_scoring::stream::collector::count;
105    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
106    // use solverforge_core::score::SimpleScore;
107    //
108    // #[derive(Clone, Hash, PartialEq, Eq)]
109    // struct Task { priority: u32 }
110    //
111    // #[derive(Clone)]
112    // struct Solution { tasks: Vec<Task> }
113    //
114    // let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
115    //     .for_each(|s: &Solution| &s.tasks)
116    //     .group_by(|t: &Task| t.priority, count())
117    //     .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
118    //     .as_constraint("Priority distribution");
119    //
120    // let solution = Solution {
121    //     tasks: vec![
122    //         Task { priority: 1 },
123    //         Task { priority: 1 },
124    //         Task { priority: 2 },
125    //     ],
126    // };
127    //
128    // // Priority 1: 2 tasks, Priority 2: 1 task, Total: -3
129    // assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-3));
130    // ```
131    pub fn penalize_with<W>(
132        self,
133        weight_fn: W,
134    ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
135    where
136        W: Fn(&C::Result) -> Sc + Send + Sync,
137    {
138        GroupedConstraintBuilder {
139            extractor: self.extractor,
140            key_fn: self.key_fn,
141            collector: self.collector,
142            impact_type: ImpactType::Penalty,
143            weight_fn,
144            is_hard: false,
145            expected_descriptor: None,
146            _phantom: PhantomData,
147        }
148    }
149
150    // Penalizes each group with a weight, explicitly marked as hard constraint.
151    pub fn penalize_hard_with<W>(
152        self,
153        weight_fn: W,
154    ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
155    where
156        W: Fn(&C::Result) -> Sc + Send + Sync,
157    {
158        GroupedConstraintBuilder {
159            extractor: self.extractor,
160            key_fn: self.key_fn,
161            collector: self.collector,
162            impact_type: ImpactType::Penalty,
163            weight_fn,
164            is_hard: true,
165            expected_descriptor: None,
166            _phantom: PhantomData,
167        }
168    }
169
170    // Rewards each group with a weight based on the collector result.
171    pub fn reward_with<W>(self, weight_fn: W) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
172    where
173        W: Fn(&C::Result) -> Sc + Send + Sync,
174    {
175        GroupedConstraintBuilder {
176            extractor: self.extractor,
177            key_fn: self.key_fn,
178            collector: self.collector,
179            impact_type: ImpactType::Reward,
180            weight_fn,
181            is_hard: false,
182            expected_descriptor: None,
183            _phantom: PhantomData,
184        }
185    }
186
187    // Rewards each group with a weight, explicitly marked as hard constraint.
188    pub fn reward_hard_with<W>(
189        self,
190        weight_fn: W,
191    ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
192    where
193        W: Fn(&C::Result) -> Sc + Send + Sync,
194    {
195        GroupedConstraintBuilder {
196            extractor: self.extractor,
197            key_fn: self.key_fn,
198            collector: self.collector,
199            impact_type: ImpactType::Reward,
200            weight_fn,
201            is_hard: true,
202            expected_descriptor: None,
203            _phantom: PhantomData,
204        }
205    }
206
207    // Adds complement entities with default values for missing keys.
208    //
209    // This ensures all keys from the complement source are represented,
210    // using the grouped value if present, or the default value otherwise.
211    //
212    // **Note:** The key function for A entities wraps the original key to
213    // return `Some(K)`. For filtering (skipping entities without valid keys),
214    // use `complement_filtered` instead.
215    //
216    // # Example
217    //
218    // ```
219    // use solverforge_scoring::stream::ConstraintFactory;
220    // use solverforge_scoring::stream::collector::count;
221    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
222    // use solverforge_core::score::SimpleScore;
223    //
224    // #[derive(Clone, Hash, PartialEq, Eq)]
225    // struct Employee { id: usize }
226    //
227    // #[derive(Clone, Hash, PartialEq, Eq)]
228    // struct Shift { employee_id: usize }
229    //
230    // #[derive(Clone)]
231    // struct Schedule {
232    //     employees: Vec<Employee>,
233    //     shifts: Vec<Shift>,
234    // }
235    //
236    // // Count shifts per employee, including employees with 0 shifts
237    // let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
238    //     .for_each(|s: &Schedule| &s.shifts)
239    //     .group_by(|shift: &Shift| shift.employee_id, count())
240    //     .complement(
241    //         |s: &Schedule| s.employees.as_slice(),
242    //         |emp: &Employee| emp.id,
243    //         |_emp: &Employee| 0usize,
244    //     )
245    //     .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
246    //     .as_constraint("Shift count");
247    //
248    // let schedule = Schedule {
249    //     employees: vec![Employee { id: 0 }, Employee { id: 1 }],
250    //     shifts: vec![
251    //         Shift { employee_id: 0 },
252    //         Shift { employee_id: 0 },
253    //     ],
254    // };
255    //
256    // // Employee 0: 2, Employee 1: 0 → Total: -2
257    // assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
258    // ```
259    pub fn complement<B, EB, KB, D>(
260        self,
261        extractor_b: EB,
262        key_b: KB,
263        default_fn: D,
264    ) -> ComplementedConstraintStream<
265        S,
266        A,
267        B,
268        K,
269        E,
270        EB,
271        impl Fn(&A) -> Option<K> + Send + Sync,
272        KB,
273        C,
274        D,
275        Sc,
276    >
277    where
278        B: Clone + Send + Sync + 'static,
279        EB: Fn(&S) -> &[B] + Send + Sync,
280        KB: Fn(&B) -> K + Send + Sync,
281        D: Fn(&B) -> C::Result + Send + Sync,
282    {
283        let key_fn = self.key_fn;
284        let wrapped_key_fn = move |a: &A| Some((key_fn)(a));
285        ComplementedConstraintStream::new(
286            self.extractor,
287            extractor_b,
288            wrapped_key_fn,
289            key_b,
290            self.collector,
291            default_fn,
292        )
293    }
294
295    // Adds complement entities with a custom key function for filtering.
296    //
297    // Like `complement`, but allows providing a custom key function for A entities
298    // that returns `Option<K>`. Entities returning `None` are skipped.
299    //
300    // # Example
301    //
302    // ```
303    // use solverforge_scoring::stream::ConstraintFactory;
304    // use solverforge_scoring::stream::collector::count;
305    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
306    // use solverforge_core::score::SimpleScore;
307    //
308    // #[derive(Clone, Hash, PartialEq, Eq)]
309    // struct Employee { id: usize }
310    //
311    // #[derive(Clone, Hash, PartialEq, Eq)]
312    // struct Shift { employee_id: Option<usize> }
313    //
314    // #[derive(Clone)]
315    // struct Schedule {
316    //     employees: Vec<Employee>,
317    //     shifts: Vec<Shift>,
318    // }
319    //
320    // // Count shifts per employee, skipping unassigned shifts
321    // // The group_by key is ignored; complement_with_key provides its own
322    // let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
323    //     .for_each(|s: &Schedule| &s.shifts)
324    //     .group_by(|_shift: &Shift| 0usize, count())  // Placeholder key, will be overridden
325    //     .complement_with_key(
326    //         |s: &Schedule| s.employees.as_slice(),
327    //         |shift: &Shift| shift.employee_id,  // Option<usize>
328    //         |emp: &Employee| emp.id,            // usize
329    //         |_emp: &Employee| 0usize,
330    //     )
331    //     .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
332    //     .as_constraint("Shift count");
333    //
334    // let schedule = Schedule {
335    //     employees: vec![Employee { id: 0 }, Employee { id: 1 }],
336    //     shifts: vec![
337    //         Shift { employee_id: Some(0) },
338    //         Shift { employee_id: Some(0) },
339    //         Shift { employee_id: None },  // Skipped
340    //     ],
341    // };
342    //
343    // // Employee 0: 2, Employee 1: 0 → Total: -2
344    // assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
345    // ```
346    pub fn complement_with_key<B, EB, KA2, KB, D>(
347        self,
348        extractor_b: EB,
349        key_a: KA2,
350        key_b: KB,
351        default_fn: D,
352    ) -> ComplementedConstraintStream<S, A, B, K, E, EB, KA2, KB, C, D, Sc>
353    where
354        B: Clone + Send + Sync + 'static,
355        EB: Fn(&S) -> &[B] + Send + Sync,
356        KA2: Fn(&A) -> Option<K> + Send + Sync,
357        KB: Fn(&B) -> K + Send + Sync,
358        D: Fn(&B) -> C::Result + Send + Sync,
359    {
360        ComplementedConstraintStream::new(
361            self.extractor,
362            extractor_b,
363            key_a,
364            key_b,
365            self.collector,
366            default_fn,
367        )
368    }
369}
370
371impl<S, A, K, E, KF, C, Sc: Score> std::fmt::Debug
372    for GroupedConstraintStream<S, A, K, E, KF, C, Sc>
373{
374    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375        f.debug_struct("GroupedConstraintStream").finish()
376    }
377}
378
379// Zero-erasure builder for finalizing a grouped constraint.
380pub struct GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
381where
382    Sc: Score,
383{
384    extractor: E,
385    key_fn: KF,
386    collector: C,
387    impact_type: ImpactType,
388    weight_fn: W,
389    is_hard: bool,
390    expected_descriptor: Option<usize>,
391    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
392}
393
394impl<S, A, K, E, KF, C, W, Sc> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
395where
396    S: Send + Sync + 'static,
397    A: Clone + Send + Sync + 'static,
398    K: Clone + Eq + Hash + Send + Sync + 'static,
399    E: Fn(&S) -> &[A] + Send + Sync,
400    KF: Fn(&A) -> K + Send + Sync,
401    C: UniCollector<A> + Send + Sync + 'static,
402    C::Accumulator: Send + Sync,
403    C::Result: Clone + Send + Sync,
404    W: Fn(&C::Result) -> Sc + Send + Sync,
405    Sc: Score + 'static,
406{
407    // Finalizes the builder into a zero-erasure `GroupedUniConstraint`.
408    //
409    // # Example
410    //
411    // ```
412    // use solverforge_scoring::stream::ConstraintFactory;
413    // use solverforge_scoring::stream::collector::count;
414    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
415    // use solverforge_core::score::SimpleScore;
416    //
417    // #[derive(Clone, Hash, PartialEq, Eq)]
418    // struct Item { category: u32 }
419    //
420    // #[derive(Clone)]
421    // struct Solution { items: Vec<Item> }
422    //
423    // let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
424    //     .for_each(|s: &Solution| &s.items)
425    //     .group_by(|i: &Item| i.category, count())
426    //     .penalize_with(|n: &usize| SimpleScore::of(*n as i64))
427    //     .as_constraint("Category penalty");
428    //
429    // assert_eq!(constraint.name(), "Category penalty");
430    // ```
431    pub fn for_descriptor(mut self, descriptor_index: usize) -> Self {
432        self.expected_descriptor = Some(descriptor_index);
433        self
434    }
435
436    pub fn as_constraint(self, name: &str) -> GroupedUniConstraint<S, A, K, E, KF, C, W, Sc> {
437        let mut constraint = GroupedUniConstraint::new(
438            ConstraintRef::new("", name),
439            self.impact_type,
440            self.extractor,
441            self.key_fn,
442            self.collector,
443            self.weight_fn,
444            self.is_hard,
445        );
446        if let Some(d) = self.expected_descriptor {
447            constraint = constraint.with_descriptor(d);
448        }
449        constraint
450    }
451}
452
453impl<S, A, K, E, KF, C, W, Sc: Score> std::fmt::Debug
454    for GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
455{
456    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457        f.debug_struct("GroupedConstraintBuilder")
458            .field("impact_type", &self.impact_type)
459            .finish()
460    }
461}