Skip to main content

solverforge_scoring/stream/
uni_stream.rs

1/* Zero-erasure uni-constraint stream for single-entity constraint patterns.
2
3A `UniConstraintStream` operates on a single entity type and supports
4filtering, weighting, and constraint finalization. All type information
5is preserved at compile time - no Arc, no dyn, fully monomorphized.
6*/
7
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::constraint::incremental::IncrementalUniConstraint;
15
16use crate::constraint::if_exists::ExistenceMode;
17
18use super::balance_stream::BalanceConstraintStream;
19use super::collection_extract::CollectionExtract;
20use super::collector::UniCollector;
21use super::filter::{AndUniFilter, FnUniFilter, TrueFilter, UniFilter};
22use super::grouped_stream::GroupedConstraintStream;
23use super::if_exists_stream::IfExistsStream;
24use super::joiner::EqualJoiner;
25
26/* Zero-erasure constraint stream over a single entity type.
27
28`UniConstraintStream` accumulates filters and can be finalized into
29an `IncrementalUniConstraint` via `penalize()` or `reward()`.
30
31All type parameters are concrete - no trait objects, no Arc allocations
32in the hot path.
33
34# Type Parameters
35
36- `S` - Solution type
37- `A` - Entity type
38- `E` - Extractor function type
39- `F` - Combined filter type
40- `Sc` - Score type
41*/
42pub struct UniConstraintStream<S, A, E, F, Sc>
43where
44    Sc: Score,
45{
46    extractor: E,
47    filter: F,
48    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
49}
50
51impl<S, A, E, Sc> UniConstraintStream<S, A, E, TrueFilter, Sc>
52where
53    S: Send + Sync + 'static,
54    A: Clone + Send + Sync + 'static,
55    E: CollectionExtract<S, Item = A>,
56    Sc: Score + 'static,
57{
58    // Creates a new uni-constraint stream with the given extractor.
59    pub fn new(extractor: E) -> Self {
60        Self {
61            extractor,
62            filter: TrueFilter,
63            _phantom: PhantomData,
64        }
65    }
66}
67
68impl<S, A, E, F, Sc> UniConstraintStream<S, A, E, F, Sc>
69where
70    S: Send + Sync + 'static,
71    A: Clone + Send + Sync + 'static,
72    E: CollectionExtract<S, Item = A>,
73    F: UniFilter<S, A>,
74    Sc: Score + 'static,
75{
76    /* Adds a filter predicate to the stream.
77
78    Multiple filters are combined with AND semantics at compile time.
79    Each filter adds a new type layer, preserving zero-erasure.
80
81    To access related entities, use shadow variables on your entity type
82    (e.g., `#[inverse_relation_shadow_variable]`) rather than solution traversal.
83    */
84    pub fn filter<P>(
85        self,
86        predicate: P,
87    ) -> UniConstraintStream<
88        S,
89        A,
90        E,
91        AndUniFilter<F, FnUniFilter<impl Fn(&S, &A) -> bool + Send + Sync>>,
92        Sc,
93    >
94    where
95        P: Fn(&A) -> bool + Send + Sync + 'static,
96    {
97        UniConstraintStream {
98            extractor: self.extractor,
99            filter: AndUniFilter::new(
100                self.filter,
101                FnUniFilter::new(move |_s: &S, a: &A| predicate(a)),
102            ),
103            _phantom: PhantomData,
104        }
105    }
106
107    /* Joins this stream using the provided join target.
108
109    Dispatch depends on argument type:
110    - `equal(|a: &A| key_fn(a))` — self-join, returns `BiConstraintStream`
111    - `(extractor_b, equal_bi(ka, kb))` — keyed cross-join, returns `CrossBiConstraintStream`
112    - `(other_stream, |a, b| predicate)` — predicate cross-join O(n*m)
113    */
114    pub fn join<J>(self, target: J) -> J::Output
115    where
116        J: super::join_target::JoinTarget<S, A, E, F, Sc>,
117    {
118        target.apply(self.extractor, self.filter)
119    }
120
121    /* Groups entities by key and aggregates with a collector.
122
123    Returns a zero-erasure `GroupedConstraintStream` that can be penalized
124    or rewarded based on the aggregated result for each group.
125    */
126    pub fn group_by<K, KF, C>(
127        self,
128        key_fn: KF,
129        collector: C,
130    ) -> GroupedConstraintStream<S, A, K, E, F, KF, C, Sc>
131    where
132        K: Clone + Eq + Hash + Send + Sync + 'static,
133        KF: Fn(&A) -> K + Send + Sync,
134        C: UniCollector<A> + Send + Sync + 'static,
135        C::Accumulator: Send + Sync,
136        C::Result: Clone + Send + Sync,
137    {
138        GroupedConstraintStream::new(self.extractor, self.filter, key_fn, collector)
139    }
140
141    /* Creates a balance constraint that penalizes uneven distribution across groups.
142
143    Unlike `group_by` which scores each group independently, `balance` computes
144    a GLOBAL standard deviation across all group counts and produces a single score.
145
146    The `key_fn` returns `Option<K>` to allow skipping entities (e.g., unassigned shifts).
147    Any filters accumulated on this stream are also applied.
148
149    # Example
150
151    ```
152    use solverforge_scoring::stream::ConstraintFactory;
153    use solverforge_scoring::api::constraint_set::IncrementalConstraint;
154    use solverforge_core::score::SoftScore;
155
156    #[derive(Clone)]
157    struct Shift { employee_id: Option<usize> }
158
159    #[derive(Clone)]
160    struct Solution { shifts: Vec<Shift> }
161
162    let constraint = ConstraintFactory::<Solution, SoftScore>::new()
163    .for_each(|s: &Solution| &s.shifts)
164    .balance(|shift: &Shift| shift.employee_id)
165    .penalize(SoftScore::of(1000))
166    .named("Balance workload");
167
168    let solution = Solution {
169    shifts: vec![
170    Shift { employee_id: Some(0) },
171    Shift { employee_id: Some(0) },
172    Shift { employee_id: Some(0) },
173    Shift { employee_id: Some(1) },
174    ],
175    };
176
177    // Employee 0: 3 shifts, Employee 1: 1 shift
178    // std_dev = 1.0, penalty = -1000
179    assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1000));
180    ```
181    */
182    pub fn balance<K, KF>(self, key_fn: KF) -> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
183    where
184        K: Clone + Eq + Hash + Send + Sync + 'static,
185        KF: Fn(&A) -> Option<K> + Send + Sync,
186    {
187        BalanceConstraintStream::new(self.extractor, self.filter, key_fn)
188    }
189
190    /* Filters A entities based on whether a matching B entity exists.
191
192    Use this when the B collection needs filtering (e.g., only vacationing employees).
193    The `extractor_b` returns a `Vec<B>` to allow for filtering.
194
195    Any filters accumulated on this stream are applied to A entities.
196
197    # Example
198
199    ```
200    use solverforge_scoring::stream::ConstraintFactory;
201    use solverforge_scoring::stream::joiner::equal_bi;
202    use solverforge_scoring::api::constraint_set::IncrementalConstraint;
203    use solverforge_core::score::SoftScore;
204
205    #[derive(Clone)]
206    struct Shift { id: usize, employee_idx: Option<usize> }
207
208    #[derive(Clone)]
209    struct Employee { id: usize, on_vacation: bool }
210
211    #[derive(Clone)]
212    struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
213
214    // Penalize shifts assigned to employees who are on vacation
215    let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
216    .for_each(|s: &Schedule| s.shifts.as_slice())
217    .filter(|shift: &Shift| shift.employee_idx.is_some())
218    .if_exists_filtered(
219    |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect(),
220    equal_bi(
221    |shift: &Shift| shift.employee_idx,
222    |emp: &Employee| Some(emp.id),
223    ),
224    )
225    .penalize(SoftScore::of(1))
226    .named("Vacation conflict");
227
228    let schedule = Schedule {
229    shifts: vec![
230    Shift { id: 0, employee_idx: Some(0) },  // assigned to vacationing emp
231    Shift { id: 1, employee_idx: Some(1) },  // assigned to working emp
232    Shift { id: 2, employee_idx: None },     // unassigned (filtered out)
233    ],
234    employees: vec![
235    Employee { id: 0, on_vacation: true },
236    Employee { id: 1, on_vacation: false },
237    ],
238    };
239
240    // Only shift 0 matches (assigned to employee 0 who is on vacation)
241    assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-1));
242    ```
243    */
244    pub fn if_exists_filtered<B, EB, K, KA, KB>(
245        self,
246        extractor_b: EB,
247        joiner: EqualJoiner<KA, KB, K>,
248    ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
249    where
250        B: Clone + Send + Sync + 'static,
251        EB: Fn(&S) -> Vec<B> + Send + Sync,
252        K: Eq + Hash + Clone + Send + Sync,
253        KA: Fn(&A) -> K + Send + Sync,
254        KB: Fn(&B) -> K + Send + Sync,
255    {
256        let (key_a, key_b) = joiner.into_keys();
257        IfExistsStream::new(
258            ExistenceMode::Exists,
259            self.extractor,
260            extractor_b,
261            key_a,
262            key_b,
263            self.filter,
264        )
265    }
266
267    /* Filters A entities based on whether NO matching B entity exists.
268
269    Use this when the B collection needs filtering.
270    The `extractor_b` returns a `Vec<B>` to allow for filtering.
271
272    Any filters accumulated on this stream are applied to A entities.
273
274    # Example
275
276    ```
277    use solverforge_scoring::stream::ConstraintFactory;
278    use solverforge_scoring::stream::joiner::equal_bi;
279    use solverforge_scoring::api::constraint_set::IncrementalConstraint;
280    use solverforge_core::score::SoftScore;
281
282    #[derive(Clone)]
283    struct Task { id: usize, assignee: Option<usize> }
284
285    #[derive(Clone)]
286    struct Worker { id: usize, available: bool }
287
288    #[derive(Clone)]
289    struct Schedule { tasks: Vec<Task>, workers: Vec<Worker> }
290
291    // Penalize tasks assigned to workers who are not available
292    let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
293    .for_each(|s: &Schedule| s.tasks.as_slice())
294    .filter(|task: &Task| task.assignee.is_some())
295    .if_not_exists_filtered(
296    |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
297    equal_bi(
298    |task: &Task| task.assignee,
299    |worker: &Worker| Some(worker.id),
300    ),
301    )
302    .penalize(SoftScore::of(1))
303    .named("Unavailable worker");
304
305    let schedule = Schedule {
306    tasks: vec![
307    Task { id: 0, assignee: Some(0) },  // worker 0 is unavailable
308    Task { id: 1, assignee: Some(1) },  // worker 1 is available
309    Task { id: 2, assignee: None },     // unassigned (filtered out)
310    ],
311    workers: vec![
312    Worker { id: 0, available: false },
313    Worker { id: 1, available: true },
314    ],
315    };
316
317    // Task 0's worker (id=0) is NOT in the available workers list
318    assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-1));
319    ```
320    */
321    pub fn if_not_exists_filtered<B, EB, K, KA, KB>(
322        self,
323        extractor_b: EB,
324        joiner: EqualJoiner<KA, KB, K>,
325    ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
326    where
327        B: Clone + Send + Sync + 'static,
328        EB: Fn(&S) -> Vec<B> + Send + Sync,
329        K: Eq + Hash + Clone + Send + Sync,
330        KA: Fn(&A) -> K + Send + Sync,
331        KB: Fn(&B) -> K + Send + Sync,
332    {
333        let (key_a, key_b) = joiner.into_keys();
334        IfExistsStream::new(
335            ExistenceMode::NotExists,
336            self.extractor,
337            extractor_b,
338            key_a,
339            key_b,
340            self.filter,
341        )
342    }
343
344    // Penalizes each matching entity with a fixed weight.
345    pub fn penalize(
346        self,
347        weight: Sc,
348    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
349    where
350        Sc: Copy,
351    {
352        // Detect if this is a hard constraint by checking if hard level is non-zero
353        let is_hard = weight
354            .to_level_numbers()
355            .first()
356            .map(|&h| h != 0)
357            .unwrap_or(false);
358        UniConstraintBuilder {
359            extractor: self.extractor,
360            filter: self.filter,
361            impact_type: ImpactType::Penalty,
362            weight: move |_: &A| weight,
363            is_hard,
364            expected_descriptor: None,
365            _phantom: PhantomData,
366        }
367    }
368
369    /* Penalizes each matching entity with a dynamic weight.
370
371    Note: For dynamic weights, use `penalize_hard_with` to explicitly mark as a hard constraint,
372    since the weight function cannot be evaluated at build time.
373    */
374    pub fn penalize_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
375    where
376        W: Fn(&A) -> Sc + Send + Sync,
377    {
378        UniConstraintBuilder {
379            extractor: self.extractor,
380            filter: self.filter,
381            impact_type: ImpactType::Penalty,
382            weight: weight_fn,
383            is_hard: false, // Can't detect at build time; use penalize_hard_with for hard constraints
384            expected_descriptor: None,
385            _phantom: PhantomData,
386        }
387    }
388
389    // Penalizes each matching entity with a dynamic weight, explicitly marked as a hard constraint.
390    pub fn penalize_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
391    where
392        W: Fn(&A) -> Sc + Send + Sync,
393    {
394        UniConstraintBuilder {
395            extractor: self.extractor,
396            filter: self.filter,
397            impact_type: ImpactType::Penalty,
398            weight: weight_fn,
399            is_hard: true,
400            expected_descriptor: None,
401            _phantom: PhantomData,
402        }
403    }
404
405    // Rewards each matching entity with a fixed weight.
406    pub fn reward(
407        self,
408        weight: Sc,
409    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
410    where
411        Sc: Copy,
412    {
413        // Detect if this is a hard constraint by checking if hard level is non-zero
414        let is_hard = weight
415            .to_level_numbers()
416            .first()
417            .map(|&h| h != 0)
418            .unwrap_or(false);
419        UniConstraintBuilder {
420            extractor: self.extractor,
421            filter: self.filter,
422            impact_type: ImpactType::Reward,
423            weight: move |_: &A| weight,
424            is_hard,
425            expected_descriptor: None,
426            _phantom: PhantomData,
427        }
428    }
429
430    /* Rewards each matching entity with a dynamic weight.
431
432    Note: For dynamic weights, use `reward_hard_with` to explicitly mark as a hard constraint,
433    since the weight function cannot be evaluated at build time.
434    */
435    pub fn reward_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
436    where
437        W: Fn(&A) -> Sc + Send + Sync,
438    {
439        UniConstraintBuilder {
440            extractor: self.extractor,
441            filter: self.filter,
442            impact_type: ImpactType::Reward,
443            weight: weight_fn,
444            is_hard: false, // Can't detect at build time; use reward_hard_with for hard constraints
445            expected_descriptor: None,
446            _phantom: PhantomData,
447        }
448    }
449
450    // Rewards each matching entity with a dynamic weight, explicitly marked as a hard constraint.
451    pub fn reward_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
452    where
453        W: Fn(&A) -> Sc + Send + Sync,
454    {
455        UniConstraintBuilder {
456            extractor: self.extractor,
457            filter: self.filter,
458            impact_type: ImpactType::Reward,
459            weight: weight_fn,
460            is_hard: true,
461            expected_descriptor: None,
462            _phantom: PhantomData,
463        }
464    }
465
466    // Penalizes each matching entity with one hard score unit.
467    pub fn penalize_hard(
468        self,
469    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
470    where
471        Sc: Copy,
472    {
473        self.penalize(Sc::one_hard())
474    }
475
476    // Penalizes each matching entity with one soft score unit.
477    pub fn penalize_soft(
478        self,
479    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
480    where
481        Sc: Copy,
482    {
483        self.penalize(Sc::one_soft())
484    }
485
486    // Rewards each matching entity with one hard score unit.
487    pub fn reward_hard(
488        self,
489    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
490    where
491        Sc: Copy,
492    {
493        self.reward(Sc::one_hard())
494    }
495
496    // Rewards each matching entity with one soft score unit.
497    pub fn reward_soft(
498        self,
499    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
500    where
501        Sc: Copy,
502    {
503        self.reward(Sc::one_soft())
504    }
505}
506
507impl<S, A, E, F, Sc: Score> UniConstraintStream<S, A, E, F, Sc> {
508    #[doc(hidden)]
509    pub fn extractor(&self) -> &E {
510        &self.extractor
511    }
512
513    #[doc(hidden)]
514    pub fn into_parts(self) -> (E, F) {
515        (self.extractor, self.filter)
516    }
517
518    #[doc(hidden)]
519    pub fn from_parts(extractor: E, filter: F) -> Self {
520        Self {
521            extractor,
522            filter,
523            _phantom: PhantomData,
524        }
525    }
526}
527
528impl<S, A, E, F, Sc: Score> std::fmt::Debug for UniConstraintStream<S, A, E, F, Sc> {
529    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
530        f.debug_struct("UniConstraintStream").finish()
531    }
532}
533
534// Zero-erasure builder for finalizing a uni-constraint.
535pub struct UniConstraintBuilder<S, A, E, F, W, Sc>
536where
537    Sc: Score,
538{
539    extractor: E,
540    filter: F,
541    impact_type: ImpactType,
542    weight: W,
543    is_hard: bool,
544    expected_descriptor: Option<usize>,
545    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
546}
547
548impl<S, A, E, F, W, Sc> UniConstraintBuilder<S, A, E, F, W, Sc>
549where
550    S: Send + Sync + 'static,
551    A: Clone + Send + Sync + 'static,
552    E: CollectionExtract<S, Item = A>,
553    F: UniFilter<S, A>,
554    W: Fn(&A) -> Sc + Send + Sync,
555    Sc: Score + 'static,
556{
557    /* Restricts this constraint to only fire for the given descriptor index.
558
559    Required when multiple entity classes exist (e.g., FurnaceAssignment at 0,
560    ShiftAssignment at 1). Without this, on_insert/on_retract fire for all entity
561    classes using the constraint's entity_index, which indexes into the wrong slice.
562    */
563    pub fn for_descriptor(mut self, descriptor_index: usize) -> Self {
564        self.expected_descriptor = Some(descriptor_index);
565        self
566    }
567
568    // Finalizes the builder into a zero-erasure `IncrementalUniConstraint`.
569    pub fn named(
570        self,
571        name: &str,
572    ) -> IncrementalUniConstraint<S, A, E, impl Fn(&S, &A) -> bool + Send + Sync, W, Sc> {
573        let filter = self.filter;
574        let combined_filter = move |s: &S, a: &A| filter.test(s, a);
575
576        let mut constraint = IncrementalUniConstraint::new(
577            ConstraintRef::new("", name),
578            self.impact_type,
579            self.extractor,
580            combined_filter,
581            self.weight,
582            self.is_hard,
583        );
584        if let Some(d) = self.expected_descriptor {
585            constraint = constraint.with_descriptor(d);
586        }
587        constraint
588    }
589}
590
591impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for UniConstraintBuilder<S, A, E, F, W, Sc> {
592    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593        f.debug_struct("UniConstraintBuilder")
594            .field("impact_type", &self.impact_type)
595            .finish()
596    }
597}