Skip to main content

solverforge_scoring/stream/
uni_stream.rs

1// Zero-erasure uni-constraint stream for single-entity constraint patterns.
2//
3// A `UniConstraintStream` operates on a single entity type and supports
4// filtering, weighting, and constraint finalization. All type information
5// is preserved at compile time - no Arc, no dyn, fully monomorphized.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::constraint::incremental::IncrementalUniConstraint;
14
15use crate::constraint::if_exists::ExistenceMode;
16
17use super::balance_stream::BalanceConstraintStream;
18use super::bi_stream::BiConstraintStream;
19use super::collector::UniCollector;
20use super::cross_bi_stream::CrossBiConstraintStream;
21use super::filter::{AndUniFilter, FnUniFilter, TrueFilter, UniFilter, UniLeftBiFilter};
22use super::grouped_stream::GroupedConstraintStream;
23use super::if_exists_stream::IfExistsStream;
24use super::joiner::EqualJoiner;
25
26// Zero-erasure constraint stream over a single entity type.
27//
28// `UniConstraintStream` accumulates filters and can be finalized into
29// an `IncrementalUniConstraint` via `penalize()` or `reward()`.
30//
31// All type parameters are concrete - no trait objects, no Arc allocations
32// in the hot path.
33//
34// # Type Parameters
35//
36// - `S` - Solution type
37// - `A` - Entity type
38// - `E` - Extractor function type
39// - `F` - Combined filter type
40// - `Sc` - Score type
41pub struct UniConstraintStream<S, A, E, F, Sc>
42where
43    Sc: Score,
44{
45    extractor: E,
46    filter: F,
47    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
48}
49
50impl<S, A, E, Sc> UniConstraintStream<S, A, E, TrueFilter, Sc>
51where
52    S: Send + Sync + 'static,
53    A: Clone + Send + Sync + 'static,
54    E: Fn(&S) -> &[A] + Send + Sync,
55    Sc: Score + 'static,
56{
57    // Creates a new uni-constraint stream with the given extractor.
58    pub fn new(extractor: E) -> Self {
59        Self {
60            extractor,
61            filter: TrueFilter,
62            _phantom: PhantomData,
63        }
64    }
65}
66
67impl<S, A, E, F, Sc> UniConstraintStream<S, A, E, F, Sc>
68where
69    S: Send + Sync + 'static,
70    A: Clone + Send + Sync + 'static,
71    E: Fn(&S) -> &[A] + Send + Sync,
72    F: UniFilter<S, A>,
73    Sc: Score + 'static,
74{
75    // Adds a filter predicate to the stream.
76    //
77    // Multiple filters are combined with AND semantics at compile time.
78    // Each filter adds a new type layer, preserving zero-erasure.
79    //
80    // To access related entities, use shadow variables on your entity type
81    // (e.g., `#[inverse_relation_shadow_variable]`) rather than solution traversal.
82    pub fn filter<P>(
83        self,
84        predicate: P,
85    ) -> UniConstraintStream<
86        S,
87        A,
88        E,
89        AndUniFilter<F, FnUniFilter<impl Fn(&S, &A) -> bool + Send + Sync>>,
90        Sc,
91    >
92    where
93        P: Fn(&A) -> bool + Send + Sync + 'static,
94    {
95        UniConstraintStream {
96            extractor: self.extractor,
97            filter: AndUniFilter::new(
98                self.filter,
99                FnUniFilter::new(move |_s: &S, a: &A| predicate(a)),
100            ),
101            _phantom: PhantomData,
102        }
103    }
104
105    // Joins this stream with itself to create pairs (zero-erasure).
106    //
107    // Requires an `EqualJoiner` to enable key-based indexing for O(k) lookups.
108    // For self-joins, pairs are ordered (i < j) to avoid duplicates.
109    //
110    // Any filters accumulated on this stream are applied to both entities
111    // individually before the join.
112    pub fn join_self<K, KA, KB>(
113        self,
114        joiner: EqualJoiner<KA, KB, K>,
115    ) -> BiConstraintStream<
116        S,
117        A,
118        K,
119        E,
120        impl Fn(&S, &A, usize) -> K + Send + Sync,
121        UniLeftBiFilter<F, A>,
122        Sc,
123    >
124    where
125        A: Hash + PartialEq,
126        K: Eq + Hash + Clone + Send + Sync,
127        KA: Fn(&A) -> K + Send + Sync,
128        KB: Fn(&A) -> K + Send + Sync,
129    {
130        let (key_extractor, _) = joiner.into_keys();
131
132        // Wrap key_extractor to match the new KE: Fn(&S, &A, usize) -> K signature.
133        // The static stream API doesn't need solution/index, so ignore them.
134        let wrapped_ke = move |_s: &S, a: &A, _idx: usize| key_extractor(a);
135
136        // Convert uni-filter to bi-filter that applies to left entity
137        let bi_filter = UniLeftBiFilter::new(self.filter);
138
139        BiConstraintStream::new_self_join_with_filter(self.extractor, wrapped_ke, bi_filter)
140    }
141
142    // Joins this stream with another collection to create cross-entity pairs (zero-erasure).
143    //
144    // Requires an `EqualJoiner` to enable key-based indexing for O(1) lookups.
145    // Unlike `join_self` which pairs entities within the same collection,
146    // `join` creates pairs from two different collections (e.g., Shift joined
147    // with Employee).
148    //
149    // Any filters accumulated on this stream are applied to the A entity
150    // before the join.
151    pub fn join<B, EB, K, KA, KB>(
152        self,
153        extractor_b: EB,
154        joiner: EqualJoiner<KA, KB, K>,
155    ) -> CrossBiConstraintStream<S, A, B, K, E, EB, KA, KB, UniLeftBiFilter<F, B>, Sc>
156    where
157        B: Clone + Send + Sync + 'static,
158        EB: Fn(&S) -> &[B] + Send + Sync,
159        K: Eq + Hash + Clone + Send + Sync,
160        KA: Fn(&A) -> K + Send + Sync,
161        KB: Fn(&B) -> K + Send + Sync,
162    {
163        let (key_a, key_b) = joiner.into_keys();
164
165        // Convert uni-filter to bi-filter that applies to left entity only
166        let bi_filter = UniLeftBiFilter::new(self.filter);
167
168        CrossBiConstraintStream::new_with_filter(
169            self.extractor,
170            extractor_b,
171            key_a,
172            key_b,
173            bi_filter,
174        )
175    }
176
177    // Groups entities by key and aggregates with a collector.
178    //
179    // Returns a zero-erasure `GroupedConstraintStream` that can be penalized
180    // or rewarded based on the aggregated result for each group.
181    pub fn group_by<K, KF, C>(
182        self,
183        key_fn: KF,
184        collector: C,
185    ) -> GroupedConstraintStream<S, A, K, E, F, KF, C, Sc>
186    where
187        K: Clone + Eq + Hash + Send + Sync + 'static,
188        KF: Fn(&A) -> K + Send + Sync,
189        C: UniCollector<A> + Send + Sync + 'static,
190        C::Accumulator: Send + Sync,
191        C::Result: Clone + Send + Sync,
192    {
193        GroupedConstraintStream::new(self.extractor, self.filter, key_fn, collector)
194    }
195
196    // Creates a balance constraint that penalizes uneven distribution across groups.
197    //
198    // Unlike `group_by` which scores each group independently, `balance` computes
199    // a GLOBAL standard deviation across all group counts and produces a single score.
200    //
201    // The `key_fn` returns `Option<K>` to allow skipping entities (e.g., unassigned shifts).
202    // Any filters accumulated on this stream are also applied.
203    //
204    // # Example
205    //
206    // ```
207    // use solverforge_scoring::stream::ConstraintFactory;
208    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
209    // use solverforge_core::score::SimpleScore;
210    //
211    // #[derive(Clone)]
212    // struct Shift { employee_id: Option<usize> }
213    //
214    // #[derive(Clone)]
215    // struct Solution { shifts: Vec<Shift> }
216    //
217    // let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
218    //     .for_each(|s: &Solution| &s.shifts)
219    //     .balance(|shift: &Shift| shift.employee_id)
220    //     .penalize(SimpleScore::of(1000))
221    //     .as_constraint("Balance workload");
222    //
223    // let solution = Solution {
224    //     shifts: vec![
225    //         Shift { employee_id: Some(0) },
226    //         Shift { employee_id: Some(0) },
227    //         Shift { employee_id: Some(0) },
228    //         Shift { employee_id: Some(1) },
229    //     ],
230    // };
231    //
232    // // Employee 0: 3 shifts, Employee 1: 1 shift
233    // // std_dev = 1.0, penalty = -1000
234    // assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1000));
235    // ```
236    pub fn balance<K, KF>(self, key_fn: KF) -> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
237    where
238        K: Clone + Eq + Hash + Send + Sync + 'static,
239        KF: Fn(&A) -> Option<K> + Send + Sync,
240    {
241        BalanceConstraintStream::new(self.extractor, self.filter, key_fn)
242    }
243
244    // Filters A entities based on whether a matching B entity exists.
245    //
246    // Use this when the B collection needs filtering (e.g., only vacationing employees).
247    // The `extractor_b` returns a `Vec<B>` to allow for filtering.
248    //
249    // Any filters accumulated on this stream are applied to A entities.
250    //
251    // # Example
252    //
253    // ```
254    // use solverforge_scoring::stream::ConstraintFactory;
255    // use solverforge_scoring::stream::joiner::equal_bi;
256    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
257    // use solverforge_core::score::SimpleScore;
258    //
259    // #[derive(Clone)]
260    // struct Shift { id: usize, employee_idx: Option<usize> }
261    //
262    // #[derive(Clone)]
263    // struct Employee { id: usize, on_vacation: bool }
264    //
265    // #[derive(Clone)]
266    // struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
267    //
268    // // Penalize shifts assigned to employees who are on vacation
269    // let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
270    //     .for_each(|s: &Schedule| s.shifts.as_slice())
271    //     .filter(|shift: &Shift| shift.employee_idx.is_some())
272    //     .if_exists_filtered(
273    //         |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect(),
274    //         equal_bi(
275    //             |shift: &Shift| shift.employee_idx,
276    //             |emp: &Employee| Some(emp.id),
277    //         ),
278    //     )
279    //     .penalize(SimpleScore::of(1))
280    //     .as_constraint("Vacation conflict");
281    //
282    // let schedule = Schedule {
283    //     shifts: vec![
284    //         Shift { id: 0, employee_idx: Some(0) },  // assigned to vacationing emp
285    //         Shift { id: 1, employee_idx: Some(1) },  // assigned to working emp
286    //         Shift { id: 2, employee_idx: None },     // unassigned (filtered out)
287    //     ],
288    //     employees: vec![
289    //         Employee { id: 0, on_vacation: true },
290    //         Employee { id: 1, on_vacation: false },
291    //     ],
292    // };
293    //
294    // // Only shift 0 matches (assigned to employee 0 who is on vacation)
295    // assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
296    // ```
297    pub fn if_exists_filtered<B, EB, K, KA, KB>(
298        self,
299        extractor_b: EB,
300        joiner: EqualJoiner<KA, KB, K>,
301    ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
302    where
303        B: Clone + Send + Sync + 'static,
304        EB: Fn(&S) -> Vec<B> + Send + Sync,
305        K: Eq + Hash + Clone + Send + Sync,
306        KA: Fn(&A) -> K + Send + Sync,
307        KB: Fn(&B) -> K + Send + Sync,
308    {
309        let (key_a, key_b) = joiner.into_keys();
310        IfExistsStream::new(
311            ExistenceMode::Exists,
312            self.extractor,
313            extractor_b,
314            key_a,
315            key_b,
316            self.filter,
317        )
318    }
319
320    // Filters A entities based on whether NO matching B entity exists.
321    //
322    // Use this when the B collection needs filtering.
323    // The `extractor_b` returns a `Vec<B>` to allow for filtering.
324    //
325    // Any filters accumulated on this stream are applied to A entities.
326    //
327    // # Example
328    //
329    // ```
330    // use solverforge_scoring::stream::ConstraintFactory;
331    // use solverforge_scoring::stream::joiner::equal_bi;
332    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
333    // use solverforge_core::score::SimpleScore;
334    //
335    // #[derive(Clone)]
336    // struct Task { id: usize, assignee: Option<usize> }
337    //
338    // #[derive(Clone)]
339    // struct Worker { id: usize, available: bool }
340    //
341    // #[derive(Clone)]
342    // struct Schedule { tasks: Vec<Task>, workers: Vec<Worker> }
343    //
344    // // Penalize tasks assigned to workers who are not available
345    // let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
346    //     .for_each(|s: &Schedule| s.tasks.as_slice())
347    //     .filter(|task: &Task| task.assignee.is_some())
348    //     .if_not_exists_filtered(
349    //         |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
350    //         equal_bi(
351    //             |task: &Task| task.assignee,
352    //             |worker: &Worker| Some(worker.id),
353    //         ),
354    //     )
355    //     .penalize(SimpleScore::of(1))
356    //     .as_constraint("Unavailable worker");
357    //
358    // let schedule = Schedule {
359    //     tasks: vec![
360    //         Task { id: 0, assignee: Some(0) },  // worker 0 is unavailable
361    //         Task { id: 1, assignee: Some(1) },  // worker 1 is available
362    //         Task { id: 2, assignee: None },     // unassigned (filtered out)
363    //     ],
364    //     workers: vec![
365    //         Worker { id: 0, available: false },
366    //         Worker { id: 1, available: true },
367    //     ],
368    // };
369    //
370    // // Task 0's worker (id=0) is NOT in the available workers list
371    // assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
372    // ```
373    pub fn if_not_exists_filtered<B, EB, K, KA, KB>(
374        self,
375        extractor_b: EB,
376        joiner: EqualJoiner<KA, KB, K>,
377    ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
378    where
379        B: Clone + Send + Sync + 'static,
380        EB: Fn(&S) -> Vec<B> + Send + Sync,
381        K: Eq + Hash + Clone + Send + Sync,
382        KA: Fn(&A) -> K + Send + Sync,
383        KB: Fn(&B) -> K + Send + Sync,
384    {
385        let (key_a, key_b) = joiner.into_keys();
386        IfExistsStream::new(
387            ExistenceMode::NotExists,
388            self.extractor,
389            extractor_b,
390            key_a,
391            key_b,
392            self.filter,
393        )
394    }
395
396    // Penalizes each matching entity with a fixed weight.
397    pub fn penalize(
398        self,
399        weight: Sc,
400    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
401    where
402        Sc: Copy,
403    {
404        // Detect if this is a hard constraint by checking if hard level is non-zero
405        let is_hard = weight
406            .to_level_numbers()
407            .first()
408            .map(|&h| h != 0)
409            .unwrap_or(false);
410        UniConstraintBuilder {
411            extractor: self.extractor,
412            filter: self.filter,
413            impact_type: ImpactType::Penalty,
414            weight: move |_: &A| weight,
415            is_hard,
416            expected_descriptor: None,
417            _phantom: PhantomData,
418        }
419    }
420
421    // Penalizes each matching entity with a dynamic weight.
422    //
423    // Note: For dynamic weights, use `penalize_hard_with` to explicitly mark as a hard constraint,
424    // since the weight function cannot be evaluated at build time.
425    pub fn penalize_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
426    where
427        W: Fn(&A) -> Sc + Send + Sync,
428    {
429        UniConstraintBuilder {
430            extractor: self.extractor,
431            filter: self.filter,
432            impact_type: ImpactType::Penalty,
433            weight: weight_fn,
434            is_hard: false, // Can't detect at build time; use penalize_hard_with for hard constraints
435            expected_descriptor: None,
436            _phantom: PhantomData,
437        }
438    }
439
440    // Penalizes each matching entity with a dynamic weight, explicitly marked as a hard constraint.
441    pub fn penalize_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
442    where
443        W: Fn(&A) -> Sc + Send + Sync,
444    {
445        UniConstraintBuilder {
446            extractor: self.extractor,
447            filter: self.filter,
448            impact_type: ImpactType::Penalty,
449            weight: weight_fn,
450            is_hard: true,
451            expected_descriptor: None,
452            _phantom: PhantomData,
453        }
454    }
455
456    // Rewards each matching entity with a fixed weight.
457    pub fn reward(
458        self,
459        weight: Sc,
460    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
461    where
462        Sc: Copy,
463    {
464        // Detect if this is a hard constraint by checking if hard level is non-zero
465        let is_hard = weight
466            .to_level_numbers()
467            .first()
468            .map(|&h| h != 0)
469            .unwrap_or(false);
470        UniConstraintBuilder {
471            extractor: self.extractor,
472            filter: self.filter,
473            impact_type: ImpactType::Reward,
474            weight: move |_: &A| weight,
475            is_hard,
476            expected_descriptor: None,
477            _phantom: PhantomData,
478        }
479    }
480
481    // Rewards each matching entity with a dynamic weight.
482    //
483    // Note: For dynamic weights, use `reward_hard_with` to explicitly mark as a hard constraint,
484    // since the weight function cannot be evaluated at build time.
485    pub fn reward_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
486    where
487        W: Fn(&A) -> Sc + Send + Sync,
488    {
489        UniConstraintBuilder {
490            extractor: self.extractor,
491            filter: self.filter,
492            impact_type: ImpactType::Reward,
493            weight: weight_fn,
494            is_hard: false, // Can't detect at build time; use reward_hard_with for hard constraints
495            expected_descriptor: None,
496            _phantom: PhantomData,
497        }
498    }
499
500    // Rewards each matching entity with a dynamic weight, explicitly marked as a hard constraint.
501    pub fn reward_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
502    where
503        W: Fn(&A) -> Sc + Send + Sync,
504    {
505        UniConstraintBuilder {
506            extractor: self.extractor,
507            filter: self.filter,
508            impact_type: ImpactType::Reward,
509            weight: weight_fn,
510            is_hard: true,
511            expected_descriptor: None,
512            _phantom: PhantomData,
513        }
514    }
515}
516
517impl<S, A, E, F, Sc: Score> std::fmt::Debug for UniConstraintStream<S, A, E, F, Sc> {
518    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519        f.debug_struct("UniConstraintStream").finish()
520    }
521}
522
523// Zero-erasure builder for finalizing a uni-constraint.
524pub struct UniConstraintBuilder<S, A, E, F, W, Sc>
525where
526    Sc: Score,
527{
528    extractor: E,
529    filter: F,
530    impact_type: ImpactType,
531    weight: W,
532    is_hard: bool,
533    expected_descriptor: Option<usize>,
534    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
535}
536
537impl<S, A, E, F, W, Sc> UniConstraintBuilder<S, A, E, F, W, Sc>
538where
539    S: Send + Sync + 'static,
540    A: Clone + Send + Sync + 'static,
541    E: Fn(&S) -> &[A] + Send + Sync,
542    F: UniFilter<S, A>,
543    W: Fn(&A) -> Sc + Send + Sync,
544    Sc: Score + 'static,
545{
546    // Restricts this constraint to only fire for the given descriptor index.
547    //
548    // Required when multiple entity classes exist (e.g., FurnaceAssignment at 0,
549    // ShiftAssignment at 1). Without this, on_insert/on_retract fire for all entity
550    // classes using the constraint's entity_index, which indexes into the wrong slice.
551    pub fn for_descriptor(mut self, descriptor_index: usize) -> Self {
552        self.expected_descriptor = Some(descriptor_index);
553        self
554    }
555
556    // Finalizes the builder into a zero-erasure `IncrementalUniConstraint`.
557    pub fn as_constraint(
558        self,
559        name: &str,
560    ) -> IncrementalUniConstraint<S, A, E, impl Fn(&S, &A) -> bool + Send + Sync, W, Sc> {
561        let filter = self.filter;
562        let combined_filter = move |s: &S, a: &A| filter.test(s, a);
563
564        let mut constraint = IncrementalUniConstraint::new(
565            ConstraintRef::new("", name),
566            self.impact_type,
567            self.extractor,
568            combined_filter,
569            self.weight,
570            self.is_hard,
571        );
572        if let Some(d) = self.expected_descriptor {
573            constraint = constraint.with_descriptor(d);
574        }
575        constraint
576    }
577}
578
579impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for UniConstraintBuilder<S, A, E, F, W, Sc> {
580    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581        f.debug_struct("UniConstraintBuilder")
582            .field("impact_type", &self.impact_type)
583            .finish()
584    }
585}