1use std::collections::{HashMap, HashSet};
2use std::hash::Hash;
3use std::marker::PhantomData;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::filter::{BiFilter, UniFilter};
10use crate::stream::{ProjectedRowCoordinate, ProjectedRowOwner, ProjectedSource};
11
12struct ProjectedJoinRow<Out> {
13 output: Out,
14 coordinate: ProjectedRowCoordinate,
15}
16
17pub struct ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
18where
19 Src: ProjectedSource<S, Out>,
20 Sc: Score,
21{
22 constraint_ref: ConstraintRef,
23 impact_type: ImpactType,
24 source: Src,
25 filter: F,
26 key_fn: KF,
27 pair_filter: PF,
28 weight: W,
29 is_hard: bool,
30 source_state: Option<Src::State>,
31 rows: Vec<Option<ProjectedJoinRow<Out>>>,
32 free_row_ids: Vec<usize>,
33 rows_by_owner: HashMap<ProjectedRowOwner, Vec<usize>>,
34 row_ids_by_coordinate: HashMap<ProjectedRowCoordinate, usize>,
35 rows_by_key: HashMap<K, Vec<usize>>,
36 _phantom: PhantomData<(fn() -> S, fn() -> Out, fn() -> Sc)>,
37}
38
39impl<S, Out, K, Src, F, KF, PF, W, Sc> ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
40where
41 S: Send + Sync + 'static,
42 Out: Send + Sync + 'static,
43 K: Eq + Hash + Send + Sync + 'static,
44 Src: ProjectedSource<S, Out>,
45 F: UniFilter<S, Out>,
46 KF: Fn(&Out) -> K + Send + Sync,
47 PF: BiFilter<S, Out, Out>,
48 W: Fn(&Out, &Out) -> Sc + Send + Sync,
49 Sc: Score + 'static,
50{
51 #[allow(clippy::too_many_arguments)]
52 pub fn new(
53 constraint_ref: ConstraintRef,
54 impact_type: ImpactType,
55 source: Src,
56 filter: F,
57 key_fn: KF,
58 pair_filter: PF,
59 weight: W,
60 is_hard: bool,
61 ) -> Self {
62 Self {
63 constraint_ref,
64 impact_type,
65 source,
66 filter,
67 key_fn,
68 pair_filter,
69 weight,
70 is_hard,
71 source_state: None,
72 rows: Vec::new(),
73 free_row_ids: Vec::new(),
74 rows_by_owner: HashMap::new(),
75 row_ids_by_coordinate: HashMap::new(),
76 rows_by_key: HashMap::new(),
77 _phantom: PhantomData,
78 }
79 }
80
81 fn compute_score(&self, left: &Out, right: &Out) -> Sc {
82 let base = (self.weight)(left, right);
83 match self.impact_type {
84 ImpactType::Penalty => -base,
85 ImpactType::Reward => base,
86 }
87 }
88
89 fn score_outputs(
90 &self,
91 solution: &S,
92 left: &Out,
93 right: &Out,
94 left_idx: usize,
95 right_idx: usize,
96 ) -> Sc {
97 if !self
98 .pair_filter
99 .test(solution, left, right, left_idx, right_idx)
100 {
101 return Sc::zero();
102 }
103 self.compute_score(left, right)
104 }
105
106 fn filter_index(coordinate: ProjectedRowCoordinate) -> usize {
107 coordinate.primary_owner.entity_index
108 }
109
110 fn score_retained_rows(
111 &self,
112 solution: &S,
113 first: &ProjectedJoinRow<Out>,
114 second: &ProjectedJoinRow<Out>,
115 ) -> Sc {
116 let (left, right) = if first.coordinate <= second.coordinate {
117 (first, second)
118 } else {
119 (second, first)
120 };
121 self.score_outputs(
122 solution,
123 &left.output,
124 &right.output,
125 Self::filter_index(left.coordinate),
126 Self::filter_index(right.coordinate),
127 )
128 }
129
130 fn score_candidate_row(
131 &self,
132 solution: &S,
133 candidate_output: &Out,
134 candidate_coordinate: ProjectedRowCoordinate,
135 other: &ProjectedJoinRow<Out>,
136 ) -> Sc {
137 let (left, right, left_idx, right_idx) = if candidate_coordinate <= other.coordinate {
138 (
139 candidate_output,
140 &other.output,
141 Self::filter_index(candidate_coordinate),
142 Self::filter_index(other.coordinate),
143 )
144 } else {
145 (
146 &other.output,
147 candidate_output,
148 Self::filter_index(other.coordinate),
149 Self::filter_index(candidate_coordinate),
150 )
151 };
152 self.score_outputs(solution, left, right, left_idx, right_idx)
153 }
154
155 fn score_pair(&self, solution: &S, first_id: usize, second_id: usize) -> Sc {
156 let Some(first) = self.rows.get(first_id).and_then(Option::as_ref) else {
157 return Sc::zero();
158 };
159 let Some(second) = self.rows.get(second_id).and_then(Option::as_ref) else {
160 return Sc::zero();
161 };
162 self.score_retained_rows(solution, first, second)
163 }
164
165 fn ensure_source_state(&mut self, solution: &S) {
166 if self.source_state.is_none() {
167 self.source_state = Some(self.source.build_state(solution));
168 }
169 }
170
171 fn index_row_owners(&mut self, coordinate: ProjectedRowCoordinate, row_id: usize) {
172 coordinate.for_each_owner(|owner| {
173 self.rows_by_owner.entry(owner).or_default().push(row_id);
174 });
175 }
176
177 fn unindex_row_owners(&mut self, coordinate: ProjectedRowCoordinate, row_id: usize) {
178 coordinate.for_each_owner(|owner| {
179 let mut remove_bucket = false;
180 if let Some(ids) = self.rows_by_owner.get_mut(&owner) {
181 ids.retain(|candidate| *candidate != row_id);
182 remove_bucket = ids.is_empty();
183 }
184 if remove_bucket {
185 self.rows_by_owner.remove(&owner);
186 }
187 });
188 }
189
190 fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
191 if self.row_ids_by_coordinate.contains_key(&coordinate) {
192 return Sc::zero();
193 }
194 let key = (self.key_fn)(&output);
195 let mut total = Sc::zero();
196 if let Some(existing) = self.rows_by_key.get(&key) {
197 for &other_id in existing {
198 if let Some(other) = self.rows.get(other_id).and_then(Option::as_ref) {
199 total = total + self.score_candidate_row(solution, &output, coordinate, other);
200 }
201 }
202 }
203 let row = Some(ProjectedJoinRow { output, coordinate });
204 let row_id = if let Some(row_id) = self.free_row_ids.pop() {
205 debug_assert!(self.rows[row_id].is_none());
206 self.rows[row_id] = row;
207 row_id
208 } else {
209 let row_id = self.rows.len();
210 self.rows.push(row);
211 row_id
212 };
213 self.row_ids_by_coordinate.insert(coordinate, row_id);
214 self.index_row_owners(coordinate, row_id);
215 self.rows_by_key.entry(key).or_default().push(row_id);
216 total
217 }
218
219 fn retract_row(&mut self, solution: &S, row_id: usize) -> Sc {
220 let Some((key, coordinate)) = self
221 .rows
222 .get(row_id)
223 .and_then(Option::as_ref)
224 .map(|row| ((self.key_fn)(&row.output), row.coordinate))
225 else {
226 return Sc::zero();
227 };
228 let mut total = Sc::zero();
229 if let Some(candidates) = self.rows_by_key.get(&key) {
230 for &other_id in candidates {
231 if other_id == row_id {
232 continue;
233 }
234 total = total - self.score_pair(solution, row_id, other_id);
235 }
236 }
237
238 if let Some(ids) = self.rows_by_key.get_mut(&key) {
239 ids.retain(|&id| id != row_id);
240 if ids.is_empty() {
241 self.rows_by_key.remove(&key);
242 }
243 }
244 self.row_ids_by_coordinate.remove(&coordinate);
245 self.unindex_row_owners(coordinate, row_id);
246 self.rows[row_id] = None;
247 self.free_row_ids.push(row_id);
248 total
249 }
250
251 fn evaluate_rows(&self, solution: &S) -> Vec<ProjectedJoinRow<Out>> {
252 let state = self.source.build_state(solution);
253 let mut rows = Vec::new();
254 self.source
255 .collect_all(solution, &state, |coordinate, output| {
256 if self.filter.test(solution, &output) {
257 rows.push(ProjectedJoinRow { output, coordinate });
258 }
259 });
260 rows
261 }
262
263 fn score_evaluation_pair(
264 &self,
265 solution: &S,
266 first: &ProjectedJoinRow<Out>,
267 second: &ProjectedJoinRow<Out>,
268 ) -> Sc {
269 if (self.key_fn)(&first.output) == (self.key_fn)(&second.output) {
270 let (left, right) = if first.coordinate <= second.coordinate {
271 (first, second)
272 } else {
273 (second, first)
274 };
275 self.score_outputs(
276 solution,
277 &left.output,
278 &right.output,
279 Self::filter_index(left.coordinate),
280 Self::filter_index(right.coordinate),
281 )
282 } else {
283 Sc::zero()
284 }
285 }
286
287 fn evaluation_pair_matches(
288 &self,
289 solution: &S,
290 first: &ProjectedJoinRow<Out>,
291 second: &ProjectedJoinRow<Out>,
292 ) -> bool {
293 if (self.key_fn)(&first.output) != (self.key_fn)(&second.output) {
294 return false;
295 }
296 let (left, right) = if first.coordinate <= second.coordinate {
297 (first, second)
298 } else {
299 (second, first)
300 };
301 self.pair_filter.test(
302 solution,
303 &left.output,
304 &right.output,
305 Self::filter_index(left.coordinate),
306 Self::filter_index(right.coordinate),
307 )
308 }
309
310 fn localized_owners(
311 &self,
312 descriptor_index: usize,
313 entity_index: usize,
314 ) -> Vec<ProjectedRowOwner> {
315 let mut owners = Vec::new();
316 for slot in 0..self.source.source_count() {
317 if self
318 .source
319 .change_source(slot)
320 .assert_localizes(descriptor_index, &self.constraint_ref.name)
321 {
322 owners.push(ProjectedRowOwner {
323 source_slot: slot,
324 entity_index,
325 });
326 }
327 }
328 owners
329 }
330
331 fn row_ids_for_owners(&self, owners: &[ProjectedRowOwner]) -> Vec<usize> {
332 let mut seen = HashSet::new();
333 let mut row_ids = Vec::new();
334 for owner in owners {
335 let Some(ids) = self.rows_by_owner.get(owner) else {
336 continue;
337 };
338 for &row_id in ids {
339 if seen.insert(row_id) {
340 row_ids.push(row_id);
341 }
342 }
343 }
344 row_ids
345 }
346
347 #[cfg(test)]
348 pub(crate) fn debug_row_storage_len(&self) -> usize {
349 self.rows.len()
350 }
351
352 #[cfg(test)]
353 pub(crate) fn debug_free_row_count(&self) -> usize {
354 self.free_row_ids.len()
355 }
356}
357
358impl<S, Out, K, Src, F, KF, PF, W, Sc> IncrementalConstraint<S, Sc>
359 for ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
360where
361 S: Send + Sync + 'static,
362 Out: Send + Sync + 'static,
363 K: Eq + Hash + Send + Sync + 'static,
364 Src: ProjectedSource<S, Out>,
365 F: UniFilter<S, Out>,
366 KF: Fn(&Out) -> K + Send + Sync,
367 PF: BiFilter<S, Out, Out>,
368 W: Fn(&Out, &Out) -> Sc + Send + Sync,
369 Sc: Score + 'static,
370{
371 fn evaluate(&self, solution: &S) -> Sc {
372 let rows = self.evaluate_rows(solution);
373
374 let mut total = Sc::zero();
375 for left_index in 0..rows.len() {
376 for right_index in (left_index + 1)..rows.len() {
377 total = total
378 + self.score_evaluation_pair(solution, &rows[left_index], &rows[right_index]);
379 }
380 }
381 total
382 }
383
384 fn match_count(&self, solution: &S) -> usize {
385 let rows = self.evaluate_rows(solution);
386
387 let mut count = 0;
388 for left_index in 0..rows.len() {
389 for right_index in (left_index + 1)..rows.len() {
390 if self.evaluation_pair_matches(solution, &rows[left_index], &rows[right_index]) {
391 count += 1;
392 }
393 }
394 }
395 count
396 }
397
398 fn initialize(&mut self, solution: &S) -> Sc {
399 self.reset();
400 let state = self.source.build_state(solution);
401 let mut rows = Vec::new();
402 self.source
403 .collect_all(solution, &state, |coordinate, output| {
404 if self.filter.test(solution, &output) {
405 rows.push((coordinate, output));
406 }
407 });
408 self.source_state = Some(state);
409
410 rows.into_iter()
411 .fold(Sc::zero(), |total, (coordinate, output)| {
412 total + self.insert_row(solution, coordinate, output)
413 })
414 }
415
416 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
417 let owners = self.localized_owners(descriptor_index, entity_index);
418 self.ensure_source_state(solution);
419 {
420 let state = self.source_state.as_mut().expect("projected source state");
421 for owner in &owners {
422 self.source.insert_entity_state(
423 solution,
424 state,
425 owner.source_slot,
426 owner.entity_index,
427 );
428 }
429 }
430 let mut rows = Vec::new();
431 let state = self.source_state.as_ref().expect("projected source state");
432 for owner in &owners {
433 self.source.collect_entity(
434 solution,
435 state,
436 owner.source_slot,
437 owner.entity_index,
438 |coordinate, output| {
439 if self.filter.test(solution, &output) {
440 rows.push((coordinate, output));
441 }
442 },
443 );
444 }
445 let mut total = Sc::zero();
446 for (coordinate, output) in rows {
447 total = total + self.insert_row(solution, coordinate, output);
448 }
449 total
450 }
451
452 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
453 let owners = self.localized_owners(descriptor_index, entity_index);
454 let mut total = Sc::zero();
455 for row_id in self.row_ids_for_owners(&owners) {
456 total = total + self.retract_row(solution, row_id);
457 }
458 if let Some(state) = self.source_state.as_mut() {
459 for owner in &owners {
460 self.source.retract_entity_state(
461 solution,
462 state,
463 owner.source_slot,
464 owner.entity_index,
465 );
466 }
467 }
468 total
469 }
470
471 fn reset(&mut self) {
472 self.source_state = None;
473 self.rows.clear();
474 self.free_row_ids.clear();
475 self.rows_by_owner.clear();
476 self.row_ids_by_coordinate.clear();
477 self.rows_by_key.clear();
478 }
479
480 fn name(&self) -> &str {
481 &self.constraint_ref.name
482 }
483
484 fn constraint_ref(&self) -> &ConstraintRef {
485 &self.constraint_ref
486 }
487
488 fn is_hard(&self) -> bool {
489 self.is_hard
490 }
491}