solverforge_scoring/stream/
uni_stream.rs

1//! Zero-erasure uni-constraint stream for single-entity constraint patterns.
2//!
3//! A `UniConstraintStream` operates on a single entity type and supports
4//! filtering, weighting, and constraint finalization. All type information
5//! is preserved at compile time - no Arc, no dyn, fully monomorphized.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::constraint::incremental::IncrementalUniConstraint;
14
15use crate::constraint::if_exists::ExistenceMode;
16
17use super::balance_stream::BalanceConstraintStream;
18use super::bi_stream::BiConstraintStream;
19use super::collector::UniCollector;
20use super::cross_bi_stream::CrossBiConstraintStream;
21use super::filter::{AndUniFilter, FnUniFilter, TrueFilter, UniFilter, UniLeftBiFilter};
22use super::grouped_stream::GroupedConstraintStream;
23use super::if_exists_stream::IfExistsStream;
24use super::joiner::EqualJoiner;
25
26/// Zero-erasure constraint stream over a single entity type.
27///
28/// `UniConstraintStream` accumulates filters and can be finalized into
29/// an `IncrementalUniConstraint` via `penalize()` or `reward()`.
30///
31/// All type parameters are concrete - no trait objects, no Arc allocations
32/// in the hot path.
33///
34/// # Type Parameters
35///
36/// - `S` - Solution type
37/// - `A` - Entity type
38/// - `E` - Extractor function type
39/// - `F` - Combined filter type
40/// - `Sc` - Score type
41pub struct UniConstraintStream<S, A, E, F, Sc>
42where
43    Sc: Score,
44{
45    extractor: E,
46    filter: F,
47    _phantom: PhantomData<(S, A, Sc)>,
48}
49
50impl<S, A, E, Sc> UniConstraintStream<S, A, E, TrueFilter, Sc>
51where
52    S: Send + Sync + 'static,
53    A: Clone + Send + Sync + 'static,
54    E: Fn(&S) -> &[A] + Send + Sync,
55    Sc: Score + 'static,
56{
57    /// Creates a new uni-constraint stream with the given extractor.
58    pub fn new(extractor: E) -> Self {
59        Self {
60            extractor,
61            filter: TrueFilter,
62            _phantom: PhantomData,
63        }
64    }
65}
66
67impl<S, A, E, F, Sc> UniConstraintStream<S, A, E, F, Sc>
68where
69    S: Send + Sync + 'static,
70    A: Clone + Send + Sync + 'static,
71    E: Fn(&S) -> &[A] + Send + Sync,
72    F: UniFilter<S, A>,
73    Sc: Score + 'static,
74{
75    /// Adds a filter predicate to the stream.
76    ///
77    /// Multiple filters are combined with AND semantics at compile time.
78    /// Each filter adds a new type layer, preserving zero-erasure.
79    ///
80    /// For filters that need access to solution state (shadow variables),
81    /// use [`filter_with_solution`](Self::filter_with_solution) instead.
82    pub fn filter<P>(
83        self,
84        predicate: P,
85    ) -> UniConstraintStream<
86        S,
87        A,
88        E,
89        AndUniFilter<F, FnUniFilter<impl Fn(&S, &A) -> bool + Send + Sync>>,
90        Sc,
91    >
92    where
93        P: Fn(&A) -> bool + Send + Sync + 'static,
94    {
95        UniConstraintStream {
96            extractor: self.extractor,
97            filter: AndUniFilter::new(
98                self.filter,
99                FnUniFilter::new(move |_s: &S, a: &A| predicate(a)),
100            ),
101            _phantom: PhantomData,
102        }
103    }
104
105    /// Adds a solution-aware filter predicate to the stream.
106    ///
107    /// Unlike [`filter`](Self::filter), this method receives the solution
108    /// reference alongside the entity, enabling access to shadow variables
109    /// and computed solution state.
110    ///
111    /// # Example
112    ///
113    /// ```
114    /// use solverforge_scoring::stream::ConstraintFactory;
115    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
116    /// use solverforge_core::score::SimpleScore;
117    ///
118    /// #[derive(Clone, Debug)]
119    /// struct Shift { id: usize, employee_idx: Option<usize> }
120    ///
121    /// #[derive(Clone, Debug)]
122    /// struct Schedule {
123    ///     shifts: Vec<Shift>,
124    ///     hours_by_employee: Vec<i32>, // shadow variable
125    /// }
126    ///
127    /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
128    ///     .for_each(|s: &Schedule| &s.shifts)
129    ///     .filter(|shift: &Shift| shift.employee_idx.is_some())
130    ///     .filter_with_solution(|schedule: &Schedule, shift: &Shift| {
131    ///         // Access shadow variable from solution
132    ///         let emp_idx = shift.employee_idx.unwrap();
133    ///         schedule.hours_by_employee[emp_idx] > 40
134    ///     })
135    ///     .penalize(SimpleScore::of(1))
136    ///     .as_constraint("Overtime");
137    ///
138    /// let schedule = Schedule {
139    ///     shifts: vec![
140    ///         Shift { id: 0, employee_idx: Some(0) },
141    ///         Shift { id: 1, employee_idx: Some(1) },
142    ///     ],
143    ///     hours_by_employee: vec![45, 35], // emp 0 has overtime
144    /// };
145    ///
146    /// // Only shift 0 matches (employee 0 has > 40 hours)
147    /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
148    /// ```
149    pub fn filter_with_solution<P>(
150        self,
151        predicate: P,
152    ) -> UniConstraintStream<S, A, E, AndUniFilter<F, FnUniFilter<P>>, Sc>
153    where
154        P: Fn(&S, &A) -> bool + Send + Sync,
155    {
156        UniConstraintStream {
157            extractor: self.extractor,
158            filter: AndUniFilter::new(self.filter, FnUniFilter::new(predicate)),
159            _phantom: PhantomData,
160        }
161    }
162
163    /// Joins this stream with itself to create pairs (zero-erasure).
164    ///
165    /// Requires an `EqualJoiner` to enable key-based indexing for O(k) lookups.
166    /// For self-joins, pairs are ordered (i < j) to avoid duplicates.
167    ///
168    /// Any filters accumulated on this stream are applied to both entities
169    /// individually before the join.
170    pub fn join_self<K, KA, KB>(
171        self,
172        joiner: EqualJoiner<KA, KB, K>,
173    ) -> BiConstraintStream<S, A, K, E, KA, UniLeftBiFilter<F, A>, Sc>
174    where
175        A: Hash + PartialEq,
176        K: Eq + Hash + Clone + Send + Sync,
177        KA: Fn(&A) -> K + Send + Sync,
178        KB: Fn(&A) -> K + Send + Sync,
179    {
180        let (key_extractor, _) = joiner.into_keys();
181
182        // Convert uni-filter to bi-filter that applies to left entity
183        let bi_filter = UniLeftBiFilter::new(self.filter);
184
185        BiConstraintStream::new_self_join_with_filter(self.extractor, key_extractor, bi_filter)
186    }
187
188    /// Joins this stream with another collection to create cross-entity pairs (zero-erasure).
189    ///
190    /// Requires an `EqualJoiner` to enable key-based indexing for O(1) lookups.
191    /// Unlike `join_self` which pairs entities within the same collection,
192    /// `join` creates pairs from two different collections (e.g., Shift joined
193    /// with Employee).
194    ///
195    /// Any filters accumulated on this stream are applied to the A entity
196    /// before the join.
197    pub fn join<B, EB, K, KA, KB>(
198        self,
199        extractor_b: EB,
200        joiner: EqualJoiner<KA, KB, K>,
201    ) -> CrossBiConstraintStream<S, A, B, K, E, EB, KA, KB, UniLeftBiFilter<F, B>, Sc>
202    where
203        B: Clone + Send + Sync + 'static,
204        EB: Fn(&S) -> &[B] + Send + Sync,
205        K: Eq + Hash + Clone + Send + Sync,
206        KA: Fn(&A) -> K + Send + Sync,
207        KB: Fn(&B) -> K + Send + Sync,
208    {
209        let (key_a, key_b) = joiner.into_keys();
210
211        // Convert uni-filter to bi-filter that applies to left entity only
212        let bi_filter = UniLeftBiFilter::new(self.filter);
213
214        CrossBiConstraintStream::new_with_filter(
215            self.extractor,
216            extractor_b,
217            key_a,
218            key_b,
219            bi_filter,
220        )
221    }
222
223    /// Groups entities by key and aggregates with a collector.
224    ///
225    /// Returns a zero-erasure `GroupedConstraintStream` that can be penalized
226    /// or rewarded based on the aggregated result for each group.
227    pub fn group_by<K, KF, C>(
228        self,
229        key_fn: KF,
230        collector: C,
231    ) -> GroupedConstraintStream<S, A, K, E, KF, C, Sc>
232    where
233        K: Clone + Eq + Hash + Send + Sync + 'static,
234        KF: Fn(&A) -> K + Send + Sync,
235        C: UniCollector<A> + Send + Sync + 'static,
236        C::Accumulator: Send + Sync,
237        C::Result: Clone + Send + Sync,
238    {
239        GroupedConstraintStream::new(self.extractor, key_fn, collector)
240    }
241
242    /// Creates a balance constraint that penalizes uneven distribution across groups.
243    ///
244    /// Unlike `group_by` which scores each group independently, `balance` computes
245    /// a GLOBAL standard deviation across all group counts and produces a single score.
246    ///
247    /// The `key_fn` returns `Option<K>` to allow skipping entities (e.g., unassigned shifts).
248    /// Any filters accumulated on this stream are also applied.
249    ///
250    /// # Example
251    ///
252    /// ```
253    /// use solverforge_scoring::stream::ConstraintFactory;
254    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
255    /// use solverforge_core::score::SimpleScore;
256    ///
257    /// #[derive(Clone)]
258    /// struct Shift { employee_id: Option<usize> }
259    ///
260    /// #[derive(Clone)]
261    /// struct Solution { shifts: Vec<Shift> }
262    ///
263    /// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
264    ///     .for_each(|s: &Solution| &s.shifts)
265    ///     .balance(|shift: &Shift| shift.employee_id)
266    ///     .penalize(SimpleScore::of(1000))
267    ///     .as_constraint("Balance workload");
268    ///
269    /// let solution = Solution {
270    ///     shifts: vec![
271    ///         Shift { employee_id: Some(0) },
272    ///         Shift { employee_id: Some(0) },
273    ///         Shift { employee_id: Some(0) },
274    ///         Shift { employee_id: Some(1) },
275    ///     ],
276    /// };
277    ///
278    /// // Employee 0: 3 shifts, Employee 1: 1 shift
279    /// // std_dev = 1.0, penalty = -1000
280    /// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1000));
281    /// ```
282    pub fn balance<K, KF>(self, key_fn: KF) -> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
283    where
284        K: Clone + Eq + Hash + Send + Sync + 'static,
285        KF: Fn(&A) -> Option<K> + Send + Sync,
286    {
287        BalanceConstraintStream::new(self.extractor, self.filter, key_fn)
288    }
289
290    /// Filters A entities based on whether a matching B entity exists.
291    ///
292    /// Use this when the B collection needs filtering (e.g., only vacationing employees).
293    /// The `extractor_b` returns a `Vec<B>` to allow for filtering.
294    ///
295    /// Any filters accumulated on this stream are applied to A entities.
296    ///
297    /// # Example
298    ///
299    /// ```
300    /// use solverforge_scoring::stream::ConstraintFactory;
301    /// use solverforge_scoring::stream::joiner::equal_bi;
302    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
303    /// use solverforge_core::score::SimpleScore;
304    ///
305    /// #[derive(Clone)]
306    /// struct Shift { id: usize, employee_idx: Option<usize> }
307    ///
308    /// #[derive(Clone)]
309    /// struct Employee { id: usize, on_vacation: bool }
310    ///
311    /// #[derive(Clone)]
312    /// struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
313    ///
314    /// // Penalize shifts assigned to employees who are on vacation
315    /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
316    ///     .for_each(|s: &Schedule| s.shifts.as_slice())
317    ///     .filter(|shift: &Shift| shift.employee_idx.is_some())
318    ///     .if_exists_filtered(
319    ///         |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect(),
320    ///         equal_bi(
321    ///             |shift: &Shift| shift.employee_idx,
322    ///             |emp: &Employee| Some(emp.id),
323    ///         ),
324    ///     )
325    ///     .penalize(SimpleScore::of(1))
326    ///     .as_constraint("Vacation conflict");
327    ///
328    /// let schedule = Schedule {
329    ///     shifts: vec![
330    ///         Shift { id: 0, employee_idx: Some(0) },  // assigned to vacationing emp
331    ///         Shift { id: 1, employee_idx: Some(1) },  // assigned to working emp
332    ///         Shift { id: 2, employee_idx: None },     // unassigned (filtered out)
333    ///     ],
334    ///     employees: vec![
335    ///         Employee { id: 0, on_vacation: true },
336    ///         Employee { id: 1, on_vacation: false },
337    ///     ],
338    /// };
339    ///
340    /// // Only shift 0 matches (assigned to employee 0 who is on vacation)
341    /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
342    /// ```
343    pub fn if_exists_filtered<B, EB, K, KA, KB>(
344        self,
345        extractor_b: EB,
346        joiner: EqualJoiner<KA, KB, K>,
347    ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
348    where
349        B: Clone + Send + Sync + 'static,
350        EB: Fn(&S) -> Vec<B> + Send + Sync,
351        K: Eq + Hash + Clone + Send + Sync,
352        KA: Fn(&A) -> K + Send + Sync,
353        KB: Fn(&B) -> K + Send + Sync,
354    {
355        let (key_a, key_b) = joiner.into_keys();
356        IfExistsStream::new(
357            ExistenceMode::Exists,
358            self.extractor,
359            extractor_b,
360            key_a,
361            key_b,
362            self.filter,
363        )
364    }
365
366    /// Filters A entities based on whether NO matching B entity exists.
367    ///
368    /// Use this when the B collection needs filtering.
369    /// The `extractor_b` returns a `Vec<B>` to allow for filtering.
370    ///
371    /// Any filters accumulated on this stream are applied to A entities.
372    ///
373    /// # Example
374    ///
375    /// ```
376    /// use solverforge_scoring::stream::ConstraintFactory;
377    /// use solverforge_scoring::stream::joiner::equal_bi;
378    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
379    /// use solverforge_core::score::SimpleScore;
380    ///
381    /// #[derive(Clone)]
382    /// struct Task { id: usize, assignee: Option<usize> }
383    ///
384    /// #[derive(Clone)]
385    /// struct Worker { id: usize, available: bool }
386    ///
387    /// #[derive(Clone)]
388    /// struct Schedule { tasks: Vec<Task>, workers: Vec<Worker> }
389    ///
390    /// // Penalize tasks assigned to workers who are not available
391    /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
392    ///     .for_each(|s: &Schedule| s.tasks.as_slice())
393    ///     .filter(|task: &Task| task.assignee.is_some())
394    ///     .if_not_exists_filtered(
395    ///         |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
396    ///         equal_bi(
397    ///             |task: &Task| task.assignee,
398    ///             |worker: &Worker| Some(worker.id),
399    ///         ),
400    ///     )
401    ///     .penalize(SimpleScore::of(1))
402    ///     .as_constraint("Unavailable worker");
403    ///
404    /// let schedule = Schedule {
405    ///     tasks: vec![
406    ///         Task { id: 0, assignee: Some(0) },  // worker 0 is unavailable
407    ///         Task { id: 1, assignee: Some(1) },  // worker 1 is available
408    ///         Task { id: 2, assignee: None },     // unassigned (filtered out)
409    ///     ],
410    ///     workers: vec![
411    ///         Worker { id: 0, available: false },
412    ///         Worker { id: 1, available: true },
413    ///     ],
414    /// };
415    ///
416    /// // Task 0's worker (id=0) is NOT in the available workers list
417    /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
418    /// ```
419    pub fn if_not_exists_filtered<B, EB, K, KA, KB>(
420        self,
421        extractor_b: EB,
422        joiner: EqualJoiner<KA, KB, K>,
423    ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
424    where
425        B: Clone + Send + Sync + 'static,
426        EB: Fn(&S) -> Vec<B> + Send + Sync,
427        K: Eq + Hash + Clone + Send + Sync,
428        KA: Fn(&A) -> K + Send + Sync,
429        KB: Fn(&B) -> K + Send + Sync,
430    {
431        let (key_a, key_b) = joiner.into_keys();
432        IfExistsStream::new(
433            ExistenceMode::NotExists,
434            self.extractor,
435            extractor_b,
436            key_a,
437            key_b,
438            self.filter,
439        )
440    }
441
442    /// Penalizes each matching entity with a fixed weight.
443    pub fn penalize(
444        self,
445        weight: Sc,
446    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
447    where
448        Sc: Copy,
449    {
450        // Detect if this is a hard constraint by checking if hard level is non-zero
451        let is_hard = weight
452            .to_level_numbers()
453            .first()
454            .map(|&h| h != 0)
455            .unwrap_or(false);
456        UniConstraintBuilder {
457            extractor: self.extractor,
458            filter: self.filter,
459            impact_type: ImpactType::Penalty,
460            weight: move |_: &A| weight,
461            is_hard,
462            _phantom: PhantomData,
463        }
464    }
465
466    /// Penalizes each matching entity with a dynamic weight.
467    ///
468    /// Note: For dynamic weights, use `penalize_hard_with` to explicitly mark as a hard constraint,
469    /// since the weight function cannot be evaluated at build time.
470    pub fn penalize_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
471    where
472        W: Fn(&A) -> Sc + Send + Sync,
473    {
474        UniConstraintBuilder {
475            extractor: self.extractor,
476            filter: self.filter,
477            impact_type: ImpactType::Penalty,
478            weight: weight_fn,
479            is_hard: false, // Can't detect at build time; use penalize_hard_with for hard constraints
480            _phantom: PhantomData,
481        }
482    }
483
484    /// Penalizes each matching entity with a dynamic weight, explicitly marked as a hard constraint.
485    pub fn penalize_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
486    where
487        W: Fn(&A) -> Sc + Send + Sync,
488    {
489        UniConstraintBuilder {
490            extractor: self.extractor,
491            filter: self.filter,
492            impact_type: ImpactType::Penalty,
493            weight: weight_fn,
494            is_hard: true,
495            _phantom: PhantomData,
496        }
497    }
498
499    /// Rewards each matching entity with a fixed weight.
500    pub fn reward(
501        self,
502        weight: Sc,
503    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
504    where
505        Sc: Copy,
506    {
507        // Detect if this is a hard constraint by checking if hard level is non-zero
508        let is_hard = weight
509            .to_level_numbers()
510            .first()
511            .map(|&h| h != 0)
512            .unwrap_or(false);
513        UniConstraintBuilder {
514            extractor: self.extractor,
515            filter: self.filter,
516            impact_type: ImpactType::Reward,
517            weight: move |_: &A| weight,
518            is_hard,
519            _phantom: PhantomData,
520        }
521    }
522
523    /// Rewards each matching entity with a dynamic weight.
524    ///
525    /// Note: For dynamic weights, use `reward_hard_with` to explicitly mark as a hard constraint,
526    /// since the weight function cannot be evaluated at build time.
527    pub fn reward_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
528    where
529        W: Fn(&A) -> Sc + Send + Sync,
530    {
531        UniConstraintBuilder {
532            extractor: self.extractor,
533            filter: self.filter,
534            impact_type: ImpactType::Reward,
535            weight: weight_fn,
536            is_hard: false, // Can't detect at build time; use reward_hard_with for hard constraints
537            _phantom: PhantomData,
538        }
539    }
540
541    /// Rewards each matching entity with a dynamic weight, explicitly marked as a hard constraint.
542    pub fn reward_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
543    where
544        W: Fn(&A) -> Sc + Send + Sync,
545    {
546        UniConstraintBuilder {
547            extractor: self.extractor,
548            filter: self.filter,
549            impact_type: ImpactType::Reward,
550            weight: weight_fn,
551            is_hard: true,
552            _phantom: PhantomData,
553        }
554    }
555}
556
557impl<S, A, E, F, Sc: Score> std::fmt::Debug for UniConstraintStream<S, A, E, F, Sc> {
558    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
559        f.debug_struct("UniConstraintStream").finish()
560    }
561}
562
563/// Zero-erasure builder for finalizing a uni-constraint.
564pub struct UniConstraintBuilder<S, A, E, F, W, Sc>
565where
566    Sc: Score,
567{
568    extractor: E,
569    filter: F,
570    impact_type: ImpactType,
571    weight: W,
572    is_hard: bool,
573    _phantom: PhantomData<(S, A, Sc)>,
574}
575
576impl<S, A, E, F, W, Sc> UniConstraintBuilder<S, A, E, F, W, Sc>
577where
578    S: Send + Sync + 'static,
579    A: Clone + Send + Sync + 'static,
580    E: Fn(&S) -> &[A] + Send + Sync,
581    F: UniFilter<S, A>,
582    W: Fn(&A) -> Sc + Send + Sync,
583    Sc: Score + 'static,
584{
585    /// Finalizes the builder into a zero-erasure `IncrementalUniConstraint`.
586    pub fn as_constraint(
587        self,
588        name: &str,
589    ) -> IncrementalUniConstraint<S, A, E, impl Fn(&S, &A) -> bool + Send + Sync, W, Sc> {
590        let filter = self.filter;
591        let combined_filter = move |s: &S, a: &A| filter.test(s, a);
592
593        IncrementalUniConstraint::new(
594            ConstraintRef::new("", name),
595            self.impact_type,
596            self.extractor,
597            combined_filter,
598            self.weight,
599            self.is_hard,
600        )
601    }
602}
603
604impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for UniConstraintBuilder<S, A, E, F, W, Sc> {
605    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
606        f.debug_struct("UniConstraintBuilder")
607            .field("impact_type", &self.impact_type)
608            .finish()
609    }
610}