Skip to main content

solverforge_scoring/constraint/
complemented.rs

1/* Zero-erasure complemented group constraint.
2
3Evaluates grouped results plus complement entities with default values.
4Provides true incremental scoring by tracking per-key accumulators.
5*/
6
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::constraint_set::IncrementalConstraint;
15use crate::stream::collector::{Accumulator, UniCollector};
16
17/* Zero-erasure constraint for complemented grouped results.
18
19Groups A entities by key, then iterates over B entities (complement source),
20using grouped values where they exist and default values otherwise.
21
22The key function for A returns `Option<K>`, allowing entities to be skipped
23when they don't have a valid key (e.g., unassigned shifts).
24
25# Type Parameters
26
27- `S` - Solution type
28- `A` - Entity type being grouped (e.g., Shift)
29- `B` - Complement entity type (e.g., Employee)
30- `K` - Group key type
31- `EA` - Extractor for A entities
32- `EB` - Extractor for B entities
33- `KA` - Key function for A (returns `Option<K>` to allow skipping)
34- `KB` - Key function for B
35- `C` - Collector type
36- `D` - Default value function
37- `W` - Weight function
38- `Sc` - Score type
39
40# Example
41
42```
43use solverforge_scoring::constraint::complemented::ComplementedGroupConstraint;
44use solverforge_scoring::stream::collector::count;
45use solverforge_scoring::api::constraint_set::IncrementalConstraint;
46use solverforge_core::{ConstraintRef, ImpactType};
47use solverforge_core::score::SoftScore;
48
49#[derive(Clone, Hash, PartialEq, Eq)]
50struct Employee { id: usize }
51
52#[derive(Clone)]
53struct Shift { employee_id: Option<usize> }
54
55#[derive(Clone)]
56struct Schedule {
57employees: Vec<Employee>,
58shifts: Vec<Shift>,
59}
60
61let constraint = ComplementedGroupConstraint::new(
62ConstraintRef::new("", "Shift count"),
63ImpactType::Penalty,
64|s: &Schedule| s.shifts.as_slice(),
65|s: &Schedule| s.employees.as_slice(),
66|shift: &Shift| shift.employee_id,  // Returns Option<usize>
67|emp: &Employee| emp.id,
68count(),
69|_emp: &Employee| 0usize,
70|count: &usize| SoftScore::of(*count as i64),
71false,
72);
73
74let schedule = Schedule {
75employees: vec![Employee { id: 0 }, Employee { id: 1 }],
76shifts: vec![
77Shift { employee_id: Some(0) },
78Shift { employee_id: Some(0) },
79Shift { employee_id: None },  // Skipped - no key
80],
81};
82
83// Employee 0: 2 shifts, Employee 1: 0 shifts → Total: -2
84// Unassigned shift is skipped
85assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
86```
87*/
88pub struct ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
89where
90    C: UniCollector<A>,
91    Sc: Score,
92{
93    constraint_ref: ConstraintRef,
94    impact_type: ImpactType,
95    extractor_a: EA,
96    extractor_b: EB,
97    key_a: KA,
98    key_b: KB,
99    collector: C,
100    default_fn: D,
101    weight_fn: W,
102    is_hard: bool,
103    // Group key -> accumulator for incremental scoring
104    groups: HashMap<K, C::Accumulator>,
105    // A entity index -> group key (for tracking which group each entity belongs to)
106    entity_groups: HashMap<usize, K>,
107    // A entity index -> extracted value (for correct retraction after entity mutation)
108    entity_values: HashMap<usize, C::Value>,
109    // B key -> B entity index (for looking up B entities by key)
110    b_by_key: HashMap<K, usize>,
111    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> Sc)>,
112}
113
114impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
115    ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
116where
117    S: 'static,
118    A: Clone + 'static,
119    B: Clone + 'static,
120    K: Clone + Eq + Hash,
121    EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
122    EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
123    KA: Fn(&A) -> Option<K>,
124    KB: Fn(&B) -> K,
125    C: UniCollector<A>,
126    C::Result: Clone,
127    D: Fn(&B) -> C::Result,
128    W: Fn(&C::Result) -> Sc,
129    Sc: Score,
130{
131    // Creates a new complemented group constraint.
132    #[allow(clippy::too_many_arguments)]
133    pub fn new(
134        constraint_ref: ConstraintRef,
135        impact_type: ImpactType,
136        extractor_a: EA,
137        extractor_b: EB,
138        key_a: KA,
139        key_b: KB,
140        collector: C,
141        default_fn: D,
142        weight_fn: W,
143        is_hard: bool,
144    ) -> Self {
145        Self {
146            constraint_ref,
147            impact_type,
148            extractor_a,
149            extractor_b,
150            key_a,
151            key_b,
152            collector,
153            default_fn,
154            weight_fn,
155            is_hard,
156            groups: HashMap::new(),
157            entity_groups: HashMap::new(),
158            entity_values: HashMap::new(),
159            b_by_key: HashMap::new(),
160            _phantom: PhantomData,
161        }
162    }
163
164    #[inline]
165    fn compute_score(&self, result: &C::Result) -> Sc {
166        let base = (self.weight_fn)(result);
167        match self.impact_type {
168            ImpactType::Penalty => -base,
169            ImpactType::Reward => base,
170        }
171    }
172
173    // Build grouped results from A entities.
174    fn build_groups(&self, entities_a: &[A]) -> HashMap<K, C::Result> {
175        let mut accumulators: HashMap<K, C::Accumulator> = HashMap::new();
176
177        for a in entities_a {
178            // Skip entities with no key (e.g., unassigned shifts)
179            let Some(key) = (self.key_a)(a) else {
180                continue;
181            };
182            let value = self.collector.extract(a);
183            accumulators
184                .entry(key)
185                .or_insert_with(|| self.collector.create_accumulator())
186                .accumulate(&value);
187        }
188
189        accumulators
190            .into_iter()
191            .map(|(k, acc)| (k, acc.finish()))
192            .collect()
193    }
194}
195
196impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> IncrementalConstraint<S, Sc>
197    for ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
198where
199    S: Send + Sync + 'static,
200    A: Clone + Send + Sync + 'static,
201    B: Clone + Send + Sync + 'static,
202    K: Clone + Eq + Hash + Send + Sync,
203    EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
204    EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
205    KA: Fn(&A) -> Option<K> + Send + Sync,
206    KB: Fn(&B) -> K + Send + Sync,
207    C: UniCollector<A> + Send + Sync,
208    C::Accumulator: Send + Sync,
209    C::Result: Clone + Send + Sync,
210    C::Value: Send + Sync,
211    D: Fn(&B) -> C::Result + Send + Sync,
212    W: Fn(&C::Result) -> Sc + Send + Sync,
213    Sc: Score,
214{
215    fn evaluate(&self, solution: &S) -> Sc {
216        let entities_a = self.extractor_a.extract(solution);
217        let entities_b = self.extractor_b.extract(solution);
218
219        let groups = self.build_groups(entities_a);
220
221        let mut total = Sc::zero();
222        for b in entities_b {
223            let key = (self.key_b)(b);
224            let result = groups
225                .get(&key)
226                .cloned()
227                .unwrap_or_else(|| (self.default_fn)(b));
228            total = total + self.compute_score(&result);
229        }
230
231        total
232    }
233
234    fn match_count(&self, solution: &S) -> usize {
235        let entities_b = self.extractor_b.extract(solution);
236        entities_b.len()
237    }
238
239    fn initialize(&mut self, solution: &S) -> Sc {
240        self.reset();
241
242        let entities_a = self.extractor_a.extract(solution);
243        let entities_b = self.extractor_b.extract(solution);
244
245        // Build B key -> index mapping
246        for (idx, b) in entities_b.iter().enumerate() {
247            let key = (self.key_b)(b);
248            self.b_by_key.insert(key, idx);
249        }
250
251        // Initialize all B entities with default scores
252        let mut total = Sc::zero();
253        for b in entities_b {
254            let default_result = (self.default_fn)(b);
255            total = total + self.compute_score(&default_result);
256        }
257
258        // Now insert all A entities incrementally
259        for (idx, a) in entities_a.iter().enumerate() {
260            total = total + self.insert_entity(entities_b, idx, a);
261        }
262
263        total
264    }
265
266    fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
267        let entities_a = self.extractor_a.extract(solution);
268        let entities_b = self.extractor_b.extract(solution);
269
270        if entity_index >= entities_a.len() {
271            return Sc::zero();
272        }
273
274        let entity = &entities_a[entity_index];
275        self.insert_entity(entities_b, entity_index, entity)
276    }
277
278    fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
279        let entities_a = self.extractor_a.extract(solution);
280        let entities_b = self.extractor_b.extract(solution);
281
282        self.retract_entity(entities_a, entities_b, entity_index)
283    }
284
285    fn reset(&mut self) {
286        self.groups.clear();
287        self.entity_groups.clear();
288        self.entity_values.clear();
289        self.b_by_key.clear();
290    }
291
292    fn name(&self) -> &str {
293        &self.constraint_ref.name
294    }
295
296    fn is_hard(&self) -> bool {
297        self.is_hard
298    }
299
300    fn constraint_ref(&self) -> ConstraintRef {
301        self.constraint_ref.clone()
302    }
303}
304
305impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
306    ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
307where
308    S: Send + Sync + 'static,
309    A: Clone + Send + Sync + 'static,
310    B: Clone + Send + Sync + 'static,
311    K: Clone + Eq + Hash + Send + Sync,
312    EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
313    EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
314    KA: Fn(&A) -> Option<K> + Send + Sync,
315    KB: Fn(&B) -> K + Send + Sync,
316    C: UniCollector<A> + Send + Sync,
317    C::Accumulator: Send + Sync,
318    C::Result: Clone + Send + Sync,
319    C::Value: Send + Sync,
320    D: Fn(&B) -> C::Result + Send + Sync,
321    W: Fn(&C::Result) -> Sc + Send + Sync,
322    Sc: Score,
323{
324    // Insert an A entity and return the score delta.
325    fn insert_entity(&mut self, entities_b: &[B], entity_index: usize, entity: &A) -> Sc {
326        // Skip entities with no key (e.g., unassigned shifts)
327        let Some(key) = (self.key_a)(entity) else {
328            return Sc::zero();
329        };
330        let value = self.collector.extract(entity);
331        let impact = self.impact_type;
332
333        // Check if there's a B entity for this key
334        let b_idx = self.b_by_key.get(&key).copied();
335        let Some(b_idx) = b_idx else {
336            // No B entity for this key - A entity doesn't affect score
337            // Still track it for retraction
338            let acc = self
339                .groups
340                .entry(key.clone())
341                .or_insert_with(|| self.collector.create_accumulator());
342            acc.accumulate(&value);
343            self.entity_groups.insert(entity_index, key);
344            self.entity_values.insert(entity_index, value);
345            return Sc::zero();
346        };
347
348        let b = &entities_b[b_idx];
349
350        // Compute old score for this B entity
351        let old_result = self
352            .groups
353            .get(&key)
354            .map(|acc| acc.finish())
355            .unwrap_or_else(|| (self.default_fn)(b));
356        let old_base = (self.weight_fn)(&old_result);
357        let old = match impact {
358            ImpactType::Penalty => -old_base,
359            ImpactType::Reward => old_base,
360        };
361
362        // Get or create accumulator and add value
363        let acc = self
364            .groups
365            .entry(key.clone())
366            .or_insert_with(|| self.collector.create_accumulator());
367        acc.accumulate(&value);
368
369        // Compute new score
370        let new_result = acc.finish();
371        let new_base = (self.weight_fn)(&new_result);
372        let new_score = match impact {
373            ImpactType::Penalty => -new_base,
374            ImpactType::Reward => new_base,
375        };
376
377        // Track entity -> key mapping and cache value for correct retraction
378        self.entity_groups.insert(entity_index, key);
379        self.entity_values.insert(entity_index, value);
380
381        // Return delta
382        new_score - old
383    }
384
385    // Retract an A entity and return the score delta.
386    fn retract_entity(&mut self, _entities_a: &[A], _entities_b: &[B], entity_index: usize) -> Sc {
387        // Find which group this entity belonged to
388        let Some(key) = self.entity_groups.remove(&entity_index) else {
389            return Sc::zero();
390        };
391
392        // Use cached value (entity may have been mutated since insert)
393        let Some(value) = self.entity_values.remove(&entity_index) else {
394            return Sc::zero();
395        };
396        let impact = self.impact_type;
397
398        // Check if there's a B entity for this key
399        let b_idx = self.b_by_key.get(&key).copied();
400        if b_idx.is_none() {
401            // No B entity for this key - just update accumulator, no score delta
402            if let Some(acc) = self.groups.get_mut(&key) {
403                acc.retract(&value);
404            }
405            return Sc::zero();
406        }
407
408        // Get accumulator
409        let Some(acc) = self.groups.get_mut(&key) else {
410            return Sc::zero();
411        };
412
413        // Compute old score
414        let old_result = acc.finish();
415        let old_base = (self.weight_fn)(&old_result);
416        let old = match impact {
417            ImpactType::Penalty => -old_base,
418            ImpactType::Reward => old_base,
419        };
420
421        // Retract value
422        acc.retract(&value);
423
424        // Compute new score
425        let new_result = acc.finish();
426        let new_base = (self.weight_fn)(&new_result);
427        let new_score = match impact {
428            ImpactType::Penalty => -new_base,
429            ImpactType::Reward => new_base,
430        };
431
432        // Return delta
433        new_score - old
434    }
435}
436
437impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> std::fmt::Debug
438    for ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
439where
440    C: UniCollector<A>,
441    Sc: Score,
442{
443    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444        f.debug_struct("ComplementedGroupConstraint")
445            .field("name", &self.constraint_ref.name)
446            .field("impact_type", &self.impact_type)
447            .field("groups", &self.groups.len())
448            .finish()
449    }
450}