1use std::collections::{HashMap, HashSet};
2use std::hash::Hash;
3use std::marker::PhantomData;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::filter::{BiFilter, UniFilter};
10use crate::stream::{ProjectedRowCoordinate, ProjectedRowOwner, ProjectedSource};
11
12struct ProjectedJoinRow<Out> {
13 output: Out,
14 coordinate: ProjectedRowCoordinate,
15}
16
17pub struct ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
18where
19 Src: ProjectedSource<S, Out>,
20 Sc: Score,
21{
22 constraint_ref: ConstraintRef,
23 impact_type: ImpactType,
24 source: Src,
25 filter: F,
26 key_fn: KF,
27 pair_filter: PF,
28 weight: W,
29 is_hard: bool,
30 source_state: Option<Src::State>,
31 rows: Vec<Option<ProjectedJoinRow<Out>>>,
32 free_row_ids: Vec<usize>,
33 rows_by_owner: HashMap<ProjectedRowOwner, Vec<usize>>,
34 row_ids_by_coordinate: HashMap<ProjectedRowCoordinate, usize>,
35 rows_by_key: HashMap<K, Vec<usize>>,
36 _phantom: PhantomData<(fn() -> S, fn() -> Out, fn() -> Sc)>,
37}
38
39impl<S, Out, K, Src, F, KF, PF, W, Sc> ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
40where
41 S: Send + Sync + 'static,
42 Out: Send + Sync + 'static,
43 K: Eq + Hash + Send + Sync + 'static,
44 Src: ProjectedSource<S, Out>,
45 F: UniFilter<S, Out>,
46 KF: Fn(&Out) -> K + Send + Sync,
47 PF: BiFilter<S, Out, Out>,
48 W: Fn(&Out, &Out) -> Sc + Send + Sync,
49 Sc: Score + 'static,
50{
51 #[allow(clippy::too_many_arguments)]
52 pub fn new(
53 constraint_ref: ConstraintRef,
54 impact_type: ImpactType,
55 source: Src,
56 filter: F,
57 key_fn: KF,
58 pair_filter: PF,
59 weight: W,
60 is_hard: bool,
61 ) -> Self {
62 Self {
63 constraint_ref,
64 impact_type,
65 source,
66 filter,
67 key_fn,
68 pair_filter,
69 weight,
70 is_hard,
71 source_state: None,
72 rows: Vec::new(),
73 free_row_ids: Vec::new(),
74 rows_by_owner: HashMap::new(),
75 row_ids_by_coordinate: HashMap::new(),
76 rows_by_key: HashMap::new(),
77 _phantom: PhantomData,
78 }
79 }
80
81 fn compute_score(&self, left: &Out, right: &Out) -> Sc {
82 let base = (self.weight)(left, right);
83 match self.impact_type {
84 ImpactType::Penalty => -base,
85 ImpactType::Reward => base,
86 }
87 }
88
89 fn score_ordered_rows(
90 &self,
91 solution: &S,
92 first: &ProjectedJoinRow<Out>,
93 second: &ProjectedJoinRow<Out>,
94 ) -> Sc {
95 let (left, right) = if first.coordinate <= second.coordinate {
96 (first, second)
97 } else {
98 (second, first)
99 };
100 if !self
101 .pair_filter
102 .test(solution, &left.output, &right.output, 0, 1)
103 {
104 return Sc::zero();
105 }
106 self.compute_score(&left.output, &right.output)
107 }
108
109 fn score_candidate_row(
110 &self,
111 solution: &S,
112 candidate_output: &Out,
113 candidate_coordinate: ProjectedRowCoordinate,
114 other: &ProjectedJoinRow<Out>,
115 ) -> Sc {
116 let (left, right) = if candidate_coordinate <= other.coordinate {
117 (candidate_output, &other.output)
118 } else {
119 (&other.output, candidate_output)
120 };
121 if !self.pair_filter.test(solution, left, right, 0, 1) {
122 return Sc::zero();
123 }
124 self.compute_score(left, right)
125 }
126
127 fn score_pair(&self, solution: &S, first_id: usize, second_id: usize) -> Sc {
128 let Some(first) = self.rows.get(first_id).and_then(Option::as_ref) else {
129 return Sc::zero();
130 };
131 let Some(second) = self.rows.get(second_id).and_then(Option::as_ref) else {
132 return Sc::zero();
133 };
134 self.score_ordered_rows(solution, first, second)
135 }
136
137 fn ensure_source_state(&mut self, solution: &S) {
138 if self.source_state.is_none() {
139 self.source_state = Some(self.source.build_state(solution));
140 }
141 }
142
143 fn index_row_owners(&mut self, coordinate: ProjectedRowCoordinate, row_id: usize) {
144 coordinate.for_each_owner(|owner| {
145 self.rows_by_owner.entry(owner).or_default().push(row_id);
146 });
147 }
148
149 fn unindex_row_owners(&mut self, coordinate: ProjectedRowCoordinate, row_id: usize) {
150 coordinate.for_each_owner(|owner| {
151 let mut remove_bucket = false;
152 if let Some(ids) = self.rows_by_owner.get_mut(&owner) {
153 ids.retain(|candidate| *candidate != row_id);
154 remove_bucket = ids.is_empty();
155 }
156 if remove_bucket {
157 self.rows_by_owner.remove(&owner);
158 }
159 });
160 }
161
162 fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
163 if self.row_ids_by_coordinate.contains_key(&coordinate) {
164 return Sc::zero();
165 }
166 let key = (self.key_fn)(&output);
167 let mut total = Sc::zero();
168 if let Some(existing) = self.rows_by_key.get(&key) {
169 for &other_id in existing {
170 if let Some(other) = self.rows.get(other_id).and_then(Option::as_ref) {
171 total = total + self.score_candidate_row(solution, &output, coordinate, other);
172 }
173 }
174 }
175 let row = Some(ProjectedJoinRow { output, coordinate });
176 let row_id = if let Some(row_id) = self.free_row_ids.pop() {
177 debug_assert!(self.rows[row_id].is_none());
178 self.rows[row_id] = row;
179 row_id
180 } else {
181 let row_id = self.rows.len();
182 self.rows.push(row);
183 row_id
184 };
185 self.row_ids_by_coordinate.insert(coordinate, row_id);
186 self.index_row_owners(coordinate, row_id);
187 self.rows_by_key.entry(key).or_default().push(row_id);
188 total
189 }
190
191 fn retract_row(&mut self, solution: &S, row_id: usize) -> Sc {
192 let Some((key, coordinate)) = self
193 .rows
194 .get(row_id)
195 .and_then(Option::as_ref)
196 .map(|row| ((self.key_fn)(&row.output), row.coordinate))
197 else {
198 return Sc::zero();
199 };
200 let mut total = Sc::zero();
201 if let Some(candidates) = self.rows_by_key.get(&key) {
202 for &other_id in candidates {
203 if other_id == row_id {
204 continue;
205 }
206 total = total - self.score_pair(solution, row_id, other_id);
207 }
208 }
209
210 if let Some(ids) = self.rows_by_key.get_mut(&key) {
211 ids.retain(|&id| id != row_id);
212 if ids.is_empty() {
213 self.rows_by_key.remove(&key);
214 }
215 }
216 self.row_ids_by_coordinate.remove(&coordinate);
217 self.unindex_row_owners(coordinate, row_id);
218 self.rows[row_id] = None;
219 self.free_row_ids.push(row_id);
220 total
221 }
222
223 fn evaluate_rows(&self, solution: &S) -> Vec<ProjectedJoinRow<Out>> {
224 let state = self.source.build_state(solution);
225 let mut rows = Vec::new();
226 self.source
227 .collect_all(solution, &state, |coordinate, output| {
228 if self.filter.test(solution, &output) {
229 rows.push(ProjectedJoinRow { output, coordinate });
230 }
231 });
232 rows
233 }
234
235 fn score_evaluation_pair(
236 &self,
237 solution: &S,
238 first: &ProjectedJoinRow<Out>,
239 second: &ProjectedJoinRow<Out>,
240 ) -> Sc {
241 if (self.key_fn)(&first.output) == (self.key_fn)(&second.output) {
242 self.score_ordered_rows(solution, first, second)
243 } else {
244 Sc::zero()
245 }
246 }
247
248 fn evaluation_pair_matches(
249 &self,
250 solution: &S,
251 first: &ProjectedJoinRow<Out>,
252 second: &ProjectedJoinRow<Out>,
253 ) -> bool {
254 if (self.key_fn)(&first.output) != (self.key_fn)(&second.output) {
255 return false;
256 }
257 let (left, right) = if first.coordinate <= second.coordinate {
258 (first, second)
259 } else {
260 (second, first)
261 };
262 self.pair_filter
263 .test(solution, &left.output, &right.output, 0, 1)
264 }
265
266 fn localized_owners(
267 &self,
268 descriptor_index: usize,
269 entity_index: usize,
270 ) -> Vec<ProjectedRowOwner> {
271 let mut owners = Vec::new();
272 for slot in 0..self.source.source_count() {
273 if self
274 .source
275 .change_source(slot)
276 .assert_localizes(descriptor_index, &self.constraint_ref.name)
277 {
278 owners.push(ProjectedRowOwner {
279 source_slot: slot,
280 entity_index,
281 });
282 }
283 }
284 owners
285 }
286
287 fn row_ids_for_owners(&self, owners: &[ProjectedRowOwner]) -> Vec<usize> {
288 let mut seen = HashSet::new();
289 let mut row_ids = Vec::new();
290 for owner in owners {
291 let Some(ids) = self.rows_by_owner.get(owner) else {
292 continue;
293 };
294 for &row_id in ids {
295 if seen.insert(row_id) {
296 row_ids.push(row_id);
297 }
298 }
299 }
300 row_ids
301 }
302
303 #[cfg(test)]
304 pub(crate) fn debug_row_storage_len(&self) -> usize {
305 self.rows.len()
306 }
307
308 #[cfg(test)]
309 pub(crate) fn debug_free_row_count(&self) -> usize {
310 self.free_row_ids.len()
311 }
312}
313
314impl<S, Out, K, Src, F, KF, PF, W, Sc> IncrementalConstraint<S, Sc>
315 for ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
316where
317 S: Send + Sync + 'static,
318 Out: Send + Sync + 'static,
319 K: Eq + Hash + Send + Sync + 'static,
320 Src: ProjectedSource<S, Out>,
321 F: UniFilter<S, Out>,
322 KF: Fn(&Out) -> K + Send + Sync,
323 PF: BiFilter<S, Out, Out>,
324 W: Fn(&Out, &Out) -> Sc + Send + Sync,
325 Sc: Score + 'static,
326{
327 fn evaluate(&self, solution: &S) -> Sc {
328 let rows = self.evaluate_rows(solution);
329
330 let mut total = Sc::zero();
331 for left_index in 0..rows.len() {
332 for right_index in (left_index + 1)..rows.len() {
333 total = total
334 + self.score_evaluation_pair(solution, &rows[left_index], &rows[right_index]);
335 }
336 }
337 total
338 }
339
340 fn match_count(&self, solution: &S) -> usize {
341 let rows = self.evaluate_rows(solution);
342
343 let mut count = 0;
344 for left_index in 0..rows.len() {
345 for right_index in (left_index + 1)..rows.len() {
346 if self.evaluation_pair_matches(solution, &rows[left_index], &rows[right_index]) {
347 count += 1;
348 }
349 }
350 }
351 count
352 }
353
354 fn initialize(&mut self, solution: &S) -> Sc {
355 self.reset();
356 let state = self.source.build_state(solution);
357 let mut rows = Vec::new();
358 self.source
359 .collect_all(solution, &state, |coordinate, output| {
360 if self.filter.test(solution, &output) {
361 rows.push((coordinate, output));
362 }
363 });
364 self.source_state = Some(state);
365
366 rows.into_iter()
367 .fold(Sc::zero(), |total, (coordinate, output)| {
368 total + self.insert_row(solution, coordinate, output)
369 })
370 }
371
372 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
373 let owners = self.localized_owners(descriptor_index, entity_index);
374 self.ensure_source_state(solution);
375 {
376 let state = self.source_state.as_mut().expect("projected source state");
377 for owner in &owners {
378 self.source.insert_entity_state(
379 solution,
380 state,
381 owner.source_slot,
382 owner.entity_index,
383 );
384 }
385 }
386 let mut rows = Vec::new();
387 let state = self.source_state.as_ref().expect("projected source state");
388 for owner in &owners {
389 self.source.collect_entity(
390 solution,
391 state,
392 owner.source_slot,
393 owner.entity_index,
394 |coordinate, output| {
395 if self.filter.test(solution, &output) {
396 rows.push((coordinate, output));
397 }
398 },
399 );
400 }
401 let mut total = Sc::zero();
402 for (coordinate, output) in rows {
403 total = total + self.insert_row(solution, coordinate, output);
404 }
405 total
406 }
407
408 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
409 let owners = self.localized_owners(descriptor_index, entity_index);
410 let mut total = Sc::zero();
411 for row_id in self.row_ids_for_owners(&owners) {
412 total = total + self.retract_row(solution, row_id);
413 }
414 if let Some(state) = self.source_state.as_mut() {
415 for owner in &owners {
416 self.source.retract_entity_state(
417 solution,
418 state,
419 owner.source_slot,
420 owner.entity_index,
421 );
422 }
423 }
424 total
425 }
426
427 fn reset(&mut self) {
428 self.source_state = None;
429 self.rows.clear();
430 self.free_row_ids.clear();
431 self.rows_by_owner.clear();
432 self.row_ids_by_coordinate.clear();
433 self.rows_by_key.clear();
434 }
435
436 fn name(&self) -> &str {
437 &self.constraint_ref.name
438 }
439
440 fn constraint_ref(&self) -> &ConstraintRef {
441 &self.constraint_ref
442 }
443
444 fn is_hard(&self) -> bool {
445 self.is_hard
446 }
447}