1use std::collections::HashMap;
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::api::constraint_set::IncrementalConstraint;
14pub struct FlattenedBiConstraint<
96 S,
97 A,
98 B,
99 C,
100 K,
101 CK,
102 EA,
103 EB,
104 KA,
105 KB,
106 Flatten,
107 CKeyFn,
108 ALookup,
109 F,
110 W,
111 Sc,
112> where
113 Sc: Score,
114{
115 constraint_ref: ConstraintRef,
116 impact_type: ImpactType,
117 extractor_a: EA,
118 extractor_b: EB,
119 key_a: KA,
120 key_b: KB,
121 flatten: Flatten,
122 c_key_fn: CKeyFn,
123 a_lookup_fn: ALookup,
124 filter: F,
125 weight: W,
126 is_hard: bool,
127 c_index: HashMap<(K, CK), Vec<(usize, C)>>,
129 a_scores: HashMap<usize, Sc>,
131 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B)>,
132}
133
134impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
135 FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
136where
137 S: 'static,
138 A: Clone + 'static,
139 B: Clone + 'static,
140 C: Clone + 'static,
141 K: Eq + Hash + Clone,
142 CK: Eq + Hash + Clone,
143 EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
144 EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
145 KA: Fn(&A) -> K,
146 KB: Fn(&B) -> K,
147 Flatten: Fn(&B) -> &[C],
148 CKeyFn: Fn(&C) -> CK,
149 ALookup: Fn(&A) -> CK,
150 F: Fn(&S, &A, &C) -> bool,
151 W: Fn(&A, &C) -> Sc,
152 Sc: Score,
153{
154 #[allow(clippy::too_many_arguments)]
156 pub fn new(
157 constraint_ref: ConstraintRef,
158 impact_type: ImpactType,
159 extractor_a: EA,
160 extractor_b: EB,
161 key_a: KA,
162 key_b: KB,
163 flatten: Flatten,
164 c_key_fn: CKeyFn,
165 a_lookup_fn: ALookup,
166 filter: F,
167 weight: W,
168 is_hard: bool,
169 ) -> Self {
170 Self {
171 constraint_ref,
172 impact_type,
173 extractor_a,
174 extractor_b,
175 key_a,
176 key_b,
177 flatten,
178 c_key_fn,
179 a_lookup_fn,
180 filter,
181 weight,
182 is_hard,
183 c_index: HashMap::new(),
184 a_scores: HashMap::new(),
185 _phantom: PhantomData,
186 }
187 }
188
189 #[inline]
190 fn compute_score(&self, a: &A, c: &C) -> Sc {
191 let base = (self.weight)(a, c);
192 match self.impact_type {
193 ImpactType::Penalty => -base,
194 ImpactType::Reward => base,
195 }
196 }
197
198 fn build_c_index(&mut self, entities_b: &[B]) {
200 self.c_index.clear();
201 for (b_idx, b) in entities_b.iter().enumerate() {
202 let join_key = (self.key_b)(b);
203 for c in (self.flatten)(b) {
204 let c_key = (self.c_key_fn)(c);
205 self.c_index
206 .entry((join_key.clone(), c_key))
207 .or_default()
208 .push((b_idx, c.clone()));
209 }
210 }
211 }
212
213 fn compute_a_score(&self, solution: &S, a: &A) -> Sc {
215 let join_key = (self.key_a)(a);
216 let lookup_key = (self.a_lookup_fn)(a);
217
218 let matches = match self.c_index.get(&(join_key, lookup_key)) {
220 Some(v) => v.as_slice(),
221 None => return Sc::zero(),
222 };
223
224 let mut total = Sc::zero();
225 for (_b_idx, c) in matches {
226 if (self.filter)(solution, a, c) {
227 total = total + self.compute_score(a, c);
228 }
229 }
230 total
231 }
232
233 fn insert_a(&mut self, solution: &S, entities_a: &[A], a_idx: usize) -> Sc {
234 if a_idx >= entities_a.len() {
235 return Sc::zero();
236 }
237
238 let a = &entities_a[a_idx];
239 let score = self.compute_a_score(solution, a);
240
241 if score != Sc::zero() {
242 self.a_scores.insert(a_idx, score);
243 }
244 score
245 }
246
247 fn retract_a(&mut self, a_idx: usize) -> Sc {
248 match self.a_scores.remove(&a_idx) {
249 Some(score) => -score,
250 None => Sc::zero(),
251 }
252 }
253}
254
255impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
256 IncrementalConstraint<S, Sc>
257 for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
258where
259 S: Send + Sync + 'static,
260 A: Clone + Send + Sync + 'static,
261 B: Clone + Send + Sync + 'static,
262 C: Clone + Send + Sync + 'static,
263 K: Eq + Hash + Clone + Send + Sync,
264 CK: Eq + Hash + Clone + Send + Sync,
265 EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
266 EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
267 KA: Fn(&A) -> K + Send + Sync,
268 KB: Fn(&B) -> K + Send + Sync,
269 Flatten: Fn(&B) -> &[C] + Send + Sync,
270 CKeyFn: Fn(&C) -> CK + Send + Sync,
271 ALookup: Fn(&A) -> CK + Send + Sync,
272 F: Fn(&S, &A, &C) -> bool + Send + Sync,
273 W: Fn(&A, &C) -> Sc + Send + Sync,
274 Sc: Score,
275{
276 fn evaluate(&self, solution: &S) -> Sc {
277 let entities_a = self.extractor_a.extract(solution);
278 let entities_b = self.extractor_b.extract(solution);
279 let mut total = Sc::zero();
280
281 let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
283 for (b_idx, b) in entities_b.iter().enumerate() {
284 let join_key = (self.key_b)(b);
285 for c in (self.flatten)(b) {
286 let c_key = (self.c_key_fn)(c);
287 temp_index
288 .entry((join_key.clone(), c_key))
289 .or_default()
290 .push((b_idx, c.clone()));
291 }
292 }
293
294 for a in entities_a {
295 let join_key = (self.key_a)(a);
296 let lookup_key = (self.a_lookup_fn)(a);
297
298 if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
299 for (_b_idx, c) in matches {
300 if (self.filter)(solution, a, c) {
301 total = total + self.compute_score(a, c);
302 }
303 }
304 }
305 }
306
307 total
308 }
309
310 fn match_count(&self, solution: &S) -> usize {
311 let entities_a = self.extractor_a.extract(solution);
312 let entities_b = self.extractor_b.extract(solution);
313 let mut count = 0;
314
315 let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
317 for (b_idx, b) in entities_b.iter().enumerate() {
318 let join_key = (self.key_b)(b);
319 for c in (self.flatten)(b) {
320 let c_key = (self.c_key_fn)(c);
321 temp_index
322 .entry((join_key.clone(), c_key))
323 .or_default()
324 .push((b_idx, c.clone()));
325 }
326 }
327
328 for a in entities_a {
329 let join_key = (self.key_a)(a);
330 let lookup_key = (self.a_lookup_fn)(a);
331
332 if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
333 for (_b_idx, c) in matches {
334 if (self.filter)(solution, a, c) {
335 count += 1;
336 }
337 }
338 }
339 }
340
341 count
342 }
343
344 fn initialize(&mut self, solution: &S) -> Sc {
345 self.reset();
346
347 let entities_a = self.extractor_a.extract(solution);
348 let entities_b = self.extractor_b.extract(solution);
349
350 self.build_c_index(entities_b);
352
353 let mut total = Sc::zero();
355 for a_idx in 0..entities_a.len() {
356 total = total + self.insert_a(solution, entities_a, a_idx);
357 }
358
359 total
360 }
361
362 fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
363 let entities_a = self.extractor_a.extract(solution);
364 self.insert_a(solution, entities_a, entity_index)
365 }
366
367 fn on_retract(&mut self, _solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
368 self.retract_a(entity_index)
369 }
370
371 fn reset(&mut self) {
372 self.c_index.clear();
373 self.a_scores.clear();
374 }
375
376 fn name(&self) -> &str {
377 &self.constraint_ref.name
378 }
379
380 fn is_hard(&self) -> bool {
381 self.is_hard
382 }
383
384 fn constraint_ref(&self) -> ConstraintRef {
385 self.constraint_ref.clone()
386 }
387}
388
389impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc: Score> std::fmt::Debug
390 for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
391{
392 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
393 f.debug_struct("FlattenedBiConstraint")
394 .field("name", &self.constraint_ref.name)
395 .field("impact_type", &self.impact_type)
396 .field("c_index_size", &self.c_index.len())
397 .finish()
398 }
399}