1use std::collections::HashMap;
6use std::hash::Hash;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use crate::api::constraint_set::IncrementalConstraint;
13
14pub struct FlattenedBiConstraint<
95 S,
96 A,
97 B,
98 C,
99 K,
100 CK,
101 EA,
102 EB,
103 KA,
104 KB,
105 Flatten,
106 CKeyFn,
107 ALookup,
108 F,
109 W,
110 Sc,
111> where
112 Sc: Score,
113{
114 constraint_ref: ConstraintRef,
115 impact_type: ImpactType,
116 extractor_a: EA,
117 extractor_b: EB,
118 key_a: KA,
119 key_b: KB,
120 flatten: Flatten,
121 c_key_fn: CKeyFn,
122 a_lookup_fn: ALookup,
123 filter: F,
124 weight: W,
125 is_hard: bool,
126 c_index: HashMap<(K, CK), Vec<(usize, C)>>,
128 a_scores: HashMap<usize, Sc>,
130 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B)>,
131}
132
133impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
134 FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
135where
136 S: 'static,
137 A: Clone + 'static,
138 B: Clone + 'static,
139 C: Clone + 'static,
140 K: Eq + Hash + Clone,
141 CK: Eq + Hash + Clone,
142 EA: Fn(&S) -> &[A],
143 EB: Fn(&S) -> &[B],
144 KA: Fn(&A) -> K,
145 KB: Fn(&B) -> K,
146 Flatten: Fn(&B) -> &[C],
147 CKeyFn: Fn(&C) -> CK,
148 ALookup: Fn(&A) -> CK,
149 F: Fn(&S, &A, &C) -> bool,
150 W: Fn(&A, &C) -> Sc,
151 Sc: Score,
152{
153 #[allow(clippy::too_many_arguments)]
155 pub fn new(
156 constraint_ref: ConstraintRef,
157 impact_type: ImpactType,
158 extractor_a: EA,
159 extractor_b: EB,
160 key_a: KA,
161 key_b: KB,
162 flatten: Flatten,
163 c_key_fn: CKeyFn,
164 a_lookup_fn: ALookup,
165 filter: F,
166 weight: W,
167 is_hard: bool,
168 ) -> Self {
169 Self {
170 constraint_ref,
171 impact_type,
172 extractor_a,
173 extractor_b,
174 key_a,
175 key_b,
176 flatten,
177 c_key_fn,
178 a_lookup_fn,
179 filter,
180 weight,
181 is_hard,
182 c_index: HashMap::new(),
183 a_scores: HashMap::new(),
184 _phantom: PhantomData,
185 }
186 }
187
188 #[inline]
189 fn compute_score(&self, a: &A, c: &C) -> Sc {
190 let base = (self.weight)(a, c);
191 match self.impact_type {
192 ImpactType::Penalty => -base,
193 ImpactType::Reward => base,
194 }
195 }
196
197 fn build_c_index(&mut self, entities_b: &[B]) {
199 self.c_index.clear();
200 for (b_idx, b) in entities_b.iter().enumerate() {
201 let join_key = (self.key_b)(b);
202 for c in (self.flatten)(b) {
203 let c_key = (self.c_key_fn)(c);
204 self.c_index
205 .entry((join_key.clone(), c_key))
206 .or_default()
207 .push((b_idx, c.clone()));
208 }
209 }
210 }
211
212 fn compute_a_score(&self, solution: &S, a: &A) -> Sc {
214 let join_key = (self.key_a)(a);
215 let lookup_key = (self.a_lookup_fn)(a);
216
217 let matches = match self.c_index.get(&(join_key, lookup_key)) {
219 Some(v) => v.as_slice(),
220 None => return Sc::zero(),
221 };
222
223 let mut total = Sc::zero();
224 for (_b_idx, c) in matches {
225 if (self.filter)(solution, a, c) {
226 total = total + self.compute_score(a, c);
227 }
228 }
229 total
230 }
231
232 fn insert_a(&mut self, solution: &S, entities_a: &[A], a_idx: usize) -> Sc {
233 if a_idx >= entities_a.len() {
234 return Sc::zero();
235 }
236
237 let a = &entities_a[a_idx];
238 let score = self.compute_a_score(solution, a);
239
240 if score != Sc::zero() {
241 self.a_scores.insert(a_idx, score);
242 }
243 score
244 }
245
246 fn retract_a(&mut self, a_idx: usize) -> Sc {
247 match self.a_scores.remove(&a_idx) {
248 Some(score) => -score,
249 None => Sc::zero(),
250 }
251 }
252}
253
254impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
255 IncrementalConstraint<S, Sc>
256 for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
257where
258 S: Send + Sync + 'static,
259 A: Clone + Send + Sync + 'static,
260 B: Clone + Send + Sync + 'static,
261 C: Clone + Send + Sync + 'static,
262 K: Eq + Hash + Clone + Send + Sync,
263 CK: Eq + Hash + Clone + Send + Sync,
264 EA: Fn(&S) -> &[A] + Send + Sync,
265 EB: Fn(&S) -> &[B] + Send + Sync,
266 KA: Fn(&A) -> K + Send + Sync,
267 KB: Fn(&B) -> K + Send + Sync,
268 Flatten: Fn(&B) -> &[C] + Send + Sync,
269 CKeyFn: Fn(&C) -> CK + Send + Sync,
270 ALookup: Fn(&A) -> CK + Send + Sync,
271 F: Fn(&S, &A, &C) -> bool + Send + Sync,
272 W: Fn(&A, &C) -> Sc + Send + Sync,
273 Sc: Score,
274{
275 fn evaluate(&self, solution: &S) -> Sc {
276 let entities_a = (self.extractor_a)(solution);
277 let entities_b = (self.extractor_b)(solution);
278 let mut total = Sc::zero();
279
280 let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
282 for (b_idx, b) in entities_b.iter().enumerate() {
283 let join_key = (self.key_b)(b);
284 for c in (self.flatten)(b) {
285 let c_key = (self.c_key_fn)(c);
286 temp_index
287 .entry((join_key.clone(), c_key))
288 .or_default()
289 .push((b_idx, c.clone()));
290 }
291 }
292
293 for a in entities_a {
294 let join_key = (self.key_a)(a);
295 let lookup_key = (self.a_lookup_fn)(a);
296
297 if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
298 for (_b_idx, c) in matches {
299 if (self.filter)(solution, a, c) {
300 total = total + self.compute_score(a, c);
301 }
302 }
303 }
304 }
305
306 total
307 }
308
309 fn match_count(&self, solution: &S) -> usize {
310 let entities_a = (self.extractor_a)(solution);
311 let entities_b = (self.extractor_b)(solution);
312 let mut count = 0;
313
314 let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
316 for (b_idx, b) in entities_b.iter().enumerate() {
317 let join_key = (self.key_b)(b);
318 for c in (self.flatten)(b) {
319 let c_key = (self.c_key_fn)(c);
320 temp_index
321 .entry((join_key.clone(), c_key))
322 .or_default()
323 .push((b_idx, c.clone()));
324 }
325 }
326
327 for a in entities_a {
328 let join_key = (self.key_a)(a);
329 let lookup_key = (self.a_lookup_fn)(a);
330
331 if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
332 for (_b_idx, c) in matches {
333 if (self.filter)(solution, a, c) {
334 count += 1;
335 }
336 }
337 }
338 }
339
340 count
341 }
342
343 fn initialize(&mut self, solution: &S) -> Sc {
344 self.reset();
345
346 let entities_a = (self.extractor_a)(solution);
347 let entities_b = (self.extractor_b)(solution);
348
349 self.build_c_index(entities_b);
351
352 let mut total = Sc::zero();
354 for a_idx in 0..entities_a.len() {
355 total = total + self.insert_a(solution, entities_a, a_idx);
356 }
357
358 total
359 }
360
361 fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
362 let entities_a = (self.extractor_a)(solution);
363 self.insert_a(solution, entities_a, entity_index)
364 }
365
366 fn on_retract(&mut self, _solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
367 self.retract_a(entity_index)
368 }
369
370 fn reset(&mut self) {
371 self.c_index.clear();
372 self.a_scores.clear();
373 }
374
375 fn name(&self) -> &str {
376 &self.constraint_ref.name
377 }
378
379 fn is_hard(&self) -> bool {
380 self.is_hard
381 }
382
383 fn constraint_ref(&self) -> ConstraintRef {
384 self.constraint_ref.clone()
385 }
386}
387
388impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc: Score> std::fmt::Debug
389 for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
390{
391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 f.debug_struct("FlattenedBiConstraint")
393 .field("name", &self.constraint_ref.name)
394 .field("impact_type", &self.impact_type)
395 .field("c_index_size", &self.c_index.len())
396 .finish()
397 }
398}