1use std::collections::HashMap;
2use std::hash::Hash;
3use std::marker::PhantomData;
4use std::slice;
5
6use solverforge_core::score::Score;
7use solverforge_core::{ConstraintRef, ImpactType};
8
9use crate::api::constraint_set::IncrementalConstraint;
10use crate::stream::collection_extract::{ChangeSource, TrackedCollectionExtract};
11use crate::stream::filter::UniFilter;
12use crate::stream::{ExistenceMode, FlattenExtract};
13
14#[derive(Debug, Clone)]
15struct ASlot<K, Sc>
16where
17 Sc: Score,
18{
19 key: Option<K>,
20 bucket_pos: usize,
21 contribution: Sc,
22}
23
24impl<K, Sc> Default for ASlot<K, Sc>
25where
26 Sc: Score,
27{
28 fn default() -> Self {
29 Self {
30 key: None,
31 bucket_pos: 0,
32 contribution: Sc::zero(),
33 }
34 }
35}
36
37pub struct IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
38where
39 Sc: Score,
40{
41 constraint_ref: ConstraintRef,
42 impact_type: ImpactType,
43 mode: ExistenceMode,
44 extractor_a: EA,
45 extractor_parent: EP,
46 key_a: KA,
47 key_b: KB,
48 filter_a: FA,
49 filter_parent: FP,
50 flatten: Flatten,
51 weight: W,
52 is_hard: bool,
53 a_source: ChangeSource,
54 parent_source: ChangeSource,
55 a_slots: Vec<ASlot<K, Sc>>,
56 a_indices_by_key: HashMap<K, Vec<usize>>,
57 b_key_counts: HashMap<K, usize>,
58 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> P, fn() -> B)>,
59}
60
61impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
62 IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
63where
64 S: 'static,
65 A: Clone + 'static,
66 P: Clone + 'static,
67 B: Clone + 'static,
68 K: Eq + Hash + Clone,
69 EA: TrackedCollectionExtract<S, Item = A>,
70 EP: TrackedCollectionExtract<S, Item = P>,
71 KA: Fn(&A) -> K,
72 KB: Fn(&B) -> K,
73 FA: UniFilter<S, A>,
74 FP: UniFilter<S, P>,
75 Flatten: FlattenExtract<P, Item = B>,
76 W: Fn(&A) -> Sc,
77 Sc: Score,
78{
79 #[allow(clippy::too_many_arguments)]
80 pub fn new(
81 constraint_ref: ConstraintRef,
82 impact_type: ImpactType,
83 mode: ExistenceMode,
84 extractor_a: EA,
85 extractor_parent: EP,
86 key_a: KA,
87 key_b: KB,
88 filter_a: FA,
89 filter_parent: FP,
90 flatten: Flatten,
91 weight: W,
92 is_hard: bool,
93 ) -> Self {
94 let a_source = extractor_a.change_source();
95 let parent_source = extractor_parent.change_source();
96 Self {
97 constraint_ref,
98 impact_type,
99 mode,
100 extractor_a,
101 extractor_parent,
102 key_a,
103 key_b,
104 filter_a,
105 filter_parent,
106 flatten,
107 weight,
108 is_hard,
109 a_source,
110 parent_source,
111 a_slots: Vec::new(),
112 a_indices_by_key: HashMap::new(),
113 b_key_counts: HashMap::new(),
114 _phantom: PhantomData,
115 }
116 }
117
118 #[inline]
119 fn compute_score(&self, a: &A) -> Sc {
120 let base = (self.weight)(a);
121 match self.impact_type {
122 ImpactType::Penalty => -base,
123 ImpactType::Reward => base,
124 }
125 }
126
127 #[inline]
128 fn matches_existence(&self, key: &K) -> bool {
129 let count = self.b_key_counts.get(key).copied().unwrap_or(0);
130 match self.mode {
131 ExistenceMode::Exists => count > 0,
132 ExistenceMode::NotExists => count == 0,
133 }
134 }
135
136 fn rebuild_b_counts(&mut self, solution: &S) {
137 self.b_key_counts.clear();
138 for parent in self.extractor_parent.extract(solution) {
139 if !self.filter_parent.test(solution, parent) {
140 continue;
141 }
142 for item in self.flatten.extract(parent) {
143 *self.b_key_counts.entry((self.key_b)(item)).or_insert(0) += 1;
144 }
145 }
146 }
147
148 fn remove_a_from_bucket(&mut self, idx: usize, key: &K, bucket_pos: usize) {
149 let mut remove_key = false;
150 if let Some(bucket) = self.a_indices_by_key.get_mut(key) {
151 let removed = bucket.swap_remove(bucket_pos);
152 debug_assert_eq!(removed, idx);
153 if bucket_pos < bucket.len() {
154 let moved_idx = bucket[bucket_pos];
155 self.a_slots[moved_idx].bucket_pos = bucket_pos;
156 }
157 remove_key = bucket.is_empty();
158 }
159 if remove_key {
160 self.a_indices_by_key.remove(key);
161 }
162 }
163
164 fn retract_a(&mut self, idx: usize) -> Sc {
165 if idx >= self.a_slots.len() {
166 return Sc::zero();
167 }
168 let slot = self.a_slots[idx].clone();
169 let Some(key) = slot.key.clone() else {
170 return Sc::zero();
171 };
172 self.remove_a_from_bucket(idx, &key, slot.bucket_pos);
173 self.a_slots[idx] = ASlot::default();
174 -slot.contribution
175 }
176
177 fn insert_a(&mut self, solution: &S, idx: usize) -> Sc {
178 let entities_a = self.extractor_a.extract(solution);
179 if idx >= entities_a.len() {
180 return Sc::zero();
181 }
182 if self.a_slots.len() < entities_a.len() {
183 self.a_slots.resize(entities_a.len(), ASlot::default());
184 }
185
186 let a = &entities_a[idx];
187 if !self.filter_a.test(solution, a) {
188 self.a_slots[idx] = ASlot::default();
189 return Sc::zero();
190 }
191
192 let key = (self.key_a)(a);
193 let bucket = self.a_indices_by_key.entry(key.clone()).or_default();
194 let bucket_pos = bucket.len();
195 bucket.push(idx);
196
197 let contribution = if self.matches_existence(&key) {
198 self.compute_score(a)
199 } else {
200 Sc::zero()
201 };
202
203 self.a_slots[idx] = ASlot {
204 key: Some(key),
205 bucket_pos,
206 contribution,
207 };
208 contribution
209 }
210
211 fn reevaluate_key(&mut self, solution: &S, key: &K) -> Sc {
212 let Some(indices) = self.a_indices_by_key.get(key).cloned() else {
213 return Sc::zero();
214 };
215 let entities_a = self.extractor_a.extract(solution);
216 let mut total = Sc::zero();
217 let exists = self.matches_existence(key);
218
219 for idx in indices {
220 let a = &entities_a[idx];
221 let new_contribution = if exists {
222 self.compute_score(a)
223 } else {
224 Sc::zero()
225 };
226 let old_contribution = self.a_slots[idx].contribution;
227 self.a_slots[idx].contribution = new_contribution;
228 total = total + (new_contribution - old_contribution);
229 }
230
231 total
232 }
233
234 fn update_key_counts(
235 &mut self,
236 solution: &S,
237 key_multiset: &HashMap<K, usize>,
238 insert: bool,
239 ) -> Sc {
240 let mut total = Sc::zero();
241
242 for (key, count) in key_multiset {
243 if insert {
244 *self.b_key_counts.entry(key.clone()).or_insert(0) += *count;
245 } else {
246 let mut remove_key = false;
247 if let Some(entry) = self.b_key_counts.get_mut(key) {
248 *entry = entry.saturating_sub(*count);
249 remove_key = *entry == 0;
250 }
251 if remove_key {
252 self.b_key_counts.remove(key);
253 }
254 }
255 }
256
257 for key in key_multiset.keys() {
258 total = total + self.reevaluate_key(solution, key);
259 }
260
261 total
262 }
263
264 fn parent_key_multiset(&self, solution: &S, idx: usize) -> HashMap<K, usize> {
265 let parents = self.extractor_parent.extract(solution);
266 if idx >= parents.len() {
267 return HashMap::new();
268 }
269 let parent = &parents[idx];
270 if !self.filter_parent.test(solution, parent) {
271 return HashMap::new();
272 }
273
274 let mut multiset = HashMap::new();
275 for item in self.flatten.extract(parent) {
276 *multiset.entry((self.key_b)(item)).or_insert(0) += 1;
277 }
278 multiset
279 }
280
281 fn initialize_a_state(&mut self, solution: &S) -> Sc {
282 self.a_slots.clear();
283 self.a_indices_by_key.clear();
284
285 let len = self.extractor_a.extract(solution).len();
286 self.a_slots.resize(len, ASlot::default());
287
288 let mut total = Sc::zero();
289 for idx in 0..len {
290 total = total + self.insert_a(solution, idx);
291 }
292 total
293 }
294
295 fn full_match_count(&self, solution: &S) -> usize {
296 let mut key_counts = HashMap::<K, usize>::new();
297 for parent in self.extractor_parent.extract(solution) {
298 if !self.filter_parent.test(solution, parent) {
299 continue;
300 }
301 for item in self.flatten.extract(parent) {
302 *key_counts.entry((self.key_b)(item)).or_insert(0) += 1;
303 }
304 }
305
306 self.extractor_a
307 .extract(solution)
308 .iter()
309 .filter(|a| {
310 self.filter_a.test(solution, a)
311 && match self.mode {
312 ExistenceMode::Exists => {
313 key_counts.get(&(self.key_a)(a)).copied().unwrap_or(0) > 0
314 }
315 ExistenceMode::NotExists => {
316 key_counts.get(&(self.key_a)(a)).copied().unwrap_or(0) == 0
317 }
318 }
319 })
320 .count()
321 }
322}
323
324impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc> IncrementalConstraint<S, Sc>
325 for IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
326where
327 S: Send + Sync + 'static,
328 A: Clone + Send + Sync + 'static,
329 P: Clone + Send + Sync + 'static,
330 B: Clone + Send + Sync + 'static,
331 K: Eq + Hash + Clone + Send + Sync,
332 EA: TrackedCollectionExtract<S, Item = A> + Send + Sync,
333 EP: TrackedCollectionExtract<S, Item = P> + Send + Sync,
334 KA: Fn(&A) -> K + Send + Sync,
335 KB: Fn(&B) -> K + Send + Sync,
336 FA: UniFilter<S, A> + Send + Sync,
337 FP: UniFilter<S, P> + Send + Sync,
338 Flatten: FlattenExtract<P, Item = B> + Send + Sync,
339 W: Fn(&A) -> Sc + Send + Sync,
340 Sc: Score,
341{
342 fn evaluate(&self, solution: &S) -> Sc {
343 let mut counts = HashMap::<K, usize>::new();
344 for parent in self.extractor_parent.extract(solution) {
345 if !self.filter_parent.test(solution, parent) {
346 continue;
347 }
348 for item in self.flatten.extract(parent) {
349 *counts.entry((self.key_b)(item)).or_insert(0) += 1;
350 }
351 }
352
353 let mut total = Sc::zero();
354 for a in self.extractor_a.extract(solution) {
355 if !self.filter_a.test(solution, a) {
356 continue;
357 }
358 let key = (self.key_a)(a);
359 let matches = match self.mode {
360 ExistenceMode::Exists => counts.get(&key).copied().unwrap_or(0) > 0,
361 ExistenceMode::NotExists => counts.get(&key).copied().unwrap_or(0) == 0,
362 };
363 if matches {
364 total = total + self.compute_score(a);
365 }
366 }
367 total
368 }
369
370 fn match_count(&self, solution: &S) -> usize {
371 self.full_match_count(solution)
372 }
373
374 fn initialize(&mut self, solution: &S) -> Sc {
375 self.reset();
376 self.rebuild_b_counts(solution);
377 self.initialize_a_state(solution)
378 }
379
380 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
381 let a_changed =
382 matches!(self.a_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
383 let parent_changed =
384 matches!(self.parent_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
385 let same_source = self.a_source == self.parent_source && a_changed && parent_changed;
386
387 let mut total = Sc::zero();
388 if same_source {
389 let keys = self.parent_key_multiset(solution, entity_index);
390 total = total + self.update_key_counts(solution, &keys, true);
391 total = total + self.insert_a(solution, entity_index);
392 return total;
393 }
394
395 if parent_changed {
396 let keys = self.parent_key_multiset(solution, entity_index);
397 total = total + self.update_key_counts(solution, &keys, true);
398 }
399 if a_changed {
400 total = total + self.insert_a(solution, entity_index);
401 }
402 total
403 }
404
405 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
406 let a_changed =
407 matches!(self.a_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
408 let parent_changed =
409 matches!(self.parent_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
410 let same_source = self.a_source == self.parent_source && a_changed && parent_changed;
411
412 let mut total = Sc::zero();
413 if same_source {
414 let keys = self.parent_key_multiset(solution, entity_index);
415 total = total + self.retract_a(entity_index);
416 total = total + self.update_key_counts(solution, &keys, false);
417 return total;
418 }
419
420 if a_changed {
421 total = total + self.retract_a(entity_index);
422 }
423 if parent_changed {
424 let keys = self.parent_key_multiset(solution, entity_index);
425 total = total + self.update_key_counts(solution, &keys, false);
426 }
427 total
428 }
429
430 fn reset(&mut self) {
431 self.a_slots.clear();
432 self.a_indices_by_key.clear();
433 self.b_key_counts.clear();
434 }
435
436 fn name(&self) -> &str {
437 &self.constraint_ref.name
438 }
439
440 fn is_hard(&self) -> bool {
441 self.is_hard
442 }
443
444 fn constraint_ref(&self) -> ConstraintRef {
445 self.constraint_ref.clone()
446 }
447}
448
449#[derive(Debug, Clone, Copy, Default)]
450pub struct SelfFlatten;
451
452impl<T> FlattenExtract<T> for SelfFlatten
453where
454 T: Send + Sync,
455{
456 type Item = T;
457
458 fn extract<'a>(&self, parent: &'a T) -> &'a [T] {
459 slice::from_ref(parent)
460 }
461}