solverforge_scoring/constraint/
grouped.rs

1//! Zero-erasure grouped constraint for group-by operations.
2//!
3//! Provides incremental scoring for constraints that group entities and
4//! apply collectors to compute aggregate scores.
5//! All type information is preserved at compile time - no Arc, no dyn.
6
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::constraint_set::IncrementalConstraint;
15use crate::stream::collector::{Accumulator, UniCollector};
16
17/// Zero-erasure constraint that groups entities by key and scores based on collector results.
18///
19/// This enables incremental scoring for group-by operations:
20/// - Tracks which entities belong to which group
21/// - Maintains collector state per group
22/// - Computes score deltas when entities are added/removed
23///
24/// All type parameters are concrete - no trait objects, no Arc allocations.
25///
26/// # Type Parameters
27///
28/// - `S` - Solution type
29/// - `A` - Entity type
30/// - `K` - Group key type
31/// - `E` - Extractor function for entities
32/// - `KF` - Key function
33/// - `C` - Collector type
34/// - `W` - Weight function
35/// - `Sc` - Score type
36///
37/// # Example
38///
39/// ```
40/// use solverforge_scoring::constraint::grouped::GroupedUniConstraint;
41/// use solverforge_scoring::stream::collector::count;
42/// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
43/// use solverforge_core::{ConstraintRef, ImpactType};
44/// use solverforge_core::score::SimpleScore;
45///
46/// #[derive(Clone, Hash, PartialEq, Eq)]
47/// struct Shift { employee_id: usize }
48///
49/// #[derive(Clone)]
50/// struct Solution { shifts: Vec<Shift> }
51///
52/// // Penalize based on squared workload per employee
53/// let constraint = GroupedUniConstraint::new(
54///     ConstraintRef::new("", "Balanced workload"),
55///     ImpactType::Penalty,
56///     |s: &Solution| &s.shifts,
57///     |shift: &Shift| shift.employee_id,
58///     count::<Shift>(),
59///     |count: &usize| SimpleScore::of((*count * *count) as i64),
60///     false,
61/// );
62///
63/// let solution = Solution {
64///     shifts: vec![
65///         Shift { employee_id: 1 },
66///         Shift { employee_id: 1 },
67///         Shift { employee_id: 1 },
68///         Shift { employee_id: 2 },
69///     ],
70/// };
71///
72/// // Employee 1: 3 shifts -> 9 penalty
73/// // Employee 2: 1 shift -> 1 penalty
74/// // Total: -10
75/// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-10));
76/// ```
77pub struct GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
78where
79    C: UniCollector<A>,
80    Sc: Score,
81{
82    constraint_ref: ConstraintRef,
83    impact_type: ImpactType,
84    extractor: E,
85    key_fn: KF,
86    collector: C,
87    weight_fn: W,
88    is_hard: bool,
89    /// Group key -> accumulator (scores computed on-the-fly, no cloning)
90    groups: HashMap<K, C::Accumulator>,
91    /// Entity index -> group key (for tracking which group an entity belongs to)
92    entity_groups: HashMap<usize, K>,
93    /// Entity index -> extracted value (for correct retraction after entity mutation)
94    entity_values: HashMap<usize, C::Value>,
95    _phantom: PhantomData<(S, A, Sc)>,
96}
97
98impl<S, A, K, E, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
99where
100    S: Send + Sync + 'static,
101    A: Clone + Send + Sync + 'static,
102    K: Clone + Eq + Hash + Send + Sync + 'static,
103    E: Fn(&S) -> &[A] + Send + Sync,
104    KF: Fn(&A) -> K + Send + Sync,
105    C: UniCollector<A> + Send + Sync + 'static,
106    C::Accumulator: Send + Sync,
107    C::Result: Send + Sync,
108    W: Fn(&C::Result) -> Sc + Send + Sync,
109    Sc: Score + 'static,
110{
111    /// Creates a new zero-erasure grouped constraint.
112    ///
113    /// # Arguments
114    ///
115    /// * `constraint_ref` - Identifier for this constraint
116    /// * `impact_type` - Whether to penalize or reward
117    /// * `extractor` - Function to get entity slice from solution
118    /// * `key_fn` - Function to extract group key from entity
119    /// * `collector` - Collector to aggregate entities per group
120    /// * `weight_fn` - Function to compute score from collector result
121    /// * `is_hard` - Whether this is a hard constraint
122    pub fn new(
123        constraint_ref: ConstraintRef,
124        impact_type: ImpactType,
125        extractor: E,
126        key_fn: KF,
127        collector: C,
128        weight_fn: W,
129        is_hard: bool,
130    ) -> Self {
131        Self {
132            constraint_ref,
133            impact_type,
134            extractor,
135            key_fn,
136            collector,
137            weight_fn,
138            is_hard,
139            groups: HashMap::new(),
140            entity_groups: HashMap::new(),
141            entity_values: HashMap::new(),
142            _phantom: PhantomData,
143        }
144    }
145
146    /// Computes the score contribution for a group's result.
147    fn compute_score(&self, result: &C::Result) -> Sc {
148        let base = (self.weight_fn)(result);
149        match self.impact_type {
150            ImpactType::Penalty => -base,
151            ImpactType::Reward => base,
152        }
153    }
154}
155
156impl<S, A, K, E, KF, C, W, Sc> IncrementalConstraint<S, Sc>
157    for GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
158where
159    S: Send + Sync + 'static,
160    A: Clone + Send + Sync + 'static,
161    K: Clone + Eq + Hash + Send + Sync + 'static,
162    E: Fn(&S) -> &[A] + Send + Sync,
163    KF: Fn(&A) -> K + Send + Sync,
164    C: UniCollector<A> + Send + Sync + 'static,
165    C::Accumulator: Send + Sync,
166    C::Result: Send + Sync,
167    C::Value: Send + Sync,
168    W: Fn(&C::Result) -> Sc + Send + Sync,
169    Sc: Score + 'static,
170{
171    fn evaluate(&self, solution: &S) -> Sc {
172        let entities = (self.extractor)(solution);
173
174        // Group entities by key
175        let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
176
177        for entity in entities {
178            let key = (self.key_fn)(entity);
179            let value = self.collector.extract(entity);
180            let acc = groups
181                .entry(key)
182                .or_insert_with(|| self.collector.create_accumulator());
183            acc.accumulate(&value);
184        }
185
186        // Sum scores for all groups
187        let mut total = Sc::zero();
188        for acc in groups.values() {
189            let result = acc.finish();
190            total = total + self.compute_score(&result);
191        }
192
193        total
194    }
195
196    fn match_count(&self, solution: &S) -> usize {
197        let entities = (self.extractor)(solution);
198
199        // Count unique groups
200        let mut groups: HashMap<K, ()> = HashMap::new();
201        for entity in entities {
202            let key = (self.key_fn)(entity);
203            groups.insert(key, ());
204        }
205
206        groups.len()
207    }
208
209    fn initialize(&mut self, solution: &S) -> Sc {
210        self.reset();
211
212        let entities = (self.extractor)(solution);
213        let mut total = Sc::zero();
214
215        for (idx, entity) in entities.iter().enumerate() {
216            total = total + self.insert_entity(entities, idx, entity);
217        }
218
219        total
220    }
221
222    fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
223        let entities = (self.extractor)(solution);
224        if entity_index >= entities.len() {
225            return Sc::zero();
226        }
227
228        let entity = &entities[entity_index];
229        self.insert_entity(entities, entity_index, entity)
230    }
231
232    fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
233        let entities = (self.extractor)(solution);
234        self.retract_entity(entities, entity_index)
235    }
236
237    fn reset(&mut self) {
238        self.groups.clear();
239        self.entity_groups.clear();
240        self.entity_values.clear();
241    }
242
243    fn name(&self) -> &str {
244        &self.constraint_ref.name
245    }
246
247    fn is_hard(&self) -> bool {
248        self.is_hard
249    }
250
251    fn constraint_ref(&self) -> ConstraintRef {
252        self.constraint_ref.clone()
253    }
254}
255
256impl<S, A, K, E, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
257where
258    S: Send + Sync + 'static,
259    A: Clone + Send + Sync + 'static,
260    K: Clone + Eq + Hash + Send + Sync + 'static,
261    E: Fn(&S) -> &[A] + Send + Sync,
262    KF: Fn(&A) -> K + Send + Sync,
263    C: UniCollector<A> + Send + Sync + 'static,
264    C::Accumulator: Send + Sync,
265    C::Result: Send + Sync,
266    C::Value: Send + Sync,
267    W: Fn(&C::Result) -> Sc + Send + Sync,
268    Sc: Score + 'static,
269{
270    fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
271        let key = (self.key_fn)(entity);
272        let value = self.collector.extract(entity);
273        let impact = self.impact_type;
274
275        // Get or create group accumulator
276        let acc = self
277            .groups
278            .entry(key.clone())
279            .or_insert_with(|| self.collector.create_accumulator());
280
281        // Compute old score from current state (inlined to avoid borrow conflict)
282        let old_base = (self.weight_fn)(&acc.finish());
283        let old = match impact {
284            ImpactType::Penalty => -old_base,
285            ImpactType::Reward => old_base,
286        };
287
288        // Accumulate and compute new score
289        acc.accumulate(&value);
290        let new_base = (self.weight_fn)(&acc.finish());
291        let new_score = match impact {
292            ImpactType::Penalty => -new_base,
293            ImpactType::Reward => new_base,
294        };
295
296        // Track entity -> group mapping and cache value for correct retraction
297        self.entity_groups.insert(entity_index, key);
298        self.entity_values.insert(entity_index, value);
299
300        // Return delta (both scores computed fresh, no cloning)
301        new_score - old
302    }
303
304    fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
305        // Find which group this entity belonged to
306        let Some(key) = self.entity_groups.remove(&entity_index) else {
307            return Sc::zero();
308        };
309
310        // Use cached value (entity may have been mutated since insert)
311        let Some(value) = self.entity_values.remove(&entity_index) else {
312            return Sc::zero();
313        };
314        let impact = self.impact_type;
315
316        // Get the group accumulator
317        let Some(acc) = self.groups.get_mut(&key) else {
318            return Sc::zero();
319        };
320
321        // Compute old score from current state (inlined to avoid borrow conflict)
322        let old_base = (self.weight_fn)(&acc.finish());
323        let old = match impact {
324            ImpactType::Penalty => -old_base,
325            ImpactType::Reward => old_base,
326        };
327
328        // Retract and compute new score
329        acc.retract(&value);
330        let new_base = (self.weight_fn)(&acc.finish());
331        let new_score = match impact {
332            ImpactType::Penalty => -new_base,
333            ImpactType::Reward => new_base,
334        };
335
336        // Return delta (both scores computed fresh, no cloning)
337        new_score - old
338    }
339}
340
341impl<S, A, K, E, KF, C, W, Sc> std::fmt::Debug for GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
342where
343    C: UniCollector<A>,
344    Sc: Score,
345{
346    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347        f.debug_struct("GroupedUniConstraint")
348            .field("name", &self.constraint_ref.name)
349            .field("impact_type", &self.impact_type)
350            .field("groups", &self.groups.len())
351            .finish()
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use crate::stream::collector::count;
359    use solverforge_core::score::SimpleScore;
360
361    #[derive(Clone, Hash, PartialEq, Eq)]
362    struct Shift {
363        employee_id: usize,
364    }
365
366    #[derive(Clone)]
367    struct Solution {
368        shifts: Vec<Shift>,
369    }
370
371    #[test]
372    fn test_grouped_constraint_evaluate() {
373        let constraint = GroupedUniConstraint::new(
374            ConstraintRef::new("", "Workload"),
375            ImpactType::Penalty,
376            |s: &Solution| &s.shifts,
377            |shift: &Shift| shift.employee_id,
378            count::<Shift>(),
379            |count: &usize| SimpleScore::of((*count * *count) as i64),
380            false,
381        );
382
383        let solution = Solution {
384            shifts: vec![
385                Shift { employee_id: 1 },
386                Shift { employee_id: 1 },
387                Shift { employee_id: 1 },
388                Shift { employee_id: 2 },
389            ],
390        };
391
392        // Employee 1: 3 shifts -> 9
393        // Employee 2: 1 shift -> 1
394        // Total penalty: -10
395        assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-10));
396    }
397
398    #[test]
399    fn test_grouped_constraint_incremental() {
400        let mut constraint = GroupedUniConstraint::new(
401            ConstraintRef::new("", "Workload"),
402            ImpactType::Penalty,
403            |s: &Solution| &s.shifts,
404            |shift: &Shift| shift.employee_id,
405            count::<Shift>(),
406            |count: &usize| SimpleScore::of(*count as i64),
407            false,
408        );
409
410        let solution = Solution {
411            shifts: vec![
412                Shift { employee_id: 1 },
413                Shift { employee_id: 1 },
414                Shift { employee_id: 2 },
415            ],
416        };
417
418        // Initialize
419        let total = constraint.initialize(&solution);
420        // Employee 1: 2 shifts -> -2
421        // Employee 2: 1 shift -> -1
422        // Total: -3
423        assert_eq!(total, SimpleScore::of(-3));
424
425        // Retract shift at index 0 (employee 1)
426        let delta = constraint.on_retract(&solution, 0);
427        // Employee 1 now has 1 shift -> score goes from -2 to -1, delta = +1
428        assert_eq!(delta, SimpleScore::of(1));
429
430        // Insert shift at index 0 (employee 1)
431        let delta = constraint.on_insert(&solution, 0);
432        // Employee 1 now has 2 shifts -> score goes from -1 to -2, delta = -1
433        assert_eq!(delta, SimpleScore::of(-1));
434    }
435
436    #[test]
437    fn test_grouped_constraint_reward() {
438        let constraint = GroupedUniConstraint::new(
439            ConstraintRef::new("", "Collaboration"),
440            ImpactType::Reward,
441            |s: &Solution| &s.shifts,
442            |shift: &Shift| shift.employee_id,
443            count::<Shift>(),
444            |count: &usize| SimpleScore::of(*count as i64),
445            false,
446        );
447
448        let solution = Solution {
449            shifts: vec![Shift { employee_id: 1 }, Shift { employee_id: 1 }],
450        };
451
452        // 2 shifts in one group -> reward of +2
453        assert_eq!(constraint.evaluate(&solution), SimpleScore::of(2));
454    }
455}