Skip to main content

solverforge_scoring/stream/
uni_stream.rs

1// Zero-erasure uni-constraint stream for single-entity constraint patterns.
2//
3// A `UniConstraintStream` operates on a single entity type and supports
4// filtering, weighting, and constraint finalization. All type information
5// is preserved at compile time - no Arc, no dyn, fully monomorphized.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::constraint::incremental::IncrementalUniConstraint;
14
15use crate::constraint::if_exists::ExistenceMode;
16
17use super::balance_stream::BalanceConstraintStream;
18use super::collector::UniCollector;
19use super::filter::{AndUniFilter, FnUniFilter, TrueFilter, UniFilter};
20use super::grouped_stream::GroupedConstraintStream;
21use super::if_exists_stream::IfExistsStream;
22use super::joiner::EqualJoiner;
23
24// Zero-erasure constraint stream over a single entity type.
25//
26// `UniConstraintStream` accumulates filters and can be finalized into
27// an `IncrementalUniConstraint` via `penalize()` or `reward()`.
28//
29// All type parameters are concrete - no trait objects, no Arc allocations
30// in the hot path.
31//
32// # Type Parameters
33//
34// - `S` - Solution type
35// - `A` - Entity type
36// - `E` - Extractor function type
37// - `F` - Combined filter type
38// - `Sc` - Score type
39pub struct UniConstraintStream<S, A, E, F, Sc>
40where
41    Sc: Score,
42{
43    extractor: E,
44    filter: F,
45    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
46}
47
48impl<S, A, E, Sc> UniConstraintStream<S, A, E, TrueFilter, Sc>
49where
50    S: Send + Sync + 'static,
51    A: Clone + Send + Sync + 'static,
52    E: Fn(&S) -> &[A] + Send + Sync,
53    Sc: Score + 'static,
54{
55    // Creates a new uni-constraint stream with the given extractor.
56    pub fn new(extractor: E) -> Self {
57        Self {
58            extractor,
59            filter: TrueFilter,
60            _phantom: PhantomData,
61        }
62    }
63}
64
65impl<S, A, E, F, Sc> UniConstraintStream<S, A, E, F, Sc>
66where
67    S: Send + Sync + 'static,
68    A: Clone + Send + Sync + 'static,
69    E: Fn(&S) -> &[A] + Send + Sync,
70    F: UniFilter<S, A>,
71    Sc: Score + 'static,
72{
73    // Adds a filter predicate to the stream.
74    //
75    // Multiple filters are combined with AND semantics at compile time.
76    // Each filter adds a new type layer, preserving zero-erasure.
77    //
78    // To access related entities, use shadow variables on your entity type
79    // (e.g., `#[inverse_relation_shadow_variable]`) rather than solution traversal.
80    pub fn filter<P>(
81        self,
82        predicate: P,
83    ) -> UniConstraintStream<
84        S,
85        A,
86        E,
87        AndUniFilter<F, FnUniFilter<impl Fn(&S, &A) -> bool + Send + Sync>>,
88        Sc,
89    >
90    where
91        P: Fn(&A) -> bool + Send + Sync + 'static,
92    {
93        UniConstraintStream {
94            extractor: self.extractor,
95            filter: AndUniFilter::new(
96                self.filter,
97                FnUniFilter::new(move |_s: &S, a: &A| predicate(a)),
98            ),
99            _phantom: PhantomData,
100        }
101    }
102
103    // Joins this stream using the provided join target.
104    //
105    // Dispatch depends on argument type:
106    // - `equal(|a: &A| key_fn(a))` — self-join, returns `BiConstraintStream`
107    // - `(extractor_b, equal_bi(ka, kb))` — keyed cross-join, returns `CrossBiConstraintStream`
108    // - `(other_stream, |a, b| predicate)` — predicate cross-join O(n*m)
109    pub fn join<J>(self, target: J) -> J::Output
110    where
111        J: super::join_target::JoinTarget<S, A, E, F, Sc>,
112    {
113        target.apply(self.extractor, self.filter)
114    }
115
116    // Groups entities by key and aggregates with a collector.
117    //
118    // Returns a zero-erasure `GroupedConstraintStream` that can be penalized
119    // or rewarded based on the aggregated result for each group.
120    pub fn group_by<K, KF, C>(
121        self,
122        key_fn: KF,
123        collector: C,
124    ) -> GroupedConstraintStream<S, A, K, E, F, KF, C, Sc>
125    where
126        K: Clone + Eq + Hash + Send + Sync + 'static,
127        KF: Fn(&A) -> K + Send + Sync,
128        C: UniCollector<A> + Send + Sync + 'static,
129        C::Accumulator: Send + Sync,
130        C::Result: Clone + Send + Sync,
131    {
132        GroupedConstraintStream::new(self.extractor, self.filter, key_fn, collector)
133    }
134
135    // Creates a balance constraint that penalizes uneven distribution across groups.
136    //
137    // Unlike `group_by` which scores each group independently, `balance` computes
138    // a GLOBAL standard deviation across all group counts and produces a single score.
139    //
140    // The `key_fn` returns `Option<K>` to allow skipping entities (e.g., unassigned shifts).
141    // Any filters accumulated on this stream are also applied.
142    //
143    // # Example
144    //
145    // ```
146    // use solverforge_scoring::stream::ConstraintFactory;
147    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
148    // use solverforge_core::score::SoftScore;
149    //
150    // #[derive(Clone)]
151    // struct Shift { employee_id: Option<usize> }
152    //
153    // #[derive(Clone)]
154    // struct Solution { shifts: Vec<Shift> }
155    //
156    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
157    //     .for_each(|s: &Solution| &s.shifts)
158    //     .balance(|shift: &Shift| shift.employee_id)
159    //     .penalize(SoftScore::of(1000))
160    //     .named("Balance workload");
161    //
162    // let solution = Solution {
163    //     shifts: vec![
164    //         Shift { employee_id: Some(0) },
165    //         Shift { employee_id: Some(0) },
166    //         Shift { employee_id: Some(0) },
167    //         Shift { employee_id: Some(1) },
168    //     ],
169    // };
170    //
171    // // Employee 0: 3 shifts, Employee 1: 1 shift
172    // // std_dev = 1.0, penalty = -1000
173    // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1000));
174    // ```
175    pub fn balance<K, KF>(self, key_fn: KF) -> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
176    where
177        K: Clone + Eq + Hash + Send + Sync + 'static,
178        KF: Fn(&A) -> Option<K> + Send + Sync,
179    {
180        BalanceConstraintStream::new(self.extractor, self.filter, key_fn)
181    }
182
183    // Filters A entities based on whether a matching B entity exists.
184    //
185    // Use this when the B collection needs filtering (e.g., only vacationing employees).
186    // The `extractor_b` returns a `Vec<B>` to allow for filtering.
187    //
188    // Any filters accumulated on this stream are applied to A entities.
189    //
190    // # Example
191    //
192    // ```
193    // use solverforge_scoring::stream::ConstraintFactory;
194    // use solverforge_scoring::stream::joiner::equal_bi;
195    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
196    // use solverforge_core::score::SoftScore;
197    //
198    // #[derive(Clone)]
199    // struct Shift { id: usize, employee_idx: Option<usize> }
200    //
201    // #[derive(Clone)]
202    // struct Employee { id: usize, on_vacation: bool }
203    //
204    // #[derive(Clone)]
205    // struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
206    //
207    // // Penalize shifts assigned to employees who are on vacation
208    // let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
209    //     .for_each(|s: &Schedule| s.shifts.as_slice())
210    //     .filter(|shift: &Shift| shift.employee_idx.is_some())
211    //     .if_exists_filtered(
212    //         |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect(),
213    //         equal_bi(
214    //             |shift: &Shift| shift.employee_idx,
215    //             |emp: &Employee| Some(emp.id),
216    //         ),
217    //     )
218    //     .penalize(SoftScore::of(1))
219    //     .named("Vacation conflict");
220    //
221    // let schedule = Schedule {
222    //     shifts: vec![
223    //         Shift { id: 0, employee_idx: Some(0) },  // assigned to vacationing emp
224    //         Shift { id: 1, employee_idx: Some(1) },  // assigned to working emp
225    //         Shift { id: 2, employee_idx: None },     // unassigned (filtered out)
226    //     ],
227    //     employees: vec![
228    //         Employee { id: 0, on_vacation: true },
229    //         Employee { id: 1, on_vacation: false },
230    //     ],
231    // };
232    //
233    // // Only shift 0 matches (assigned to employee 0 who is on vacation)
234    // assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-1));
235    // ```
236    pub fn if_exists_filtered<B, EB, K, KA, KB>(
237        self,
238        extractor_b: EB,
239        joiner: EqualJoiner<KA, KB, K>,
240    ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
241    where
242        B: Clone + Send + Sync + 'static,
243        EB: Fn(&S) -> Vec<B> + Send + Sync,
244        K: Eq + Hash + Clone + Send + Sync,
245        KA: Fn(&A) -> K + Send + Sync,
246        KB: Fn(&B) -> K + Send + Sync,
247    {
248        let (key_a, key_b) = joiner.into_keys();
249        IfExistsStream::new(
250            ExistenceMode::Exists,
251            self.extractor,
252            extractor_b,
253            key_a,
254            key_b,
255            self.filter,
256        )
257    }
258
259    // Filters A entities based on whether NO matching B entity exists.
260    //
261    // Use this when the B collection needs filtering.
262    // The `extractor_b` returns a `Vec<B>` to allow for filtering.
263    //
264    // Any filters accumulated on this stream are applied to A entities.
265    //
266    // # Example
267    //
268    // ```
269    // use solverforge_scoring::stream::ConstraintFactory;
270    // use solverforge_scoring::stream::joiner::equal_bi;
271    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
272    // use solverforge_core::score::SoftScore;
273    //
274    // #[derive(Clone)]
275    // struct Task { id: usize, assignee: Option<usize> }
276    //
277    // #[derive(Clone)]
278    // struct Worker { id: usize, available: bool }
279    //
280    // #[derive(Clone)]
281    // struct Schedule { tasks: Vec<Task>, workers: Vec<Worker> }
282    //
283    // // Penalize tasks assigned to workers who are not available
284    // let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
285    //     .for_each(|s: &Schedule| s.tasks.as_slice())
286    //     .filter(|task: &Task| task.assignee.is_some())
287    //     .if_not_exists_filtered(
288    //         |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
289    //         equal_bi(
290    //             |task: &Task| task.assignee,
291    //             |worker: &Worker| Some(worker.id),
292    //         ),
293    //     )
294    //     .penalize(SoftScore::of(1))
295    //     .named("Unavailable worker");
296    //
297    // let schedule = Schedule {
298    //     tasks: vec![
299    //         Task { id: 0, assignee: Some(0) },  // worker 0 is unavailable
300    //         Task { id: 1, assignee: Some(1) },  // worker 1 is available
301    //         Task { id: 2, assignee: None },     // unassigned (filtered out)
302    //     ],
303    //     workers: vec![
304    //         Worker { id: 0, available: false },
305    //         Worker { id: 1, available: true },
306    //     ],
307    // };
308    //
309    // // Task 0's worker (id=0) is NOT in the available workers list
310    // assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-1));
311    // ```
312    pub fn if_not_exists_filtered<B, EB, K, KA, KB>(
313        self,
314        extractor_b: EB,
315        joiner: EqualJoiner<KA, KB, K>,
316    ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
317    where
318        B: Clone + Send + Sync + 'static,
319        EB: Fn(&S) -> Vec<B> + Send + Sync,
320        K: Eq + Hash + Clone + Send + Sync,
321        KA: Fn(&A) -> K + Send + Sync,
322        KB: Fn(&B) -> K + Send + Sync,
323    {
324        let (key_a, key_b) = joiner.into_keys();
325        IfExistsStream::new(
326            ExistenceMode::NotExists,
327            self.extractor,
328            extractor_b,
329            key_a,
330            key_b,
331            self.filter,
332        )
333    }
334
335    // Penalizes each matching entity with a fixed weight.
336    pub fn penalize(
337        self,
338        weight: Sc,
339    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
340    where
341        Sc: Copy,
342    {
343        // Detect if this is a hard constraint by checking if hard level is non-zero
344        let is_hard = weight
345            .to_level_numbers()
346            .first()
347            .map(|&h| h != 0)
348            .unwrap_or(false);
349        UniConstraintBuilder {
350            extractor: self.extractor,
351            filter: self.filter,
352            impact_type: ImpactType::Penalty,
353            weight: move |_: &A| weight,
354            is_hard,
355            expected_descriptor: None,
356            _phantom: PhantomData,
357        }
358    }
359
360    // Penalizes each matching entity with a dynamic weight.
361    //
362    // Note: For dynamic weights, use `penalize_hard_with` to explicitly mark as a hard constraint,
363    // since the weight function cannot be evaluated at build time.
364    pub fn penalize_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
365    where
366        W: Fn(&A) -> Sc + Send + Sync,
367    {
368        UniConstraintBuilder {
369            extractor: self.extractor,
370            filter: self.filter,
371            impact_type: ImpactType::Penalty,
372            weight: weight_fn,
373            is_hard: false, // Can't detect at build time; use penalize_hard_with for hard constraints
374            expected_descriptor: None,
375            _phantom: PhantomData,
376        }
377    }
378
379    // Penalizes each matching entity with a dynamic weight, explicitly marked as a hard constraint.
380    pub fn penalize_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
381    where
382        W: Fn(&A) -> Sc + Send + Sync,
383    {
384        UniConstraintBuilder {
385            extractor: self.extractor,
386            filter: self.filter,
387            impact_type: ImpactType::Penalty,
388            weight: weight_fn,
389            is_hard: true,
390            expected_descriptor: None,
391            _phantom: PhantomData,
392        }
393    }
394
395    // Rewards each matching entity with a fixed weight.
396    pub fn reward(
397        self,
398        weight: Sc,
399    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
400    where
401        Sc: Copy,
402    {
403        // Detect if this is a hard constraint by checking if hard level is non-zero
404        let is_hard = weight
405            .to_level_numbers()
406            .first()
407            .map(|&h| h != 0)
408            .unwrap_or(false);
409        UniConstraintBuilder {
410            extractor: self.extractor,
411            filter: self.filter,
412            impact_type: ImpactType::Reward,
413            weight: move |_: &A| weight,
414            is_hard,
415            expected_descriptor: None,
416            _phantom: PhantomData,
417        }
418    }
419
420    // Rewards each matching entity with a dynamic weight.
421    //
422    // Note: For dynamic weights, use `reward_hard_with` to explicitly mark as a hard constraint,
423    // since the weight function cannot be evaluated at build time.
424    pub fn reward_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
425    where
426        W: Fn(&A) -> Sc + Send + Sync,
427    {
428        UniConstraintBuilder {
429            extractor: self.extractor,
430            filter: self.filter,
431            impact_type: ImpactType::Reward,
432            weight: weight_fn,
433            is_hard: false, // Can't detect at build time; use reward_hard_with for hard constraints
434            expected_descriptor: None,
435            _phantom: PhantomData,
436        }
437    }
438
439    // Rewards each matching entity with a dynamic weight, explicitly marked as a hard constraint.
440    pub fn reward_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
441    where
442        W: Fn(&A) -> Sc + Send + Sync,
443    {
444        UniConstraintBuilder {
445            extractor: self.extractor,
446            filter: self.filter,
447            impact_type: ImpactType::Reward,
448            weight: weight_fn,
449            is_hard: true,
450            expected_descriptor: None,
451            _phantom: PhantomData,
452        }
453    }
454
455    // Penalizes each matching entity with one hard score unit.
456    pub fn penalize_hard(
457        self,
458    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
459    where
460        Sc: Copy,
461    {
462        self.penalize(Sc::one_hard())
463    }
464
465    // Penalizes each matching entity with one soft score unit.
466    pub fn penalize_soft(
467        self,
468    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
469    where
470        Sc: Copy,
471    {
472        self.penalize(Sc::one_soft())
473    }
474
475    // Rewards each matching entity with one hard score unit.
476    pub fn reward_hard(
477        self,
478    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
479    where
480        Sc: Copy,
481    {
482        self.reward(Sc::one_hard())
483    }
484
485    // Rewards each matching entity with one soft score unit.
486    pub fn reward_soft(
487        self,
488    ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
489    where
490        Sc: Copy,
491    {
492        self.reward(Sc::one_soft())
493    }
494}
495
496impl<S, A, E, F, Sc: Score> UniConstraintStream<S, A, E, F, Sc> {
497    #[doc(hidden)]
498    pub fn extractor(&self) -> &E {
499        &self.extractor
500    }
501
502    #[doc(hidden)]
503    pub fn into_parts(self) -> (E, F) {
504        (self.extractor, self.filter)
505    }
506
507    #[doc(hidden)]
508    pub fn from_parts(extractor: E, filter: F) -> Self {
509        Self {
510            extractor,
511            filter,
512            _phantom: PhantomData,
513        }
514    }
515}
516
517impl<S, A, E, F, Sc: Score> std::fmt::Debug for UniConstraintStream<S, A, E, F, Sc> {
518    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519        f.debug_struct("UniConstraintStream").finish()
520    }
521}
522
523// Zero-erasure builder for finalizing a uni-constraint.
524pub struct UniConstraintBuilder<S, A, E, F, W, Sc>
525where
526    Sc: Score,
527{
528    extractor: E,
529    filter: F,
530    impact_type: ImpactType,
531    weight: W,
532    is_hard: bool,
533    expected_descriptor: Option<usize>,
534    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
535}
536
537impl<S, A, E, F, W, Sc> UniConstraintBuilder<S, A, E, F, W, Sc>
538where
539    S: Send + Sync + 'static,
540    A: Clone + Send + Sync + 'static,
541    E: Fn(&S) -> &[A] + Send + Sync,
542    F: UniFilter<S, A>,
543    W: Fn(&A) -> Sc + Send + Sync,
544    Sc: Score + 'static,
545{
546    // Restricts this constraint to only fire for the given descriptor index.
547    //
548    // Required when multiple entity classes exist (e.g., FurnaceAssignment at 0,
549    // ShiftAssignment at 1). Without this, on_insert/on_retract fire for all entity
550    // classes using the constraint's entity_index, which indexes into the wrong slice.
551    pub fn for_descriptor(mut self, descriptor_index: usize) -> Self {
552        self.expected_descriptor = Some(descriptor_index);
553        self
554    }
555
556    // Finalizes the builder into a zero-erasure `IncrementalUniConstraint`.
557    pub fn named(
558        self,
559        name: &str,
560    ) -> IncrementalUniConstraint<S, A, E, impl Fn(&S, &A) -> bool + Send + Sync, W, Sc> {
561        let filter = self.filter;
562        let combined_filter = move |s: &S, a: &A| filter.test(s, a);
563
564        let mut constraint = IncrementalUniConstraint::new(
565            ConstraintRef::new("", name),
566            self.impact_type,
567            self.extractor,
568            combined_filter,
569            self.weight,
570            self.is_hard,
571        );
572        if let Some(d) = self.expected_descriptor {
573            constraint = constraint.with_descriptor(d);
574        }
575        constraint
576    }
577}
578
579impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for UniConstraintBuilder<S, A, E, F, W, Sc> {
580    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581        f.debug_struct("UniConstraintBuilder")
582            .field("impact_type", &self.impact_type)
583            .finish()
584    }
585}