Skip to main content

solverforge_scoring/stream/
uni_stream.rs

1// Zero-erasure uni-constraint stream for single-entity constraint patterns.
2//
3// A `UniConstraintStream` operates on a single entity type and supports
4// filtering, weighting, and constraint finalization. All type information
5// is preserved at compile time - no Arc, no dyn, fully monomorphized.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::constraint::incremental::IncrementalUniConstraint;
14
15use crate::constraint::if_exists::ExistenceMode;
16
17use super::balance_stream::BalanceConstraintStream;
18use super::bi_stream::BiConstraintStream;
19use super::collector::UniCollector;
20use super::cross_bi_stream::CrossBiConstraintStream;
21use super::filter::{AndUniFilter, FnUniFilter, TrueFilter, UniFilter, UniLeftBiFilter};
22use super::grouped_stream::GroupedConstraintStream;
23use super::if_exists_stream::IfExistsStream;
24use super::joiner::EqualJoiner;
25
26// Zero-erasure constraint stream over a single entity type.
27//
28// `UniConstraintStream` accumulates filters and can be finalized into
29// an `IncrementalUniConstraint` via `penalize()` or `reward()`.
30//
31// All type parameters are concrete - no trait objects, no Arc allocations
32// in the hot path.
33//
34// # Type Parameters
35//
36// - `S` - Solution type
37// - `A` - Entity type
38// - `E` - Extractor function type
39// - `F` - Combined filter type
40// - `Sc` - Score type
41pub struct UniConstraintStream<S, A, E, F, Sc>
42where
43    Sc: Score,
44{
45    extractor: E,
46    filter: F,
47    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
48}
49
50impl<S, A, E, Sc> UniConstraintStream<S, A, E, TrueFilter, Sc>
51where
52    S: Send + Sync + 'static,
53    A: Clone + Send + Sync + 'static,
54    E: Fn(&S) -> &[A] + Send + Sync,
55    Sc: Score + 'static,
56{
57    // Creates a new uni-constraint stream with the given extractor.
58    pub fn new(extractor: E) -> Self {
59        Self {
60            extractor,
61            filter: TrueFilter,
62            _phantom: PhantomData,
63        }
64    }
65}
66
67impl<S, A, E, F, Sc> UniConstraintStream<S, A, E, F, Sc>
68where
69    S: Send + Sync + 'static,
70    A: Clone + Send + Sync + 'static,
71    E: Fn(&S) -> &[A] + Send + Sync,
72    F: UniFilter<S, A>,
73    Sc: Score + 'static,
74{
75    // Adds a filter predicate to the stream.
76    //
77    // Multiple filters are combined with AND semantics at compile time.
78    // Each filter adds a new type layer, preserving zero-erasure.
79    //
80    // To access related entities, use shadow variables on your entity type
81    // (e.g., `#[inverse_relation_shadow_variable]`) rather than solution traversal.
82    pub fn filter<P>(
83        self,
84        predicate: P,
85    ) -> UniConstraintStream<
86        S,
87        A,
88        E,
89        AndUniFilter<F, FnUniFilter<impl Fn(&S, &A) -> bool + Send + Sync>>,
90        Sc,
91    >
92    where
93        P: Fn(&A) -> bool + Send + Sync + 'static,
94    {
95        UniConstraintStream {
96            extractor: self.extractor,
97            filter: AndUniFilter::new(
98                self.filter,
99                FnUniFilter::new(move |_s: &S, a: &A| predicate(a)),
100            ),
101            _phantom: PhantomData,
102        }
103    }
104
105    // Joins this stream with itself to create pairs (zero-erasure).
106    //
107    // Requires an `EqualJoiner` to enable key-based indexing for O(k) lookups.
108    // For self-joins, pairs are ordered (i < j) to avoid duplicates.
109    //
110    // Any filters accumulated on this stream are applied to both entities
111    // individually before the join.
112    pub fn join_self<K, KA, KB>(
113        self,
114        joiner: EqualJoiner<KA, KB, K>,
115    ) -> BiConstraintStream<
116        S,
117        A,
118        K,
119        E,
120        impl Fn(&S, &A, usize) -> K + Send + Sync,
121        UniLeftBiFilter<F, A>,
122        Sc,
123    >
124    where
125        A: Hash + PartialEq,
126        K: Eq + Hash + Clone + Send + Sync,
127        KA: Fn(&A) -> K + Send + Sync,
128        KB: Fn(&A) -> K + Send + Sync,
129    {
130        let (key_extractor, _) = joiner.into_keys();
131
132        // Wrap key_extractor to match the new KE: Fn(&S, &A, usize) -> K signature.
133        // The static stream API doesn't need solution/index, so ignore them.
134        let wrapped_ke = move |_s: &S, a: &A, _idx: usize| key_extractor(a);
135
136        // Convert uni-filter to bi-filter that applies to left entity
137        let bi_filter = UniLeftBiFilter::new(self.filter);
138
139        BiConstraintStream::new_self_join_with_filter(self.extractor, wrapped_ke, bi_filter)
140    }
141
142    // Joins this stream with another collection to create cross-entity pairs (zero-erasure).
143    //
144    // Requires an `EqualJoiner` to enable key-based indexing for O(1) lookups.
145    // Unlike `join_self` which pairs entities within the same collection,
146    // `join` creates pairs from two different collections (e.g., Shift joined
147    // with Employee).
148    //
149    // Any filters accumulated on this stream are applied to the A entity
150    // before the join.
151    pub fn join<B, EB, K, KA, KB>(
152        self,
153        extractor_b: EB,
154        joiner: EqualJoiner<KA, KB, K>,
155    ) -> CrossBiConstraintStream<S, A, B, K, E, EB, KA, KB, UniLeftBiFilter<F, B>, Sc>
156    where
157        B: Clone + Send + Sync + 'static,
158        EB: Fn(&S) -> &[B] + Send + Sync,
159        K: Eq + Hash + Clone + Send + Sync,
160        KA: Fn(&A) -> K + Send + Sync,
161        KB: Fn(&B) -> K + Send + Sync,
162    {
163        let (key_a, key_b) = joiner.into_keys();
164
165        // Convert uni-filter to bi-filter that applies to left entity only
166        let bi_filter = UniLeftBiFilter::new(self.filter);
167
168        CrossBiConstraintStream::new_with_filter(
169            self.extractor,
170            extractor_b,
171            key_a,
172            key_b,
173            bi_filter,
174        )
175    }
176
177    // Groups entities by key and aggregates with a collector.
178    //
179    // Returns a zero-erasure `GroupedConstraintStream` that can be penalized
180    // or rewarded based on the aggregated result for each group.
181    pub fn group_by<K, KF, C>(
182        self,
183        key_fn: KF,
184        collector: C,
185    ) -> GroupedConstraintStream<S, A, K, E, KF, C, Sc>
186    where
187        K: Clone + Eq + Hash + Send + Sync + 'static,
188        KF: Fn(&A) -> K + Send + Sync,
189        C: UniCollector<A> + Send + Sync + 'static,
190        C::Accumulator: Send + Sync,
191        C::Result: Clone + Send + Sync,
192    {
193        GroupedConstraintStream::new(self.extractor, key_fn, collector)
194    }
195
196    // Creates a balance constraint that penalizes uneven distribution across groups.
197    //
198    // Unlike `group_by` which scores each group independently, `balance` computes
199    // a GLOBAL standard deviation across all group counts and produces a single score.
200    //
201    // The `key_fn` returns `Option<K>` to allow skipping entities (e.g., unassigned shifts).
202    // Any filters accumulated on this stream are also applied.
203    //
204    // # Example
205    //
206    // ```
207    // use solverforge_scoring::stream::ConstraintFactory;
208    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
209    // use solverforge_core::score::SimpleScore;
210    //
211    // #[derive(Clone)]
212    // struct Shift { employee_id: Option<usize> }
213    //
214    // #[derive(Clone)]
215    // struct Solution { shifts: Vec<Shift> }
216    //
217    // let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
218    //     .for_each(|s: &Solution| &s.shifts)
219    //     .balance(|shift: &Shift| shift.employee_id)
220    //     .penalize(SimpleScore::of(1000))
221    //     .as_constraint("Balance workload");
222    //
223    // let solution = Solution {
224    //     shifts: vec![
225    //         Shift { employee_id: Some(0) },
226    //         Shift { employee_id: Some(0) },
227    //         Shift { employee_id: Some(0) },
228    //         Shift { employee_id: Some(1) },
229    //     ],
230    // };
231    //
232    // // Employee 0: 3 shifts, Employee 1: 1 shift
233    // // std_dev = 1.0, penalty = -1000
234    // assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1000));
235    // ```
236    pub fn balance<K, KF>(self, key_fn: KF) -> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
237    where
238        K: Clone + Eq + Hash + Send + Sync + 'static,
239        KF: Fn(&A) -> Option<K> + Send + Sync,
240    {
241        BalanceConstraintStream::new(self.extractor, self.filter, key_fn)
242    }
243
244    // Filters A entities based on whether a matching B entity exists.
245    //
246    // Use this when the B collection needs filtering (e.g., only vacationing employees).
247    // The `extractor_b` returns a `Vec<B>` to allow for filtering.
248    //
249    // Any filters accumulated on this stream are applied to A entities.
250    //
251    // # Example
252    //
253    // ```
254    // use solverforge_scoring::stream::ConstraintFactory;
255    // use solverforge_scoring::stream::joiner::equal_bi;
256    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
257    // use solverforge_core::score::SimpleScore;
258    //
259    // #[derive(Clone)]
260    // struct Shift { id: usize, employee_idx: Option<usize> }
261    //
262    // #[derive(Clone)]
263    // struct Employee { id: usize, on_vacation: bool }
264    //
265    // #[derive(Clone)]
266    // struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
267    //
268    // // Penalize shifts assigned to employees who are on vacation
269    // let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
270    //     .for_each(|s: &Schedule| s.shifts.as_slice())
271    //     .filter(|shift: &Shift| shift.employee_idx.is_some())
272    //     .if_exists_filtered(
273    //         |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect(),
274    //         equal_bi(
275    //             |shift: &Shift| shift.employee_idx,
276    //             |emp: &Employee| Some(emp.id),
277    //         ),
278    //     )
279    //     .penalize(SimpleScore::of(1))
280    //     .as_constraint("Vacation conflict");
281    //
282    // let schedule = Schedule {
283    //     shifts: vec![
284    //         Shift { id: 0, employee_idx: Some(0) },  // assigned to vacationing emp
285    //         Shift { id: 1, employee_idx: Some(1) },  // assigned to working emp
286    //         Shift { id: 2, employee_idx: None },     // unassigned (filtered out)
287    //     ],
288    //     employees: vec![
289    //         Employee { id: 0, on_vacation: true },
290    //         Employee { id: 1, on_vacation: false },
291    //     ],
292    // };
293    //
294    // // Only shift 0 matches (assigned to employee 0 who is on vacation)
295    // assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
296    // ```
297    pub fn if_exists_filtered<B, EB, K, KA, KB>(
298        self,
299        extractor_b: EB,
300        joiner: EqualJoiner<KA, KB, K>,
301    ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
302    where
303        B: Clone + Send + Sync + 'static,
304        EB: Fn(&S) -> Vec<B> + Send + Sync,
305        K: Eq + Hash + Clone + Send + Sync,
306        KA: Fn(&A) -> K + Send + Sync,
307        KB: Fn(&B) -> K + Send + Sync,
308    {
309        let (key_a, key_b) = joiner.into_keys();
310        IfExistsStream::new(
311            ExistenceMode::Exists,
312            self.extractor,
313            extractor_b,
314            key_a,
315            key_b,
316            self.filter,
317        )
318    }
319
320    // Filters A entities based on whether NO matching B entity exists.
321    //
322    // Use this when the B collection needs filtering.
323    // The `extractor_b` returns a `Vec<B>` to allow for filtering.
324    //
325    // Any filters accumulated on this stream are applied to A entities.
326    //
327    // # Example
328    //
329    // ```
330    // use solverforge_scoring::stream::ConstraintFactory;
331    // use solverforge_scoring::stream::joiner::equal_bi;
332    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
333    // use solverforge_core::score::SimpleScore;
334    //
335    // #[derive(Clone)]
336    // struct Task { id: usize, assignee: Option<usize> }
337    //
338    // #[derive(Clone)]
339    // struct Worker { id: usize, available: bool }
340    //
341    // #[derive(Clone)]
342    // struct Schedule { tasks: Vec<Task>, workers: Vec<Worker> }
343    //
344    // // Penalize tasks assigned to workers who are not available
345    // let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
346    //     .for_each(|s: &Schedule| s.tasks.as_slice())
347    //     .filter(|task: &Task| task.assignee.is_some())
348    //     .if_not_exists_filtered(
349    //         |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
350    //         equal_bi(
351    //             |task: &Task| task.assignee,
352    //             |worker: &Worker| Some(worker.id),
353    //         ),
354    //     )
355    //     .penalize(SimpleScore::of(1))
356    //     .as_constraint("Unavailable worker");
357    //
358    // let schedule = Schedule {
359    //     tasks: vec![
360    //         Task { id: 0, assignee: Some(0) },  // worker 0 is unavailable
361    //         Task { id: 1, assignee: Some(1) },  // worker 1 is available
362    //         Task { id: 2, assignee: None },     // unassigned (filtered out)
363    //     ],
364    //     workers: vec![
365    //         Worker { id: 0, available: false },
366    //         Worker { id: 1, available: true },
367    //     ],
368    // };
369    //
370    // // Task 0's worker (id=0) is NOT in the available workers list
371    // assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
372    // ```
373    pub fn if_not_exists_filtered<B, EB, K, KA, KB>(
374        self,
375        extractor_b: EB,
376        joiner: EqualJoiner<KA, KB, K>,
377    ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
378    where
379        B: Clone + Send + Sync + 'static,
380        EB: Fn(&S) -> Vec<B> + Send + Sync,
381        K: Eq + Hash + Clone + Send + Sync,
382        KA: Fn(&A) -> K + Send + Sync,
383        KB: Fn(&B) -> K + Send + Sync,
384    {
385        let (key_a, key_b) = joiner.into_keys();
386        IfExistsStream::new(
387            ExistenceMode::NotExists,
388            self.extractor,
389            extractor_b,
390            key_a,
391            key_b,
392            self.filter,
393        )
394    }
395
396    // Penalizes each matching entity with a fixed weight.
397    pub fn penalize(
398        self,
399        weight: Sc,
400    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
401    where
402        Sc: Copy,
403    {
404        // Detect if this is a hard constraint by checking if hard level is non-zero
405        let is_hard = weight
406            .to_level_numbers()
407            .first()
408            .map(|&h| h != 0)
409            .unwrap_or(false);
410        UniConstraintBuilder {
411            extractor: self.extractor,
412            filter: self.filter,
413            impact_type: ImpactType::Penalty,
414            weight: move |_: &A| weight,
415            is_hard,
416            _phantom: PhantomData,
417        }
418    }
419
420    // Penalizes each matching entity with a dynamic weight.
421    //
422    // Note: For dynamic weights, use `penalize_hard_with` to explicitly mark as a hard constraint,
423    // since the weight function cannot be evaluated at build time.
424    pub fn penalize_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
425    where
426        W: Fn(&A) -> Sc + Send + Sync,
427    {
428        UniConstraintBuilder {
429            extractor: self.extractor,
430            filter: self.filter,
431            impact_type: ImpactType::Penalty,
432            weight: weight_fn,
433            is_hard: false, // Can't detect at build time; use penalize_hard_with for hard constraints
434            _phantom: PhantomData,
435        }
436    }
437
438    // Penalizes each matching entity with a dynamic weight, explicitly marked as a hard constraint.
439    pub fn penalize_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
440    where
441        W: Fn(&A) -> Sc + Send + Sync,
442    {
443        UniConstraintBuilder {
444            extractor: self.extractor,
445            filter: self.filter,
446            impact_type: ImpactType::Penalty,
447            weight: weight_fn,
448            is_hard: true,
449            _phantom: PhantomData,
450        }
451    }
452
453    // Rewards each matching entity with a fixed weight.
454    pub fn reward(
455        self,
456        weight: Sc,
457    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
458    where
459        Sc: Copy,
460    {
461        // Detect if this is a hard constraint by checking if hard level is non-zero
462        let is_hard = weight
463            .to_level_numbers()
464            .first()
465            .map(|&h| h != 0)
466            .unwrap_or(false);
467        UniConstraintBuilder {
468            extractor: self.extractor,
469            filter: self.filter,
470            impact_type: ImpactType::Reward,
471            weight: move |_: &A| weight,
472            is_hard,
473            _phantom: PhantomData,
474        }
475    }
476
477    // Rewards each matching entity with a dynamic weight.
478    //
479    // Note: For dynamic weights, use `reward_hard_with` to explicitly mark as a hard constraint,
480    // since the weight function cannot be evaluated at build time.
481    pub fn reward_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
482    where
483        W: Fn(&A) -> Sc + Send + Sync,
484    {
485        UniConstraintBuilder {
486            extractor: self.extractor,
487            filter: self.filter,
488            impact_type: ImpactType::Reward,
489            weight: weight_fn,
490            is_hard: false, // Can't detect at build time; use reward_hard_with for hard constraints
491            _phantom: PhantomData,
492        }
493    }
494
495    // Rewards each matching entity with a dynamic weight, explicitly marked as a hard constraint.
496    pub fn reward_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
497    where
498        W: Fn(&A) -> Sc + Send + Sync,
499    {
500        UniConstraintBuilder {
501            extractor: self.extractor,
502            filter: self.filter,
503            impact_type: ImpactType::Reward,
504            weight: weight_fn,
505            is_hard: true,
506            _phantom: PhantomData,
507        }
508    }
509}
510
511impl<S, A, E, F, Sc: Score> std::fmt::Debug for UniConstraintStream<S, A, E, F, Sc> {
512    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
513        f.debug_struct("UniConstraintStream").finish()
514    }
515}
516
517// Zero-erasure builder for finalizing a uni-constraint.
518pub struct UniConstraintBuilder<S, A, E, F, W, Sc>
519where
520    Sc: Score,
521{
522    extractor: E,
523    filter: F,
524    impact_type: ImpactType,
525    weight: W,
526    is_hard: bool,
527    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
528}
529
530impl<S, A, E, F, W, Sc> UniConstraintBuilder<S, A, E, F, W, Sc>
531where
532    S: Send + Sync + 'static,
533    A: Clone + Send + Sync + 'static,
534    E: Fn(&S) -> &[A] + Send + Sync,
535    F: UniFilter<S, A>,
536    W: Fn(&A) -> Sc + Send + Sync,
537    Sc: Score + 'static,
538{
539    // Finalizes the builder into a zero-erasure `IncrementalUniConstraint`.
540    pub fn as_constraint(
541        self,
542        name: &str,
543    ) -> IncrementalUniConstraint<S, A, E, impl Fn(&S, &A) -> bool + Send + Sync, W, Sc> {
544        let filter = self.filter;
545        let combined_filter = move |s: &S, a: &A| filter.test(s, a);
546
547        IncrementalUniConstraint::new(
548            ConstraintRef::new("", name),
549            self.impact_type,
550            self.extractor,
551            combined_filter,
552            self.weight,
553            self.is_hard,
554        )
555    }
556}
557
558impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for UniConstraintBuilder<S, A, E, F, W, Sc> {
559    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
560        f.debug_struct("UniConstraintBuilder")
561            .field("impact_type", &self.impact_type)
562            .finish()
563    }
564}