1use std::collections::HashMap;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::constraint_set::IncrementalConstraint;
15use crate::stream::collector::{Accumulator, UniCollector};
16
17pub struct GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
78where
79 C: UniCollector<A>,
80 Sc: Score,
81{
82 constraint_ref: ConstraintRef,
83 impact_type: ImpactType,
84 extractor: E,
85 key_fn: KF,
86 collector: C,
87 weight_fn: W,
88 is_hard: bool,
89 expected_descriptor: Option<usize>,
90 groups: HashMap<K, C::Accumulator>,
92 entity_groups: HashMap<usize, K>,
94 entity_values: HashMap<usize, C::Value>,
96 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
97}
98
99impl<S, A, K, E, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
100where
101 S: Send + Sync + 'static,
102 A: Clone + Send + Sync + 'static,
103 K: Clone + Eq + Hash + Send + Sync + 'static,
104 E: Fn(&S) -> &[A] + Send + Sync,
105 KF: Fn(&A) -> K + Send + Sync,
106 C: UniCollector<A> + Send + Sync + 'static,
107 C::Accumulator: Send + Sync,
108 C::Result: Send + Sync,
109 W: Fn(&C::Result) -> Sc + Send + Sync,
110 Sc: Score + 'static,
111{
112 pub fn new(
124 constraint_ref: ConstraintRef,
125 impact_type: ImpactType,
126 extractor: E,
127 key_fn: KF,
128 collector: C,
129 weight_fn: W,
130 is_hard: bool,
131 ) -> Self {
132 Self {
133 constraint_ref,
134 impact_type,
135 extractor,
136 key_fn,
137 collector,
138 weight_fn,
139 is_hard,
140 expected_descriptor: None,
141 groups: HashMap::new(),
142 entity_groups: HashMap::new(),
143 entity_values: HashMap::new(),
144 _phantom: PhantomData,
145 }
146 }
147
148 pub fn with_descriptor(mut self, descriptor_index: usize) -> Self {
149 self.expected_descriptor = Some(descriptor_index);
150 self
151 }
152
153 fn compute_score(&self, result: &C::Result) -> Sc {
155 let base = (self.weight_fn)(result);
156 match self.impact_type {
157 ImpactType::Penalty => -base,
158 ImpactType::Reward => base,
159 }
160 }
161}
162
163impl<S, A, K, E, KF, C, W, Sc> IncrementalConstraint<S, Sc>
164 for GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
165where
166 S: Send + Sync + 'static,
167 A: Clone + Send + Sync + 'static,
168 K: Clone + Eq + Hash + Send + Sync + 'static,
169 E: Fn(&S) -> &[A] + Send + Sync,
170 KF: Fn(&A) -> K + Send + Sync,
171 C: UniCollector<A> + Send + Sync + 'static,
172 C::Accumulator: Send + Sync,
173 C::Result: Send + Sync,
174 C::Value: Send + Sync,
175 W: Fn(&C::Result) -> Sc + Send + Sync,
176 Sc: Score + 'static,
177{
178 fn evaluate(&self, solution: &S) -> Sc {
179 let entities = (self.extractor)(solution);
180
181 let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
183
184 for entity in entities {
185 let key = (self.key_fn)(entity);
186 let value = self.collector.extract(entity);
187 let acc = groups
188 .entry(key)
189 .or_insert_with(|| self.collector.create_accumulator());
190 acc.accumulate(&value);
191 }
192
193 let mut total = Sc::zero();
195 for acc in groups.values() {
196 let result = acc.finish();
197 total = total + self.compute_score(&result);
198 }
199
200 total
201 }
202
203 fn match_count(&self, solution: &S) -> usize {
204 let entities = (self.extractor)(solution);
205
206 let mut groups: HashMap<K, ()> = HashMap::new();
208 for entity in entities {
209 let key = (self.key_fn)(entity);
210 groups.insert(key, ());
211 }
212
213 groups.len()
214 }
215
216 fn initialize(&mut self, solution: &S) -> Sc {
217 self.reset();
218
219 let entities = (self.extractor)(solution);
220 let mut total = Sc::zero();
221
222 for (idx, entity) in entities.iter().enumerate() {
223 total = total + self.insert_entity(entities, idx, entity);
224 }
225
226 total
227 }
228
229 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
230 if let Some(expected) = self.expected_descriptor {
231 if descriptor_index != expected {
232 return Sc::zero();
233 }
234 }
235 let entities = (self.extractor)(solution);
236 if entity_index >= entities.len() {
237 return Sc::zero();
238 }
239
240 let entity = &entities[entity_index];
241 self.insert_entity(entities, entity_index, entity)
242 }
243
244 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
245 if let Some(expected) = self.expected_descriptor {
246 if descriptor_index != expected {
247 return Sc::zero();
248 }
249 }
250 let entities = (self.extractor)(solution);
251 self.retract_entity(entities, entity_index)
252 }
253
254 fn reset(&mut self) {
255 self.groups.clear();
256 self.entity_groups.clear();
257 self.entity_values.clear();
258 }
259
260 fn name(&self) -> &str {
261 &self.constraint_ref.name
262 }
263
264 fn is_hard(&self) -> bool {
265 self.is_hard
266 }
267
268 fn constraint_ref(&self) -> ConstraintRef {
269 self.constraint_ref.clone()
270 }
271}
272
273impl<S, A, K, E, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
274where
275 S: Send + Sync + 'static,
276 A: Clone + Send + Sync + 'static,
277 K: Clone + Eq + Hash + Send + Sync + 'static,
278 E: Fn(&S) -> &[A] + Send + Sync,
279 KF: Fn(&A) -> K + Send + Sync,
280 C: UniCollector<A> + Send + Sync + 'static,
281 C::Accumulator: Send + Sync,
282 C::Result: Send + Sync,
283 C::Value: Send + Sync,
284 W: Fn(&C::Result) -> Sc + Send + Sync,
285 Sc: Score + 'static,
286{
287 fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
288 let key = (self.key_fn)(entity);
289 let value = self.collector.extract(entity);
290 let impact = self.impact_type;
291
292 let acc = self
294 .groups
295 .entry(key.clone())
296 .or_insert_with(|| self.collector.create_accumulator());
297
298 let old_base = (self.weight_fn)(&acc.finish());
300 let old = match impact {
301 ImpactType::Penalty => -old_base,
302 ImpactType::Reward => old_base,
303 };
304
305 acc.accumulate(&value);
307 let new_base = (self.weight_fn)(&acc.finish());
308 let new_score = match impact {
309 ImpactType::Penalty => -new_base,
310 ImpactType::Reward => new_base,
311 };
312
313 self.entity_groups.insert(entity_index, key);
315 self.entity_values.insert(entity_index, value);
316
317 new_score - old
319 }
320
321 fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
322 let Some(key) = self.entity_groups.remove(&entity_index) else {
324 return Sc::zero();
325 };
326
327 let Some(value) = self.entity_values.remove(&entity_index) else {
329 return Sc::zero();
330 };
331 let impact = self.impact_type;
332
333 let Some(acc) = self.groups.get_mut(&key) else {
335 return Sc::zero();
336 };
337
338 let old_base = (self.weight_fn)(&acc.finish());
340 let old = match impact {
341 ImpactType::Penalty => -old_base,
342 ImpactType::Reward => old_base,
343 };
344
345 acc.retract(&value);
347 let new_base = (self.weight_fn)(&acc.finish());
348 let new_score = match impact {
349 ImpactType::Penalty => -new_base,
350 ImpactType::Reward => new_base,
351 };
352
353 new_score - old
355 }
356}
357
358impl<S, A, K, E, KF, C, W, Sc> std::fmt::Debug for GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
359where
360 C: UniCollector<A>,
361 Sc: Score,
362{
363 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 f.debug_struct("GroupedUniConstraint")
365 .field("name", &self.constraint_ref.name)
366 .field("impact_type", &self.impact_type)
367 .field("groups", &self.groups.len())
368 .finish()
369 }
370}