solverforge_scoring/stream/
uni_stream.rs

1//! Zero-erasure uni-constraint stream for single-entity constraint patterns.
2//!
3//! A `UniConstraintStream` operates on a single entity type and supports
4//! filtering, weighting, and constraint finalization. All type information
5//! is preserved at compile time - no Arc, no dyn, fully monomorphized.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::constraint::incremental::IncrementalUniConstraint;
14
15use crate::constraint::if_exists::ExistenceMode;
16
17use super::balance_stream::BalanceConstraintStream;
18use super::bi_stream::BiConstraintStream;
19use super::collector::UniCollector;
20use super::cross_bi_stream::CrossBiConstraintStream;
21use super::filter::{AndUniFilter, FnUniFilter, TrueFilter, UniFilter, UniLeftBiFilter};
22use super::grouped_stream::GroupedConstraintStream;
23use super::if_exists_stream::IfExistsStream;
24use super::joiner::EqualJoiner;
25
26/// Zero-erasure constraint stream over a single entity type.
27///
28/// `UniConstraintStream` accumulates filters and can be finalized into
29/// an `IncrementalUniConstraint` via `penalize()` or `reward()`.
30///
31/// All type parameters are concrete - no trait objects, no Arc allocations
32/// in the hot path.
33///
34/// # Type Parameters
35///
36/// - `S` - Solution type
37/// - `A` - Entity type
38/// - `E` - Extractor function type
39/// - `F` - Combined filter type
40/// - `Sc` - Score type
41pub struct UniConstraintStream<S, A, E, F, Sc>
42where
43    Sc: Score,
44{
45    extractor: E,
46    filter: F,
47    _phantom: PhantomData<(S, A, Sc)>,
48}
49
50impl<S, A, E, Sc> UniConstraintStream<S, A, E, TrueFilter, Sc>
51where
52    S: Send + Sync + 'static,
53    A: Clone + Send + Sync + 'static,
54    E: Fn(&S) -> &[A] + Send + Sync,
55    Sc: Score + 'static,
56{
57    /// Creates a new uni-constraint stream with the given extractor.
58    pub fn new(extractor: E) -> Self {
59        Self {
60            extractor,
61            filter: TrueFilter,
62            _phantom: PhantomData,
63        }
64    }
65}
66
67impl<S, A, E, F, Sc> UniConstraintStream<S, A, E, F, Sc>
68where
69    S: Send + Sync + 'static,
70    A: Clone + Send + Sync + 'static,
71    E: Fn(&S) -> &[A] + Send + Sync,
72    F: UniFilter<A>,
73    Sc: Score + 'static,
74{
75    /// Adds a filter predicate to the stream.
76    ///
77    /// Multiple filters are combined with AND semantics at compile time.
78    /// Each filter adds a new type layer, preserving zero-erasure.
79    pub fn filter<P>(
80        self,
81        predicate: P,
82    ) -> UniConstraintStream<S, A, E, AndUniFilter<F, FnUniFilter<P>>, Sc>
83    where
84        P: Fn(&A) -> bool + Send + Sync,
85    {
86        UniConstraintStream {
87            extractor: self.extractor,
88            filter: AndUniFilter::new(self.filter, FnUniFilter::new(predicate)),
89            _phantom: PhantomData,
90        }
91    }
92
93    /// Joins this stream with itself to create pairs (zero-erasure).
94    ///
95    /// Requires an `EqualJoiner` to enable key-based indexing for O(k) lookups.
96    /// For self-joins, pairs are ordered (i < j) to avoid duplicates.
97    ///
98    /// Any filters accumulated on this stream are applied to both entities
99    /// individually before the join.
100    pub fn join_self<K, KA, KB>(
101        self,
102        joiner: EqualJoiner<KA, KB, K>,
103    ) -> BiConstraintStream<S, A, K, E, KA, UniLeftBiFilter<F, A>, Sc>
104    where
105        A: Hash + PartialEq,
106        K: Eq + Hash + Clone + Send + Sync,
107        KA: Fn(&A) -> K + Send + Sync,
108        KB: Fn(&A) -> K + Send + Sync,
109    {
110        let (key_extractor, _) = joiner.into_keys();
111
112        // Convert uni-filter to bi-filter that applies to left entity
113        let bi_filter = UniLeftBiFilter::new(self.filter);
114
115        BiConstraintStream::new_self_join_with_filter(self.extractor, key_extractor, bi_filter)
116    }
117
118    /// Joins this stream with another collection to create cross-entity pairs (zero-erasure).
119    ///
120    /// Requires an `EqualJoiner` to enable key-based indexing for O(1) lookups.
121    /// Unlike `join_self` which pairs entities within the same collection,
122    /// `join` creates pairs from two different collections (e.g., Shift joined
123    /// with Employee).
124    ///
125    /// Any filters accumulated on this stream are applied to the A entity
126    /// before the join.
127    pub fn join<B, EB, K, KA, KB>(
128        self,
129        extractor_b: EB,
130        joiner: EqualJoiner<KA, KB, K>,
131    ) -> CrossBiConstraintStream<S, A, B, K, E, EB, KA, KB, UniLeftBiFilter<F, B>, Sc>
132    where
133        B: Clone + Send + Sync + 'static,
134        EB: Fn(&S) -> &[B] + Send + Sync,
135        K: Eq + Hash + Clone + Send + Sync,
136        KA: Fn(&A) -> K + Send + Sync,
137        KB: Fn(&B) -> K + Send + Sync,
138    {
139        let (key_a, key_b) = joiner.into_keys();
140
141        // Convert uni-filter to bi-filter that applies to left entity only
142        let bi_filter = UniLeftBiFilter::new(self.filter);
143
144        CrossBiConstraintStream::new_with_filter(
145            self.extractor,
146            extractor_b,
147            key_a,
148            key_b,
149            bi_filter,
150        )
151    }
152
153    /// Groups entities by key and aggregates with a collector.
154    ///
155    /// Returns a zero-erasure `GroupedConstraintStream` that can be penalized
156    /// or rewarded based on the aggregated result for each group.
157    pub fn group_by<K, KF, C>(
158        self,
159        key_fn: KF,
160        collector: C,
161    ) -> GroupedConstraintStream<S, A, K, E, KF, C, Sc>
162    where
163        K: Clone + Eq + Hash + Send + Sync + 'static,
164        KF: Fn(&A) -> K + Send + Sync,
165        C: UniCollector<A> + Send + Sync + 'static,
166        C::Accumulator: Send + Sync,
167        C::Result: Clone + Send + Sync,
168    {
169        GroupedConstraintStream::new(self.extractor, key_fn, collector)
170    }
171
172    /// Creates a balance constraint that penalizes uneven distribution across groups.
173    ///
174    /// Unlike `group_by` which scores each group independently, `balance` computes
175    /// a GLOBAL standard deviation across all group counts and produces a single score.
176    ///
177    /// The `key_fn` returns `Option<K>` to allow skipping entities (e.g., unassigned shifts).
178    /// Any filters accumulated on this stream are also applied.
179    ///
180    /// # Example
181    ///
182    /// ```
183    /// use solverforge_scoring::stream::ConstraintFactory;
184    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
185    /// use solverforge_core::score::SimpleScore;
186    ///
187    /// #[derive(Clone)]
188    /// struct Shift { employee_id: Option<usize> }
189    ///
190    /// #[derive(Clone)]
191    /// struct Solution { shifts: Vec<Shift> }
192    ///
193    /// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
194    ///     .for_each(|s: &Solution| &s.shifts)
195    ///     .balance(|shift: &Shift| shift.employee_id)
196    ///     .penalize(SimpleScore::of(1000))
197    ///     .as_constraint("Balance workload");
198    ///
199    /// let solution = Solution {
200    ///     shifts: vec![
201    ///         Shift { employee_id: Some(0) },
202    ///         Shift { employee_id: Some(0) },
203    ///         Shift { employee_id: Some(0) },
204    ///         Shift { employee_id: Some(1) },
205    ///     ],
206    /// };
207    ///
208    /// // Employee 0: 3 shifts, Employee 1: 1 shift
209    /// // std_dev = 1.0, penalty = -1000
210    /// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1000));
211    /// ```
212    pub fn balance<K, KF>(self, key_fn: KF) -> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
213    where
214        K: Clone + Eq + Hash + Send + Sync + 'static,
215        KF: Fn(&A) -> Option<K> + Send + Sync,
216    {
217        BalanceConstraintStream::new(self.extractor, self.filter, key_fn)
218    }
219
220    /// Filters A entities based on whether a matching B entity exists.
221    ///
222    /// Use this when the B collection needs filtering (e.g., only vacationing employees).
223    /// The `extractor_b` returns a `Vec<B>` to allow for filtering.
224    ///
225    /// Any filters accumulated on this stream are applied to A entities.
226    ///
227    /// # Example
228    ///
229    /// ```
230    /// use solverforge_scoring::stream::ConstraintFactory;
231    /// use solverforge_scoring::stream::joiner::equal_bi;
232    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
233    /// use solverforge_core::score::SimpleScore;
234    ///
235    /// #[derive(Clone)]
236    /// struct Shift { id: usize, employee_idx: Option<usize> }
237    ///
238    /// #[derive(Clone)]
239    /// struct Employee { id: usize, on_vacation: bool }
240    ///
241    /// #[derive(Clone)]
242    /// struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
243    ///
244    /// // Penalize shifts assigned to employees who are on vacation
245    /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
246    ///     .for_each(|s: &Schedule| s.shifts.as_slice())
247    ///     .filter(|shift: &Shift| shift.employee_idx.is_some())
248    ///     .if_exists_filtered(
249    ///         |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect(),
250    ///         equal_bi(
251    ///             |shift: &Shift| shift.employee_idx,
252    ///             |emp: &Employee| Some(emp.id),
253    ///         ),
254    ///     )
255    ///     .penalize(SimpleScore::of(1))
256    ///     .as_constraint("Vacation conflict");
257    ///
258    /// let schedule = Schedule {
259    ///     shifts: vec![
260    ///         Shift { id: 0, employee_idx: Some(0) },  // assigned to vacationing emp
261    ///         Shift { id: 1, employee_idx: Some(1) },  // assigned to working emp
262    ///         Shift { id: 2, employee_idx: None },     // unassigned (filtered out)
263    ///     ],
264    ///     employees: vec![
265    ///         Employee { id: 0, on_vacation: true },
266    ///         Employee { id: 1, on_vacation: false },
267    ///     ],
268    /// };
269    ///
270    /// // Only shift 0 matches (assigned to employee 0 who is on vacation)
271    /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
272    /// ```
273    pub fn if_exists_filtered<B, EB, K, KA, KB>(
274        self,
275        extractor_b: EB,
276        joiner: EqualJoiner<KA, KB, K>,
277    ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
278    where
279        B: Clone + Send + Sync + 'static,
280        EB: Fn(&S) -> Vec<B> + Send + Sync,
281        K: Eq + Hash + Clone + Send + Sync,
282        KA: Fn(&A) -> K + Send + Sync,
283        KB: Fn(&B) -> K + Send + Sync,
284    {
285        let (key_a, key_b) = joiner.into_keys();
286        IfExistsStream::new(
287            ExistenceMode::Exists,
288            self.extractor,
289            extractor_b,
290            key_a,
291            key_b,
292            self.filter,
293        )
294    }
295
296    /// Filters A entities based on whether NO matching B entity exists.
297    ///
298    /// Use this when the B collection needs filtering.
299    /// The `extractor_b` returns a `Vec<B>` to allow for filtering.
300    ///
301    /// Any filters accumulated on this stream are applied to A entities.
302    ///
303    /// # Example
304    ///
305    /// ```
306    /// use solverforge_scoring::stream::ConstraintFactory;
307    /// use solverforge_scoring::stream::joiner::equal_bi;
308    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
309    /// use solverforge_core::score::SimpleScore;
310    ///
311    /// #[derive(Clone)]
312    /// struct Task { id: usize, assignee: Option<usize> }
313    ///
314    /// #[derive(Clone)]
315    /// struct Worker { id: usize, available: bool }
316    ///
317    /// #[derive(Clone)]
318    /// struct Schedule { tasks: Vec<Task>, workers: Vec<Worker> }
319    ///
320    /// // Penalize tasks assigned to workers who are not available
321    /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
322    ///     .for_each(|s: &Schedule| s.tasks.as_slice())
323    ///     .filter(|task: &Task| task.assignee.is_some())
324    ///     .if_not_exists_filtered(
325    ///         |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
326    ///         equal_bi(
327    ///             |task: &Task| task.assignee,
328    ///             |worker: &Worker| Some(worker.id),
329    ///         ),
330    ///     )
331    ///     .penalize(SimpleScore::of(1))
332    ///     .as_constraint("Unavailable worker");
333    ///
334    /// let schedule = Schedule {
335    ///     tasks: vec![
336    ///         Task { id: 0, assignee: Some(0) },  // worker 0 is unavailable
337    ///         Task { id: 1, assignee: Some(1) },  // worker 1 is available
338    ///         Task { id: 2, assignee: None },     // unassigned (filtered out)
339    ///     ],
340    ///     workers: vec![
341    ///         Worker { id: 0, available: false },
342    ///         Worker { id: 1, available: true },
343    ///     ],
344    /// };
345    ///
346    /// // Task 0's worker (id=0) is NOT in the available workers list
347    /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
348    /// ```
349    pub fn if_not_exists_filtered<B, EB, K, KA, KB>(
350        self,
351        extractor_b: EB,
352        joiner: EqualJoiner<KA, KB, K>,
353    ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
354    where
355        B: Clone + Send + Sync + 'static,
356        EB: Fn(&S) -> Vec<B> + Send + Sync,
357        K: Eq + Hash + Clone + Send + Sync,
358        KA: Fn(&A) -> K + Send + Sync,
359        KB: Fn(&B) -> K + Send + Sync,
360    {
361        let (key_a, key_b) = joiner.into_keys();
362        IfExistsStream::new(
363            ExistenceMode::NotExists,
364            self.extractor,
365            extractor_b,
366            key_a,
367            key_b,
368            self.filter,
369        )
370    }
371
372    /// Penalizes each matching entity with a fixed weight.
373    pub fn penalize(
374        self,
375        weight: Sc,
376    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
377    where
378        Sc: Clone,
379    {
380        // Detect if this is a hard constraint by checking if hard level is non-zero
381        let is_hard = weight
382            .to_level_numbers()
383            .first()
384            .map(|&h| h != 0)
385            .unwrap_or(false);
386        UniConstraintBuilder {
387            extractor: self.extractor,
388            filter: self.filter,
389            impact_type: ImpactType::Penalty,
390            weight: move |_: &A| weight.clone(),
391            is_hard,
392            _phantom: PhantomData,
393        }
394    }
395
396    /// Penalizes each matching entity with a dynamic weight.
397    ///
398    /// Note: For dynamic weights, use `penalize_hard_with` to explicitly mark as a hard constraint,
399    /// since the weight function cannot be evaluated at build time.
400    pub fn penalize_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
401    where
402        W: Fn(&A) -> Sc + Send + Sync,
403    {
404        UniConstraintBuilder {
405            extractor: self.extractor,
406            filter: self.filter,
407            impact_type: ImpactType::Penalty,
408            weight: weight_fn,
409            is_hard: false, // Can't detect at build time; use penalize_hard_with for hard constraints
410            _phantom: PhantomData,
411        }
412    }
413
414    /// Penalizes each matching entity with a dynamic weight, explicitly marked as a hard constraint.
415    pub fn penalize_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
416    where
417        W: Fn(&A) -> Sc + Send + Sync,
418    {
419        UniConstraintBuilder {
420            extractor: self.extractor,
421            filter: self.filter,
422            impact_type: ImpactType::Penalty,
423            weight: weight_fn,
424            is_hard: true,
425            _phantom: PhantomData,
426        }
427    }
428
429    /// Rewards each matching entity with a fixed weight.
430    pub fn reward(
431        self,
432        weight: Sc,
433    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
434    where
435        Sc: Clone,
436    {
437        // Detect if this is a hard constraint by checking if hard level is non-zero
438        let is_hard = weight
439            .to_level_numbers()
440            .first()
441            .map(|&h| h != 0)
442            .unwrap_or(false);
443        UniConstraintBuilder {
444            extractor: self.extractor,
445            filter: self.filter,
446            impact_type: ImpactType::Reward,
447            weight: move |_: &A| weight.clone(),
448            is_hard,
449            _phantom: PhantomData,
450        }
451    }
452
453    /// Rewards each matching entity with a dynamic weight.
454    ///
455    /// Note: For dynamic weights, use `reward_hard_with` to explicitly mark as a hard constraint,
456    /// since the weight function cannot be evaluated at build time.
457    pub fn reward_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
458    where
459        W: Fn(&A) -> Sc + Send + Sync,
460    {
461        UniConstraintBuilder {
462            extractor: self.extractor,
463            filter: self.filter,
464            impact_type: ImpactType::Reward,
465            weight: weight_fn,
466            is_hard: false, // Can't detect at build time; use reward_hard_with for hard constraints
467            _phantom: PhantomData,
468        }
469    }
470
471    /// Rewards each matching entity with a dynamic weight, explicitly marked as a hard constraint.
472    pub fn reward_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
473    where
474        W: Fn(&A) -> Sc + Send + Sync,
475    {
476        UniConstraintBuilder {
477            extractor: self.extractor,
478            filter: self.filter,
479            impact_type: ImpactType::Reward,
480            weight: weight_fn,
481            is_hard: true,
482            _phantom: PhantomData,
483        }
484    }
485}
486
487impl<S, A, E, F, Sc: Score> std::fmt::Debug for UniConstraintStream<S, A, E, F, Sc> {
488    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
489        f.debug_struct("UniConstraintStream").finish()
490    }
491}
492
493/// Zero-erasure builder for finalizing a uni-constraint.
494pub struct UniConstraintBuilder<S, A, E, F, W, Sc>
495where
496    Sc: Score,
497{
498    extractor: E,
499    filter: F,
500    impact_type: ImpactType,
501    weight: W,
502    is_hard: bool,
503    _phantom: PhantomData<(S, A, Sc)>,
504}
505
506impl<S, A, E, F, W, Sc> UniConstraintBuilder<S, A, E, F, W, Sc>
507where
508    S: Send + Sync + 'static,
509    A: Clone + Send + Sync + 'static,
510    E: Fn(&S) -> &[A] + Send + Sync,
511    F: UniFilter<A>,
512    W: Fn(&A) -> Sc + Send + Sync,
513    Sc: Score + 'static,
514{
515    /// Finalizes the builder into a zero-erasure `IncrementalUniConstraint`.
516    pub fn as_constraint(
517        self,
518        name: &str,
519    ) -> IncrementalUniConstraint<S, A, E, impl Fn(&A) -> bool + Send + Sync, W, Sc> {
520        let filter = self.filter;
521        let combined_filter = move |a: &A| filter.test(a);
522
523        IncrementalUniConstraint::new(
524            ConstraintRef::new("", name),
525            self.impact_type,
526            self.extractor,
527            combined_filter,
528            self.weight,
529            self.is_hard,
530        )
531    }
532}
533
534impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for UniConstraintBuilder<S, A, E, F, W, Sc> {
535    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536        f.debug_struct("UniConstraintBuilder")
537            .field("impact_type", &self.impact_type)
538            .finish()
539    }
540}