1use std::collections::{HashMap, HashSet};
7use std::fmt::Debug;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::analysis::DetailedConstraintMatch;
15use crate::api::constraint_set::IncrementalConstraint;
16
17pub struct IncrementalQuadConstraint<S, A, K, E, KE, F, W, Sc>
75where
76 Sc: Score,
77{
78 constraint_ref: ConstraintRef,
79 impact_type: ImpactType,
80 extractor: E,
81 key_extractor: KE,
82 filter: F,
83 weight: W,
84 is_hard: bool,
85 entity_to_matches: HashMap<usize, HashSet<(usize, usize, usize, usize)>>,
87 matches: HashSet<(usize, usize, usize, usize)>,
89 key_to_indices: HashMap<K, HashSet<usize>>,
91 index_to_key: HashMap<usize, K>,
93 _phantom: PhantomData<(S, A, Sc)>,
94}
95
96impl<S, A, K, E, KE, F, W, Sc> IncrementalQuadConstraint<S, A, K, E, KE, F, W, Sc>
97where
98 S: 'static,
99 A: Clone + 'static,
100 K: Eq + Hash + Clone,
101 E: Fn(&S) -> &[A],
102 KE: Fn(&A) -> K,
103 F: Fn(&A, &A, &A, &A) -> bool,
104 W: Fn(&A, &A, &A, &A) -> Sc,
105 Sc: Score,
106{
107 pub fn new(
109 constraint_ref: ConstraintRef,
110 impact_type: ImpactType,
111 extractor: E,
112 key_extractor: KE,
113 filter: F,
114 weight: W,
115 is_hard: bool,
116 ) -> Self {
117 Self {
118 constraint_ref,
119 impact_type,
120 extractor,
121 key_extractor,
122 filter,
123 weight,
124 is_hard,
125 entity_to_matches: HashMap::new(),
126 matches: HashSet::new(),
127 key_to_indices: HashMap::new(),
128 index_to_key: HashMap::new(),
129 _phantom: PhantomData,
130 }
131 }
132
133 #[inline]
134 fn compute_score(&self, a: &A, b: &A, c: &A, d: &A) -> Sc {
135 let base = (self.weight)(a, b, c, d);
136 match self.impact_type {
137 ImpactType::Penalty => -base,
138 ImpactType::Reward => base,
139 }
140 }
141
142 fn insert_entity(&mut self, entities: &[A], index: usize) -> Sc {
144 if index >= entities.len() {
145 return Sc::zero();
146 }
147
148 let entity = &entities[index];
149 let key = (self.key_extractor)(entity);
150
151 self.index_to_key.insert(index, key.clone());
153
154 self.key_to_indices
156 .entry(key.clone())
157 .or_default()
158 .insert(index);
159
160 let key_to_indices = &self.key_to_indices;
162 let matches = &mut self.matches;
163 let entity_to_matches = &mut self.entity_to_matches;
164 let filter = &self.filter;
165 let weight = &self.weight;
166 let impact_type = self.impact_type;
167
168 let mut total = Sc::zero();
170 if let Some(others) = key_to_indices.get(&key) {
171 for &i in others {
173 if i == index {
174 continue;
175 }
176 for &j in others {
177 if j <= i || j == index {
178 continue;
179 }
180 for &k in others {
181 if k <= j || k == index {
182 continue;
183 }
184
185 let mut indices = [index, i, j, k];
187 indices.sort();
188 let [a_idx, b_idx, c_idx, d_idx] = indices;
189
190 let quad = (a_idx, b_idx, c_idx, d_idx);
191
192 if matches.contains(&quad) {
194 continue;
195 }
196
197 let a = &entities[a_idx];
198 let b = &entities[b_idx];
199 let c = &entities[c_idx];
200 let d = &entities[d_idx];
201
202 if filter(a, b, c, d) && matches.insert(quad) {
203 entity_to_matches.entry(a_idx).or_default().insert(quad);
204 entity_to_matches.entry(b_idx).or_default().insert(quad);
205 entity_to_matches.entry(c_idx).or_default().insert(quad);
206 entity_to_matches.entry(d_idx).or_default().insert(quad);
207 let base = weight(a, b, c, d);
208 let score = match impact_type {
209 ImpactType::Penalty => -base,
210 ImpactType::Reward => base,
211 };
212 total = total + score;
213 }
214 }
215 }
216 }
217 }
218
219 total
220 }
221
222 fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
224 if let Some(key) = self.index_to_key.remove(&index) {
226 if let Some(indices) = self.key_to_indices.get_mut(&key) {
227 indices.remove(&index);
228 if indices.is_empty() {
229 self.key_to_indices.remove(&key);
230 }
231 }
232 }
233
234 let Some(quads) = self.entity_to_matches.remove(&index) else {
236 return Sc::zero();
237 };
238
239 let mut total = Sc::zero();
240 for quad in quads {
241 self.matches.remove(&quad);
242
243 let (i, j, k, l) = quad;
245 for &other in &[i, j, k, l] {
246 if other != index {
247 if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
248 other_set.remove(&quad);
249 if other_set.is_empty() {
250 self.entity_to_matches.remove(&other);
251 }
252 }
253 }
254 }
255
256 if i < entities.len() && j < entities.len() && k < entities.len() && l < entities.len()
258 {
259 let score =
260 self.compute_score(&entities[i], &entities[j], &entities[k], &entities[l]);
261 total = total - score;
262 }
263 }
264
265 total
266 }
267}
268
269impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
270 for IncrementalQuadConstraint<S, A, K, E, KE, F, W, Sc>
271where
272 S: Send + Sync + 'static,
273 A: Clone + Debug + Send + Sync + 'static,
274 K: Eq + Hash + Clone + Send + Sync,
275 E: Fn(&S) -> &[A] + Send + Sync,
276 KE: Fn(&A) -> K + Send + Sync,
277 F: Fn(&A, &A, &A, &A) -> bool + Send + Sync,
278 W: Fn(&A, &A, &A, &A) -> Sc + Send + Sync,
279 Sc: Score,
280{
281 fn evaluate(&self, solution: &S) -> Sc {
282 let entities = (self.extractor)(solution);
283 let mut total = Sc::zero();
284
285 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
287 for (i, entity) in entities.iter().enumerate() {
288 let key = (self.key_extractor)(entity);
289 temp_index.entry(key).or_default().push(i);
290 }
291
292 for indices in temp_index.values() {
294 for pos_i in 0..indices.len() {
295 for pos_j in (pos_i + 1)..indices.len() {
296 for pos_k in (pos_j + 1)..indices.len() {
297 for pos_l in (pos_k + 1)..indices.len() {
298 let i = indices[pos_i];
299 let j = indices[pos_j];
300 let k = indices[pos_k];
301 let l = indices[pos_l];
302 let a = &entities[i];
303 let b = &entities[j];
304 let c = &entities[k];
305 let d = &entities[l];
306 if (self.filter)(a, b, c, d) {
307 total = total + self.compute_score(a, b, c, d);
308 }
309 }
310 }
311 }
312 }
313 }
314
315 total
316 }
317
318 fn match_count(&self, solution: &S) -> usize {
319 let entities = (self.extractor)(solution);
320 let mut count = 0;
321
322 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
324 for (i, entity) in entities.iter().enumerate() {
325 let key = (self.key_extractor)(entity);
326 temp_index.entry(key).or_default().push(i);
327 }
328
329 for indices in temp_index.values() {
331 for pos_i in 0..indices.len() {
332 for pos_j in (pos_i + 1)..indices.len() {
333 for pos_k in (pos_j + 1)..indices.len() {
334 for pos_l in (pos_k + 1)..indices.len() {
335 let i = indices[pos_i];
336 let j = indices[pos_j];
337 let k = indices[pos_k];
338 let l = indices[pos_l];
339 if (self.filter)(&entities[i], &entities[j], &entities[k], &entities[l])
340 {
341 count += 1;
342 }
343 }
344 }
345 }
346 }
347 }
348
349 count
350 }
351
352 fn initialize(&mut self, solution: &S) -> Sc {
353 self.reset();
354
355 let entities = (self.extractor)(solution);
356 let mut total = Sc::zero();
357 for i in 0..entities.len() {
358 total = total + self.insert_entity(entities, i);
359 }
360 total
361 }
362
363 fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
364 let entities = (self.extractor)(solution);
365 self.insert_entity(entities, entity_index)
366 }
367
368 fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
369 let entities = (self.extractor)(solution);
370 self.retract_entity(entities, entity_index)
371 }
372
373 fn reset(&mut self) {
374 self.entity_to_matches.clear();
375 self.matches.clear();
376 self.key_to_indices.clear();
377 self.index_to_key.clear();
378 }
379
380 fn name(&self) -> &str {
381 &self.constraint_ref.name
382 }
383
384 fn is_hard(&self) -> bool {
385 self.is_hard
386 }
387
388 fn constraint_ref(&self) -> ConstraintRef {
389 self.constraint_ref.clone()
390 }
391
392 fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
393 impl_get_matches_nary!(quad: self, solution)
394 }
395}
396
397impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
398 for IncrementalQuadConstraint<S, A, K, E, KE, F, W, Sc>
399{
400 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401 f.debug_struct("IncrementalQuadConstraint")
402 .field("name", &self.constraint_ref.name)
403 .field("impact_type", &self.impact_type)
404 .field("match_count", &self.matches.len())
405 .finish()
406 }
407}