Skip to main content

solverforge_scoring/constraint/
grouped.rs

1/* Zero-erasure grouped constraint for group-by operations.
2
3Provides incremental scoring for constraints that group entities and
4apply collectors to compute aggregate scores.
5All type information is preserved at compile time - no Arc, no dyn.
6*/
7
8use std::collections::HashMap;
9use std::hash::Hash;
10use std::marker::PhantomData;
11
12use solverforge_core::score::Score;
13use solverforge_core::{ConstraintRef, ImpactType};
14
15use crate::api::constraint_set::IncrementalConstraint;
16use crate::stream::collector::{Accumulator, UniCollector};
17use crate::stream::filter::UniFilter;
18
19struct GroupState<Acc> {
20    accumulator: Acc,
21    count: usize,
22}
23
24/* Zero-erasure constraint that groups entities by key and scores based on collector results.
25
26This enables incremental scoring for group-by operations:
27- Tracks which entities belong to which group
28- Maintains collector state per group
29- Computes score deltas when entities are added/removed
30
31All type parameters are concrete - no trait objects, no Arc allocations.
32
33# Type Parameters
34
35- `S` - Solution type
36- `A` - Entity type
37- `K` - Group key type
38- `E` - Extractor function for entities
39- `Fi` - Filter type (applied before grouping)
40- `KF` - Key function
41- `C` - Collector type
42- `W` - Weight function
43- `Sc` - Score type
44
45# Example
46
47```
48use solverforge_scoring::constraint::grouped::GroupedUniConstraint;
49use solverforge_scoring::stream::collector::count;
50use solverforge_scoring::stream::filter::TrueFilter;
51use solverforge_scoring::api::constraint_set::IncrementalConstraint;
52use solverforge_core::{ConstraintRef, ImpactType};
53use solverforge_core::score::SoftScore;
54
55#[derive(Clone, Hash, PartialEq, Eq)]
56struct Shift { employee_id: usize }
57
58#[derive(Clone)]
59struct Solution { shifts: Vec<Shift> }
60
61// Penalize based on squared workload per employee
62let constraint = GroupedUniConstraint::new(
63ConstraintRef::new("", "Balanced workload"),
64ImpactType::Penalty,
65|s: &Solution| &s.shifts,
66TrueFilter,
67|shift: &Shift| shift.employee_id,
68count::<Shift>(),
69|_employee_id: &usize, count: &usize| SoftScore::of((*count * *count) as i64),
70false,
71);
72
73let solution = Solution {
74shifts: vec![
75Shift { employee_id: 1 },
76Shift { employee_id: 1 },
77Shift { employee_id: 1 },
78Shift { employee_id: 2 },
79],
80};
81
82// Employee 1: 3 shifts -> 9 penalty
83// Employee 2: 1 shift -> 1 penalty
84// Total: -10
85assert_eq!(constraint.evaluate(&solution), SoftScore::of(-10));
86```
87*/
88pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
89where
90    C: UniCollector<A>,
91    Sc: Score,
92{
93    constraint_ref: ConstraintRef,
94    impact_type: ImpactType,
95    extractor: E,
96    filter: Fi,
97    key_fn: KF,
98    collector: C,
99    weight_fn: W,
100    is_hard: bool,
101    change_source: crate::stream::collection_extract::ChangeSource,
102    // Group key -> accumulator plus count (scores computed on-the-fly, no cloning)
103    groups: HashMap<K, GroupState<C::Accumulator>>,
104    // Entity index -> group key (for tracking which group an entity belongs to)
105    entity_groups: HashMap<usize, K>,
106    // Entity index -> extracted value (for correct retraction after entity mutation)
107    entity_values: HashMap<usize, C::Value>,
108    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
109}
110
111impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
112where
113    S: Send + Sync + 'static,
114    A: Clone + Send + Sync + 'static,
115    K: Clone + Eq + Hash + Send + Sync + 'static,
116    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
117    Fi: UniFilter<S, A>,
118    KF: Fn(&A) -> K + Send + Sync,
119    C: UniCollector<A> + Send + Sync + 'static,
120    C::Accumulator: Send + Sync,
121    C::Result: Send + Sync,
122    W: Fn(&K, &C::Result) -> Sc + Send + Sync,
123    Sc: Score + 'static,
124{
125    /* Creates a new zero-erasure grouped constraint.
126
127    # Arguments
128
129    * `constraint_ref` - Identifier for this constraint
130    * `impact_type` - Whether to penalize or reward
131    * `extractor` - Function to get entity slice from solution
132    * `filter` - Filter applied to entities before grouping
133    * `key_fn` - Function to extract group key from entity
134    * `collector` - Collector to aggregate entities per group
135    * `weight_fn` - Function to compute score from collector result
136    * `is_hard` - Whether this is a hard constraint
137    */
138    #[allow(clippy::too_many_arguments)]
139    pub fn new(
140        constraint_ref: ConstraintRef,
141        impact_type: ImpactType,
142        extractor: E,
143        filter: Fi,
144        key_fn: KF,
145        collector: C,
146        weight_fn: W,
147        is_hard: bool,
148    ) -> Self {
149        let change_source = extractor.change_source();
150        Self {
151            constraint_ref,
152            impact_type,
153            extractor,
154            filter,
155            key_fn,
156            collector,
157            weight_fn,
158            is_hard,
159            change_source,
160            groups: HashMap::new(),
161            entity_groups: HashMap::new(),
162            entity_values: HashMap::new(),
163            _phantom: PhantomData,
164        }
165    }
166
167    // Computes the score contribution for a group's result.
168    fn compute_score(&self, key: &K, result: &C::Result) -> Sc {
169        let base = (self.weight_fn)(key, result);
170        match self.impact_type {
171            ImpactType::Penalty => -base,
172            ImpactType::Reward => base,
173        }
174    }
175}
176
177impl<S, A, K, E, Fi, KF, C, W, Sc> IncrementalConstraint<S, Sc>
178    for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
179where
180    S: Send + Sync + 'static,
181    A: Clone + Send + Sync + 'static,
182    K: Clone + Eq + Hash + Send + Sync + 'static,
183    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
184    Fi: UniFilter<S, A>,
185    KF: Fn(&A) -> K + Send + Sync,
186    C: UniCollector<A> + Send + Sync + 'static,
187    C::Accumulator: Send + Sync,
188    C::Result: Send + Sync,
189    C::Value: Send + Sync,
190    W: Fn(&K, &C::Result) -> Sc + Send + Sync,
191    Sc: Score + 'static,
192{
193    fn evaluate(&self, solution: &S) -> Sc {
194        let entities = self.extractor.extract(solution);
195
196        // Group entities by key, applying filter
197        let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
198
199        for entity in entities {
200            if !self.filter.test(solution, entity) {
201                continue;
202            }
203            let key = (self.key_fn)(entity);
204            let value = self.collector.extract(entity);
205            let acc = groups
206                .entry(key)
207                .or_insert_with(|| self.collector.create_accumulator());
208            acc.accumulate(&value);
209        }
210
211        // Sum scores for all groups
212        let mut total = Sc::zero();
213        for (key, acc) in &groups {
214            let result = acc.finish();
215            total = total + self.compute_score(key, &result);
216        }
217
218        total
219    }
220
221    fn match_count(&self, solution: &S) -> usize {
222        let entities = self.extractor.extract(solution);
223
224        // Count unique groups (filtered)
225        let mut groups: HashMap<K, ()> = HashMap::new();
226        for entity in entities {
227            if !self.filter.test(solution, entity) {
228                continue;
229            }
230            let key = (self.key_fn)(entity);
231            groups.insert(key, ());
232        }
233
234        groups.len()
235    }
236
237    fn initialize(&mut self, solution: &S) -> Sc {
238        self.reset();
239
240        let entities = self.extractor.extract(solution);
241        let mut total = Sc::zero();
242
243        for (idx, entity) in entities.iter().enumerate() {
244            if !self.filter.test(solution, entity) {
245                continue;
246            }
247            total = total + self.insert_entity(entities, idx, entity);
248        }
249
250        total
251    }
252
253    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
254        if !self
255            .change_source
256            .assert_localizes(descriptor_index, &self.constraint_ref.name)
257        {
258            return Sc::zero();
259        }
260        let entities = self.extractor.extract(solution);
261        if entity_index >= entities.len() {
262            return Sc::zero();
263        }
264
265        let entity = &entities[entity_index];
266        if !self.filter.test(solution, entity) {
267            return Sc::zero();
268        }
269        self.insert_entity(entities, entity_index, entity)
270    }
271
272    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
273        if !self
274            .change_source
275            .assert_localizes(descriptor_index, &self.constraint_ref.name)
276        {
277            return Sc::zero();
278        }
279        let entities = self.extractor.extract(solution);
280        self.retract_entity(entities, entity_index)
281    }
282
283    fn reset(&mut self) {
284        self.groups.clear();
285        self.entity_groups.clear();
286        self.entity_values.clear();
287    }
288
289    fn name(&self) -> &str {
290        &self.constraint_ref.name
291    }
292
293    fn is_hard(&self) -> bool {
294        self.is_hard
295    }
296
297    fn constraint_ref(&self) -> &ConstraintRef {
298        &self.constraint_ref
299    }
300}
301
302impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
303where
304    S: Send + Sync + 'static,
305    A: Clone + Send + Sync + 'static,
306    K: Clone + Eq + Hash + Send + Sync + 'static,
307    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
308    Fi: UniFilter<S, A>,
309    KF: Fn(&A) -> K + Send + Sync,
310    C: UniCollector<A> + Send + Sync + 'static,
311    C::Accumulator: Send + Sync,
312    C::Result: Send + Sync,
313    C::Value: Send + Sync,
314    W: Fn(&K, &C::Result) -> Sc + Send + Sync,
315    Sc: Score + 'static,
316{
317    fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
318        let key = (self.key_fn)(entity);
319        let entity_key = (self.key_fn)(entity);
320        let value = self.collector.extract(entity);
321        let impact = self.impact_type;
322
323        let weight_fn = &self.weight_fn;
324        let (old, new_score) = match self.groups.entry(key) {
325            std::collections::hash_map::Entry::Occupied(mut entry) => {
326                let old_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
327                let old = match impact {
328                    ImpactType::Penalty => -old_base,
329                    ImpactType::Reward => old_base,
330                };
331                let group = entry.get_mut();
332                group.accumulator.accumulate(&value);
333                group.count += 1;
334                let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
335                let new_score = match impact {
336                    ImpactType::Penalty => -new_base,
337                    ImpactType::Reward => new_base,
338                };
339                (old, new_score)
340            }
341            std::collections::hash_map::Entry::Vacant(entry) => {
342                let mut entry = entry.insert_entry(GroupState {
343                    accumulator: self.collector.create_accumulator(),
344                    count: 0,
345                });
346                let group = entry.get_mut();
347                group.accumulator.accumulate(&value);
348                group.count += 1;
349                let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
350                let new_score = match impact {
351                    ImpactType::Penalty => -new_base,
352                    ImpactType::Reward => new_base,
353                };
354                (Sc::zero(), new_score)
355            }
356        };
357
358        // Track entity -> group mapping and cache value for correct retraction
359        self.entity_groups.insert(entity_index, entity_key);
360        self.entity_values.insert(entity_index, value);
361
362        // Return delta (both scores computed fresh, no cloning)
363        new_score - old
364    }
365
366    fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
367        // Find which group this entity belonged to
368        let Some(key) = self.entity_groups.remove(&entity_index) else {
369            return Sc::zero();
370        };
371
372        // Use cached value (entity may have been mutated since insert)
373        let Some(value) = self.entity_values.remove(&entity_index) else {
374            return Sc::zero();
375        };
376        let impact = self.impact_type;
377
378        let weight_fn = &self.weight_fn;
379        let std::collections::hash_map::Entry::Occupied(mut entry) = self.groups.entry(key) else {
380            return Sc::zero();
381        };
382
383        let old_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
384        let old = match impact {
385            ImpactType::Penalty => -old_base,
386            ImpactType::Reward => old_base,
387        };
388
389        let group = entry.get_mut();
390        group.accumulator.retract(&value);
391        group.count = group.count.saturating_sub(1);
392        let is_empty = group.count == 0;
393        let new_score = if is_empty {
394            entry.remove();
395            Sc::zero()
396        } else {
397            let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
398            match impact {
399                ImpactType::Penalty => -new_base,
400                ImpactType::Reward => new_base,
401            }
402        };
403
404        // Return delta (both scores computed fresh, no cloning)
405        new_score - old
406    }
407}
408
409impl<S, A, K, E, Fi, KF, C, W, Sc> std::fmt::Debug
410    for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
411where
412    C: UniCollector<A>,
413    Sc: Score,
414{
415    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416        f.debug_struct("GroupedUniConstraint")
417            .field("name", &self.constraint_ref.name)
418            .field("impact_type", &self.impact_type)
419            .field("groups", &self.groups.len())
420            .finish()
421    }
422}