1use std::collections::HashMap;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::constraint_set::IncrementalConstraint;
15use crate::stream::collector::{Accumulator, UniCollector};
16
17pub struct GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
78where
79 C: UniCollector<A>,
80 Sc: Score,
81{
82 constraint_ref: ConstraintRef,
83 impact_type: ImpactType,
84 extractor: E,
85 key_fn: KF,
86 collector: C,
87 weight_fn: W,
88 is_hard: bool,
89 groups: HashMap<K, C::Accumulator>,
91 entity_groups: HashMap<usize, K>,
93 entity_values: HashMap<usize, C::Value>,
95 _phantom: PhantomData<(S, A, Sc)>,
96}
97
98impl<S, A, K, E, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
99where
100 S: Send + Sync + 'static,
101 A: Clone + Send + Sync + 'static,
102 K: Clone + Eq + Hash + Send + Sync + 'static,
103 E: Fn(&S) -> &[A] + Send + Sync,
104 KF: Fn(&A) -> K + Send + Sync,
105 C: UniCollector<A> + Send + Sync + 'static,
106 C::Accumulator: Send + Sync,
107 C::Result: Send + Sync,
108 W: Fn(&C::Result) -> Sc + Send + Sync,
109 Sc: Score + 'static,
110{
111 pub fn new(
123 constraint_ref: ConstraintRef,
124 impact_type: ImpactType,
125 extractor: E,
126 key_fn: KF,
127 collector: C,
128 weight_fn: W,
129 is_hard: bool,
130 ) -> Self {
131 Self {
132 constraint_ref,
133 impact_type,
134 extractor,
135 key_fn,
136 collector,
137 weight_fn,
138 is_hard,
139 groups: HashMap::new(),
140 entity_groups: HashMap::new(),
141 entity_values: HashMap::new(),
142 _phantom: PhantomData,
143 }
144 }
145
146 fn compute_score(&self, result: &C::Result) -> Sc {
148 let base = (self.weight_fn)(result);
149 match self.impact_type {
150 ImpactType::Penalty => -base,
151 ImpactType::Reward => base,
152 }
153 }
154}
155
156impl<S, A, K, E, KF, C, W, Sc> IncrementalConstraint<S, Sc>
157 for GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
158where
159 S: Send + Sync + 'static,
160 A: Clone + Send + Sync + 'static,
161 K: Clone + Eq + Hash + Send + Sync + 'static,
162 E: Fn(&S) -> &[A] + Send + Sync,
163 KF: Fn(&A) -> K + Send + Sync,
164 C: UniCollector<A> + Send + Sync + 'static,
165 C::Accumulator: Send + Sync,
166 C::Result: Send + Sync,
167 C::Value: Send + Sync,
168 W: Fn(&C::Result) -> Sc + Send + Sync,
169 Sc: Score + 'static,
170{
171 fn evaluate(&self, solution: &S) -> Sc {
172 let entities = (self.extractor)(solution);
173
174 let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
176
177 for entity in entities {
178 let key = (self.key_fn)(entity);
179 let value = self.collector.extract(entity);
180 let acc = groups
181 .entry(key)
182 .or_insert_with(|| self.collector.create_accumulator());
183 acc.accumulate(&value);
184 }
185
186 let mut total = Sc::zero();
188 for acc in groups.values() {
189 let result = acc.finish();
190 total = total + self.compute_score(&result);
191 }
192
193 total
194 }
195
196 fn match_count(&self, solution: &S) -> usize {
197 let entities = (self.extractor)(solution);
198
199 let mut groups: HashMap<K, ()> = HashMap::new();
201 for entity in entities {
202 let key = (self.key_fn)(entity);
203 groups.insert(key, ());
204 }
205
206 groups.len()
207 }
208
209 fn initialize(&mut self, solution: &S) -> Sc {
210 self.reset();
211
212 let entities = (self.extractor)(solution);
213 let mut total = Sc::zero();
214
215 for (idx, entity) in entities.iter().enumerate() {
216 total = total + self.insert_entity(entities, idx, entity);
217 }
218
219 total
220 }
221
222 fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
223 let entities = (self.extractor)(solution);
224 if entity_index >= entities.len() {
225 return Sc::zero();
226 }
227
228 let entity = &entities[entity_index];
229 self.insert_entity(entities, entity_index, entity)
230 }
231
232 fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
233 let entities = (self.extractor)(solution);
234 self.retract_entity(entities, entity_index)
235 }
236
237 fn reset(&mut self) {
238 self.groups.clear();
239 self.entity_groups.clear();
240 self.entity_values.clear();
241 }
242
243 fn name(&self) -> &str {
244 &self.constraint_ref.name
245 }
246
247 fn is_hard(&self) -> bool {
248 self.is_hard
249 }
250
251 fn constraint_ref(&self) -> ConstraintRef {
252 self.constraint_ref.clone()
253 }
254}
255
256impl<S, A, K, E, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
257where
258 S: Send + Sync + 'static,
259 A: Clone + Send + Sync + 'static,
260 K: Clone + Eq + Hash + Send + Sync + 'static,
261 E: Fn(&S) -> &[A] + Send + Sync,
262 KF: Fn(&A) -> K + Send + Sync,
263 C: UniCollector<A> + Send + Sync + 'static,
264 C::Accumulator: Send + Sync,
265 C::Result: Send + Sync,
266 C::Value: Send + Sync,
267 W: Fn(&C::Result) -> Sc + Send + Sync,
268 Sc: Score + 'static,
269{
270 fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
271 let key = (self.key_fn)(entity);
272 let value = self.collector.extract(entity);
273 let impact = self.impact_type;
274
275 let acc = self
277 .groups
278 .entry(key.clone())
279 .or_insert_with(|| self.collector.create_accumulator());
280
281 let old_base = (self.weight_fn)(&acc.finish());
283 let old = match impact {
284 ImpactType::Penalty => -old_base,
285 ImpactType::Reward => old_base,
286 };
287
288 acc.accumulate(&value);
290 let new_base = (self.weight_fn)(&acc.finish());
291 let new_score = match impact {
292 ImpactType::Penalty => -new_base,
293 ImpactType::Reward => new_base,
294 };
295
296 self.entity_groups.insert(entity_index, key);
298 self.entity_values.insert(entity_index, value);
299
300 new_score - old
302 }
303
304 fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
305 let Some(key) = self.entity_groups.remove(&entity_index) else {
307 return Sc::zero();
308 };
309
310 let Some(value) = self.entity_values.remove(&entity_index) else {
312 return Sc::zero();
313 };
314 let impact = self.impact_type;
315
316 let Some(acc) = self.groups.get_mut(&key) else {
318 return Sc::zero();
319 };
320
321 let old_base = (self.weight_fn)(&acc.finish());
323 let old = match impact {
324 ImpactType::Penalty => -old_base,
325 ImpactType::Reward => old_base,
326 };
327
328 acc.retract(&value);
330 let new_base = (self.weight_fn)(&acc.finish());
331 let new_score = match impact {
332 ImpactType::Penalty => -new_base,
333 ImpactType::Reward => new_base,
334 };
335
336 new_score - old
338 }
339}
340
341impl<S, A, K, E, KF, C, W, Sc> std::fmt::Debug for GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
342where
343 C: UniCollector<A>,
344 Sc: Score,
345{
346 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347 f.debug_struct("GroupedUniConstraint")
348 .field("name", &self.constraint_ref.name)
349 .field("impact_type", &self.impact_type)
350 .field("groups", &self.groups.len())
351 .finish()
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use crate::stream::collector::count;
359 use solverforge_core::score::SimpleScore;
360
361 #[derive(Clone, Hash, PartialEq, Eq)]
362 struct Shift {
363 employee_id: usize,
364 }
365
366 #[derive(Clone)]
367 struct Solution {
368 shifts: Vec<Shift>,
369 }
370
371 #[test]
372 fn test_grouped_constraint_evaluate() {
373 let constraint = GroupedUniConstraint::new(
374 ConstraintRef::new("", "Workload"),
375 ImpactType::Penalty,
376 |s: &Solution| &s.shifts,
377 |shift: &Shift| shift.employee_id,
378 count::<Shift>(),
379 |count: &usize| SimpleScore::of((*count * *count) as i64),
380 false,
381 );
382
383 let solution = Solution {
384 shifts: vec![
385 Shift { employee_id: 1 },
386 Shift { employee_id: 1 },
387 Shift { employee_id: 1 },
388 Shift { employee_id: 2 },
389 ],
390 };
391
392 assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-10));
396 }
397
398 #[test]
399 fn test_grouped_constraint_incremental() {
400 let mut constraint = GroupedUniConstraint::new(
401 ConstraintRef::new("", "Workload"),
402 ImpactType::Penalty,
403 |s: &Solution| &s.shifts,
404 |shift: &Shift| shift.employee_id,
405 count::<Shift>(),
406 |count: &usize| SimpleScore::of(*count as i64),
407 false,
408 );
409
410 let solution = Solution {
411 shifts: vec![
412 Shift { employee_id: 1 },
413 Shift { employee_id: 1 },
414 Shift { employee_id: 2 },
415 ],
416 };
417
418 let total = constraint.initialize(&solution);
420 assert_eq!(total, SimpleScore::of(-3));
424
425 let delta = constraint.on_retract(&solution, 0);
427 assert_eq!(delta, SimpleScore::of(1));
429
430 let delta = constraint.on_insert(&solution, 0);
432 assert_eq!(delta, SimpleScore::of(-1));
434 }
435
436 #[test]
437 fn test_grouped_constraint_reward() {
438 let constraint = GroupedUniConstraint::new(
439 ConstraintRef::new("", "Collaboration"),
440 ImpactType::Reward,
441 |s: &Solution| &s.shifts,
442 |shift: &Shift| shift.employee_id,
443 count::<Shift>(),
444 |count: &usize| SimpleScore::of(*count as i64),
445 false,
446 );
447
448 let solution = Solution {
449 shifts: vec![Shift { employee_id: 1 }, Shift { employee_id: 1 }],
450 };
451
452 assert_eq!(constraint.evaluate(&solution), SimpleScore::of(2));
454 }
455}