Skip to main content

solverforge_scoring/constraint/complemented/
state.rs

1/* Zero-erasure complemented group constraint.
2
3Evaluates grouped results plus complement entities with default values.
4Provides true incremental scoring by tracking per-key accumulators.
5*/
6
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::stream::collection_extract::ChangeSource;
15use crate::stream::collector::{Accumulator, UniCollector};
16
17/* Zero-erasure constraint for complemented grouped results.
18
19Groups A entities by key, then iterates over B entities (complement source),
20using grouped values where they exist and default values otherwise.
21
22The key function for A returns `Option<K>`, allowing entities to be skipped
23when they don't have a valid key (e.g., unassigned shifts).
24
25# Type Parameters
26
27- `S` - Solution type
28- `A` - Entity type being grouped (e.g., Shift)
29- `B` - Complement entity type (e.g., Employee)
30- `K` - Group key type
31- `EA` - Extractor for A entities
32- `EB` - Extractor for B entities
33- `KA` - Key function for A (returns `Option<K>` to allow skipping)
34- `KB` - Key function for B
35- `C` - Collector type
36- `D` - Default value function
37- `W` - Weight function
38- `Sc` - Score type
39
40# Example
41
42```
43use solverforge_scoring::constraint::complemented::ComplementedGroupConstraint;
44use solverforge_scoring::stream::collector::count;
45use solverforge_scoring::api::constraint_set::IncrementalConstraint;
46use solverforge_core::{ConstraintRef, ImpactType};
47use solverforge_core::score::SoftScore;
48
49#[derive(Clone, Hash, PartialEq, Eq)]
50struct Employee { id: usize }
51
52#[derive(Clone)]
53struct Shift { employee_id: Option<usize> }
54
55#[derive(Clone)]
56struct Schedule {
57employees: Vec<Employee>,
58shifts: Vec<Shift>,
59}
60
61let constraint = ComplementedGroupConstraint::new(
62ConstraintRef::new("", "Shift count"),
63ImpactType::Penalty,
64|s: &Schedule| s.shifts.as_slice(),
65|s: &Schedule| s.employees.as_slice(),
66|shift: &Shift| shift.employee_id,  // Returns Option<usize>
67|emp: &Employee| emp.id,
68count(),
69|_emp: &Employee| 0usize,
70|count: &usize| SoftScore::of(*count as i64),
71false,
72);
73
74let schedule = Schedule {
75employees: vec![Employee { id: 0 }, Employee { id: 1 }],
76shifts: vec![
77Shift { employee_id: Some(0) },
78Shift { employee_id: Some(0) },
79Shift { employee_id: None },  // Skipped - no key
80],
81};
82
83// Employee 0: 2 shifts, Employee 1: 0 shifts → Total: -2
84// Unassigned shift is skipped
85assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
86```
87*/
88pub struct ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
89where
90    C: UniCollector<A>,
91    Sc: Score,
92{
93    pub(super) constraint_ref: ConstraintRef,
94    pub(super) impact_type: ImpactType,
95    pub(super) extractor_a: EA,
96    pub(super) extractor_b: EB,
97    pub(super) key_a: KA,
98    pub(super) key_b: KB,
99    pub(super) collector: C,
100    pub(super) default_fn: D,
101    pub(super) weight_fn: W,
102    pub(super) is_hard: bool,
103    pub(super) a_source: ChangeSource,
104    pub(super) b_source: ChangeSource,
105    // Group key -> accumulator for incremental scoring
106    pub(super) groups: HashMap<K, C::Accumulator>,
107    // A entity index -> group key (for tracking which group each entity belongs to)
108    pub(super) entity_groups: HashMap<usize, K>,
109    // A entity index -> extracted value (for correct retraction after entity mutation)
110    pub(super) entity_values: HashMap<usize, C::Value>,
111    // B key -> B entity index (for looking up B entities by key)
112    pub(super) b_by_key: HashMap<K, usize>,
113    // B entity index -> B key (for localized B retraction)
114    pub(super) b_index_to_key: HashMap<usize, K>,
115    pub(super) _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> Sc)>,
116}
117
118impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
119    ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
120where
121    S: 'static,
122    A: Clone + 'static,
123    B: Clone + 'static,
124    K: Clone + Eq + Hash,
125    EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
126    EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
127    KA: Fn(&A) -> Option<K>,
128    KB: Fn(&B) -> K,
129    C: UniCollector<A>,
130    C::Result: Clone,
131    D: Fn(&B) -> C::Result,
132    W: Fn(&C::Result) -> Sc,
133    Sc: Score,
134{
135    // Creates a new complemented group constraint.
136    #[allow(clippy::too_many_arguments)]
137    pub fn new(
138        constraint_ref: ConstraintRef,
139        impact_type: ImpactType,
140        extractor_a: EA,
141        extractor_b: EB,
142        key_a: KA,
143        key_b: KB,
144        collector: C,
145        default_fn: D,
146        weight_fn: W,
147        is_hard: bool,
148    ) -> Self {
149        let a_source = extractor_a.change_source();
150        let b_source = extractor_b.change_source();
151        Self {
152            constraint_ref,
153            impact_type,
154            extractor_a,
155            extractor_b,
156            key_a,
157            key_b,
158            collector,
159            default_fn,
160            weight_fn,
161            is_hard,
162            a_source,
163            b_source,
164            groups: HashMap::new(),
165            entity_groups: HashMap::new(),
166            entity_values: HashMap::new(),
167            b_by_key: HashMap::new(),
168            b_index_to_key: HashMap::new(),
169            _phantom: PhantomData,
170        }
171    }
172
173    #[inline]
174    pub(super) fn compute_score(&self, result: &C::Result) -> Sc {
175        let base = (self.weight_fn)(result);
176        match self.impact_type {
177            ImpactType::Penalty => -base,
178            ImpactType::Reward => base,
179        }
180    }
181
182    // Build grouped results from A entities.
183    pub(super) fn build_groups(&self, entities_a: &[A]) -> HashMap<K, C::Result> {
184        let mut accumulators: HashMap<K, C::Accumulator> = HashMap::new();
185
186        for a in entities_a {
187            // Skip entities with no key (e.g., unassigned shifts)
188            let Some(key) = (self.key_a)(a) else {
189                continue;
190            };
191            let value = self.collector.extract(a);
192            accumulators
193                .entry(key)
194                .or_insert_with(|| self.collector.create_accumulator())
195                .accumulate(&value);
196        }
197
198        accumulators
199            .into_iter()
200            .map(|(k, acc)| (k, acc.finish()))
201            .collect()
202    }
203}