1use std::collections::HashMap;
9use std::hash::Hash;
10use std::marker::PhantomData;
11
12use solverforge_core::score::Score;
13use solverforge_core::{ConstraintRef, ImpactType};
14
15use crate::api::constraint_set::IncrementalConstraint;
16use crate::stream::collector::{Accumulator, UniCollector};
17use crate::stream::filter::UniFilter;
18
19pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
84where
85 C: UniCollector<A>,
86 Sc: Score,
87{
88 constraint_ref: ConstraintRef,
89 impact_type: ImpactType,
90 extractor: E,
91 filter: Fi,
92 key_fn: KF,
93 collector: C,
94 weight_fn: W,
95 is_hard: bool,
96 change_source: crate::stream::collection_extract::ChangeSource,
97 groups: HashMap<K, C::Accumulator>,
99 group_counts: HashMap<K, usize>,
101 entity_groups: HashMap<usize, K>,
103 entity_values: HashMap<usize, C::Value>,
105 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
106}
107
108impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
109where
110 S: Send + Sync + 'static,
111 A: Clone + Send + Sync + 'static,
112 K: Clone + Eq + Hash + Send + Sync + 'static,
113 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
114 Fi: UniFilter<S, A>,
115 KF: Fn(&A) -> K + Send + Sync,
116 C: UniCollector<A> + Send + Sync + 'static,
117 C::Accumulator: Send + Sync,
118 C::Result: Send + Sync,
119 W: Fn(&C::Result) -> Sc + Send + Sync,
120 Sc: Score + 'static,
121{
122 #[allow(clippy::too_many_arguments)]
136 pub fn new(
137 constraint_ref: ConstraintRef,
138 impact_type: ImpactType,
139 extractor: E,
140 filter: Fi,
141 key_fn: KF,
142 collector: C,
143 weight_fn: W,
144 is_hard: bool,
145 ) -> Self {
146 let change_source = extractor.change_source();
147 Self {
148 constraint_ref,
149 impact_type,
150 extractor,
151 filter,
152 key_fn,
153 collector,
154 weight_fn,
155 is_hard,
156 change_source,
157 groups: HashMap::new(),
158 group_counts: HashMap::new(),
159 entity_groups: HashMap::new(),
160 entity_values: HashMap::new(),
161 _phantom: PhantomData,
162 }
163 }
164
165 fn compute_score(&self, result: &C::Result) -> Sc {
167 let base = (self.weight_fn)(result);
168 match self.impact_type {
169 ImpactType::Penalty => -base,
170 ImpactType::Reward => base,
171 }
172 }
173}
174
175impl<S, A, K, E, Fi, KF, C, W, Sc> IncrementalConstraint<S, Sc>
176 for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
177where
178 S: Send + Sync + 'static,
179 A: Clone + Send + Sync + 'static,
180 K: Clone + Eq + Hash + Send + Sync + 'static,
181 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
182 Fi: UniFilter<S, A>,
183 KF: Fn(&A) -> K + Send + Sync,
184 C: UniCollector<A> + Send + Sync + 'static,
185 C::Accumulator: Send + Sync,
186 C::Result: Send + Sync,
187 C::Value: Send + Sync,
188 W: Fn(&C::Result) -> Sc + Send + Sync,
189 Sc: Score + 'static,
190{
191 fn evaluate(&self, solution: &S) -> Sc {
192 let entities = self.extractor.extract(solution);
193
194 let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
196
197 for entity in entities {
198 if !self.filter.test(solution, entity) {
199 continue;
200 }
201 let key = (self.key_fn)(entity);
202 let value = self.collector.extract(entity);
203 let acc = groups
204 .entry(key)
205 .or_insert_with(|| self.collector.create_accumulator());
206 acc.accumulate(&value);
207 }
208
209 let mut total = Sc::zero();
211 for acc in groups.values() {
212 let result = acc.finish();
213 total = total + self.compute_score(&result);
214 }
215
216 total
217 }
218
219 fn match_count(&self, solution: &S) -> usize {
220 let entities = self.extractor.extract(solution);
221
222 let mut groups: HashMap<K, ()> = HashMap::new();
224 for entity in entities {
225 if !self.filter.test(solution, entity) {
226 continue;
227 }
228 let key = (self.key_fn)(entity);
229 groups.insert(key, ());
230 }
231
232 groups.len()
233 }
234
235 fn initialize(&mut self, solution: &S) -> Sc {
236 self.reset();
237
238 let entities = self.extractor.extract(solution);
239 let mut total = Sc::zero();
240
241 for (idx, entity) in entities.iter().enumerate() {
242 if !self.filter.test(solution, entity) {
243 continue;
244 }
245 total = total + self.insert_entity(entities, idx, entity);
246 }
247
248 total
249 }
250
251 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
252 if !self
253 .change_source
254 .assert_localizes(descriptor_index, &self.constraint_ref.name)
255 {
256 return Sc::zero();
257 }
258 let entities = self.extractor.extract(solution);
259 if entity_index >= entities.len() {
260 return Sc::zero();
261 }
262
263 let entity = &entities[entity_index];
264 if !self.filter.test(solution, entity) {
265 return Sc::zero();
266 }
267 self.insert_entity(entities, entity_index, entity)
268 }
269
270 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
271 if !self
272 .change_source
273 .assert_localizes(descriptor_index, &self.constraint_ref.name)
274 {
275 return Sc::zero();
276 }
277 let entities = self.extractor.extract(solution);
278 self.retract_entity(entities, entity_index)
279 }
280
281 fn reset(&mut self) {
282 self.groups.clear();
283 self.group_counts.clear();
284 self.entity_groups.clear();
285 self.entity_values.clear();
286 }
287
288 fn name(&self) -> &str {
289 &self.constraint_ref.name
290 }
291
292 fn is_hard(&self) -> bool {
293 self.is_hard
294 }
295
296 fn constraint_ref(&self) -> &ConstraintRef {
297 &self.constraint_ref
298 }
299}
300
301impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
302where
303 S: Send + Sync + 'static,
304 A: Clone + Send + Sync + 'static,
305 K: Clone + Eq + Hash + Send + Sync + 'static,
306 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
307 Fi: UniFilter<S, A>,
308 KF: Fn(&A) -> K + Send + Sync,
309 C: UniCollector<A> + Send + Sync + 'static,
310 C::Accumulator: Send + Sync,
311 C::Result: Send + Sync,
312 C::Value: Send + Sync,
313 W: Fn(&C::Result) -> Sc + Send + Sync,
314 Sc: Score + 'static,
315{
316 fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
317 let key = (self.key_fn)(entity);
318 let value = self.collector.extract(entity);
319 let impact = self.impact_type;
320
321 let is_new = !self.groups.contains_key(&key);
323 let acc = self
324 .groups
325 .entry(key.clone())
326 .or_insert_with(|| self.collector.create_accumulator());
327
328 let old = if is_new {
330 Sc::zero()
331 } else {
332 let old_base = (self.weight_fn)(&acc.finish());
333 match impact {
334 ImpactType::Penalty => -old_base,
335 ImpactType::Reward => old_base,
336 }
337 };
338
339 acc.accumulate(&value);
341 let new_base = (self.weight_fn)(&acc.finish());
342 let new_score = match impact {
343 ImpactType::Penalty => -new_base,
344 ImpactType::Reward => new_base,
345 };
346
347 self.entity_groups.insert(entity_index, key.clone());
349 self.entity_values.insert(entity_index, value);
350 *self.group_counts.entry(key).or_insert(0) += 1;
351
352 new_score - old
354 }
355
356 fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
357 let Some(key) = self.entity_groups.remove(&entity_index) else {
359 return Sc::zero();
360 };
361
362 let Some(value) = self.entity_values.remove(&entity_index) else {
364 return Sc::zero();
365 };
366 let impact = self.impact_type;
367
368 let Some(acc) = self.groups.get_mut(&key) else {
370 return Sc::zero();
371 };
372
373 let old_base = (self.weight_fn)(&acc.finish());
375 let old = match impact {
376 ImpactType::Penalty => -old_base,
377 ImpactType::Reward => old_base,
378 };
379
380 let is_empty = {
382 let cnt = self.group_counts.entry(key.clone()).or_insert(0);
383 *cnt = cnt.saturating_sub(1);
384 *cnt == 0
385 };
386 if is_empty {
387 self.group_counts.remove(&key);
388 }
389
390 acc.retract(&value);
392 let new_score = if is_empty {
393 self.groups.remove(&key);
395 Sc::zero()
396 } else {
397 let new_base = (self.weight_fn)(&acc.finish());
398 match impact {
399 ImpactType::Penalty => -new_base,
400 ImpactType::Reward => new_base,
401 }
402 };
403
404 new_score - old
406 }
407}
408
409impl<S, A, K, E, Fi, KF, C, W, Sc> std::fmt::Debug
410 for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
411where
412 C: UniCollector<A>,
413 Sc: Score,
414{
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 f.debug_struct("GroupedUniConstraint")
417 .field("name", &self.constraint_ref.name)
418 .field("impact_type", &self.impact_type)
419 .field("groups", &self.groups.len())
420 .finish()
421 }
422}