solverforge_scoring/constraint/complemented/state.rs
1/* Zero-erasure complemented group constraint.
2
3Evaluates grouped results plus complement entities with default values.
4Provides true incremental scoring by tracking per-key accumulators.
5*/
6
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::stream::collection_extract::ChangeSource;
15use crate::stream::collector::{Accumulator, UniCollector};
16
17type CollectorRetraction<C, A> = <<C as UniCollector<A>>::Accumulator as Accumulator<
18 <C as UniCollector<A>>::Value,
19 <C as UniCollector<A>>::Result,
20>>::Retraction;
21
22/* Zero-erasure constraint for complemented grouped results.
23
24Groups A entities by key, then iterates over B entities (complement source),
25using grouped values where they exist and default values otherwise.
26
27The key function for A returns `Option<K>`, allowing entities to be skipped
28when they don't have a valid key (e.g., unassigned shifts).
29
30# Type Parameters
31
32- `S` - Solution type
33- `A` - Entity type being grouped (e.g., Shift)
34- `B` - Complement entity type (e.g., Employee)
35- `K` - Group key type
36- `EA` - Extractor for A entities
37- `EB` - Extractor for B entities
38- `KA` - Key function for A (returns `Option<K>` to allow skipping)
39- `KB` - Key function for B
40- `C` - Collector type
41- `D` - Default value function
42- `W` - Weight function
43- `Sc` - Score type
44
45# Example
46
47```
48use solverforge_scoring::constraint::complemented::ComplementedGroupConstraint;
49use solverforge_scoring::stream::collector::count;
50use solverforge_scoring::api::constraint_set::IncrementalConstraint;
51use solverforge_core::{ConstraintRef, ImpactType};
52use solverforge_core::score::SoftScore;
53
54#[derive(Clone, Hash, PartialEq, Eq)]
55struct Employee { id: usize }
56
57#[derive(Clone)]
58struct Shift { employee_id: Option<usize> }
59
60#[derive(Clone)]
61struct Schedule {
62employees: Vec<Employee>,
63shifts: Vec<Shift>,
64}
65
66let constraint = ComplementedGroupConstraint::new(
67ConstraintRef::new("", "Shift count"),
68ImpactType::Penalty,
69|s: &Schedule| s.shifts.as_slice(),
70|s: &Schedule| s.employees.as_slice(),
71|shift: &Shift| shift.employee_id, // Returns Option<usize>
72|emp: &Employee| emp.id,
73count(),
74|_emp: &Employee| 0usize,
75|_employee_id: &usize, count: &usize| SoftScore::of(*count as i64),
76false,
77);
78
79let schedule = Schedule {
80employees: vec![Employee { id: 0 }, Employee { id: 1 }],
81shifts: vec![
82Shift { employee_id: Some(0) },
83Shift { employee_id: Some(0) },
84Shift { employee_id: None }, // Skipped - no key
85],
86};
87
88// Employee 0: 2 shifts, Employee 1: 0 shifts → Total: -2
89// Unassigned shift is skipped
90assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
91```
92*/
93pub struct ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
94where
95 C: UniCollector<A>,
96 Sc: Score,
97{
98 pub(super) constraint_ref: ConstraintRef,
99 pub(super) impact_type: ImpactType,
100 pub(super) extractor_a: EA,
101 pub(super) extractor_b: EB,
102 pub(super) key_a: KA,
103 pub(super) key_b: KB,
104 pub(super) collector: C,
105 pub(super) default_fn: D,
106 pub(super) weight_fn: W,
107 pub(super) is_hard: bool,
108 pub(super) a_source: ChangeSource,
109 pub(super) b_source: ChangeSource,
110 // Group key -> accumulator for incremental scoring
111 pub(super) groups: HashMap<K, C::Accumulator>,
112 // A entity index -> group key (for tracking which group each entity belongs to)
113 pub(super) entity_groups: HashMap<usize, K>,
114 // A entity index -> accumulator retraction token
115 pub(super) entity_retractions: HashMap<usize, CollectorRetraction<C, A>>,
116 // B key -> B entity index (for looking up B entities by key)
117 pub(super) b_by_key: HashMap<K, usize>,
118 // B entity index -> B key (for localized B retraction)
119 pub(super) b_index_to_key: HashMap<usize, K>,
120 pub(super) _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> Sc)>,
121}
122
123impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
124 ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
125where
126 S: 'static,
127 A: Clone + 'static,
128 B: Clone + 'static,
129 K: Clone + Eq + Hash,
130 EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
131 EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
132 KA: Fn(&A) -> Option<K>,
133 KB: Fn(&B) -> K,
134 C: UniCollector<A>,
135 D: Fn(&B) -> C::Result,
136 W: Fn(&K, &C::Result) -> Sc,
137 Sc: Score,
138{
139 // Creates a new complemented group constraint.
140 #[allow(clippy::too_many_arguments)]
141 pub fn new(
142 constraint_ref: ConstraintRef,
143 impact_type: ImpactType,
144 extractor_a: EA,
145 extractor_b: EB,
146 key_a: KA,
147 key_b: KB,
148 collector: C,
149 default_fn: D,
150 weight_fn: W,
151 is_hard: bool,
152 ) -> Self {
153 let a_source = extractor_a.change_source();
154 let b_source = extractor_b.change_source();
155 Self {
156 constraint_ref,
157 impact_type,
158 extractor_a,
159 extractor_b,
160 key_a,
161 key_b,
162 collector,
163 default_fn,
164 weight_fn,
165 is_hard,
166 a_source,
167 b_source,
168 groups: HashMap::new(),
169 entity_groups: HashMap::new(),
170 entity_retractions: HashMap::new(),
171 b_by_key: HashMap::new(),
172 b_index_to_key: HashMap::new(),
173 _phantom: PhantomData,
174 }
175 }
176
177 #[inline]
178 pub(super) fn compute_score(&self, key: &K, result: &C::Result) -> Sc {
179 let base = (self.weight_fn)(key, result);
180 match self.impact_type {
181 ImpactType::Penalty => -base,
182 ImpactType::Reward => base,
183 }
184 }
185
186 // Build grouped results from A entities.
187 pub(super) fn build_groups(&self, entities_a: &[A]) -> HashMap<K, C::Accumulator> {
188 let mut accumulators: HashMap<K, C::Accumulator> = HashMap::new();
189
190 for a in entities_a {
191 // Skip entities with no key (e.g., unassigned shifts)
192 let Some(key) = (self.key_a)(a) else {
193 continue;
194 };
195 let value = self.collector.extract(a);
196 accumulators
197 .entry(key)
198 .or_insert_with(|| self.collector.create_accumulator())
199 .accumulate(value);
200 }
201
202 accumulators
203 }
204}