1use std::hash::Hash;
2use std::marker::PhantomData;
3use std::slice;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::collection_extract::{ChangeSource, TrackedCollectionExtract};
10use crate::stream::filter::UniFilter;
11use crate::stream::{ExistenceMode, FlattenExtract};
12
13mod key_state;
14
15use key_state::ExistsKeyState;
16#[cfg(test)]
17pub(crate) use key_state::ExistsStorageKind;
18
19#[derive(Debug, Clone)]
20struct ASlot<K, Sc>
21where
22 Sc: Score,
23{
24 key: Option<K>,
25 bucket_pos: usize,
26 contribution: Sc,
27}
28
29impl<K, Sc> Default for ASlot<K, Sc>
30where
31 Sc: Score,
32{
33 fn default() -> Self {
34 Self {
35 key: None,
36 bucket_pos: 0,
37 contribution: Sc::zero(),
38 }
39 }
40}
41
42pub struct IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
43where
44 Sc: Score,
45{
46 constraint_ref: ConstraintRef,
47 impact_type: ImpactType,
48 mode: ExistenceMode,
49 extractor_a: EA,
50 extractor_parent: EP,
51 key_a: KA,
52 key_b: KB,
53 filter_a: FA,
54 filter_parent: FP,
55 flatten: Flatten,
56 weight: W,
57 is_hard: bool,
58 a_source: ChangeSource,
59 parent_source: ChangeSource,
60 a_slots: Vec<ASlot<K, Sc>>,
61 key_state: ExistsKeyState<K>,
62 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> P, fn() -> B)>,
63}
64
65impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
66 IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
67where
68 S: 'static,
69 A: Clone + 'static,
70 P: Clone + 'static,
71 B: Clone + 'static,
72 K: Eq + Hash + Clone + 'static,
73 EA: TrackedCollectionExtract<S, Item = A>,
74 EP: TrackedCollectionExtract<S, Item = P>,
75 KA: Fn(&A) -> K,
76 KB: Fn(&B) -> K,
77 FA: UniFilter<S, A>,
78 FP: UniFilter<S, P>,
79 Flatten: FlattenExtract<P, Item = B>,
80 W: Fn(&A) -> Sc,
81 Sc: Score,
82{
83 #[allow(clippy::too_many_arguments)]
84 pub fn new(
85 constraint_ref: ConstraintRef,
86 impact_type: ImpactType,
87 mode: ExistenceMode,
88 extractor_a: EA,
89 extractor_parent: EP,
90 key_a: KA,
91 key_b: KB,
92 filter_a: FA,
93 filter_parent: FP,
94 flatten: Flatten,
95 weight: W,
96 is_hard: bool,
97 ) -> Self {
98 let a_source = extractor_a.change_source();
99 let parent_source = extractor_parent.change_source();
100 Self {
101 constraint_ref,
102 impact_type,
103 mode,
104 extractor_a,
105 extractor_parent,
106 key_a,
107 key_b,
108 filter_a,
109 filter_parent,
110 flatten,
111 weight,
112 is_hard,
113 a_source,
114 parent_source,
115 a_slots: Vec::new(),
116 key_state: ExistsKeyState::new(),
117 _phantom: PhantomData,
118 }
119 }
120
121 #[cfg(test)]
122 pub(crate) fn storage_kind(&self) -> ExistsStorageKind {
123 self.key_state.storage_kind()
124 }
125
126 #[inline]
127 fn compute_score(&self, a: &A) -> Sc {
128 let base = (self.weight)(a);
129 match self.impact_type {
130 ImpactType::Penalty => -base,
131 ImpactType::Reward => base,
132 }
133 }
134
135 #[inline]
136 fn matches_existence(&self, key: &K) -> bool {
137 self.matches_count(self.key_state.b_count(key))
138 }
139
140 #[inline]
141 fn matches_count(&self, count: usize) -> bool {
142 match self.mode {
143 ExistenceMode::Exists => count > 0,
144 ExistenceMode::NotExists => count == 0,
145 }
146 }
147
148 fn rebuild_b_counts(&mut self, solution: &S) {
149 self.key_state.clear_b_counts();
150 for parent in self.extractor_parent.extract(solution) {
151 if !self.filter_parent.test(solution, parent) {
152 continue;
153 }
154 for item in self.flatten.extract(parent) {
155 let key = (self.key_b)(item);
156 self.key_state.increment_b_count(&key, 1);
157 }
158 }
159 }
160
161 fn remove_a_from_bucket(&mut self, idx: usize, key: &K, bucket_pos: usize) {
162 if let Some(moved) = self.key_state.remove_a_index(key, idx, bucket_pos) {
163 self.a_slots[moved.idx].bucket_pos = moved.bucket_pos;
164 }
165 }
166
167 fn retract_a(&mut self, idx: usize) -> Sc {
168 if idx >= self.a_slots.len() {
169 return Sc::zero();
170 }
171 let slot = self.a_slots[idx].clone();
172 let Some(key) = slot.key.clone() else {
173 return Sc::zero();
174 };
175 self.remove_a_from_bucket(idx, &key, slot.bucket_pos);
176 self.a_slots[idx] = ASlot::default();
177 -slot.contribution
178 }
179
180 fn insert_a(&mut self, solution: &S, idx: usize) -> Sc {
181 let entities_a = self.extractor_a.extract(solution);
182 if idx >= entities_a.len() {
183 return Sc::zero();
184 }
185 if self.a_slots.len() < entities_a.len() {
186 self.a_slots.resize(entities_a.len(), ASlot::default());
187 }
188
189 let a = &entities_a[idx];
190 if !self.filter_a.test(solution, a) {
191 self.a_slots[idx] = ASlot::default();
192 return Sc::zero();
193 }
194
195 let key = (self.key_a)(a);
196 let bucket_pos = self.key_state.insert_a_index(key.clone(), idx);
197
198 let contribution = if self.matches_existence(&key) {
199 self.compute_score(a)
200 } else {
201 Sc::zero()
202 };
203
204 self.a_slots[idx] = ASlot {
205 key: Some(key),
206 bucket_pos,
207 contribution,
208 };
209 contribution
210 }
211
212 fn reevaluate_key(&mut self, solution: &S, key: &K) -> Sc {
213 let indices = self.key_state.a_indices(key);
214 if indices.is_empty() {
215 return Sc::zero();
216 };
217 let entities_a = self.extractor_a.extract(solution);
218 let mut total = Sc::zero();
219 let exists = self.matches_existence(key);
220
221 for idx in indices {
222 let a = &entities_a[idx];
223 let new_contribution = if exists {
224 self.compute_score(a)
225 } else {
226 Sc::zero()
227 };
228 let old_contribution = self.a_slots[idx].contribution;
229 self.a_slots[idx].contribution = new_contribution;
230 total = total + (new_contribution - old_contribution);
231 }
232
233 total
234 }
235
236 fn update_key_counts(&mut self, solution: &S, key_counts: &[(K, usize)], insert: bool) -> Sc {
237 let mut total = Sc::zero();
238
239 for (key, count) in key_counts {
240 if insert {
241 self.key_state.increment_b_count(key, *count);
242 } else {
243 self.key_state.decrement_b_count(key, *count);
244 }
245 }
246
247 for (key, _) in key_counts {
248 total = total + self.reevaluate_key(solution, key);
249 }
250
251 total
252 }
253
254 fn parent_key_counts(&self, solution: &S, idx: usize) -> Vec<(K, usize)> {
255 let parents = self.extractor_parent.extract(solution);
256 if idx >= parents.len() {
257 return Vec::new();
258 }
259 let parent = &parents[idx];
260 if !self.filter_parent.test(solution, parent) {
261 return Vec::new();
262 }
263
264 let mut key_counts = Vec::<(K, usize)>::new();
265 for item in self.flatten.extract(parent) {
266 let key = (self.key_b)(item);
267 if let Some((_, count)) = key_counts
268 .iter_mut()
269 .find(|(existing_key, _)| existing_key == &key)
270 {
271 *count += 1;
272 } else {
273 key_counts.push((key, 1));
274 }
275 }
276 key_counts
277 }
278
279 fn initialize_a_state(&mut self, solution: &S) -> Sc {
280 self.a_slots.clear();
281 self.key_state.clear_a_buckets();
282
283 let len = self.extractor_a.extract(solution).len();
284 self.a_slots.resize(len, ASlot::default());
285
286 let mut total = Sc::zero();
287 for idx in 0..len {
288 total = total + self.insert_a(solution, idx);
289 }
290 total
291 }
292
293 fn build_b_counts(&self, solution: &S) -> ExistsKeyState<K> {
294 let mut key_state = ExistsKeyState::new();
295 for parent in self.extractor_parent.extract(solution) {
296 if !self.filter_parent.test(solution, parent) {
297 continue;
298 }
299 for item in self.flatten.extract(parent) {
300 let key = (self.key_b)(item);
301 key_state.increment_b_count(&key, 1);
302 }
303 }
304 key_state
305 }
306
307 fn full_match_count(&self, solution: &S) -> usize {
308 let key_state = self.build_b_counts(solution);
309
310 self.extractor_a
311 .extract(solution)
312 .iter()
313 .filter(|a| {
314 if !self.filter_a.test(solution, a) {
315 return false;
316 }
317 let key = (self.key_a)(a);
318 self.matches_count(key_state.b_count(&key))
319 })
320 .count()
321 }
322}
323
324impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc> IncrementalConstraint<S, Sc>
325 for IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
326where
327 S: Send + Sync + 'static,
328 A: Clone + Send + Sync + 'static,
329 P: Clone + Send + Sync + 'static,
330 B: Clone + Send + Sync + 'static,
331 K: Eq + Hash + Clone + Send + Sync + 'static,
332 EA: TrackedCollectionExtract<S, Item = A> + Send + Sync,
333 EP: TrackedCollectionExtract<S, Item = P> + Send + Sync,
334 KA: Fn(&A) -> K + Send + Sync,
335 KB: Fn(&B) -> K + Send + Sync,
336 FA: UniFilter<S, A> + Send + Sync,
337 FP: UniFilter<S, P> + Send + Sync,
338 Flatten: FlattenExtract<P, Item = B> + Send + Sync,
339 W: Fn(&A) -> Sc + Send + Sync,
340 Sc: Score,
341{
342 fn evaluate(&self, solution: &S) -> Sc {
343 let key_state = self.build_b_counts(solution);
344
345 let mut total = Sc::zero();
346 for a in self.extractor_a.extract(solution) {
347 if !self.filter_a.test(solution, a) {
348 continue;
349 }
350 let key = (self.key_a)(a);
351 if self.matches_count(key_state.b_count(&key)) {
352 total = total + self.compute_score(a);
353 }
354 }
355 total
356 }
357
358 fn match_count(&self, solution: &S) -> usize {
359 self.full_match_count(solution)
360 }
361
362 fn initialize(&mut self, solution: &S) -> Sc {
363 self.reset();
364 self.rebuild_b_counts(solution);
365 self.initialize_a_state(solution)
366 }
367
368 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
369 let a_changed =
370 matches!(self.a_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
371 let parent_changed =
372 matches!(self.parent_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
373 let same_source = self.a_source == self.parent_source && a_changed && parent_changed;
374
375 let mut total = Sc::zero();
376 if same_source {
377 let keys = self.parent_key_counts(solution, entity_index);
378 total = total + self.update_key_counts(solution, &keys, true);
379 total = total + self.insert_a(solution, entity_index);
380 return total;
381 }
382
383 if parent_changed {
384 let keys = self.parent_key_counts(solution, entity_index);
385 total = total + self.update_key_counts(solution, &keys, true);
386 }
387 if a_changed {
388 total = total + self.insert_a(solution, entity_index);
389 }
390 total
391 }
392
393 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
394 let a_changed =
395 matches!(self.a_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
396 let parent_changed =
397 matches!(self.parent_source, ChangeSource::Descriptor(idx) if idx == descriptor_index);
398 let same_source = self.a_source == self.parent_source && a_changed && parent_changed;
399
400 let mut total = Sc::zero();
401 if same_source {
402 let keys = self.parent_key_counts(solution, entity_index);
403 total = total + self.retract_a(entity_index);
404 total = total + self.update_key_counts(solution, &keys, false);
405 return total;
406 }
407
408 if a_changed {
409 total = total + self.retract_a(entity_index);
410 }
411 if parent_changed {
412 let keys = self.parent_key_counts(solution, entity_index);
413 total = total + self.update_key_counts(solution, &keys, false);
414 }
415 total
416 }
417
418 fn reset(&mut self) {
419 self.a_slots.clear();
420 self.key_state.clear_a_buckets();
421 self.key_state.clear_b_counts();
422 }
423
424 fn name(&self) -> &str {
425 &self.constraint_ref.name
426 }
427
428 fn is_hard(&self) -> bool {
429 self.is_hard
430 }
431
432 fn constraint_ref(&self) -> ConstraintRef {
433 self.constraint_ref.clone()
434 }
435}
436
437#[derive(Debug, Clone, Copy, Default)]
438pub struct SelfFlatten;
439
440impl<T> FlattenExtract<T> for SelfFlatten
441where
442 T: Send + Sync,
443{
444 type Item = T;
445
446 fn extract<'a>(&self, parent: &'a T) -> &'a [T] {
447 slice::from_ref(parent)
448 }
449}