Skip to main content

solverforge_scoring/constraint/
grouped.rs

1/* Zero-erasure grouped constraint for group-by operations.
2
3Provides incremental scoring for constraints that group entities and
4apply collectors to compute aggregate scores.
5All type information is preserved at compile time - no Arc, no dyn.
6*/
7
8use std::collections::HashMap;
9use std::hash::Hash;
10use std::marker::PhantomData;
11
12use solverforge_core::score::Score;
13use solverforge_core::{ConstraintRef, ImpactType};
14
15use crate::api::constraint_set::IncrementalConstraint;
16use crate::stream::collector::{Accumulator, UniCollector};
17use crate::stream::filter::UniFilter;
18
19/* Zero-erasure constraint that groups entities by key and scores based on collector results.
20
21This enables incremental scoring for group-by operations:
22- Tracks which entities belong to which group
23- Maintains collector state per group
24- Computes score deltas when entities are added/removed
25
26All type parameters are concrete - no trait objects, no Arc allocations.
27
28# Type Parameters
29
30- `S` - Solution type
31- `A` - Entity type
32- `K` - Group key type
33- `E` - Extractor function for entities
34- `Fi` - Filter type (applied before grouping)
35- `KF` - Key function
36- `C` - Collector type
37- `W` - Weight function
38- `Sc` - Score type
39
40# Example
41
42```
43use solverforge_scoring::constraint::grouped::GroupedUniConstraint;
44use solverforge_scoring::stream::collector::count;
45use solverforge_scoring::stream::filter::TrueFilter;
46use solverforge_scoring::api::constraint_set::IncrementalConstraint;
47use solverforge_core::{ConstraintRef, ImpactType};
48use solverforge_core::score::SoftScore;
49
50#[derive(Clone, Hash, PartialEq, Eq)]
51struct Shift { employee_id: usize }
52
53#[derive(Clone)]
54struct Solution { shifts: Vec<Shift> }
55
56// Penalize based on squared workload per employee
57let constraint = GroupedUniConstraint::new(
58ConstraintRef::new("", "Balanced workload"),
59ImpactType::Penalty,
60|s: &Solution| &s.shifts,
61TrueFilter,
62|shift: &Shift| shift.employee_id,
63count::<Shift>(),
64|count: &usize| SoftScore::of((*count * *count) as i64),
65false,
66);
67
68let solution = Solution {
69shifts: vec![
70Shift { employee_id: 1 },
71Shift { employee_id: 1 },
72Shift { employee_id: 1 },
73Shift { employee_id: 2 },
74],
75};
76
77// Employee 1: 3 shifts -> 9 penalty
78// Employee 2: 1 shift -> 1 penalty
79// Total: -10
80assert_eq!(constraint.evaluate(&solution), SoftScore::of(-10));
81```
82*/
83pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
84where
85    C: UniCollector<A>,
86    Sc: Score,
87{
88    constraint_ref: ConstraintRef,
89    impact_type: ImpactType,
90    extractor: E,
91    filter: Fi,
92    key_fn: KF,
93    collector: C,
94    weight_fn: W,
95    is_hard: bool,
96    change_source: crate::stream::collection_extract::ChangeSource,
97    // Group key -> accumulator (scores computed on-the-fly, no cloning)
98    groups: HashMap<K, C::Accumulator>,
99    // Group key -> number of entities in the group (for empty-group detection)
100    group_counts: HashMap<K, usize>,
101    // Entity index -> group key (for tracking which group an entity belongs to)
102    entity_groups: HashMap<usize, K>,
103    // Entity index -> extracted value (for correct retraction after entity mutation)
104    entity_values: HashMap<usize, C::Value>,
105    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
106}
107
108impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
109where
110    S: Send + Sync + 'static,
111    A: Clone + Send + Sync + 'static,
112    K: Clone + Eq + Hash + Send + Sync + 'static,
113    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
114    Fi: UniFilter<S, A>,
115    KF: Fn(&A) -> K + Send + Sync,
116    C: UniCollector<A> + Send + Sync + 'static,
117    C::Accumulator: Send + Sync,
118    C::Result: Send + Sync,
119    W: Fn(&C::Result) -> Sc + Send + Sync,
120    Sc: Score + 'static,
121{
122    /* Creates a new zero-erasure grouped constraint.
123
124    # Arguments
125
126    * `constraint_ref` - Identifier for this constraint
127    * `impact_type` - Whether to penalize or reward
128    * `extractor` - Function to get entity slice from solution
129    * `filter` - Filter applied to entities before grouping
130    * `key_fn` - Function to extract group key from entity
131    * `collector` - Collector to aggregate entities per group
132    * `weight_fn` - Function to compute score from collector result
133    * `is_hard` - Whether this is a hard constraint
134    */
135    #[allow(clippy::too_many_arguments)]
136    pub fn new(
137        constraint_ref: ConstraintRef,
138        impact_type: ImpactType,
139        extractor: E,
140        filter: Fi,
141        key_fn: KF,
142        collector: C,
143        weight_fn: W,
144        is_hard: bool,
145    ) -> Self {
146        let change_source = extractor.change_source();
147        Self {
148            constraint_ref,
149            impact_type,
150            extractor,
151            filter,
152            key_fn,
153            collector,
154            weight_fn,
155            is_hard,
156            change_source,
157            groups: HashMap::new(),
158            group_counts: HashMap::new(),
159            entity_groups: HashMap::new(),
160            entity_values: HashMap::new(),
161            _phantom: PhantomData,
162        }
163    }
164
165    // Computes the score contribution for a group's result.
166    fn compute_score(&self, result: &C::Result) -> Sc {
167        let base = (self.weight_fn)(result);
168        match self.impact_type {
169            ImpactType::Penalty => -base,
170            ImpactType::Reward => base,
171        }
172    }
173}
174
175impl<S, A, K, E, Fi, KF, C, W, Sc> IncrementalConstraint<S, Sc>
176    for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
177where
178    S: Send + Sync + 'static,
179    A: Clone + Send + Sync + 'static,
180    K: Clone + Eq + Hash + Send + Sync + 'static,
181    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
182    Fi: UniFilter<S, A>,
183    KF: Fn(&A) -> K + Send + Sync,
184    C: UniCollector<A> + Send + Sync + 'static,
185    C::Accumulator: Send + Sync,
186    C::Result: Send + Sync,
187    C::Value: Send + Sync,
188    W: Fn(&C::Result) -> Sc + Send + Sync,
189    Sc: Score + 'static,
190{
191    fn evaluate(&self, solution: &S) -> Sc {
192        let entities = self.extractor.extract(solution);
193
194        // Group entities by key, applying filter
195        let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
196
197        for entity in entities {
198            if !self.filter.test(solution, entity) {
199                continue;
200            }
201            let key = (self.key_fn)(entity);
202            let value = self.collector.extract(entity);
203            let acc = groups
204                .entry(key)
205                .or_insert_with(|| self.collector.create_accumulator());
206            acc.accumulate(&value);
207        }
208
209        // Sum scores for all groups
210        let mut total = Sc::zero();
211        for acc in groups.values() {
212            let result = acc.finish();
213            total = total + self.compute_score(&result);
214        }
215
216        total
217    }
218
219    fn match_count(&self, solution: &S) -> usize {
220        let entities = self.extractor.extract(solution);
221
222        // Count unique groups (filtered)
223        let mut groups: HashMap<K, ()> = HashMap::new();
224        for entity in entities {
225            if !self.filter.test(solution, entity) {
226                continue;
227            }
228            let key = (self.key_fn)(entity);
229            groups.insert(key, ());
230        }
231
232        groups.len()
233    }
234
235    fn initialize(&mut self, solution: &S) -> Sc {
236        self.reset();
237
238        let entities = self.extractor.extract(solution);
239        let mut total = Sc::zero();
240
241        for (idx, entity) in entities.iter().enumerate() {
242            if !self.filter.test(solution, entity) {
243                continue;
244            }
245            total = total + self.insert_entity(entities, idx, entity);
246        }
247
248        total
249    }
250
251    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
252        if !self
253            .change_source
254            .assert_localizes(descriptor_index, &self.constraint_ref.name)
255        {
256            return Sc::zero();
257        }
258        let entities = self.extractor.extract(solution);
259        if entity_index >= entities.len() {
260            return Sc::zero();
261        }
262
263        let entity = &entities[entity_index];
264        if !self.filter.test(solution, entity) {
265            return Sc::zero();
266        }
267        self.insert_entity(entities, entity_index, entity)
268    }
269
270    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
271        if !self
272            .change_source
273            .assert_localizes(descriptor_index, &self.constraint_ref.name)
274        {
275            return Sc::zero();
276        }
277        let entities = self.extractor.extract(solution);
278        self.retract_entity(entities, entity_index)
279    }
280
281    fn reset(&mut self) {
282        self.groups.clear();
283        self.group_counts.clear();
284        self.entity_groups.clear();
285        self.entity_values.clear();
286    }
287
288    fn name(&self) -> &str {
289        &self.constraint_ref.name
290    }
291
292    fn is_hard(&self) -> bool {
293        self.is_hard
294    }
295
296    fn constraint_ref(&self) -> &ConstraintRef {
297        &self.constraint_ref
298    }
299}
300
301impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
302where
303    S: Send + Sync + 'static,
304    A: Clone + Send + Sync + 'static,
305    K: Clone + Eq + Hash + Send + Sync + 'static,
306    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
307    Fi: UniFilter<S, A>,
308    KF: Fn(&A) -> K + Send + Sync,
309    C: UniCollector<A> + Send + Sync + 'static,
310    C::Accumulator: Send + Sync,
311    C::Result: Send + Sync,
312    C::Value: Send + Sync,
313    W: Fn(&C::Result) -> Sc + Send + Sync,
314    Sc: Score + 'static,
315{
316    fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
317        let key = (self.key_fn)(entity);
318        let value = self.collector.extract(entity);
319        let impact = self.impact_type;
320
321        // Get or create group accumulator
322        let is_new = !self.groups.contains_key(&key);
323        let acc = self
324            .groups
325            .entry(key.clone())
326            .or_insert_with(|| self.collector.create_accumulator());
327
328        // Old score is zero for new groups (they didn't exist before, contributing nothing)
329        let old = if is_new {
330            Sc::zero()
331        } else {
332            let old_base = (self.weight_fn)(&acc.finish());
333            match impact {
334                ImpactType::Penalty => -old_base,
335                ImpactType::Reward => old_base,
336            }
337        };
338
339        // Accumulate and compute new score
340        acc.accumulate(&value);
341        let new_base = (self.weight_fn)(&acc.finish());
342        let new_score = match impact {
343            ImpactType::Penalty => -new_base,
344            ImpactType::Reward => new_base,
345        };
346
347        // Track entity -> group mapping and cache value for correct retraction
348        self.entity_groups.insert(entity_index, key.clone());
349        self.entity_values.insert(entity_index, value);
350        *self.group_counts.entry(key).or_insert(0) += 1;
351
352        // Return delta (both scores computed fresh, no cloning)
353        new_score - old
354    }
355
356    fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
357        // Find which group this entity belonged to
358        let Some(key) = self.entity_groups.remove(&entity_index) else {
359            return Sc::zero();
360        };
361
362        // Use cached value (entity may have been mutated since insert)
363        let Some(value) = self.entity_values.remove(&entity_index) else {
364            return Sc::zero();
365        };
366        let impact = self.impact_type;
367
368        // Get the group accumulator
369        let Some(acc) = self.groups.get_mut(&key) else {
370            return Sc::zero();
371        };
372
373        // Compute old score from current state (inlined to avoid borrow conflict)
374        let old_base = (self.weight_fn)(&acc.finish());
375        let old = match impact {
376            ImpactType::Penalty => -old_base,
377            ImpactType::Reward => old_base,
378        };
379
380        // Decrement group count; remove group if now empty
381        let is_empty = {
382            let cnt = self.group_counts.entry(key.clone()).or_insert(0);
383            *cnt = cnt.saturating_sub(1);
384            *cnt == 0
385        };
386        if is_empty {
387            self.group_counts.remove(&key);
388        }
389
390        // Retract and compute new score
391        acc.retract(&value);
392        let new_score = if is_empty {
393            // Group is now empty; remove it and treat its contribution as zero
394            self.groups.remove(&key);
395            Sc::zero()
396        } else {
397            let new_base = (self.weight_fn)(&acc.finish());
398            match impact {
399                ImpactType::Penalty => -new_base,
400                ImpactType::Reward => new_base,
401            }
402        };
403
404        // Return delta (both scores computed fresh, no cloning)
405        new_score - old
406    }
407}
408
409impl<S, A, K, E, Fi, KF, C, W, Sc> std::fmt::Debug
410    for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
411where
412    C: UniCollector<A>,
413    Sc: Score,
414{
415    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416        f.debug_struct("GroupedUniConstraint")
417            .field("name", &self.constraint_ref.name)
418            .field("impact_type", &self.impact_type)
419            .field("groups", &self.groups.len())
420            .finish()
421    }
422}