1use std::collections::HashMap;
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::api::constraint_set::IncrementalConstraint;
14
15pub struct FlattenedBiConstraint<
97 S,
98 A,
99 B,
100 C,
101 K,
102 CK,
103 EA,
104 EB,
105 KA,
106 KB,
107 Flatten,
108 CKeyFn,
109 ALookup,
110 F,
111 W,
112 Sc,
113> where
114 Sc: Score,
115{
116 constraint_ref: ConstraintRef,
117 impact_type: ImpactType,
118 extractor_a: EA,
119 extractor_b: EB,
120 key_a: KA,
121 key_b: KB,
122 flatten: Flatten,
123 c_key_fn: CKeyFn,
124 a_lookup_fn: ALookup,
125 filter: F,
126 weight: W,
127 is_hard: bool,
128 c_index: HashMap<(K, CK), Vec<(usize, C)>>,
130 a_scores: HashMap<usize, Sc>,
132 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B)>,
133}
134
135impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
136 FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
137where
138 S: 'static,
139 A: Clone + 'static,
140 B: Clone + 'static,
141 C: Clone + 'static,
142 K: Eq + Hash + Clone,
143 CK: Eq + Hash + Clone,
144 EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
145 EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
146 KA: Fn(&A) -> K,
147 KB: Fn(&B) -> K,
148 Flatten: Fn(&B) -> &[C],
149 CKeyFn: Fn(&C) -> CK,
150 ALookup: Fn(&A) -> CK,
151 F: Fn(&S, &A, &C) -> bool,
152 W: Fn(&A, &C) -> Sc,
153 Sc: Score,
154{
155 #[allow(clippy::too_many_arguments)]
157 pub fn new(
158 constraint_ref: ConstraintRef,
159 impact_type: ImpactType,
160 extractor_a: EA,
161 extractor_b: EB,
162 key_a: KA,
163 key_b: KB,
164 flatten: Flatten,
165 c_key_fn: CKeyFn,
166 a_lookup_fn: ALookup,
167 filter: F,
168 weight: W,
169 is_hard: bool,
170 ) -> Self {
171 Self {
172 constraint_ref,
173 impact_type,
174 extractor_a,
175 extractor_b,
176 key_a,
177 key_b,
178 flatten,
179 c_key_fn,
180 a_lookup_fn,
181 filter,
182 weight,
183 is_hard,
184 c_index: HashMap::new(),
185 a_scores: HashMap::new(),
186 _phantom: PhantomData,
187 }
188 }
189
190 #[inline]
191 fn compute_score(&self, a: &A, c: &C) -> Sc {
192 let base = (self.weight)(a, c);
193 match self.impact_type {
194 ImpactType::Penalty => -base,
195 ImpactType::Reward => base,
196 }
197 }
198
199 fn build_c_index(&mut self, entities_b: &[B]) {
201 self.c_index.clear();
202 for (b_idx, b) in entities_b.iter().enumerate() {
203 let join_key = (self.key_b)(b);
204 for c in (self.flatten)(b) {
205 let c_key = (self.c_key_fn)(c);
206 self.c_index
207 .entry((join_key.clone(), c_key))
208 .or_default()
209 .push((b_idx, c.clone()));
210 }
211 }
212 }
213
214 fn compute_a_score(&self, solution: &S, a: &A) -> Sc {
216 let join_key = (self.key_a)(a);
217 let lookup_key = (self.a_lookup_fn)(a);
218
219 let matches = match self.c_index.get(&(join_key, lookup_key)) {
221 Some(v) => v.as_slice(),
222 None => return Sc::zero(),
223 };
224
225 let mut total = Sc::zero();
226 for (_b_idx, c) in matches {
227 if (self.filter)(solution, a, c) {
228 total = total + self.compute_score(a, c);
229 }
230 }
231 total
232 }
233
234 fn insert_a(&mut self, solution: &S, entities_a: &[A], a_idx: usize) -> Sc {
235 if a_idx >= entities_a.len() {
236 return Sc::zero();
237 }
238
239 let a = &entities_a[a_idx];
240 let score = self.compute_a_score(solution, a);
241
242 if score != Sc::zero() {
243 self.a_scores.insert(a_idx, score);
244 }
245 score
246 }
247
248 fn retract_a(&mut self, a_idx: usize) -> Sc {
249 match self.a_scores.remove(&a_idx) {
250 Some(score) => -score,
251 None => Sc::zero(),
252 }
253 }
254}
255
256impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
257 IncrementalConstraint<S, Sc>
258 for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
259where
260 S: Send + Sync + 'static,
261 A: Clone + Send + Sync + 'static,
262 B: Clone + Send + Sync + 'static,
263 C: Clone + Send + Sync + 'static,
264 K: Eq + Hash + Clone + Send + Sync,
265 CK: Eq + Hash + Clone + Send + Sync,
266 EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
267 EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
268 KA: Fn(&A) -> K + Send + Sync,
269 KB: Fn(&B) -> K + Send + Sync,
270 Flatten: Fn(&B) -> &[C] + Send + Sync,
271 CKeyFn: Fn(&C) -> CK + Send + Sync,
272 ALookup: Fn(&A) -> CK + Send + Sync,
273 F: Fn(&S, &A, &C) -> bool + Send + Sync,
274 W: Fn(&A, &C) -> Sc + Send + Sync,
275 Sc: Score,
276{
277 fn evaluate(&self, solution: &S) -> Sc {
278 let entities_a = self.extractor_a.extract(solution);
279 let entities_b = self.extractor_b.extract(solution);
280 let mut total = Sc::zero();
281
282 let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
284 for (b_idx, b) in entities_b.iter().enumerate() {
285 let join_key = (self.key_b)(b);
286 for c in (self.flatten)(b) {
287 let c_key = (self.c_key_fn)(c);
288 temp_index
289 .entry((join_key.clone(), c_key))
290 .or_default()
291 .push((b_idx, c.clone()));
292 }
293 }
294
295 for a in entities_a {
296 let join_key = (self.key_a)(a);
297 let lookup_key = (self.a_lookup_fn)(a);
298
299 if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
300 for (_b_idx, c) in matches {
301 if (self.filter)(solution, a, c) {
302 total = total + self.compute_score(a, c);
303 }
304 }
305 }
306 }
307
308 total
309 }
310
311 fn match_count(&self, solution: &S) -> usize {
312 let entities_a = self.extractor_a.extract(solution);
313 let entities_b = self.extractor_b.extract(solution);
314 let mut count = 0;
315
316 let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
318 for (b_idx, b) in entities_b.iter().enumerate() {
319 let join_key = (self.key_b)(b);
320 for c in (self.flatten)(b) {
321 let c_key = (self.c_key_fn)(c);
322 temp_index
323 .entry((join_key.clone(), c_key))
324 .or_default()
325 .push((b_idx, c.clone()));
326 }
327 }
328
329 for a in entities_a {
330 let join_key = (self.key_a)(a);
331 let lookup_key = (self.a_lookup_fn)(a);
332
333 if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
334 for (_b_idx, c) in matches {
335 if (self.filter)(solution, a, c) {
336 count += 1;
337 }
338 }
339 }
340 }
341
342 count
343 }
344
345 fn initialize(&mut self, solution: &S) -> Sc {
346 self.reset();
347
348 let entities_a = self.extractor_a.extract(solution);
349 let entities_b = self.extractor_b.extract(solution);
350
351 self.build_c_index(entities_b);
353
354 let mut total = Sc::zero();
356 for a_idx in 0..entities_a.len() {
357 total = total + self.insert_a(solution, entities_a, a_idx);
358 }
359
360 total
361 }
362
363 fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
364 let entities_a = self.extractor_a.extract(solution);
365 self.insert_a(solution, entities_a, entity_index)
366 }
367
368 fn on_retract(&mut self, _solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
369 self.retract_a(entity_index)
370 }
371
372 fn reset(&mut self) {
373 self.c_index.clear();
374 self.a_scores.clear();
375 }
376
377 fn name(&self) -> &str {
378 &self.constraint_ref.name
379 }
380
381 fn is_hard(&self) -> bool {
382 self.is_hard
383 }
384
385 fn constraint_ref(&self) -> ConstraintRef {
386 self.constraint_ref.clone()
387 }
388}
389
390impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc: Score> std::fmt::Debug
391 for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
392{
393 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394 f.debug_struct("FlattenedBiConstraint")
395 .field("name", &self.constraint_ref.name)
396 .field("impact_type", &self.impact_type)
397 .field("c_index_size", &self.c_index.len())
398 .finish()
399 }
400}