1use std::collections::HashMap;
9use std::hash::Hash;
10use std::marker::PhantomData;
11
12use solverforge_core::score::Score;
13use solverforge_core::{ConstraintRef, ImpactType};
14
15use crate::api::constraint_set::IncrementalConstraint;
16use crate::stream::collector::{Accumulator, UniCollector};
17use crate::stream::filter::UniFilter;
18
19struct GroupState<Acc> {
20 accumulator: Acc,
21 count: usize,
22}
23
24pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
89where
90 C: UniCollector<A>,
91 Sc: Score,
92{
93 constraint_ref: ConstraintRef,
94 impact_type: ImpactType,
95 extractor: E,
96 filter: Fi,
97 key_fn: KF,
98 collector: C,
99 weight_fn: W,
100 is_hard: bool,
101 change_source: crate::stream::collection_extract::ChangeSource,
102 groups: HashMap<K, GroupState<C::Accumulator>>,
104 entity_groups: HashMap<usize, K>,
106 entity_values: HashMap<usize, C::Value>,
108 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
109}
110
111impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
112where
113 S: Send + Sync + 'static,
114 A: Clone + Send + Sync + 'static,
115 K: Clone + Eq + Hash + Send + Sync + 'static,
116 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
117 Fi: UniFilter<S, A>,
118 KF: Fn(&A) -> K + Send + Sync,
119 C: UniCollector<A> + Send + Sync + 'static,
120 C::Accumulator: Send + Sync,
121 C::Result: Send + Sync,
122 W: Fn(&K, &C::Result) -> Sc + Send + Sync,
123 Sc: Score + 'static,
124{
125 #[allow(clippy::too_many_arguments)]
139 pub fn new(
140 constraint_ref: ConstraintRef,
141 impact_type: ImpactType,
142 extractor: E,
143 filter: Fi,
144 key_fn: KF,
145 collector: C,
146 weight_fn: W,
147 is_hard: bool,
148 ) -> Self {
149 let change_source = extractor.change_source();
150 Self {
151 constraint_ref,
152 impact_type,
153 extractor,
154 filter,
155 key_fn,
156 collector,
157 weight_fn,
158 is_hard,
159 change_source,
160 groups: HashMap::new(),
161 entity_groups: HashMap::new(),
162 entity_values: HashMap::new(),
163 _phantom: PhantomData,
164 }
165 }
166
167 fn compute_score(&self, key: &K, result: &C::Result) -> Sc {
169 let base = (self.weight_fn)(key, result);
170 match self.impact_type {
171 ImpactType::Penalty => -base,
172 ImpactType::Reward => base,
173 }
174 }
175}
176
177impl<S, A, K, E, Fi, KF, C, W, Sc> IncrementalConstraint<S, Sc>
178 for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
179where
180 S: Send + Sync + 'static,
181 A: Clone + Send + Sync + 'static,
182 K: Clone + Eq + Hash + Send + Sync + 'static,
183 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
184 Fi: UniFilter<S, A>,
185 KF: Fn(&A) -> K + Send + Sync,
186 C: UniCollector<A> + Send + Sync + 'static,
187 C::Accumulator: Send + Sync,
188 C::Result: Send + Sync,
189 C::Value: Send + Sync,
190 W: Fn(&K, &C::Result) -> Sc + Send + Sync,
191 Sc: Score + 'static,
192{
193 fn evaluate(&self, solution: &S) -> Sc {
194 let entities = self.extractor.extract(solution);
195
196 let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
198
199 for entity in entities {
200 if !self.filter.test(solution, entity) {
201 continue;
202 }
203 let key = (self.key_fn)(entity);
204 let value = self.collector.extract(entity);
205 let acc = groups
206 .entry(key)
207 .or_insert_with(|| self.collector.create_accumulator());
208 acc.accumulate(&value);
209 }
210
211 let mut total = Sc::zero();
213 for (key, acc) in &groups {
214 let result = acc.finish();
215 total = total + self.compute_score(key, &result);
216 }
217
218 total
219 }
220
221 fn match_count(&self, solution: &S) -> usize {
222 let entities = self.extractor.extract(solution);
223
224 let mut groups: HashMap<K, ()> = HashMap::new();
226 for entity in entities {
227 if !self.filter.test(solution, entity) {
228 continue;
229 }
230 let key = (self.key_fn)(entity);
231 groups.insert(key, ());
232 }
233
234 groups.len()
235 }
236
237 fn initialize(&mut self, solution: &S) -> Sc {
238 self.reset();
239
240 let entities = self.extractor.extract(solution);
241 let mut total = Sc::zero();
242
243 for (idx, entity) in entities.iter().enumerate() {
244 if !self.filter.test(solution, entity) {
245 continue;
246 }
247 total = total + self.insert_entity(entities, idx, entity);
248 }
249
250 total
251 }
252
253 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
254 if !self
255 .change_source
256 .assert_localizes(descriptor_index, &self.constraint_ref.name)
257 {
258 return Sc::zero();
259 }
260 let entities = self.extractor.extract(solution);
261 if entity_index >= entities.len() {
262 return Sc::zero();
263 }
264
265 let entity = &entities[entity_index];
266 if !self.filter.test(solution, entity) {
267 return Sc::zero();
268 }
269 self.insert_entity(entities, entity_index, entity)
270 }
271
272 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
273 if !self
274 .change_source
275 .assert_localizes(descriptor_index, &self.constraint_ref.name)
276 {
277 return Sc::zero();
278 }
279 let entities = self.extractor.extract(solution);
280 self.retract_entity(entities, entity_index)
281 }
282
283 fn reset(&mut self) {
284 self.groups.clear();
285 self.entity_groups.clear();
286 self.entity_values.clear();
287 }
288
289 fn name(&self) -> &str {
290 &self.constraint_ref.name
291 }
292
293 fn is_hard(&self) -> bool {
294 self.is_hard
295 }
296
297 fn constraint_ref(&self) -> &ConstraintRef {
298 &self.constraint_ref
299 }
300}
301
302impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
303where
304 S: Send + Sync + 'static,
305 A: Clone + Send + Sync + 'static,
306 K: Clone + Eq + Hash + Send + Sync + 'static,
307 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
308 Fi: UniFilter<S, A>,
309 KF: Fn(&A) -> K + Send + Sync,
310 C: UniCollector<A> + Send + Sync + 'static,
311 C::Accumulator: Send + Sync,
312 C::Result: Send + Sync,
313 C::Value: Send + Sync,
314 W: Fn(&K, &C::Result) -> Sc + Send + Sync,
315 Sc: Score + 'static,
316{
317 fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
318 let key = (self.key_fn)(entity);
319 let entity_key = (self.key_fn)(entity);
320 let value = self.collector.extract(entity);
321 let impact = self.impact_type;
322
323 let weight_fn = &self.weight_fn;
324 let (old, new_score) = match self.groups.entry(key) {
325 std::collections::hash_map::Entry::Occupied(mut entry) => {
326 let old_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
327 let old = match impact {
328 ImpactType::Penalty => -old_base,
329 ImpactType::Reward => old_base,
330 };
331 let group = entry.get_mut();
332 group.accumulator.accumulate(&value);
333 group.count += 1;
334 let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
335 let new_score = match impact {
336 ImpactType::Penalty => -new_base,
337 ImpactType::Reward => new_base,
338 };
339 (old, new_score)
340 }
341 std::collections::hash_map::Entry::Vacant(entry) => {
342 let mut entry = entry.insert_entry(GroupState {
343 accumulator: self.collector.create_accumulator(),
344 count: 0,
345 });
346 let group = entry.get_mut();
347 group.accumulator.accumulate(&value);
348 group.count += 1;
349 let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
350 let new_score = match impact {
351 ImpactType::Penalty => -new_base,
352 ImpactType::Reward => new_base,
353 };
354 (Sc::zero(), new_score)
355 }
356 };
357
358 self.entity_groups.insert(entity_index, entity_key);
360 self.entity_values.insert(entity_index, value);
361
362 new_score - old
364 }
365
366 fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
367 let Some(key) = self.entity_groups.remove(&entity_index) else {
369 return Sc::zero();
370 };
371
372 let Some(value) = self.entity_values.remove(&entity_index) else {
374 return Sc::zero();
375 };
376 let impact = self.impact_type;
377
378 let weight_fn = &self.weight_fn;
379 let std::collections::hash_map::Entry::Occupied(mut entry) = self.groups.entry(key) else {
380 return Sc::zero();
381 };
382
383 let old_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
384 let old = match impact {
385 ImpactType::Penalty => -old_base,
386 ImpactType::Reward => old_base,
387 };
388
389 let group = entry.get_mut();
390 group.accumulator.retract(&value);
391 group.count = group.count.saturating_sub(1);
392 let is_empty = group.count == 0;
393 let new_score = if is_empty {
394 entry.remove();
395 Sc::zero()
396 } else {
397 let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
398 match impact {
399 ImpactType::Penalty => -new_base,
400 ImpactType::Reward => new_base,
401 }
402 };
403
404 new_score - old
406 }
407}
408
409impl<S, A, K, E, Fi, KF, C, W, Sc> std::fmt::Debug
410 for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
411where
412 C: UniCollector<A>,
413 Sc: Score,
414{
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 f.debug_struct("GroupedUniConstraint")
417 .field("name", &self.constraint_ref.name)
418 .field("impact_type", &self.impact_type)
419 .field("groups", &self.groups.len())
420 .finish()
421 }
422}