1use std::collections::HashMap;
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::api::constraint_set::IncrementalConstraint;
14use crate::stream::collector::{Accumulator, UniCollector};
15
16pub struct ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
87where
88 C: UniCollector<A>,
89 Sc: Score,
90{
91 constraint_ref: ConstraintRef,
92 impact_type: ImpactType,
93 extractor_a: EA,
94 extractor_b: EB,
95 key_a: KA,
96 key_b: KB,
97 collector: C,
98 default_fn: D,
99 weight_fn: W,
100 is_hard: bool,
101 groups: HashMap<K, C::Accumulator>,
103 entity_groups: HashMap<usize, K>,
105 entity_values: HashMap<usize, C::Value>,
107 b_by_key: HashMap<K, usize>,
109 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> Sc)>,
110}
111
112impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
113 ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
114where
115 S: 'static,
116 A: Clone + 'static,
117 B: Clone + 'static,
118 K: Clone + Eq + Hash,
119 EA: Fn(&S) -> &[A],
120 EB: Fn(&S) -> &[B],
121 KA: Fn(&A) -> Option<K>,
122 KB: Fn(&B) -> K,
123 C: UniCollector<A>,
124 C::Result: Clone,
125 D: Fn(&B) -> C::Result,
126 W: Fn(&C::Result) -> Sc,
127 Sc: Score,
128{
129 #[allow(clippy::too_many_arguments)]
131 pub fn new(
132 constraint_ref: ConstraintRef,
133 impact_type: ImpactType,
134 extractor_a: EA,
135 extractor_b: EB,
136 key_a: KA,
137 key_b: KB,
138 collector: C,
139 default_fn: D,
140 weight_fn: W,
141 is_hard: bool,
142 ) -> Self {
143 Self {
144 constraint_ref,
145 impact_type,
146 extractor_a,
147 extractor_b,
148 key_a,
149 key_b,
150 collector,
151 default_fn,
152 weight_fn,
153 is_hard,
154 groups: HashMap::new(),
155 entity_groups: HashMap::new(),
156 entity_values: HashMap::new(),
157 b_by_key: HashMap::new(),
158 _phantom: PhantomData,
159 }
160 }
161
162 #[inline]
163 fn compute_score(&self, result: &C::Result) -> Sc {
164 let base = (self.weight_fn)(result);
165 match self.impact_type {
166 ImpactType::Penalty => -base,
167 ImpactType::Reward => base,
168 }
169 }
170
171 fn build_groups(&self, entities_a: &[A]) -> HashMap<K, C::Result> {
173 let mut accumulators: HashMap<K, C::Accumulator> = HashMap::new();
174
175 for a in entities_a {
176 let Some(key) = (self.key_a)(a) else {
178 continue;
179 };
180 let value = self.collector.extract(a);
181 accumulators
182 .entry(key)
183 .or_insert_with(|| self.collector.create_accumulator())
184 .accumulate(&value);
185 }
186
187 accumulators
188 .into_iter()
189 .map(|(k, acc)| (k, acc.finish()))
190 .collect()
191 }
192}
193
194impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> IncrementalConstraint<S, Sc>
195 for ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
196where
197 S: Send + Sync + 'static,
198 A: Clone + Send + Sync + 'static,
199 B: Clone + Send + Sync + 'static,
200 K: Clone + Eq + Hash + Send + Sync,
201 EA: Fn(&S) -> &[A] + Send + Sync,
202 EB: Fn(&S) -> &[B] + Send + Sync,
203 KA: Fn(&A) -> Option<K> + Send + Sync,
204 KB: Fn(&B) -> K + Send + Sync,
205 C: UniCollector<A> + Send + Sync,
206 C::Accumulator: Send + Sync,
207 C::Result: Clone + Send + Sync,
208 C::Value: Send + Sync,
209 D: Fn(&B) -> C::Result + Send + Sync,
210 W: Fn(&C::Result) -> Sc + Send + Sync,
211 Sc: Score,
212{
213 fn evaluate(&self, solution: &S) -> Sc {
214 let entities_a = (self.extractor_a)(solution);
215 let entities_b = (self.extractor_b)(solution);
216
217 let groups = self.build_groups(entities_a);
218
219 let mut total = Sc::zero();
220 for b in entities_b {
221 let key = (self.key_b)(b);
222 let result = groups
223 .get(&key)
224 .cloned()
225 .unwrap_or_else(|| (self.default_fn)(b));
226 total = total + self.compute_score(&result);
227 }
228
229 total
230 }
231
232 fn match_count(&self, solution: &S) -> usize {
233 let entities_b = (self.extractor_b)(solution);
234 entities_b.len()
235 }
236
237 fn initialize(&mut self, solution: &S) -> Sc {
238 self.reset();
239
240 let entities_a = (self.extractor_a)(solution);
241 let entities_b = (self.extractor_b)(solution);
242
243 for (idx, b) in entities_b.iter().enumerate() {
245 let key = (self.key_b)(b);
246 self.b_by_key.insert(key, idx);
247 }
248
249 let mut total = Sc::zero();
251 for b in entities_b {
252 let default_result = (self.default_fn)(b);
253 total = total + self.compute_score(&default_result);
254 }
255
256 for (idx, a) in entities_a.iter().enumerate() {
258 total = total + self.insert_entity(entities_b, idx, a);
259 }
260
261 total
262 }
263
264 fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
265 let entities_a = (self.extractor_a)(solution);
266 let entities_b = (self.extractor_b)(solution);
267
268 if entity_index >= entities_a.len() {
269 return Sc::zero();
270 }
271
272 let entity = &entities_a[entity_index];
273 self.insert_entity(entities_b, entity_index, entity)
274 }
275
276 fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
277 let entities_a = (self.extractor_a)(solution);
278 let entities_b = (self.extractor_b)(solution);
279
280 self.retract_entity(entities_a, entities_b, entity_index)
281 }
282
283 fn reset(&mut self) {
284 self.groups.clear();
285 self.entity_groups.clear();
286 self.entity_values.clear();
287 self.b_by_key.clear();
288 }
289
290 fn name(&self) -> &str {
291 &self.constraint_ref.name
292 }
293
294 fn is_hard(&self) -> bool {
295 self.is_hard
296 }
297
298 fn constraint_ref(&self) -> ConstraintRef {
299 self.constraint_ref.clone()
300 }
301}
302
303impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
304 ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
305where
306 S: Send + Sync + 'static,
307 A: Clone + Send + Sync + 'static,
308 B: Clone + Send + Sync + 'static,
309 K: Clone + Eq + Hash + Send + Sync,
310 EA: Fn(&S) -> &[A] + Send + Sync,
311 EB: Fn(&S) -> &[B] + Send + Sync,
312 KA: Fn(&A) -> Option<K> + Send + Sync,
313 KB: Fn(&B) -> K + Send + Sync,
314 C: UniCollector<A> + Send + Sync,
315 C::Accumulator: Send + Sync,
316 C::Result: Clone + Send + Sync,
317 C::Value: Send + Sync,
318 D: Fn(&B) -> C::Result + Send + Sync,
319 W: Fn(&C::Result) -> Sc + Send + Sync,
320 Sc: Score,
321{
322 fn insert_entity(&mut self, entities_b: &[B], entity_index: usize, entity: &A) -> Sc {
324 let Some(key) = (self.key_a)(entity) else {
326 return Sc::zero();
327 };
328 let value = self.collector.extract(entity);
329 let impact = self.impact_type;
330
331 let b_idx = self.b_by_key.get(&key).copied();
333 let Some(b_idx) = b_idx else {
334 let acc = self
337 .groups
338 .entry(key.clone())
339 .or_insert_with(|| self.collector.create_accumulator());
340 acc.accumulate(&value);
341 self.entity_groups.insert(entity_index, key);
342 self.entity_values.insert(entity_index, value);
343 return Sc::zero();
344 };
345
346 let b = &entities_b[b_idx];
347
348 let old_result = self
350 .groups
351 .get(&key)
352 .map(|acc| acc.finish())
353 .unwrap_or_else(|| (self.default_fn)(b));
354 let old_base = (self.weight_fn)(&old_result);
355 let old = match impact {
356 ImpactType::Penalty => -old_base,
357 ImpactType::Reward => old_base,
358 };
359
360 let acc = self
362 .groups
363 .entry(key.clone())
364 .or_insert_with(|| self.collector.create_accumulator());
365 acc.accumulate(&value);
366
367 let new_result = acc.finish();
369 let new_base = (self.weight_fn)(&new_result);
370 let new_score = match impact {
371 ImpactType::Penalty => -new_base,
372 ImpactType::Reward => new_base,
373 };
374
375 self.entity_groups.insert(entity_index, key);
377 self.entity_values.insert(entity_index, value);
378
379 new_score - old
381 }
382
383 fn retract_entity(&mut self, _entities_a: &[A], _entities_b: &[B], entity_index: usize) -> Sc {
385 let Some(key) = self.entity_groups.remove(&entity_index) else {
387 return Sc::zero();
388 };
389
390 let Some(value) = self.entity_values.remove(&entity_index) else {
392 return Sc::zero();
393 };
394 let impact = self.impact_type;
395
396 let b_idx = self.b_by_key.get(&key).copied();
398 if b_idx.is_none() {
399 if let Some(acc) = self.groups.get_mut(&key) {
401 acc.retract(&value);
402 }
403 return Sc::zero();
404 }
405
406 let Some(acc) = self.groups.get_mut(&key) else {
408 return Sc::zero();
409 };
410
411 let old_result = acc.finish();
413 let old_base = (self.weight_fn)(&old_result);
414 let old = match impact {
415 ImpactType::Penalty => -old_base,
416 ImpactType::Reward => old_base,
417 };
418
419 acc.retract(&value);
421
422 let new_result = acc.finish();
424 let new_base = (self.weight_fn)(&new_result);
425 let new_score = match impact {
426 ImpactType::Penalty => -new_base,
427 ImpactType::Reward => new_base,
428 };
429
430 new_score - old
432 }
433}
434
435impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> std::fmt::Debug
436 for ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
437where
438 C: UniCollector<A>,
439 Sc: Score,
440{
441 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
442 f.debug_struct("ComplementedGroupConstraint")
443 .field("name", &self.constraint_ref.name)
444 .field("impact_type", &self.impact_type)
445 .field("groups", &self.groups.len())
446 .finish()
447 }
448}