1use std::collections::{HashMap, HashSet};
7use std::fmt::Debug;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::analysis::DetailedConstraintMatch;
15use crate::api::constraint_set::IncrementalConstraint;
16
17pub struct IncrementalTriConstraint<S, A, K, E, KE, F, W, Sc>
74where
75 Sc: Score,
76{
77 constraint_ref: ConstraintRef,
78 impact_type: ImpactType,
79 extractor: E,
80 key_extractor: KE,
81 filter: F,
82 weight: W,
83 is_hard: bool,
84 entity_to_matches: HashMap<usize, HashSet<(usize, usize, usize)>>,
86 matches: HashSet<(usize, usize, usize)>,
88 key_to_indices: HashMap<K, HashSet<usize>>,
90 index_to_key: HashMap<usize, K>,
92 _phantom: PhantomData<(S, A, Sc)>,
93}
94
95impl<S, A, K, E, KE, F, W, Sc> IncrementalTriConstraint<S, A, K, E, KE, F, W, Sc>
96where
97 S: 'static,
98 A: Clone + 'static,
99 K: Eq + Hash + Clone,
100 E: Fn(&S) -> &[A],
101 KE: Fn(&A) -> K,
102 F: Fn(&A, &A, &A) -> bool,
103 W: Fn(&A, &A, &A) -> Sc,
104 Sc: Score,
105{
106 pub fn new(
108 constraint_ref: ConstraintRef,
109 impact_type: ImpactType,
110 extractor: E,
111 key_extractor: KE,
112 filter: F,
113 weight: W,
114 is_hard: bool,
115 ) -> Self {
116 Self {
117 constraint_ref,
118 impact_type,
119 extractor,
120 key_extractor,
121 filter,
122 weight,
123 is_hard,
124 entity_to_matches: HashMap::new(),
125 matches: HashSet::new(),
126 key_to_indices: HashMap::new(),
127 index_to_key: HashMap::new(),
128 _phantom: PhantomData,
129 }
130 }
131
132 #[inline]
133 fn compute_score(&self, a: &A, b: &A, c: &A) -> Sc {
134 let base = (self.weight)(a, b, c);
135 match self.impact_type {
136 ImpactType::Penalty => -base,
137 ImpactType::Reward => base,
138 }
139 }
140
141 fn insert_entity(&mut self, entities: &[A], index: usize) -> Sc {
143 if index >= entities.len() {
144 return Sc::zero();
145 }
146
147 let entity = &entities[index];
148 let key = (self.key_extractor)(entity);
149
150 self.index_to_key.insert(index, key.clone());
152
153 self.key_to_indices
155 .entry(key.clone())
156 .or_default()
157 .insert(index);
158
159 let key_to_indices = &self.key_to_indices;
161 let matches = &mut self.matches;
162 let entity_to_matches = &mut self.entity_to_matches;
163 let filter = &self.filter;
164 let weight = &self.weight;
165 let impact_type = self.impact_type;
166
167 let mut total = Sc::zero();
169 if let Some(others) = key_to_indices.get(&key) {
170 for &i in others {
172 if i == index {
173 continue;
174 }
175 for &j in others {
176 if j <= i || j == index {
178 continue;
179 }
180
181 let mut indices = [index, i, j];
183 indices.sort();
184 let [a_idx, b_idx, c_idx] = indices;
185
186 let triple = (a_idx, b_idx, c_idx);
187
188 if matches.contains(&triple) {
190 continue;
191 }
192
193 let a = &entities[a_idx];
194 let b = &entities[b_idx];
195 let c = &entities[c_idx];
196
197 if filter(a, b, c) && matches.insert(triple) {
198 entity_to_matches.entry(a_idx).or_default().insert(triple);
199 entity_to_matches.entry(b_idx).or_default().insert(triple);
200 entity_to_matches.entry(c_idx).or_default().insert(triple);
201 let base = weight(a, b, c);
202 let score = match impact_type {
203 ImpactType::Penalty => -base,
204 ImpactType::Reward => base,
205 };
206 total = total + score;
207 }
208 }
209 }
210 }
211
212 total
213 }
214
215 fn retract_entity(&mut self, entities: &[A], index: usize) -> Sc {
217 if let Some(key) = self.index_to_key.remove(&index) {
219 if let Some(indices) = self.key_to_indices.get_mut(&key) {
220 indices.remove(&index);
221 if indices.is_empty() {
222 self.key_to_indices.remove(&key);
223 }
224 }
225 }
226
227 let Some(triples) = self.entity_to_matches.remove(&index) else {
229 return Sc::zero();
230 };
231
232 let mut total = Sc::zero();
233 for triple in triples {
234 self.matches.remove(&triple);
235
236 let (i, j, k) = triple;
238 for &other in &[i, j, k] {
239 if other != index {
240 if let Some(other_set) = self.entity_to_matches.get_mut(&other) {
241 other_set.remove(&triple);
242 if other_set.is_empty() {
243 self.entity_to_matches.remove(&other);
244 }
245 }
246 }
247 }
248
249 if i < entities.len() && j < entities.len() && k < entities.len() {
251 let score = self.compute_score(&entities[i], &entities[j], &entities[k]);
252 total = total - score;
253 }
254 }
255
256 total
257 }
258}
259
260impl<S, A, K, E, KE, F, W, Sc> IncrementalConstraint<S, Sc>
261 for IncrementalTriConstraint<S, A, K, E, KE, F, W, Sc>
262where
263 S: Send + Sync + 'static,
264 A: Clone + Debug + Send + Sync + 'static,
265 K: Eq + Hash + Clone + Send + Sync,
266 E: Fn(&S) -> &[A] + Send + Sync,
267 KE: Fn(&A) -> K + Send + Sync,
268 F: Fn(&A, &A, &A) -> bool + Send + Sync,
269 W: Fn(&A, &A, &A) -> Sc + Send + Sync,
270 Sc: Score,
271{
272 fn evaluate(&self, solution: &S) -> Sc {
273 let entities = (self.extractor)(solution);
274 let mut total = Sc::zero();
275
276 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
278 for (i, entity) in entities.iter().enumerate() {
279 let key = (self.key_extractor)(entity);
280 temp_index.entry(key).or_default().push(i);
281 }
282
283 for indices in temp_index.values() {
285 for pos_i in 0..indices.len() {
286 for pos_j in (pos_i + 1)..indices.len() {
287 for pos_k in (pos_j + 1)..indices.len() {
288 let i = indices[pos_i];
289 let j = indices[pos_j];
290 let k = indices[pos_k];
291 let a = &entities[i];
292 let b = &entities[j];
293 let c = &entities[k];
294 if (self.filter)(a, b, c) {
295 total = total + self.compute_score(a, b, c);
296 }
297 }
298 }
299 }
300 }
301
302 total
303 }
304
305 fn match_count(&self, solution: &S) -> usize {
306 let entities = (self.extractor)(solution);
307 let mut count = 0;
308
309 let mut temp_index: HashMap<K, Vec<usize>> = HashMap::new();
311 for (i, entity) in entities.iter().enumerate() {
312 let key = (self.key_extractor)(entity);
313 temp_index.entry(key).or_default().push(i);
314 }
315
316 for indices in temp_index.values() {
318 for pos_i in 0..indices.len() {
319 for pos_j in (pos_i + 1)..indices.len() {
320 for pos_k in (pos_j + 1)..indices.len() {
321 let i = indices[pos_i];
322 let j = indices[pos_j];
323 let k = indices[pos_k];
324 if (self.filter)(&entities[i], &entities[j], &entities[k]) {
325 count += 1;
326 }
327 }
328 }
329 }
330 }
331
332 count
333 }
334
335 fn initialize(&mut self, solution: &S) -> Sc {
336 self.reset();
337
338 let entities = (self.extractor)(solution);
339 let mut total = Sc::zero();
340 for i in 0..entities.len() {
341 total = total + self.insert_entity(entities, i);
342 }
343 total
344 }
345
346 fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
347 let entities = (self.extractor)(solution);
348 self.insert_entity(entities, entity_index)
349 }
350
351 fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
352 let entities = (self.extractor)(solution);
353 self.retract_entity(entities, entity_index)
354 }
355
356 fn reset(&mut self) {
357 self.entity_to_matches.clear();
358 self.matches.clear();
359 self.key_to_indices.clear();
360 self.index_to_key.clear();
361 }
362
363 fn name(&self) -> &str {
364 &self.constraint_ref.name
365 }
366
367 fn is_hard(&self) -> bool {
368 self.is_hard
369 }
370
371 fn constraint_ref(&self) -> ConstraintRef {
372 self.constraint_ref.clone()
373 }
374
375 fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
376 impl_get_matches_nary!(tri: self, solution)
377 }
378}
379
380impl<S, A, K, E, KE, F, W, Sc: Score> std::fmt::Debug
381 for IncrementalTriConstraint<S, A, K, E, KE, F, W, Sc>
382{
383 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384 f.debug_struct("IncrementalTriConstraint")
385 .field("name", &self.constraint_ref.name)
386 .field("impact_type", &self.impact_type)
387 .field("match_count", &self.matches.len())
388 .finish()
389 }
390}