solverforge_scoring/stream/
uni_stream.rs

1//! Zero-erasure uni-constraint stream for single-entity constraint patterns.
2//!
3//! A `UniConstraintStream` operates on a single entity type and supports
4//! filtering, weighting, and constraint finalization. All type information
5//! is preserved at compile time - no Arc, no dyn, fully monomorphized.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::constraint::incremental::IncrementalUniConstraint;
14
15use crate::constraint::if_exists::ExistenceMode;
16
17use super::balance_stream::BalanceConstraintStream;
18use super::bi_stream::BiConstraintStream;
19use super::collector::UniCollector;
20use super::cross_bi_stream::CrossBiConstraintStream;
21use super::filter::{AndUniFilter, FnUniFilter, TrueFilter, UniFilter, UniLeftBiFilter};
22use super::grouped_stream::GroupedConstraintStream;
23use super::if_exists_stream::IfExistsStream;
24use super::joiner::EqualJoiner;
25
26/// Zero-erasure constraint stream over a single entity type.
27///
28/// `UniConstraintStream` accumulates filters and can be finalized into
29/// an `IncrementalUniConstraint` via `penalize()` or `reward()`.
30///
31/// All type parameters are concrete - no trait objects, no Arc allocations
32/// in the hot path.
33///
34/// # Type Parameters
35///
36/// - `S` - Solution type
37/// - `A` - Entity type
38/// - `E` - Extractor function type
39/// - `F` - Combined filter type
40/// - `Sc` - Score type
41pub struct UniConstraintStream<S, A, E, F, Sc>
42where
43    Sc: Score,
44{
45    extractor: E,
46    filter: F,
47    _phantom: PhantomData<(S, A, Sc)>,
48}
49
50impl<S, A, E, Sc> UniConstraintStream<S, A, E, TrueFilter, Sc>
51where
52    S: Send + Sync + 'static,
53    A: Clone + Send + Sync + 'static,
54    E: Fn(&S) -> &[A] + Send + Sync,
55    Sc: Score + 'static,
56{
57    /// Creates a new uni-constraint stream with the given extractor.
58    pub fn new(extractor: E) -> Self {
59        Self {
60            extractor,
61            filter: TrueFilter,
62            _phantom: PhantomData,
63        }
64    }
65}
66
67impl<S, A, E, F, Sc> UniConstraintStream<S, A, E, F, Sc>
68where
69    S: Send + Sync + 'static,
70    A: Clone + Send + Sync + 'static,
71    E: Fn(&S) -> &[A] + Send + Sync,
72    F: UniFilter<S, A>,
73    Sc: Score + 'static,
74{
75    /// Adds a filter predicate to the stream.
76    ///
77    /// Multiple filters are combined with AND semantics at compile time.
78    /// Each filter adds a new type layer, preserving zero-erasure.
79    ///
80    /// To access related entities, use shadow variables on your entity type
81    /// (e.g., `#[inverse_relation_shadow_variable]`) rather than solution traversal.
82    pub fn filter<P>(
83        self,
84        predicate: P,
85    ) -> UniConstraintStream<
86        S,
87        A,
88        E,
89        AndUniFilter<F, FnUniFilter<impl Fn(&S, &A) -> bool + Send + Sync>>,
90        Sc,
91    >
92    where
93        P: Fn(&A) -> bool + Send + Sync + 'static,
94    {
95        UniConstraintStream {
96            extractor: self.extractor,
97            filter: AndUniFilter::new(
98                self.filter,
99                FnUniFilter::new(move |_s: &S, a: &A| predicate(a)),
100            ),
101            _phantom: PhantomData,
102        }
103    }
104
105    /// Joins this stream with itself to create pairs (zero-erasure).
106    ///
107    /// Requires an `EqualJoiner` to enable key-based indexing for O(k) lookups.
108    /// For self-joins, pairs are ordered (i < j) to avoid duplicates.
109    ///
110    /// Any filters accumulated on this stream are applied to both entities
111    /// individually before the join.
112    pub fn join_self<K, KA, KB>(
113        self,
114        joiner: EqualJoiner<KA, KB, K>,
115    ) -> BiConstraintStream<S, A, K, E, KA, UniLeftBiFilter<F, A>, Sc>
116    where
117        A: Hash + PartialEq,
118        K: Eq + Hash + Clone + Send + Sync,
119        KA: Fn(&A) -> K + Send + Sync,
120        KB: Fn(&A) -> K + Send + Sync,
121    {
122        let (key_extractor, _) = joiner.into_keys();
123
124        // Convert uni-filter to bi-filter that applies to left entity
125        let bi_filter = UniLeftBiFilter::new(self.filter);
126
127        BiConstraintStream::new_self_join_with_filter(self.extractor, key_extractor, bi_filter)
128    }
129
130    /// Joins this stream with another collection to create cross-entity pairs (zero-erasure).
131    ///
132    /// Requires an `EqualJoiner` to enable key-based indexing for O(1) lookups.
133    /// Unlike `join_self` which pairs entities within the same collection,
134    /// `join` creates pairs from two different collections (e.g., Shift joined
135    /// with Employee).
136    ///
137    /// Any filters accumulated on this stream are applied to the A entity
138    /// before the join.
139    pub fn join<B, EB, K, KA, KB>(
140        self,
141        extractor_b: EB,
142        joiner: EqualJoiner<KA, KB, K>,
143    ) -> CrossBiConstraintStream<S, A, B, K, E, EB, KA, KB, UniLeftBiFilter<F, B>, Sc>
144    where
145        B: Clone + Send + Sync + 'static,
146        EB: Fn(&S) -> &[B] + Send + Sync,
147        K: Eq + Hash + Clone + Send + Sync,
148        KA: Fn(&A) -> K + Send + Sync,
149        KB: Fn(&B) -> K + Send + Sync,
150    {
151        let (key_a, key_b) = joiner.into_keys();
152
153        // Convert uni-filter to bi-filter that applies to left entity only
154        let bi_filter = UniLeftBiFilter::new(self.filter);
155
156        CrossBiConstraintStream::new_with_filter(
157            self.extractor,
158            extractor_b,
159            key_a,
160            key_b,
161            bi_filter,
162        )
163    }
164
165    /// Groups entities by key and aggregates with a collector.
166    ///
167    /// Returns a zero-erasure `GroupedConstraintStream` that can be penalized
168    /// or rewarded based on the aggregated result for each group.
169    pub fn group_by<K, KF, C>(
170        self,
171        key_fn: KF,
172        collector: C,
173    ) -> GroupedConstraintStream<S, A, K, E, KF, C, Sc>
174    where
175        K: Clone + Eq + Hash + Send + Sync + 'static,
176        KF: Fn(&A) -> K + Send + Sync,
177        C: UniCollector<A> + Send + Sync + 'static,
178        C::Accumulator: Send + Sync,
179        C::Result: Clone + Send + Sync,
180    {
181        GroupedConstraintStream::new(self.extractor, key_fn, collector)
182    }
183
184    /// Creates a balance constraint that penalizes uneven distribution across groups.
185    ///
186    /// Unlike `group_by` which scores each group independently, `balance` computes
187    /// a GLOBAL standard deviation across all group counts and produces a single score.
188    ///
189    /// The `key_fn` returns `Option<K>` to allow skipping entities (e.g., unassigned shifts).
190    /// Any filters accumulated on this stream are also applied.
191    ///
192    /// # Example
193    ///
194    /// ```
195    /// use solverforge_scoring::stream::ConstraintFactory;
196    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
197    /// use solverforge_core::score::SimpleScore;
198    ///
199    /// #[derive(Clone)]
200    /// struct Shift { employee_id: Option<usize> }
201    ///
202    /// #[derive(Clone)]
203    /// struct Solution { shifts: Vec<Shift> }
204    ///
205    /// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
206    ///     .for_each(|s: &Solution| &s.shifts)
207    ///     .balance(|shift: &Shift| shift.employee_id)
208    ///     .penalize(SimpleScore::of(1000))
209    ///     .as_constraint("Balance workload");
210    ///
211    /// let solution = Solution {
212    ///     shifts: vec![
213    ///         Shift { employee_id: Some(0) },
214    ///         Shift { employee_id: Some(0) },
215    ///         Shift { employee_id: Some(0) },
216    ///         Shift { employee_id: Some(1) },
217    ///     ],
218    /// };
219    ///
220    /// // Employee 0: 3 shifts, Employee 1: 1 shift
221    /// // std_dev = 1.0, penalty = -1000
222    /// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1000));
223    /// ```
224    pub fn balance<K, KF>(self, key_fn: KF) -> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
225    where
226        K: Clone + Eq + Hash + Send + Sync + 'static,
227        KF: Fn(&A) -> Option<K> + Send + Sync,
228    {
229        BalanceConstraintStream::new(self.extractor, self.filter, key_fn)
230    }
231
232    /// Filters A entities based on whether a matching B entity exists.
233    ///
234    /// Use this when the B collection needs filtering (e.g., only vacationing employees).
235    /// The `extractor_b` returns a `Vec<B>` to allow for filtering.
236    ///
237    /// Any filters accumulated on this stream are applied to A entities.
238    ///
239    /// # Example
240    ///
241    /// ```
242    /// use solverforge_scoring::stream::ConstraintFactory;
243    /// use solverforge_scoring::stream::joiner::equal_bi;
244    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
245    /// use solverforge_core::score::SimpleScore;
246    ///
247    /// #[derive(Clone)]
248    /// struct Shift { id: usize, employee_idx: Option<usize> }
249    ///
250    /// #[derive(Clone)]
251    /// struct Employee { id: usize, on_vacation: bool }
252    ///
253    /// #[derive(Clone)]
254    /// struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
255    ///
256    /// // Penalize shifts assigned to employees who are on vacation
257    /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
258    ///     .for_each(|s: &Schedule| s.shifts.as_slice())
259    ///     .filter(|shift: &Shift| shift.employee_idx.is_some())
260    ///     .if_exists_filtered(
261    ///         |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect(),
262    ///         equal_bi(
263    ///             |shift: &Shift| shift.employee_idx,
264    ///             |emp: &Employee| Some(emp.id),
265    ///         ),
266    ///     )
267    ///     .penalize(SimpleScore::of(1))
268    ///     .as_constraint("Vacation conflict");
269    ///
270    /// let schedule = Schedule {
271    ///     shifts: vec![
272    ///         Shift { id: 0, employee_idx: Some(0) },  // assigned to vacationing emp
273    ///         Shift { id: 1, employee_idx: Some(1) },  // assigned to working emp
274    ///         Shift { id: 2, employee_idx: None },     // unassigned (filtered out)
275    ///     ],
276    ///     employees: vec![
277    ///         Employee { id: 0, on_vacation: true },
278    ///         Employee { id: 1, on_vacation: false },
279    ///     ],
280    /// };
281    ///
282    /// // Only shift 0 matches (assigned to employee 0 who is on vacation)
283    /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
284    /// ```
285    pub fn if_exists_filtered<B, EB, K, KA, KB>(
286        self,
287        extractor_b: EB,
288        joiner: EqualJoiner<KA, KB, K>,
289    ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
290    where
291        B: Clone + Send + Sync + 'static,
292        EB: Fn(&S) -> Vec<B> + Send + Sync,
293        K: Eq + Hash + Clone + Send + Sync,
294        KA: Fn(&A) -> K + Send + Sync,
295        KB: Fn(&B) -> K + Send + Sync,
296    {
297        let (key_a, key_b) = joiner.into_keys();
298        IfExistsStream::new(
299            ExistenceMode::Exists,
300            self.extractor,
301            extractor_b,
302            key_a,
303            key_b,
304            self.filter,
305        )
306    }
307
308    /// Filters A entities based on whether NO matching B entity exists.
309    ///
310    /// Use this when the B collection needs filtering.
311    /// The `extractor_b` returns a `Vec<B>` to allow for filtering.
312    ///
313    /// Any filters accumulated on this stream are applied to A entities.
314    ///
315    /// # Example
316    ///
317    /// ```
318    /// use solverforge_scoring::stream::ConstraintFactory;
319    /// use solverforge_scoring::stream::joiner::equal_bi;
320    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
321    /// use solverforge_core::score::SimpleScore;
322    ///
323    /// #[derive(Clone)]
324    /// struct Task { id: usize, assignee: Option<usize> }
325    ///
326    /// #[derive(Clone)]
327    /// struct Worker { id: usize, available: bool }
328    ///
329    /// #[derive(Clone)]
330    /// struct Schedule { tasks: Vec<Task>, workers: Vec<Worker> }
331    ///
332    /// // Penalize tasks assigned to workers who are not available
333    /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
334    ///     .for_each(|s: &Schedule| s.tasks.as_slice())
335    ///     .filter(|task: &Task| task.assignee.is_some())
336    ///     .if_not_exists_filtered(
337    ///         |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
338    ///         equal_bi(
339    ///             |task: &Task| task.assignee,
340    ///             |worker: &Worker| Some(worker.id),
341    ///         ),
342    ///     )
343    ///     .penalize(SimpleScore::of(1))
344    ///     .as_constraint("Unavailable worker");
345    ///
346    /// let schedule = Schedule {
347    ///     tasks: vec![
348    ///         Task { id: 0, assignee: Some(0) },  // worker 0 is unavailable
349    ///         Task { id: 1, assignee: Some(1) },  // worker 1 is available
350    ///         Task { id: 2, assignee: None },     // unassigned (filtered out)
351    ///     ],
352    ///     workers: vec![
353    ///         Worker { id: 0, available: false },
354    ///         Worker { id: 1, available: true },
355    ///     ],
356    /// };
357    ///
358    /// // Task 0's worker (id=0) is NOT in the available workers list
359    /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
360    /// ```
361    pub fn if_not_exists_filtered<B, EB, K, KA, KB>(
362        self,
363        extractor_b: EB,
364        joiner: EqualJoiner<KA, KB, K>,
365    ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
366    where
367        B: Clone + Send + Sync + 'static,
368        EB: Fn(&S) -> Vec<B> + Send + Sync,
369        K: Eq + Hash + Clone + Send + Sync,
370        KA: Fn(&A) -> K + Send + Sync,
371        KB: Fn(&B) -> K + Send + Sync,
372    {
373        let (key_a, key_b) = joiner.into_keys();
374        IfExistsStream::new(
375            ExistenceMode::NotExists,
376            self.extractor,
377            extractor_b,
378            key_a,
379            key_b,
380            self.filter,
381        )
382    }
383
384    /// Penalizes each matching entity with a fixed weight.
385    pub fn penalize(
386        self,
387        weight: Sc,
388    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
389    where
390        Sc: Copy,
391    {
392        // Detect if this is a hard constraint by checking if hard level is non-zero
393        let is_hard = weight
394            .to_level_numbers()
395            .first()
396            .map(|&h| h != 0)
397            .unwrap_or(false);
398        UniConstraintBuilder {
399            extractor: self.extractor,
400            filter: self.filter,
401            impact_type: ImpactType::Penalty,
402            weight: move |_: &A| weight,
403            is_hard,
404            _phantom: PhantomData,
405        }
406    }
407
408    /// Penalizes each matching entity with a dynamic weight.
409    ///
410    /// Note: For dynamic weights, use `penalize_hard_with` to explicitly mark as a hard constraint,
411    /// since the weight function cannot be evaluated at build time.
412    pub fn penalize_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
413    where
414        W: Fn(&A) -> Sc + Send + Sync,
415    {
416        UniConstraintBuilder {
417            extractor: self.extractor,
418            filter: self.filter,
419            impact_type: ImpactType::Penalty,
420            weight: weight_fn,
421            is_hard: false, // Can't detect at build time; use penalize_hard_with for hard constraints
422            _phantom: PhantomData,
423        }
424    }
425
426    /// Penalizes each matching entity with a dynamic weight, explicitly marked as a hard constraint.
427    pub fn penalize_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
428    where
429        W: Fn(&A) -> Sc + Send + Sync,
430    {
431        UniConstraintBuilder {
432            extractor: self.extractor,
433            filter: self.filter,
434            impact_type: ImpactType::Penalty,
435            weight: weight_fn,
436            is_hard: true,
437            _phantom: PhantomData,
438        }
439    }
440
441    /// Rewards each matching entity with a fixed weight.
442    pub fn reward(
443        self,
444        weight: Sc,
445    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
446    where
447        Sc: Copy,
448    {
449        // Detect if this is a hard constraint by checking if hard level is non-zero
450        let is_hard = weight
451            .to_level_numbers()
452            .first()
453            .map(|&h| h != 0)
454            .unwrap_or(false);
455        UniConstraintBuilder {
456            extractor: self.extractor,
457            filter: self.filter,
458            impact_type: ImpactType::Reward,
459            weight: move |_: &A| weight,
460            is_hard,
461            _phantom: PhantomData,
462        }
463    }
464
465    /// Rewards each matching entity with a dynamic weight.
466    ///
467    /// Note: For dynamic weights, use `reward_hard_with` to explicitly mark as a hard constraint,
468    /// since the weight function cannot be evaluated at build time.
469    pub fn reward_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
470    where
471        W: Fn(&A) -> Sc + Send + Sync,
472    {
473        UniConstraintBuilder {
474            extractor: self.extractor,
475            filter: self.filter,
476            impact_type: ImpactType::Reward,
477            weight: weight_fn,
478            is_hard: false, // Can't detect at build time; use reward_hard_with for hard constraints
479            _phantom: PhantomData,
480        }
481    }
482
483    /// Rewards each matching entity with a dynamic weight, explicitly marked as a hard constraint.
484    pub fn reward_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
485    where
486        W: Fn(&A) -> Sc + Send + Sync,
487    {
488        UniConstraintBuilder {
489            extractor: self.extractor,
490            filter: self.filter,
491            impact_type: ImpactType::Reward,
492            weight: weight_fn,
493            is_hard: true,
494            _phantom: PhantomData,
495        }
496    }
497}
498
499impl<S, A, E, F, Sc: Score> std::fmt::Debug for UniConstraintStream<S, A, E, F, Sc> {
500    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
501        f.debug_struct("UniConstraintStream").finish()
502    }
503}
504
505/// Zero-erasure builder for finalizing a uni-constraint.
506pub struct UniConstraintBuilder<S, A, E, F, W, Sc>
507where
508    Sc: Score,
509{
510    extractor: E,
511    filter: F,
512    impact_type: ImpactType,
513    weight: W,
514    is_hard: bool,
515    _phantom: PhantomData<(S, A, Sc)>,
516}
517
518impl<S, A, E, F, W, Sc> UniConstraintBuilder<S, A, E, F, W, Sc>
519where
520    S: Send + Sync + 'static,
521    A: Clone + Send + Sync + 'static,
522    E: Fn(&S) -> &[A] + Send + Sync,
523    F: UniFilter<S, A>,
524    W: Fn(&A) -> Sc + Send + Sync,
525    Sc: Score + 'static,
526{
527    /// Finalizes the builder into a zero-erasure `IncrementalUniConstraint`.
528    pub fn as_constraint(
529        self,
530        name: &str,
531    ) -> IncrementalUniConstraint<S, A, E, impl Fn(&S, &A) -> bool + Send + Sync, W, Sc> {
532        let filter = self.filter;
533        let combined_filter = move |s: &S, a: &A| filter.test(s, a);
534
535        IncrementalUniConstraint::new(
536            ConstraintRef::new("", name),
537            self.impact_type,
538            self.extractor,
539            combined_filter,
540            self.weight,
541            self.is_hard,
542        )
543    }
544}
545
546impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for UniConstraintBuilder<S, A, E, F, W, Sc> {
547    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548        f.debug_struct("UniConstraintBuilder")
549            .field("impact_type", &self.impact_type)
550            .finish()
551    }
552}