1use std::collections::HashMap;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::constraint_set::IncrementalConstraint;
15use crate::stream::collector::{Accumulator, UniCollector};
16
17pub struct ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
89where
90 C: UniCollector<A>,
91 Sc: Score,
92{
93 constraint_ref: ConstraintRef,
94 impact_type: ImpactType,
95 extractor_a: EA,
96 extractor_b: EB,
97 key_a: KA,
98 key_b: KB,
99 collector: C,
100 default_fn: D,
101 weight_fn: W,
102 is_hard: bool,
103 groups: HashMap<K, C::Accumulator>,
105 entity_groups: HashMap<usize, K>,
107 entity_values: HashMap<usize, C::Value>,
109 b_by_key: HashMap<K, usize>,
111 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> Sc)>,
112}
113
114impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
115 ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
116where
117 S: 'static,
118 A: Clone + 'static,
119 B: Clone + 'static,
120 K: Clone + Eq + Hash,
121 EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
122 EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
123 KA: Fn(&A) -> Option<K>,
124 KB: Fn(&B) -> K,
125 C: UniCollector<A>,
126 C::Result: Clone,
127 D: Fn(&B) -> C::Result,
128 W: Fn(&C::Result) -> Sc,
129 Sc: Score,
130{
131 #[allow(clippy::too_many_arguments)]
133 pub fn new(
134 constraint_ref: ConstraintRef,
135 impact_type: ImpactType,
136 extractor_a: EA,
137 extractor_b: EB,
138 key_a: KA,
139 key_b: KB,
140 collector: C,
141 default_fn: D,
142 weight_fn: W,
143 is_hard: bool,
144 ) -> Self {
145 Self {
146 constraint_ref,
147 impact_type,
148 extractor_a,
149 extractor_b,
150 key_a,
151 key_b,
152 collector,
153 default_fn,
154 weight_fn,
155 is_hard,
156 groups: HashMap::new(),
157 entity_groups: HashMap::new(),
158 entity_values: HashMap::new(),
159 b_by_key: HashMap::new(),
160 _phantom: PhantomData,
161 }
162 }
163
164 #[inline]
165 fn compute_score(&self, result: &C::Result) -> Sc {
166 let base = (self.weight_fn)(result);
167 match self.impact_type {
168 ImpactType::Penalty => -base,
169 ImpactType::Reward => base,
170 }
171 }
172
173 fn build_groups(&self, entities_a: &[A]) -> HashMap<K, C::Result> {
175 let mut accumulators: HashMap<K, C::Accumulator> = HashMap::new();
176
177 for a in entities_a {
178 let Some(key) = (self.key_a)(a) else {
180 continue;
181 };
182 let value = self.collector.extract(a);
183 accumulators
184 .entry(key)
185 .or_insert_with(|| self.collector.create_accumulator())
186 .accumulate(&value);
187 }
188
189 accumulators
190 .into_iter()
191 .map(|(k, acc)| (k, acc.finish()))
192 .collect()
193 }
194}
195
196impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> IncrementalConstraint<S, Sc>
197 for ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
198where
199 S: Send + Sync + 'static,
200 A: Clone + Send + Sync + 'static,
201 B: Clone + Send + Sync + 'static,
202 K: Clone + Eq + Hash + Send + Sync,
203 EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
204 EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
205 KA: Fn(&A) -> Option<K> + Send + Sync,
206 KB: Fn(&B) -> K + Send + Sync,
207 C: UniCollector<A> + Send + Sync,
208 C::Accumulator: Send + Sync,
209 C::Result: Clone + Send + Sync,
210 C::Value: Send + Sync,
211 D: Fn(&B) -> C::Result + Send + Sync,
212 W: Fn(&C::Result) -> Sc + Send + Sync,
213 Sc: Score,
214{
215 fn evaluate(&self, solution: &S) -> Sc {
216 let entities_a = self.extractor_a.extract(solution);
217 let entities_b = self.extractor_b.extract(solution);
218
219 let groups = self.build_groups(entities_a);
220
221 let mut total = Sc::zero();
222 for b in entities_b {
223 let key = (self.key_b)(b);
224 let result = groups
225 .get(&key)
226 .cloned()
227 .unwrap_or_else(|| (self.default_fn)(b));
228 total = total + self.compute_score(&result);
229 }
230
231 total
232 }
233
234 fn match_count(&self, solution: &S) -> usize {
235 let entities_b = self.extractor_b.extract(solution);
236 entities_b.len()
237 }
238
239 fn initialize(&mut self, solution: &S) -> Sc {
240 self.reset();
241
242 let entities_a = self.extractor_a.extract(solution);
243 let entities_b = self.extractor_b.extract(solution);
244
245 for (idx, b) in entities_b.iter().enumerate() {
247 let key = (self.key_b)(b);
248 self.b_by_key.insert(key, idx);
249 }
250
251 let mut total = Sc::zero();
253 for b in entities_b {
254 let default_result = (self.default_fn)(b);
255 total = total + self.compute_score(&default_result);
256 }
257
258 for (idx, a) in entities_a.iter().enumerate() {
260 total = total + self.insert_entity(entities_b, idx, a);
261 }
262
263 total
264 }
265
266 fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
267 let entities_a = self.extractor_a.extract(solution);
268 let entities_b = self.extractor_b.extract(solution);
269
270 if entity_index >= entities_a.len() {
271 return Sc::zero();
272 }
273
274 let entity = &entities_a[entity_index];
275 self.insert_entity(entities_b, entity_index, entity)
276 }
277
278 fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
279 let entities_a = self.extractor_a.extract(solution);
280 let entities_b = self.extractor_b.extract(solution);
281
282 self.retract_entity(entities_a, entities_b, entity_index)
283 }
284
285 fn reset(&mut self) {
286 self.groups.clear();
287 self.entity_groups.clear();
288 self.entity_values.clear();
289 self.b_by_key.clear();
290 }
291
292 fn name(&self) -> &str {
293 &self.constraint_ref.name
294 }
295
296 fn is_hard(&self) -> bool {
297 self.is_hard
298 }
299
300 fn constraint_ref(&self) -> ConstraintRef {
301 self.constraint_ref.clone()
302 }
303}
304
305impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
306 ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
307where
308 S: Send + Sync + 'static,
309 A: Clone + Send + Sync + 'static,
310 B: Clone + Send + Sync + 'static,
311 K: Clone + Eq + Hash + Send + Sync,
312 EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
313 EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
314 KA: Fn(&A) -> Option<K> + Send + Sync,
315 KB: Fn(&B) -> K + Send + Sync,
316 C: UniCollector<A> + Send + Sync,
317 C::Accumulator: Send + Sync,
318 C::Result: Clone + Send + Sync,
319 C::Value: Send + Sync,
320 D: Fn(&B) -> C::Result + Send + Sync,
321 W: Fn(&C::Result) -> Sc + Send + Sync,
322 Sc: Score,
323{
324 fn insert_entity(&mut self, entities_b: &[B], entity_index: usize, entity: &A) -> Sc {
326 let Some(key) = (self.key_a)(entity) else {
328 return Sc::zero();
329 };
330 let value = self.collector.extract(entity);
331 let impact = self.impact_type;
332
333 let b_idx = self.b_by_key.get(&key).copied();
335 let Some(b_idx) = b_idx else {
336 let acc = self
339 .groups
340 .entry(key.clone())
341 .or_insert_with(|| self.collector.create_accumulator());
342 acc.accumulate(&value);
343 self.entity_groups.insert(entity_index, key);
344 self.entity_values.insert(entity_index, value);
345 return Sc::zero();
346 };
347
348 let b = &entities_b[b_idx];
349
350 let old_result = self
352 .groups
353 .get(&key)
354 .map(|acc| acc.finish())
355 .unwrap_or_else(|| (self.default_fn)(b));
356 let old_base = (self.weight_fn)(&old_result);
357 let old = match impact {
358 ImpactType::Penalty => -old_base,
359 ImpactType::Reward => old_base,
360 };
361
362 let acc = self
364 .groups
365 .entry(key.clone())
366 .or_insert_with(|| self.collector.create_accumulator());
367 acc.accumulate(&value);
368
369 let new_result = acc.finish();
371 let new_base = (self.weight_fn)(&new_result);
372 let new_score = match impact {
373 ImpactType::Penalty => -new_base,
374 ImpactType::Reward => new_base,
375 };
376
377 self.entity_groups.insert(entity_index, key);
379 self.entity_values.insert(entity_index, value);
380
381 new_score - old
383 }
384
385 fn retract_entity(&mut self, _entities_a: &[A], _entities_b: &[B], entity_index: usize) -> Sc {
387 let Some(key) = self.entity_groups.remove(&entity_index) else {
389 return Sc::zero();
390 };
391
392 let Some(value) = self.entity_values.remove(&entity_index) else {
394 return Sc::zero();
395 };
396 let impact = self.impact_type;
397
398 let b_idx = self.b_by_key.get(&key).copied();
400 if b_idx.is_none() {
401 if let Some(acc) = self.groups.get_mut(&key) {
403 acc.retract(&value);
404 }
405 return Sc::zero();
406 }
407
408 let Some(acc) = self.groups.get_mut(&key) else {
410 return Sc::zero();
411 };
412
413 let old_result = acc.finish();
415 let old_base = (self.weight_fn)(&old_result);
416 let old = match impact {
417 ImpactType::Penalty => -old_base,
418 ImpactType::Reward => old_base,
419 };
420
421 acc.retract(&value);
423
424 let new_result = acc.finish();
426 let new_base = (self.weight_fn)(&new_result);
427 let new_score = match impact {
428 ImpactType::Penalty => -new_base,
429 ImpactType::Reward => new_base,
430 };
431
432 new_score - old
434 }
435}
436
437impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> std::fmt::Debug
438 for ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
439where
440 C: UniCollector<A>,
441 Sc: Score,
442{
443 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444 f.debug_struct("ComplementedGroupConstraint")
445 .field("name", &self.constraint_ref.name)
446 .field("impact_type", &self.impact_type)
447 .field("groups", &self.groups.len())
448 .finish()
449 }
450}