solverforge_scoring/constraint/complemented/state.rs
1/* Zero-erasure complemented group constraint.
2
3Evaluates grouped results plus complement entities with default values.
4Provides true incremental scoring by tracking per-key accumulators.
5*/
6
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::stream::collection_extract::ChangeSource;
15use crate::stream::collector::{Accumulator, UniCollector};
16
17/* Zero-erasure constraint for complemented grouped results.
18
19Groups A entities by key, then iterates over B entities (complement source),
20using grouped values where they exist and default values otherwise.
21
22The key function for A returns `Option<K>`, allowing entities to be skipped
23when they don't have a valid key (e.g., unassigned shifts).
24
25# Type Parameters
26
27- `S` - Solution type
28- `A` - Entity type being grouped (e.g., Shift)
29- `B` - Complement entity type (e.g., Employee)
30- `K` - Group key type
31- `EA` - Extractor for A entities
32- `EB` - Extractor for B entities
33- `KA` - Key function for A (returns `Option<K>` to allow skipping)
34- `KB` - Key function for B
35- `C` - Collector type
36- `D` - Default value function
37- `W` - Weight function
38- `Sc` - Score type
39
40# Example
41
42```
43use solverforge_scoring::constraint::complemented::ComplementedGroupConstraint;
44use solverforge_scoring::stream::collector::count;
45use solverforge_scoring::api::constraint_set::IncrementalConstraint;
46use solverforge_core::{ConstraintRef, ImpactType};
47use solverforge_core::score::SoftScore;
48
49#[derive(Clone, Hash, PartialEq, Eq)]
50struct Employee { id: usize }
51
52#[derive(Clone)]
53struct Shift { employee_id: Option<usize> }
54
55#[derive(Clone)]
56struct Schedule {
57employees: Vec<Employee>,
58shifts: Vec<Shift>,
59}
60
61let constraint = ComplementedGroupConstraint::new(
62ConstraintRef::new("", "Shift count"),
63ImpactType::Penalty,
64|s: &Schedule| s.shifts.as_slice(),
65|s: &Schedule| s.employees.as_slice(),
66|shift: &Shift| shift.employee_id, // Returns Option<usize>
67|emp: &Employee| emp.id,
68count(),
69|_emp: &Employee| 0usize,
70|count: &usize| SoftScore::of(*count as i64),
71false,
72);
73
74let schedule = Schedule {
75employees: vec![Employee { id: 0 }, Employee { id: 1 }],
76shifts: vec![
77Shift { employee_id: Some(0) },
78Shift { employee_id: Some(0) },
79Shift { employee_id: None }, // Skipped - no key
80],
81};
82
83// Employee 0: 2 shifts, Employee 1: 0 shifts → Total: -2
84// Unassigned shift is skipped
85assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
86```
87*/
88pub struct ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
89where
90 C: UniCollector<A>,
91 Sc: Score,
92{
93 pub(super) constraint_ref: ConstraintRef,
94 pub(super) impact_type: ImpactType,
95 pub(super) extractor_a: EA,
96 pub(super) extractor_b: EB,
97 pub(super) key_a: KA,
98 pub(super) key_b: KB,
99 pub(super) collector: C,
100 pub(super) default_fn: D,
101 pub(super) weight_fn: W,
102 pub(super) is_hard: bool,
103 pub(super) a_source: ChangeSource,
104 pub(super) b_source: ChangeSource,
105 // Group key -> accumulator for incremental scoring
106 pub(super) groups: HashMap<K, C::Accumulator>,
107 // A entity index -> group key (for tracking which group each entity belongs to)
108 pub(super) entity_groups: HashMap<usize, K>,
109 // A entity index -> extracted value (for correct retraction after entity mutation)
110 pub(super) entity_values: HashMap<usize, C::Value>,
111 // B key -> B entity index (for looking up B entities by key)
112 pub(super) b_by_key: HashMap<K, usize>,
113 // B entity index -> B key (for localized B retraction)
114 pub(super) b_index_to_key: HashMap<usize, K>,
115 pub(super) _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> Sc)>,
116}
117
118impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
119 ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
120where
121 S: 'static,
122 A: Clone + 'static,
123 B: Clone + 'static,
124 K: Clone + Eq + Hash,
125 EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
126 EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
127 KA: Fn(&A) -> Option<K>,
128 KB: Fn(&B) -> K,
129 C: UniCollector<A>,
130 C::Result: Clone,
131 D: Fn(&B) -> C::Result,
132 W: Fn(&C::Result) -> Sc,
133 Sc: Score,
134{
135 // Creates a new complemented group constraint.
136 #[allow(clippy::too_many_arguments)]
137 pub fn new(
138 constraint_ref: ConstraintRef,
139 impact_type: ImpactType,
140 extractor_a: EA,
141 extractor_b: EB,
142 key_a: KA,
143 key_b: KB,
144 collector: C,
145 default_fn: D,
146 weight_fn: W,
147 is_hard: bool,
148 ) -> Self {
149 let a_source = extractor_a.change_source();
150 let b_source = extractor_b.change_source();
151 Self {
152 constraint_ref,
153 impact_type,
154 extractor_a,
155 extractor_b,
156 key_a,
157 key_b,
158 collector,
159 default_fn,
160 weight_fn,
161 is_hard,
162 a_source,
163 b_source,
164 groups: HashMap::new(),
165 entity_groups: HashMap::new(),
166 entity_values: HashMap::new(),
167 b_by_key: HashMap::new(),
168 b_index_to_key: HashMap::new(),
169 _phantom: PhantomData,
170 }
171 }
172
173 #[inline]
174 pub(super) fn compute_score(&self, result: &C::Result) -> Sc {
175 let base = (self.weight_fn)(result);
176 match self.impact_type {
177 ImpactType::Penalty => -base,
178 ImpactType::Reward => base,
179 }
180 }
181
182 // Build grouped results from A entities.
183 pub(super) fn build_groups(&self, entities_a: &[A]) -> HashMap<K, C::Result> {
184 let mut accumulators: HashMap<K, C::Accumulator> = HashMap::new();
185
186 for a in entities_a {
187 // Skip entities with no key (e.g., unassigned shifts)
188 let Some(key) = (self.key_a)(a) else {
189 continue;
190 };
191 let value = self.collector.extract(a);
192 accumulators
193 .entry(key)
194 .or_insert_with(|| self.collector.create_accumulator())
195 .accumulate(&value);
196 }
197
198 accumulators
199 .into_iter()
200 .map(|(k, acc)| (k, acc.finish()))
201 .collect()
202 }
203}