1use std::collections::HashMap;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::constraint_set::IncrementalConstraint;
15use crate::stream::collector::{Accumulator, UniCollector};
16use crate::stream::filter::UniFilter;
17
18pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
82where
83 C: UniCollector<A>,
84 Sc: Score,
85{
86 constraint_ref: ConstraintRef,
87 impact_type: ImpactType,
88 extractor: E,
89 filter: Fi,
90 key_fn: KF,
91 collector: C,
92 weight_fn: W,
93 is_hard: bool,
94 expected_descriptor: Option<usize>,
95 groups: HashMap<K, C::Accumulator>,
97 group_counts: HashMap<K, usize>,
99 entity_groups: HashMap<usize, K>,
101 entity_values: HashMap<usize, C::Value>,
103 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
104}
105
106impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
107where
108 S: Send + Sync + 'static,
109 A: Clone + Send + Sync + 'static,
110 K: Clone + Eq + Hash + Send + Sync + 'static,
111 E: Fn(&S) -> &[A] + Send + Sync,
112 Fi: UniFilter<S, A>,
113 KF: Fn(&A) -> K + Send + Sync,
114 C: UniCollector<A> + Send + Sync + 'static,
115 C::Accumulator: Send + Sync,
116 C::Result: Send + Sync,
117 W: Fn(&C::Result) -> Sc + Send + Sync,
118 Sc: Score + 'static,
119{
120 #[allow(clippy::too_many_arguments)]
133 pub fn new(
134 constraint_ref: ConstraintRef,
135 impact_type: ImpactType,
136 extractor: E,
137 filter: Fi,
138 key_fn: KF,
139 collector: C,
140 weight_fn: W,
141 is_hard: bool,
142 ) -> Self {
143 Self {
144 constraint_ref,
145 impact_type,
146 extractor,
147 filter,
148 key_fn,
149 collector,
150 weight_fn,
151 is_hard,
152 expected_descriptor: None,
153 groups: HashMap::new(),
154 group_counts: HashMap::new(),
155 entity_groups: HashMap::new(),
156 entity_values: HashMap::new(),
157 _phantom: PhantomData,
158 }
159 }
160
161 pub fn with_descriptor(mut self, descriptor_index: usize) -> Self {
162 self.expected_descriptor = Some(descriptor_index);
163 self
164 }
165
166 fn compute_score(&self, result: &C::Result) -> Sc {
168 let base = (self.weight_fn)(result);
169 match self.impact_type {
170 ImpactType::Penalty => -base,
171 ImpactType::Reward => base,
172 }
173 }
174}
175
176impl<S, A, K, E, Fi, KF, C, W, Sc> IncrementalConstraint<S, Sc>
177 for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
178where
179 S: Send + Sync + 'static,
180 A: Clone + Send + Sync + 'static,
181 K: Clone + Eq + Hash + Send + Sync + 'static,
182 E: Fn(&S) -> &[A] + Send + Sync,
183 Fi: UniFilter<S, A>,
184 KF: Fn(&A) -> K + Send + Sync,
185 C: UniCollector<A> + Send + Sync + 'static,
186 C::Accumulator: Send + Sync,
187 C::Result: Send + Sync,
188 C::Value: Send + Sync,
189 W: Fn(&C::Result) -> Sc + Send + Sync,
190 Sc: Score + 'static,
191{
192 fn evaluate(&self, solution: &S) -> Sc {
193 let entities = (self.extractor)(solution);
194
195 let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
197
198 for entity in entities {
199 if !self.filter.test(solution, entity) {
200 continue;
201 }
202 let key = (self.key_fn)(entity);
203 let value = self.collector.extract(entity);
204 let acc = groups
205 .entry(key)
206 .or_insert_with(|| self.collector.create_accumulator());
207 acc.accumulate(&value);
208 }
209
210 let mut total = Sc::zero();
212 for acc in groups.values() {
213 let result = acc.finish();
214 total = total + self.compute_score(&result);
215 }
216
217 total
218 }
219
220 fn match_count(&self, solution: &S) -> usize {
221 let entities = (self.extractor)(solution);
222
223 let mut groups: HashMap<K, ()> = HashMap::new();
225 for entity in entities {
226 if !self.filter.test(solution, entity) {
227 continue;
228 }
229 let key = (self.key_fn)(entity);
230 groups.insert(key, ());
231 }
232
233 groups.len()
234 }
235
236 fn initialize(&mut self, solution: &S) -> Sc {
237 self.reset();
238
239 let entities = (self.extractor)(solution);
240 let mut total = Sc::zero();
241
242 for (idx, entity) in entities.iter().enumerate() {
243 if !self.filter.test(solution, entity) {
244 continue;
245 }
246 total = total + self.insert_entity(entities, idx, entity);
247 }
248
249 total
250 }
251
252 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
253 if let Some(expected) = self.expected_descriptor {
254 if descriptor_index != expected {
255 return Sc::zero();
256 }
257 }
258 let entities = (self.extractor)(solution);
259 if entity_index >= entities.len() {
260 return Sc::zero();
261 }
262
263 let entity = &entities[entity_index];
264 if !self.filter.test(solution, entity) {
265 return Sc::zero();
266 }
267 self.insert_entity(entities, entity_index, entity)
268 }
269
270 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
271 if let Some(expected) = self.expected_descriptor {
272 if descriptor_index != expected {
273 return Sc::zero();
274 }
275 }
276 let entities = (self.extractor)(solution);
277 self.retract_entity(entities, entity_index)
278 }
279
280 fn reset(&mut self) {
281 self.groups.clear();
282 self.group_counts.clear();
283 self.entity_groups.clear();
284 self.entity_values.clear();
285 }
286
287 fn name(&self) -> &str {
288 &self.constraint_ref.name
289 }
290
291 fn is_hard(&self) -> bool {
292 self.is_hard
293 }
294
295 fn constraint_ref(&self) -> ConstraintRef {
296 self.constraint_ref.clone()
297 }
298}
299
300impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
301where
302 S: Send + Sync + 'static,
303 A: Clone + Send + Sync + 'static,
304 K: Clone + Eq + Hash + Send + Sync + 'static,
305 E: Fn(&S) -> &[A] + Send + Sync,
306 Fi: UniFilter<S, A>,
307 KF: Fn(&A) -> K + Send + Sync,
308 C: UniCollector<A> + Send + Sync + 'static,
309 C::Accumulator: Send + Sync,
310 C::Result: Send + Sync,
311 C::Value: Send + Sync,
312 W: Fn(&C::Result) -> Sc + Send + Sync,
313 Sc: Score + 'static,
314{
315 fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
316 let key = (self.key_fn)(entity);
317 let value = self.collector.extract(entity);
318 let impact = self.impact_type;
319
320 let is_new = !self.groups.contains_key(&key);
322 let acc = self
323 .groups
324 .entry(key.clone())
325 .or_insert_with(|| self.collector.create_accumulator());
326
327 let old = if is_new {
329 Sc::zero()
330 } else {
331 let old_base = (self.weight_fn)(&acc.finish());
332 match impact {
333 ImpactType::Penalty => -old_base,
334 ImpactType::Reward => old_base,
335 }
336 };
337
338 acc.accumulate(&value);
340 let new_base = (self.weight_fn)(&acc.finish());
341 let new_score = match impact {
342 ImpactType::Penalty => -new_base,
343 ImpactType::Reward => new_base,
344 };
345
346 self.entity_groups.insert(entity_index, key.clone());
348 self.entity_values.insert(entity_index, value);
349 *self.group_counts.entry(key).or_insert(0) += 1;
350
351 new_score - old
353 }
354
355 fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
356 let Some(key) = self.entity_groups.remove(&entity_index) else {
358 return Sc::zero();
359 };
360
361 let Some(value) = self.entity_values.remove(&entity_index) else {
363 return Sc::zero();
364 };
365 let impact = self.impact_type;
366
367 let Some(acc) = self.groups.get_mut(&key) else {
369 return Sc::zero();
370 };
371
372 let old_base = (self.weight_fn)(&acc.finish());
374 let old = match impact {
375 ImpactType::Penalty => -old_base,
376 ImpactType::Reward => old_base,
377 };
378
379 let is_empty = {
381 let cnt = self.group_counts.entry(key.clone()).or_insert(0);
382 *cnt = cnt.saturating_sub(1);
383 *cnt == 0
384 };
385 if is_empty {
386 self.group_counts.remove(&key);
387 }
388
389 acc.retract(&value);
391 let new_score = if is_empty {
392 self.groups.remove(&key);
394 Sc::zero()
395 } else {
396 let new_base = (self.weight_fn)(&acc.finish());
397 match impact {
398 ImpactType::Penalty => -new_base,
399 ImpactType::Reward => new_base,
400 }
401 };
402
403 new_score - old
405 }
406}
407
408impl<S, A, K, E, Fi, KF, C, W, Sc> std::fmt::Debug
409 for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
410where
411 C: UniCollector<A>,
412 Sc: Score,
413{
414 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415 f.debug_struct("GroupedUniConstraint")
416 .field("name", &self.constraint_ref.name)
417 .field("impact_type", &self.impact_type)
418 .field("groups", &self.groups.len())
419 .finish()
420 }
421}