Skip to main content

solverforge_scoring/constraint/
complemented.rs

1// Zero-erasure complemented group constraint.
2//
3// Evaluates grouped results plus complement entities with default values.
4// Provides true incremental scoring by tracking per-key accumulators.
5
6use std::collections::HashMap;
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::api::constraint_set::IncrementalConstraint;
14use crate::stream::collector::{Accumulator, UniCollector};
15
16// Zero-erasure constraint for complemented grouped results.
17//
18// Groups A entities by key, then iterates over B entities (complement source),
19// using grouped values where they exist and default values otherwise.
20//
21// The key function for A returns `Option<K>`, allowing entities to be skipped
22// when they don't have a valid key (e.g., unassigned shifts).
23//
24// # Type Parameters
25//
26// - `S` - Solution type
27// - `A` - Entity type being grouped (e.g., Shift)
28// - `B` - Complement entity type (e.g., Employee)
29// - `K` - Group key type
30// - `EA` - Extractor for A entities
31// - `EB` - Extractor for B entities
32// - `KA` - Key function for A (returns `Option<K>` to allow skipping)
33// - `KB` - Key function for B
34// - `C` - Collector type
35// - `D` - Default value function
36// - `W` - Weight function
37// - `Sc` - Score type
38//
39// # Example
40//
41// ```
42// use solverforge_scoring::constraint::complemented::ComplementedGroupConstraint;
43// use solverforge_scoring::stream::collector::count;
44// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
45// use solverforge_core::{ConstraintRef, ImpactType};
46// use solverforge_core::score::SimpleScore;
47//
48// #[derive(Clone, Hash, PartialEq, Eq)]
49// struct Employee { id: usize }
50//
51// #[derive(Clone)]
52// struct Shift { employee_id: Option<usize> }
53//
54// #[derive(Clone)]
55// struct Schedule {
56//     employees: Vec<Employee>,
57//     shifts: Vec<Shift>,
58// }
59//
60// let constraint = ComplementedGroupConstraint::new(
61//     ConstraintRef::new("", "Shift count"),
62//     ImpactType::Penalty,
63//     |s: &Schedule| s.shifts.as_slice(),
64//     |s: &Schedule| s.employees.as_slice(),
65//     |shift: &Shift| shift.employee_id,  // Returns Option<usize>
66//     |emp: &Employee| emp.id,
67//     count(),
68//     |_emp: &Employee| 0usize,
69//     |count: &usize| SimpleScore::of(*count as i64),
70//     false,
71// );
72//
73// let schedule = Schedule {
74//     employees: vec![Employee { id: 0 }, Employee { id: 1 }],
75//     shifts: vec![
76//         Shift { employee_id: Some(0) },
77//         Shift { employee_id: Some(0) },
78//         Shift { employee_id: None },  // Skipped - no key
79//     ],
80// };
81//
82// // Employee 0: 2 shifts, Employee 1: 0 shifts → Total: -2
83// // Unassigned shift is skipped
84// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
85// ```
86pub struct ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
87where
88    C: UniCollector<A>,
89    Sc: Score,
90{
91    constraint_ref: ConstraintRef,
92    impact_type: ImpactType,
93    extractor_a: EA,
94    extractor_b: EB,
95    key_a: KA,
96    key_b: KB,
97    collector: C,
98    default_fn: D,
99    weight_fn: W,
100    is_hard: bool,
101    // Group key -> accumulator for incremental scoring
102    groups: HashMap<K, C::Accumulator>,
103    // A entity index -> group key (for tracking which group each entity belongs to)
104    entity_groups: HashMap<usize, K>,
105    // A entity index -> extracted value (for correct retraction after entity mutation)
106    entity_values: HashMap<usize, C::Value>,
107    // B key -> B entity index (for looking up B entities by key)
108    b_by_key: HashMap<K, usize>,
109    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> Sc)>,
110}
111
112impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
113    ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
114where
115    S: 'static,
116    A: Clone + 'static,
117    B: Clone + 'static,
118    K: Clone + Eq + Hash,
119    EA: Fn(&S) -> &[A],
120    EB: Fn(&S) -> &[B],
121    KA: Fn(&A) -> Option<K>,
122    KB: Fn(&B) -> K,
123    C: UniCollector<A>,
124    C::Result: Clone,
125    D: Fn(&B) -> C::Result,
126    W: Fn(&C::Result) -> Sc,
127    Sc: Score,
128{
129    // Creates a new complemented group constraint.
130    #[allow(clippy::too_many_arguments)]
131    pub fn new(
132        constraint_ref: ConstraintRef,
133        impact_type: ImpactType,
134        extractor_a: EA,
135        extractor_b: EB,
136        key_a: KA,
137        key_b: KB,
138        collector: C,
139        default_fn: D,
140        weight_fn: W,
141        is_hard: bool,
142    ) -> Self {
143        Self {
144            constraint_ref,
145            impact_type,
146            extractor_a,
147            extractor_b,
148            key_a,
149            key_b,
150            collector,
151            default_fn,
152            weight_fn,
153            is_hard,
154            groups: HashMap::new(),
155            entity_groups: HashMap::new(),
156            entity_values: HashMap::new(),
157            b_by_key: HashMap::new(),
158            _phantom: PhantomData,
159        }
160    }
161
162    #[inline]
163    fn compute_score(&self, result: &C::Result) -> Sc {
164        let base = (self.weight_fn)(result);
165        match self.impact_type {
166            ImpactType::Penalty => -base,
167            ImpactType::Reward => base,
168        }
169    }
170
171    // Build grouped results from A entities.
172    fn build_groups(&self, entities_a: &[A]) -> HashMap<K, C::Result> {
173        let mut accumulators: HashMap<K, C::Accumulator> = HashMap::new();
174
175        for a in entities_a {
176            // Skip entities with no key (e.g., unassigned shifts)
177            let Some(key) = (self.key_a)(a) else {
178                continue;
179            };
180            let value = self.collector.extract(a);
181            accumulators
182                .entry(key)
183                .or_insert_with(|| self.collector.create_accumulator())
184                .accumulate(&value);
185        }
186
187        accumulators
188            .into_iter()
189            .map(|(k, acc)| (k, acc.finish()))
190            .collect()
191    }
192}
193
194impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> IncrementalConstraint<S, Sc>
195    for ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
196where
197    S: Send + Sync + 'static,
198    A: Clone + Send + Sync + 'static,
199    B: Clone + Send + Sync + 'static,
200    K: Clone + Eq + Hash + Send + Sync,
201    EA: Fn(&S) -> &[A] + Send + Sync,
202    EB: Fn(&S) -> &[B] + Send + Sync,
203    KA: Fn(&A) -> Option<K> + Send + Sync,
204    KB: Fn(&B) -> K + Send + Sync,
205    C: UniCollector<A> + Send + Sync,
206    C::Accumulator: Send + Sync,
207    C::Result: Clone + Send + Sync,
208    C::Value: Send + Sync,
209    D: Fn(&B) -> C::Result + Send + Sync,
210    W: Fn(&C::Result) -> Sc + Send + Sync,
211    Sc: Score,
212{
213    fn evaluate(&self, solution: &S) -> Sc {
214        let entities_a = (self.extractor_a)(solution);
215        let entities_b = (self.extractor_b)(solution);
216
217        let groups = self.build_groups(entities_a);
218
219        let mut total = Sc::zero();
220        for b in entities_b {
221            let key = (self.key_b)(b);
222            let result = groups
223                .get(&key)
224                .cloned()
225                .unwrap_or_else(|| (self.default_fn)(b));
226            total = total + self.compute_score(&result);
227        }
228
229        total
230    }
231
232    fn match_count(&self, solution: &S) -> usize {
233        let entities_b = (self.extractor_b)(solution);
234        entities_b.len()
235    }
236
237    fn initialize(&mut self, solution: &S) -> Sc {
238        self.reset();
239
240        let entities_a = (self.extractor_a)(solution);
241        let entities_b = (self.extractor_b)(solution);
242
243        // Build B key -> index mapping
244        for (idx, b) in entities_b.iter().enumerate() {
245            let key = (self.key_b)(b);
246            self.b_by_key.insert(key, idx);
247        }
248
249        // Initialize all B entities with default scores
250        let mut total = Sc::zero();
251        for b in entities_b {
252            let default_result = (self.default_fn)(b);
253            total = total + self.compute_score(&default_result);
254        }
255
256        // Now insert all A entities incrementally
257        for (idx, a) in entities_a.iter().enumerate() {
258            total = total + self.insert_entity(entities_b, idx, a);
259        }
260
261        total
262    }
263
264    fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
265        let entities_a = (self.extractor_a)(solution);
266        let entities_b = (self.extractor_b)(solution);
267
268        if entity_index >= entities_a.len() {
269            return Sc::zero();
270        }
271
272        let entity = &entities_a[entity_index];
273        self.insert_entity(entities_b, entity_index, entity)
274    }
275
276    fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
277        let entities_a = (self.extractor_a)(solution);
278        let entities_b = (self.extractor_b)(solution);
279
280        self.retract_entity(entities_a, entities_b, entity_index)
281    }
282
283    fn reset(&mut self) {
284        self.groups.clear();
285        self.entity_groups.clear();
286        self.entity_values.clear();
287        self.b_by_key.clear();
288    }
289
290    fn name(&self) -> &str {
291        &self.constraint_ref.name
292    }
293
294    fn is_hard(&self) -> bool {
295        self.is_hard
296    }
297
298    fn constraint_ref(&self) -> ConstraintRef {
299        self.constraint_ref.clone()
300    }
301}
302
303impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
304    ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
305where
306    S: Send + Sync + 'static,
307    A: Clone + Send + Sync + 'static,
308    B: Clone + Send + Sync + 'static,
309    K: Clone + Eq + Hash + Send + Sync,
310    EA: Fn(&S) -> &[A] + Send + Sync,
311    EB: Fn(&S) -> &[B] + Send + Sync,
312    KA: Fn(&A) -> Option<K> + Send + Sync,
313    KB: Fn(&B) -> K + Send + Sync,
314    C: UniCollector<A> + Send + Sync,
315    C::Accumulator: Send + Sync,
316    C::Result: Clone + Send + Sync,
317    C::Value: Send + Sync,
318    D: Fn(&B) -> C::Result + Send + Sync,
319    W: Fn(&C::Result) -> Sc + Send + Sync,
320    Sc: Score,
321{
322    // Insert an A entity and return the score delta.
323    fn insert_entity(&mut self, entities_b: &[B], entity_index: usize, entity: &A) -> Sc {
324        // Skip entities with no key (e.g., unassigned shifts)
325        let Some(key) = (self.key_a)(entity) else {
326            return Sc::zero();
327        };
328        let value = self.collector.extract(entity);
329        let impact = self.impact_type;
330
331        // Check if there's a B entity for this key
332        let b_idx = self.b_by_key.get(&key).copied();
333        let Some(b_idx) = b_idx else {
334            // No B entity for this key - A entity doesn't affect score
335            // Still track it for retraction
336            let acc = self
337                .groups
338                .entry(key.clone())
339                .or_insert_with(|| self.collector.create_accumulator());
340            acc.accumulate(&value);
341            self.entity_groups.insert(entity_index, key);
342            self.entity_values.insert(entity_index, value);
343            return Sc::zero();
344        };
345
346        let b = &entities_b[b_idx];
347
348        // Compute old score for this B entity
349        let old_result = self
350            .groups
351            .get(&key)
352            .map(|acc| acc.finish())
353            .unwrap_or_else(|| (self.default_fn)(b));
354        let old_base = (self.weight_fn)(&old_result);
355        let old = match impact {
356            ImpactType::Penalty => -old_base,
357            ImpactType::Reward => old_base,
358        };
359
360        // Get or create accumulator and add value
361        let acc = self
362            .groups
363            .entry(key.clone())
364            .or_insert_with(|| self.collector.create_accumulator());
365        acc.accumulate(&value);
366
367        // Compute new score
368        let new_result = acc.finish();
369        let new_base = (self.weight_fn)(&new_result);
370        let new_score = match impact {
371            ImpactType::Penalty => -new_base,
372            ImpactType::Reward => new_base,
373        };
374
375        // Track entity -> key mapping and cache value for correct retraction
376        self.entity_groups.insert(entity_index, key);
377        self.entity_values.insert(entity_index, value);
378
379        // Return delta
380        new_score - old
381    }
382
383    // Retract an A entity and return the score delta.
384    fn retract_entity(&mut self, _entities_a: &[A], _entities_b: &[B], entity_index: usize) -> Sc {
385        // Find which group this entity belonged to
386        let Some(key) = self.entity_groups.remove(&entity_index) else {
387            return Sc::zero();
388        };
389
390        // Use cached value (entity may have been mutated since insert)
391        let Some(value) = self.entity_values.remove(&entity_index) else {
392            return Sc::zero();
393        };
394        let impact = self.impact_type;
395
396        // Check if there's a B entity for this key
397        let b_idx = self.b_by_key.get(&key).copied();
398        if b_idx.is_none() {
399            // No B entity for this key - just update accumulator, no score delta
400            if let Some(acc) = self.groups.get_mut(&key) {
401                acc.retract(&value);
402            }
403            return Sc::zero();
404        }
405
406        // Get accumulator
407        let Some(acc) = self.groups.get_mut(&key) else {
408            return Sc::zero();
409        };
410
411        // Compute old score
412        let old_result = acc.finish();
413        let old_base = (self.weight_fn)(&old_result);
414        let old = match impact {
415            ImpactType::Penalty => -old_base,
416            ImpactType::Reward => old_base,
417        };
418
419        // Retract value
420        acc.retract(&value);
421
422        // Compute new score
423        let new_result = acc.finish();
424        let new_base = (self.weight_fn)(&new_result);
425        let new_score = match impact {
426            ImpactType::Penalty => -new_base,
427            ImpactType::Reward => new_base,
428        };
429
430        // Return delta
431        new_score - old
432    }
433}
434
435impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> std::fmt::Debug
436    for ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
437where
438    C: UniCollector<A>,
439    Sc: Score,
440{
441    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
442        f.debug_struct("ComplementedGroupConstraint")
443            .field("name", &self.constraint_ref.name)
444            .field("impact_type", &self.impact_type)
445            .field("groups", &self.groups.len())
446            .finish()
447    }
448}