1use std::collections::HashMap;
9use std::hash::Hash;
10use std::marker::PhantomData;
11
12use solverforge_core::score::Score;
13use solverforge_core::{ConstraintRef, ImpactType};
14
15use crate::api::constraint_set::IncrementalConstraint;
16use crate::stream::collector::{Accumulator, UniCollector};
17use crate::stream::filter::UniFilter;
18
19pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
84where
85 C: UniCollector<A>,
86 Sc: Score,
87{
88 constraint_ref: ConstraintRef,
89 impact_type: ImpactType,
90 extractor: E,
91 filter: Fi,
92 key_fn: KF,
93 collector: C,
94 weight_fn: W,
95 is_hard: bool,
96 expected_descriptor: Option<usize>,
97 groups: HashMap<K, C::Accumulator>,
99 group_counts: HashMap<K, usize>,
101 entity_groups: HashMap<usize, K>,
103 entity_values: HashMap<usize, C::Value>,
105 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
106}
107
108impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
109where
110 S: Send + Sync + 'static,
111 A: Clone + Send + Sync + 'static,
112 K: Clone + Eq + Hash + Send + Sync + 'static,
113 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
114 Fi: UniFilter<S, A>,
115 KF: Fn(&A) -> K + Send + Sync,
116 C: UniCollector<A> + Send + Sync + 'static,
117 C::Accumulator: Send + Sync,
118 C::Result: Send + Sync,
119 W: Fn(&C::Result) -> Sc + Send + Sync,
120 Sc: Score + 'static,
121{
122 #[allow(clippy::too_many_arguments)]
136 pub fn new(
137 constraint_ref: ConstraintRef,
138 impact_type: ImpactType,
139 extractor: E,
140 filter: Fi,
141 key_fn: KF,
142 collector: C,
143 weight_fn: W,
144 is_hard: bool,
145 ) -> Self {
146 Self {
147 constraint_ref,
148 impact_type,
149 extractor,
150 filter,
151 key_fn,
152 collector,
153 weight_fn,
154 is_hard,
155 expected_descriptor: None,
156 groups: HashMap::new(),
157 group_counts: HashMap::new(),
158 entity_groups: HashMap::new(),
159 entity_values: HashMap::new(),
160 _phantom: PhantomData,
161 }
162 }
163
164 pub fn with_descriptor(mut self, descriptor_index: usize) -> Self {
165 self.expected_descriptor = Some(descriptor_index);
166 self
167 }
168
169 fn compute_score(&self, result: &C::Result) -> Sc {
171 let base = (self.weight_fn)(result);
172 match self.impact_type {
173 ImpactType::Penalty => -base,
174 ImpactType::Reward => base,
175 }
176 }
177}
178
179impl<S, A, K, E, Fi, KF, C, W, Sc> IncrementalConstraint<S, Sc>
180 for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
181where
182 S: Send + Sync + 'static,
183 A: Clone + Send + Sync + 'static,
184 K: Clone + Eq + Hash + Send + Sync + 'static,
185 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
186 Fi: UniFilter<S, A>,
187 KF: Fn(&A) -> K + Send + Sync,
188 C: UniCollector<A> + Send + Sync + 'static,
189 C::Accumulator: Send + Sync,
190 C::Result: Send + Sync,
191 C::Value: Send + Sync,
192 W: Fn(&C::Result) -> Sc + Send + Sync,
193 Sc: Score + 'static,
194{
195 fn evaluate(&self, solution: &S) -> Sc {
196 let entities = self.extractor.extract(solution);
197
198 let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
200
201 for entity in entities {
202 if !self.filter.test(solution, entity) {
203 continue;
204 }
205 let key = (self.key_fn)(entity);
206 let value = self.collector.extract(entity);
207 let acc = groups
208 .entry(key)
209 .or_insert_with(|| self.collector.create_accumulator());
210 acc.accumulate(&value);
211 }
212
213 let mut total = Sc::zero();
215 for acc in groups.values() {
216 let result = acc.finish();
217 total = total + self.compute_score(&result);
218 }
219
220 total
221 }
222
223 fn match_count(&self, solution: &S) -> usize {
224 let entities = self.extractor.extract(solution);
225
226 let mut groups: HashMap<K, ()> = HashMap::new();
228 for entity in entities {
229 if !self.filter.test(solution, entity) {
230 continue;
231 }
232 let key = (self.key_fn)(entity);
233 groups.insert(key, ());
234 }
235
236 groups.len()
237 }
238
239 fn initialize(&mut self, solution: &S) -> Sc {
240 self.reset();
241
242 let entities = self.extractor.extract(solution);
243 let mut total = Sc::zero();
244
245 for (idx, entity) in entities.iter().enumerate() {
246 if !self.filter.test(solution, entity) {
247 continue;
248 }
249 total = total + self.insert_entity(entities, idx, entity);
250 }
251
252 total
253 }
254
255 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
256 if let Some(expected) = self.expected_descriptor {
257 if descriptor_index != expected {
258 return Sc::zero();
259 }
260 }
261 let entities = self.extractor.extract(solution);
262 if entity_index >= entities.len() {
263 return Sc::zero();
264 }
265
266 let entity = &entities[entity_index];
267 if !self.filter.test(solution, entity) {
268 return Sc::zero();
269 }
270 self.insert_entity(entities, entity_index, entity)
271 }
272
273 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
274 if let Some(expected) = self.expected_descriptor {
275 if descriptor_index != expected {
276 return Sc::zero();
277 }
278 }
279 let entities = self.extractor.extract(solution);
280 self.retract_entity(entities, entity_index)
281 }
282
283 fn reset(&mut self) {
284 self.groups.clear();
285 self.group_counts.clear();
286 self.entity_groups.clear();
287 self.entity_values.clear();
288 }
289
290 fn name(&self) -> &str {
291 &self.constraint_ref.name
292 }
293
294 fn is_hard(&self) -> bool {
295 self.is_hard
296 }
297
298 fn constraint_ref(&self) -> ConstraintRef {
299 self.constraint_ref.clone()
300 }
301}
302
303impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
304where
305 S: Send + Sync + 'static,
306 A: Clone + Send + Sync + 'static,
307 K: Clone + Eq + Hash + Send + Sync + 'static,
308 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
309 Fi: UniFilter<S, A>,
310 KF: Fn(&A) -> K + Send + Sync,
311 C: UniCollector<A> + Send + Sync + 'static,
312 C::Accumulator: Send + Sync,
313 C::Result: Send + Sync,
314 C::Value: Send + Sync,
315 W: Fn(&C::Result) -> Sc + Send + Sync,
316 Sc: Score + 'static,
317{
318 fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
319 let key = (self.key_fn)(entity);
320 let value = self.collector.extract(entity);
321 let impact = self.impact_type;
322
323 let is_new = !self.groups.contains_key(&key);
325 let acc = self
326 .groups
327 .entry(key.clone())
328 .or_insert_with(|| self.collector.create_accumulator());
329
330 let old = if is_new {
332 Sc::zero()
333 } else {
334 let old_base = (self.weight_fn)(&acc.finish());
335 match impact {
336 ImpactType::Penalty => -old_base,
337 ImpactType::Reward => old_base,
338 }
339 };
340
341 acc.accumulate(&value);
343 let new_base = (self.weight_fn)(&acc.finish());
344 let new_score = match impact {
345 ImpactType::Penalty => -new_base,
346 ImpactType::Reward => new_base,
347 };
348
349 self.entity_groups.insert(entity_index, key.clone());
351 self.entity_values.insert(entity_index, value);
352 *self.group_counts.entry(key).or_insert(0) += 1;
353
354 new_score - old
356 }
357
358 fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
359 let Some(key) = self.entity_groups.remove(&entity_index) else {
361 return Sc::zero();
362 };
363
364 let Some(value) = self.entity_values.remove(&entity_index) else {
366 return Sc::zero();
367 };
368 let impact = self.impact_type;
369
370 let Some(acc) = self.groups.get_mut(&key) else {
372 return Sc::zero();
373 };
374
375 let old_base = (self.weight_fn)(&acc.finish());
377 let old = match impact {
378 ImpactType::Penalty => -old_base,
379 ImpactType::Reward => old_base,
380 };
381
382 let is_empty = {
384 let cnt = self.group_counts.entry(key.clone()).or_insert(0);
385 *cnt = cnt.saturating_sub(1);
386 *cnt == 0
387 };
388 if is_empty {
389 self.group_counts.remove(&key);
390 }
391
392 acc.retract(&value);
394 let new_score = if is_empty {
395 self.groups.remove(&key);
397 Sc::zero()
398 } else {
399 let new_base = (self.weight_fn)(&acc.finish());
400 match impact {
401 ImpactType::Penalty => -new_base,
402 ImpactType::Reward => new_base,
403 }
404 };
405
406 new_score - old
408 }
409}
410
411impl<S, A, K, E, Fi, KF, C, W, Sc> std::fmt::Debug
412 for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
413where
414 C: UniCollector<A>,
415 Sc: Score,
416{
417 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418 f.debug_struct("GroupedUniConstraint")
419 .field("name", &self.constraint_ref.name)
420 .field("impact_type", &self.impact_type)
421 .field("groups", &self.groups.len())
422 .finish()
423 }
424}