1use std::hash::Hash;
2use std::marker::PhantomData;
3use std::slice;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::collection_extract::{ChangeSource, CollectionExtract};
10use crate::stream::filter::UniFilter;
11use crate::stream::{ExistenceMode, FlattenExtract};
12
13mod key_state;
14
15use key_state::ExistsKeyState;
16#[cfg(test)]
17pub(crate) use key_state::ExistsStorageKind;
18
19#[derive(Debug, Clone)]
20struct ASlot<K, Sc>
21where
22 Sc: Score,
23{
24 key: Option<K>,
25 bucket_pos: usize,
26 score: Sc,
27}
28
29impl<K, Sc> Default for ASlot<K, Sc>
30where
31 Sc: Score,
32{
33 fn default() -> Self {
34 Self {
35 key: None,
36 bucket_pos: 0,
37 score: Sc::zero(),
38 }
39 }
40}
41
42pub struct IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
43where
44 Sc: Score,
45{
46 constraint_ref: ConstraintRef,
47 impact_type: ImpactType,
48 mode: ExistenceMode,
49 extractor_a: EA,
50 extractor_parent: EP,
51 key_a: KA,
52 key_b: KB,
53 filter_a: FA,
54 filter_parent: FP,
55 flatten: Flatten,
56 weight: W,
57 is_hard: bool,
58 a_source: ChangeSource,
59 parent_source: ChangeSource,
60 a_slots: Vec<ASlot<K, Sc>>,
61 key_state: ExistsKeyState<K, Sc>,
62 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> P, fn() -> B)>,
63}
64
65impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
66 IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
67where
68 S: 'static,
69 A: Clone + 'static,
70 P: Clone + 'static,
71 B: Clone + 'static,
72 K: Eq + Hash + Clone + 'static,
73 EA: CollectionExtract<S, Item = A>,
74 EP: CollectionExtract<S, Item = P>,
75 KA: Fn(&A) -> K,
76 KB: Fn(&B) -> K,
77 FA: UniFilter<S, A>,
78 FP: UniFilter<S, P>,
79 Flatten: FlattenExtract<P, Item = B>,
80 W: Fn(&A) -> Sc,
81 Sc: Score,
82{
83 #[allow(clippy::too_many_arguments)]
84 pub fn new(
85 constraint_ref: ConstraintRef,
86 impact_type: ImpactType,
87 mode: ExistenceMode,
88 extractor_a: EA,
89 extractor_parent: EP,
90 key_a: KA,
91 key_b: KB,
92 filter_a: FA,
93 filter_parent: FP,
94 flatten: Flatten,
95 weight: W,
96 is_hard: bool,
97 ) -> Self {
98 let a_source = extractor_a.change_source();
99 let parent_source = extractor_parent.change_source();
100 Self {
101 constraint_ref,
102 impact_type,
103 mode,
104 extractor_a,
105 extractor_parent,
106 key_a,
107 key_b,
108 filter_a,
109 filter_parent,
110 flatten,
111 weight,
112 is_hard,
113 a_source,
114 parent_source,
115 a_slots: Vec::new(),
116 key_state: ExistsKeyState::new(),
117 _phantom: PhantomData,
118 }
119 }
120
121 #[cfg(test)]
122 pub(crate) fn storage_kind(&self) -> ExistsStorageKind {
123 self.key_state.storage_kind()
124 }
125
126 #[inline]
127 fn compute_score(&self, a: &A) -> Sc {
128 let base = (self.weight)(a);
129 match self.impact_type {
130 ImpactType::Penalty => -base,
131 ImpactType::Reward => base,
132 }
133 }
134
135 #[inline]
136 fn matches_existence(&self, key: &K) -> bool {
137 self.matches_count(self.key_state.b_count(key))
138 }
139
140 #[inline]
141 fn matches_count(&self, count: usize) -> bool {
142 match self.mode {
143 ExistenceMode::Exists => count > 0,
144 ExistenceMode::NotExists => count == 0,
145 }
146 }
147
148 fn rebuild_b_counts(&mut self, solution: &S) {
149 self.key_state.clear_b_counts();
150 for parent in self.extractor_parent.extract(solution) {
151 if !self.filter_parent.test(solution, parent) {
152 continue;
153 }
154 for item in self.flatten.extract(parent) {
155 let key = (self.key_b)(item);
156 self.key_state.increment_b_count(&key, 1);
157 }
158 }
159 }
160
161 fn remove_a_from_bucket(&mut self, idx: usize, key: &K, bucket_pos: usize) {
162 if let Some(moved) = self.key_state.remove_a_index(key, idx, bucket_pos) {
163 self.a_slots[moved.idx].bucket_pos = moved.bucket_pos;
164 }
165 }
166
167 fn retract_a(&mut self, idx: usize) -> Sc {
168 if idx >= self.a_slots.len() {
169 return Sc::zero();
170 }
171 let bucket_pos = self.a_slots[idx].bucket_pos;
172 let score = self.a_slots[idx].score;
173 let Some(key) = self.a_slots[idx].key.take() else {
174 return Sc::zero();
175 };
176 let contribution = if self.matches_existence(&key) {
177 score
178 } else {
179 Sc::zero()
180 };
181 self.remove_a_from_bucket(idx, &key, bucket_pos);
182 self.key_state.subtract_a_score(&key, score);
183 self.a_slots[idx] = ASlot::default();
184 -contribution
185 }
186
187 fn insert_a(&mut self, solution: &S, idx: usize) -> Sc {
188 let entities_a = self.extractor_a.extract(solution);
189 if idx >= entities_a.len() {
190 return Sc::zero();
191 }
192 if self.a_slots.len() < entities_a.len() {
193 self.a_slots.resize(entities_a.len(), ASlot::default());
194 }
195
196 let a = &entities_a[idx];
197 if !self.filter_a.test(solution, a) {
198 self.a_slots[idx] = ASlot::default();
199 return Sc::zero();
200 }
201
202 let key = (self.key_a)(a);
203 let bucket_pos = self.key_state.insert_a_index(key.clone(), idx);
204 let score = self.compute_score(a);
205 self.key_state.add_a_score(&key, score);
206
207 let contribution = if self.matches_existence(&key) {
208 score
209 } else {
210 Sc::zero()
211 };
212
213 self.a_slots[idx] = ASlot {
214 key: Some(key),
215 bucket_pos,
216 score,
217 };
218 contribution
219 }
220
221 fn key_existence_delta(&self, key: &K, old_count: usize, new_count: usize) -> Sc {
222 let old_matches = self.matches_count(old_count);
223 let new_matches = self.matches_count(new_count);
224 if old_matches == new_matches {
225 Sc::zero()
226 } else if new_matches {
227 self.key_state.a_score_total(key)
228 } else {
229 -self.key_state.a_score_total(key)
230 }
231 }
232
233 fn update_key_counts(&mut self, key_counts: &[(K, usize)], insert: bool) -> Sc {
234 let mut total = Sc::zero();
235
236 for (key, count) in key_counts {
237 let old_count = self.key_state.b_count(key);
238 if insert {
239 self.key_state.increment_b_count(key, *count);
240 } else {
241 self.key_state.decrement_b_count(key, *count);
242 }
243 total = total + self.key_existence_delta(key, old_count, self.key_state.b_count(key));
244 }
245
246 total
247 }
248
249 fn parent_key_counts(&self, solution: &S, idx: usize) -> Vec<(K, usize)> {
250 let parents = self.extractor_parent.extract(solution);
251 if idx >= parents.len() {
252 return Vec::new();
253 }
254 let parent = &parents[idx];
255 if !self.filter_parent.test(solution, parent) {
256 return Vec::new();
257 }
258
259 let mut key_counts = Vec::<(K, usize)>::new();
260 for item in self.flatten.extract(parent) {
261 let key = (self.key_b)(item);
262 if let Some((_, count)) = key_counts
263 .iter_mut()
264 .find(|(existing_key, _)| existing_key == &key)
265 {
266 *count += 1;
267 } else {
268 key_counts.push((key, 1));
269 }
270 }
271 key_counts
272 }
273
274 fn initialize_a_state(&mut self, solution: &S) -> Sc {
275 self.a_slots.clear();
276 self.key_state.clear_a_buckets();
277
278 let len = self.extractor_a.extract(solution).len();
279 self.a_slots.resize(len, ASlot::default());
280
281 let mut total = Sc::zero();
282 for idx in 0..len {
283 total = total + self.insert_a(solution, idx);
284 }
285 total
286 }
287
288 fn build_b_counts(&self, solution: &S) -> ExistsKeyState<K, Sc> {
289 let mut key_state = ExistsKeyState::new();
290 for parent in self.extractor_parent.extract(solution) {
291 if !self.filter_parent.test(solution, parent) {
292 continue;
293 }
294 for item in self.flatten.extract(parent) {
295 let key = (self.key_b)(item);
296 key_state.increment_b_count(&key, 1);
297 }
298 }
299 key_state
300 }
301
302 fn full_match_count(&self, solution: &S) -> usize {
303 let key_state = self.build_b_counts(solution);
304
305 self.extractor_a
306 .extract(solution)
307 .iter()
308 .filter(|a| {
309 if !self.filter_a.test(solution, a) {
310 return false;
311 }
312 let key = (self.key_a)(a);
313 self.matches_count(key_state.b_count(&key))
314 })
315 .count()
316 }
317}
318
319impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc> IncrementalConstraint<S, Sc>
320 for IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
321where
322 S: Send + Sync + 'static,
323 A: Clone + Send + Sync + 'static,
324 P: Clone + Send + Sync + 'static,
325 B: Clone + Send + Sync + 'static,
326 K: Eq + Hash + Clone + Send + Sync + 'static,
327 EA: CollectionExtract<S, Item = A> + Send + Sync,
328 EP: CollectionExtract<S, Item = P> + Send + Sync,
329 KA: Fn(&A) -> K + Send + Sync,
330 KB: Fn(&B) -> K + Send + Sync,
331 FA: UniFilter<S, A> + Send + Sync,
332 FP: UniFilter<S, P> + Send + Sync,
333 Flatten: FlattenExtract<P, Item = B> + Send + Sync,
334 W: Fn(&A) -> Sc + Send + Sync,
335 Sc: Score,
336{
337 fn evaluate(&self, solution: &S) -> Sc {
338 let key_state = self.build_b_counts(solution);
339
340 let mut total = Sc::zero();
341 for a in self.extractor_a.extract(solution) {
342 if !self.filter_a.test(solution, a) {
343 continue;
344 }
345 let key = (self.key_a)(a);
346 if self.matches_count(key_state.b_count(&key)) {
347 total = total + self.compute_score(a);
348 }
349 }
350 total
351 }
352
353 fn match_count(&self, solution: &S) -> usize {
354 self.full_match_count(solution)
355 }
356
357 fn initialize(&mut self, solution: &S) -> Sc {
358 self.reset();
359 self.rebuild_b_counts(solution);
360 self.initialize_a_state(solution)
361 }
362
363 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
364 let a_changed = self
365 .a_source
366 .assert_localizes(descriptor_index, &self.constraint_ref.name);
367 let parent_changed = self
368 .parent_source
369 .assert_localizes(descriptor_index, &self.constraint_ref.name);
370 let same_source =
371 self.a_source.same_index_domain(self.parent_source) && a_changed && parent_changed;
372
373 let mut total = Sc::zero();
374 if same_source {
375 let keys = self.parent_key_counts(solution, entity_index);
376 total = total + self.update_key_counts(&keys, true);
377 total = total + self.insert_a(solution, entity_index);
378 return total;
379 }
380
381 if parent_changed {
382 let keys = self.parent_key_counts(solution, entity_index);
383 total = total + self.update_key_counts(&keys, true);
384 }
385 if a_changed {
386 total = total + self.insert_a(solution, entity_index);
387 }
388 total
389 }
390
391 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
392 let a_changed = self
393 .a_source
394 .assert_localizes(descriptor_index, &self.constraint_ref.name);
395 let parent_changed = self
396 .parent_source
397 .assert_localizes(descriptor_index, &self.constraint_ref.name);
398 let same_source =
399 self.a_source.same_index_domain(self.parent_source) && a_changed && parent_changed;
400
401 let mut total = Sc::zero();
402 if same_source {
403 let keys = self.parent_key_counts(solution, entity_index);
404 total = total + self.retract_a(entity_index);
405 total = total + self.update_key_counts(&keys, false);
406 return total;
407 }
408
409 if a_changed {
410 total = total + self.retract_a(entity_index);
411 }
412 if parent_changed {
413 let keys = self.parent_key_counts(solution, entity_index);
414 total = total + self.update_key_counts(&keys, false);
415 }
416 total
417 }
418
419 fn reset(&mut self) {
420 self.a_slots.clear();
421 self.key_state.clear_a_buckets();
422 self.key_state.clear_b_counts();
423 }
424
425 fn name(&self) -> &str {
426 &self.constraint_ref.name
427 }
428
429 fn is_hard(&self) -> bool {
430 self.is_hard
431 }
432
433 fn constraint_ref(&self) -> &ConstraintRef {
434 &self.constraint_ref
435 }
436}
437
438#[derive(Debug, Clone, Copy, Default)]
439pub struct SelfFlatten;
440
441impl<T> FlattenExtract<T> for SelfFlatten
442where
443 T: Send + Sync,
444{
445 type Item = T;
446
447 fn extract<'a>(&self, parent: &'a T) -> &'a [T] {
448 slice::from_ref(parent)
449 }
450}