1use std::collections::HashMap;
2use std::hash::Hash;
3use std::marker::PhantomData;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::filter::{BiFilter, UniFilter};
10use crate::stream::{ProjectedRowCoordinate, ProjectedSource};
11
12struct ProjectedJoinRow<Out, K> {
13 key: K,
14 output: Out,
15 order: ProjectedRowCoordinate,
16}
17
18pub struct ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
19where
20 Sc: Score,
21{
22 constraint_ref: ConstraintRef,
23 impact_type: ImpactType,
24 source: Src,
25 filter: F,
26 key_fn: KF,
27 pair_filter: PF,
28 weight: W,
29 is_hard: bool,
30 rows: Vec<Option<ProjectedJoinRow<Out, K>>>,
31 free_row_ids: Vec<usize>,
32 rows_by_entity: HashMap<(usize, usize), Vec<usize>>,
33 rows_by_key: HashMap<K, Vec<usize>>,
34 _phantom: PhantomData<(fn() -> S, fn() -> Out, fn() -> Sc)>,
35}
36
37impl<S, Out, K, Src, F, KF, PF, W, Sc> ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
38where
39 S: Send + Sync + 'static,
40 Out: Clone + Send + Sync + 'static,
41 K: Clone + Eq + Hash + Send + Sync + 'static,
42 Src: ProjectedSource<S, Out>,
43 F: UniFilter<S, Out>,
44 KF: Fn(&Out) -> K + Send + Sync,
45 PF: BiFilter<S, Out, Out>,
46 W: Fn(&Out, &Out) -> Sc + Send + Sync,
47 Sc: Score + 'static,
48{
49 #[allow(clippy::too_many_arguments)]
50 pub fn new(
51 constraint_ref: ConstraintRef,
52 impact_type: ImpactType,
53 source: Src,
54 filter: F,
55 key_fn: KF,
56 pair_filter: PF,
57 weight: W,
58 is_hard: bool,
59 ) -> Self {
60 Self {
61 constraint_ref,
62 impact_type,
63 source,
64 filter,
65 key_fn,
66 pair_filter,
67 weight,
68 is_hard,
69 rows: Vec::new(),
70 free_row_ids: Vec::new(),
71 rows_by_entity: HashMap::new(),
72 rows_by_key: HashMap::new(),
73 _phantom: PhantomData,
74 }
75 }
76
77 fn compute_score(&self, left: &Out, right: &Out) -> Sc {
78 let base = (self.weight)(left, right);
79 match self.impact_type {
80 ImpactType::Penalty => -base,
81 ImpactType::Reward => base,
82 }
83 }
84
85 fn score_ordered_rows(
86 &self,
87 solution: &S,
88 first: &ProjectedJoinRow<Out, K>,
89 second: &ProjectedJoinRow<Out, K>,
90 ) -> Sc {
91 let (left, right) = if first.order <= second.order {
92 (first, second)
93 } else {
94 (second, first)
95 };
96 if !self
97 .pair_filter
98 .test(solution, &left.output, &right.output, 0, 1)
99 {
100 return Sc::zero();
101 }
102 self.compute_score(&left.output, &right.output)
103 }
104
105 fn score_pair(&self, solution: &S, first_id: usize, second_id: usize) -> Sc {
106 let Some(first) = self.rows.get(first_id).and_then(Option::as_ref) else {
107 return Sc::zero();
108 };
109 let Some(second) = self.rows.get(second_id).and_then(Option::as_ref) else {
110 return Sc::zero();
111 };
112 self.score_ordered_rows(solution, first, second)
113 }
114
115 fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
116 let key = (self.key_fn)(&output);
117 let existing = self.rows_by_key.get(&key).cloned().unwrap_or_default();
118 let row = Some(ProjectedJoinRow {
119 key: key.clone(),
120 output,
121 order: coordinate,
122 });
123 let row_id = if let Some(row_id) = self.free_row_ids.pop() {
124 debug_assert!(self.rows[row_id].is_none());
125 self.rows[row_id] = row;
126 row_id
127 } else {
128 let row_id = self.rows.len();
129 self.rows.push(row);
130 row_id
131 };
132 self.rows_by_entity
133 .entry((coordinate.source_slot, coordinate.entity_index))
134 .or_default()
135 .push(row_id);
136
137 let mut total = Sc::zero();
138 for other_id in existing {
139 total = total + self.score_pair(solution, row_id, other_id);
140 }
141 self.rows_by_key.entry(key).or_default().push(row_id);
142 total
143 }
144
145 fn retract_row(&mut self, solution: &S, row_id: usize) -> Sc {
146 let Some(row) = self.rows.get(row_id).and_then(Option::as_ref) else {
147 return Sc::zero();
148 };
149 let key = row.key.clone();
150 let candidates = self.rows_by_key.get(&key).cloned().unwrap_or_default();
151 let mut total = Sc::zero();
152 for other_id in candidates {
153 if other_id == row_id {
154 continue;
155 }
156 total = total - self.score_pair(solution, row_id, other_id);
157 }
158
159 if let Some(ids) = self.rows_by_key.get_mut(&key) {
160 ids.retain(|&id| id != row_id);
161 if ids.is_empty() {
162 self.rows_by_key.remove(&key);
163 }
164 }
165 self.rows[row_id] = None;
166 self.free_row_ids.push(row_id);
167 total
168 }
169
170 fn insert_entity_outputs(&mut self, solution: &S, slot: usize, entity_index: usize) -> Sc {
171 let mut outputs = Vec::new();
172 self.source
173 .collect_entity(solution, slot, entity_index, |coordinate, output| {
174 if self.filter.test(solution, &output) {
175 outputs.push((coordinate, output));
176 }
177 });
178
179 outputs
180 .into_iter()
181 .fold(Sc::zero(), |total, (coordinate, output)| {
182 total + self.insert_row(solution, coordinate, output)
183 })
184 }
185
186 fn retract_entity_outputs(&mut self, solution: &S, slot: usize, entity_index: usize) -> Sc {
187 let Some(row_ids) = self.rows_by_entity.remove(&(slot, entity_index)) else {
188 return Sc::zero();
189 };
190 row_ids.into_iter().fold(Sc::zero(), |total, row_id| {
191 total + self.retract_row(solution, row_id)
192 })
193 }
194
195 fn evaluate_rows(&self, solution: &S) -> Vec<ProjectedJoinRow<Out, K>> {
196 let mut rows = Vec::new();
197 self.source.collect_all(solution, |coordinate, output| {
198 if self.filter.test(solution, &output) {
199 rows.push(ProjectedJoinRow {
200 key: (self.key_fn)(&output),
201 output,
202 order: coordinate,
203 });
204 }
205 });
206 rows
207 }
208
209 fn score_evaluation_pair(
210 &self,
211 solution: &S,
212 first: &ProjectedJoinRow<Out, K>,
213 second: &ProjectedJoinRow<Out, K>,
214 ) -> Sc {
215 if first.key == second.key {
216 self.score_ordered_rows(solution, first, second)
217 } else {
218 Sc::zero()
219 }
220 }
221
222 fn evaluation_pair_matches(
223 &self,
224 solution: &S,
225 first: &ProjectedJoinRow<Out, K>,
226 second: &ProjectedJoinRow<Out, K>,
227 ) -> bool {
228 if first.key != second.key {
229 return false;
230 }
231 let (left, right) = if first.order <= second.order {
232 (first, second)
233 } else {
234 (second, first)
235 };
236 self.pair_filter
237 .test(solution, &left.output, &right.output, 0, 1)
238 }
239
240 fn localized_slots(&self, descriptor_index: usize) -> Vec<usize> {
241 let mut slots = Vec::new();
242 for slot in 0..self.source.source_count() {
243 if self
244 .source
245 .change_source(slot)
246 .assert_localizes(descriptor_index, &self.constraint_ref.name)
247 {
248 slots.push(slot);
249 }
250 }
251 slots
252 }
253
254 #[cfg(test)]
255 pub(crate) fn debug_row_storage_len(&self) -> usize {
256 self.rows.len()
257 }
258
259 #[cfg(test)]
260 pub(crate) fn debug_free_row_count(&self) -> usize {
261 self.free_row_ids.len()
262 }
263}
264
265impl<S, Out, K, Src, F, KF, PF, W, Sc> IncrementalConstraint<S, Sc>
266 for ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
267where
268 S: Send + Sync + 'static,
269 Out: Clone + Send + Sync + 'static,
270 K: Clone + Eq + Hash + Send + Sync + 'static,
271 Src: ProjectedSource<S, Out>,
272 F: UniFilter<S, Out>,
273 KF: Fn(&Out) -> K + Send + Sync,
274 PF: BiFilter<S, Out, Out>,
275 W: Fn(&Out, &Out) -> Sc + Send + Sync,
276 Sc: Score + 'static,
277{
278 fn evaluate(&self, solution: &S) -> Sc {
279 let rows = self.evaluate_rows(solution);
280
281 let mut total = Sc::zero();
282 for left_index in 0..rows.len() {
283 for right_index in (left_index + 1)..rows.len() {
284 total = total
285 + self.score_evaluation_pair(solution, &rows[left_index], &rows[right_index]);
286 }
287 }
288 total
289 }
290
291 fn match_count(&self, solution: &S) -> usize {
292 let rows = self.evaluate_rows(solution);
293
294 let mut count = 0;
295 for left_index in 0..rows.len() {
296 for right_index in (left_index + 1)..rows.len() {
297 if self.evaluation_pair_matches(solution, &rows[left_index], &rows[right_index]) {
298 count += 1;
299 }
300 }
301 }
302 count
303 }
304
305 fn initialize(&mut self, solution: &S) -> Sc {
306 self.reset();
307 let mut rows = Vec::new();
308 self.source.collect_all(solution, |coordinate, output| {
309 if self.filter.test(solution, &output) {
310 rows.push((coordinate, output));
311 }
312 });
313
314 rows.into_iter()
315 .fold(Sc::zero(), |total, (coordinate, output)| {
316 total + self.insert_row(solution, coordinate, output)
317 })
318 }
319
320 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
321 let mut total = Sc::zero();
322 for slot in self.localized_slots(descriptor_index) {
323 total = total + self.insert_entity_outputs(solution, slot, entity_index);
324 }
325 total
326 }
327
328 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
329 let mut total = Sc::zero();
330 for slot in self.localized_slots(descriptor_index) {
331 total = total + self.retract_entity_outputs(solution, slot, entity_index);
332 }
333 total
334 }
335
336 fn reset(&mut self) {
337 self.rows.clear();
338 self.free_row_ids.clear();
339 self.rows_by_entity.clear();
340 self.rows_by_key.clear();
341 }
342
343 fn name(&self) -> &str {
344 &self.constraint_ref.name
345 }
346
347 fn is_hard(&self) -> bool {
348 self.is_hard
349 }
350
351 fn constraint_ref(&self) -> ConstraintRef {
352 self.constraint_ref.clone()
353 }
354}