Skip to main content

solverforge_scoring/constraint/
grouped.rs

1/* Zero-erasure grouped constraint for group-by operations.
2
3Provides incremental scoring for constraints that group entities and
4apply collectors to compute aggregate scores.
5All type information is preserved at compile time - no Arc, no dyn.
6*/
7
8use std::collections::HashMap;
9use std::hash::Hash;
10use std::marker::PhantomData;
11
12use solverforge_core::score::Score;
13use solverforge_core::{ConstraintRef, ImpactType};
14
15use crate::api::constraint_set::IncrementalConstraint;
16use crate::stream::collector::{Accumulator, UniCollector};
17use crate::stream::filter::UniFilter;
18
19struct GroupState<Acc> {
20    accumulator: Acc,
21    count: usize,
22}
23
24type CollectorRetraction<C, A> = <<C as UniCollector<A>>::Accumulator as Accumulator<
25    <C as UniCollector<A>>::Value,
26    <C as UniCollector<A>>::Result,
27>>::Retraction;
28
29/* Zero-erasure constraint that groups entities by key and scores based on collector results.
30
31This enables incremental scoring for group-by operations:
32- Tracks which entities belong to which group
33- Maintains collector state per group
34- Computes score deltas when entities are added/removed
35
36All type parameters are concrete - no trait objects, no Arc allocations.
37
38# Type Parameters
39
40- `S` - Solution type
41- `A` - Entity type
42- `K` - Group key type
43- `E` - Extractor function for entities
44- `Fi` - Filter type (applied before grouping)
45- `KF` - Key function
46- `C` - Collector type
47- `W` - Weight function
48- `Sc` - Score type
49
50# Example
51
52```
53use solverforge_scoring::constraint::grouped::GroupedUniConstraint;
54use solverforge_scoring::stream::collector::count;
55use solverforge_scoring::stream::filter::TrueFilter;
56use solverforge_scoring::api::constraint_set::IncrementalConstraint;
57use solverforge_core::{ConstraintRef, ImpactType};
58use solverforge_core::score::SoftScore;
59
60#[derive(Clone, Hash, PartialEq, Eq)]
61struct Shift { employee_id: usize }
62
63#[derive(Clone)]
64struct Solution { shifts: Vec<Shift> }
65
66// Penalize based on squared workload per employee
67let constraint = GroupedUniConstraint::new(
68ConstraintRef::new("", "Balanced workload"),
69ImpactType::Penalty,
70|s: &Solution| &s.shifts,
71TrueFilter,
72|shift: &Shift| shift.employee_id,
73count::<Shift>(),
74|_employee_id: &usize, count: &usize| SoftScore::of((*count * *count) as i64),
75false,
76);
77
78let solution = Solution {
79shifts: vec![
80Shift { employee_id: 1 },
81Shift { employee_id: 1 },
82Shift { employee_id: 1 },
83Shift { employee_id: 2 },
84],
85};
86
87// Employee 1: 3 shifts -> 9 penalty
88// Employee 2: 1 shift -> 1 penalty
89// Total: -10
90assert_eq!(constraint.evaluate(&solution), SoftScore::of(-10));
91```
92*/
93pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
94where
95    C: UniCollector<A>,
96    Sc: Score,
97{
98    constraint_ref: ConstraintRef,
99    impact_type: ImpactType,
100    extractor: E,
101    filter: Fi,
102    key_fn: KF,
103    collector: C,
104    weight_fn: W,
105    is_hard: bool,
106    change_source: crate::stream::collection_extract::ChangeSource,
107    // Group key -> accumulator plus count (scores computed on-the-fly, no cloning)
108    groups: HashMap<K, GroupState<C::Accumulator>>,
109    // Entity index -> group key (for tracking which group an entity belongs to)
110    entity_groups: HashMap<usize, K>,
111    // Entity index -> accumulator retraction token
112    entity_retractions: HashMap<usize, CollectorRetraction<C, A>>,
113    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
114}
115
116impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
117where
118    S: Send + Sync + 'static,
119    A: Clone + Send + Sync + 'static,
120    K: Clone + Eq + Hash + Send + Sync + 'static,
121    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
122    Fi: UniFilter<S, A>,
123    KF: Fn(&A) -> K + Send + Sync,
124    C: UniCollector<A> + Send + Sync + 'static,
125    C::Accumulator: Send + Sync,
126    C::Result: Send + Sync,
127    W: Fn(&K, &C::Result) -> Sc + Send + Sync,
128    Sc: Score + 'static,
129{
130    /* Creates a new zero-erasure grouped constraint.
131
132    # Arguments
133
134    * `constraint_ref` - Identifier for this constraint
135    * `impact_type` - Whether to penalize or reward
136    * `extractor` - Function to get entity slice from solution
137    * `filter` - Filter applied to entities before grouping
138    * `key_fn` - Function to extract group key from entity
139    * `collector` - Collector to aggregate entities per group
140    * `weight_fn` - Function to compute score from collector result
141    * `is_hard` - Whether this is a hard constraint
142    */
143    #[allow(clippy::too_many_arguments)]
144    pub fn new(
145        constraint_ref: ConstraintRef,
146        impact_type: ImpactType,
147        extractor: E,
148        filter: Fi,
149        key_fn: KF,
150        collector: C,
151        weight_fn: W,
152        is_hard: bool,
153    ) -> Self {
154        let change_source = extractor.change_source();
155        Self {
156            constraint_ref,
157            impact_type,
158            extractor,
159            filter,
160            key_fn,
161            collector,
162            weight_fn,
163            is_hard,
164            change_source,
165            groups: HashMap::new(),
166            entity_groups: HashMap::new(),
167            entity_retractions: HashMap::new(),
168            _phantom: PhantomData,
169        }
170    }
171
172    // Computes the score contribution for a group's result.
173    fn compute_score(&self, key: &K, result: &C::Result) -> Sc {
174        let base = (self.weight_fn)(key, result);
175        match self.impact_type {
176            ImpactType::Penalty => -base,
177            ImpactType::Reward => base,
178        }
179    }
180}
181
182impl<S, A, K, E, Fi, KF, C, W, Sc> IncrementalConstraint<S, Sc>
183    for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
184where
185    S: Send + Sync + 'static,
186    A: Clone + Send + Sync + 'static,
187    K: Clone + Eq + Hash + Send + Sync + 'static,
188    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
189    Fi: UniFilter<S, A>,
190    KF: Fn(&A) -> K + Send + Sync,
191    C: UniCollector<A> + Send + Sync + 'static,
192    C::Accumulator: Send + Sync,
193    C::Result: Send + Sync,
194    C::Value: Send + Sync,
195    W: Fn(&K, &C::Result) -> Sc + Send + Sync,
196    Sc: Score + 'static,
197{
198    fn evaluate(&self, solution: &S) -> Sc {
199        let entities = self.extractor.extract(solution);
200
201        // Group entities by key, applying filter
202        let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
203
204        for entity in entities {
205            if !self.filter.test(solution, entity) {
206                continue;
207            }
208            let key = (self.key_fn)(entity);
209            let value = self.collector.extract(entity);
210            let acc = groups
211                .entry(key)
212                .or_insert_with(|| self.collector.create_accumulator());
213            acc.accumulate(value);
214        }
215
216        // Sum scores for all groups
217        let mut total = Sc::zero();
218        for (key, acc) in &groups {
219            total = total + acc.with_result(|result| self.compute_score(key, result));
220        }
221
222        total
223    }
224
225    fn match_count(&self, solution: &S) -> usize {
226        let entities = self.extractor.extract(solution);
227
228        // Count unique groups (filtered)
229        let mut groups: HashMap<K, ()> = HashMap::new();
230        for entity in entities {
231            if !self.filter.test(solution, entity) {
232                continue;
233            }
234            let key = (self.key_fn)(entity);
235            groups.insert(key, ());
236        }
237
238        groups.len()
239    }
240
241    fn initialize(&mut self, solution: &S) -> Sc {
242        self.reset();
243
244        let entities = self.extractor.extract(solution);
245        let mut total = Sc::zero();
246
247        for (idx, entity) in entities.iter().enumerate() {
248            if !self.filter.test(solution, entity) {
249                continue;
250            }
251            total = total + self.insert_entity(entities, idx, entity);
252        }
253
254        total
255    }
256
257    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
258        if !self
259            .change_source
260            .assert_localizes(descriptor_index, &self.constraint_ref.name)
261        {
262            return Sc::zero();
263        }
264        let entities = self.extractor.extract(solution);
265        if entity_index >= entities.len() {
266            return Sc::zero();
267        }
268
269        let entity = &entities[entity_index];
270        if !self.filter.test(solution, entity) {
271            return Sc::zero();
272        }
273        self.insert_entity(entities, entity_index, entity)
274    }
275
276    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
277        if !self
278            .change_source
279            .assert_localizes(descriptor_index, &self.constraint_ref.name)
280        {
281            return Sc::zero();
282        }
283        let entities = self.extractor.extract(solution);
284        self.retract_entity(entities, entity_index)
285    }
286
287    fn reset(&mut self) {
288        self.groups.clear();
289        self.entity_groups.clear();
290        self.entity_retractions.clear();
291    }
292
293    fn name(&self) -> &str {
294        &self.constraint_ref.name
295    }
296
297    fn is_hard(&self) -> bool {
298        self.is_hard
299    }
300
301    fn constraint_ref(&self) -> &ConstraintRef {
302        &self.constraint_ref
303    }
304}
305
306impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
307where
308    S: Send + Sync + 'static,
309    A: Clone + Send + Sync + 'static,
310    K: Clone + Eq + Hash + Send + Sync + 'static,
311    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
312    Fi: UniFilter<S, A>,
313    KF: Fn(&A) -> K + Send + Sync,
314    C: UniCollector<A> + Send + Sync + 'static,
315    C::Accumulator: Send + Sync,
316    C::Result: Send + Sync,
317    C::Value: Send + Sync,
318    W: Fn(&K, &C::Result) -> Sc + Send + Sync,
319    Sc: Score + 'static,
320{
321    fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
322        let key = (self.key_fn)(entity);
323        let entity_key = key.clone();
324        let value = self.collector.extract(entity);
325        let impact = self.impact_type;
326
327        let weight_fn = &self.weight_fn;
328        let (old, new_score) = match self.groups.entry(key) {
329            std::collections::hash_map::Entry::Occupied(mut entry) => {
330                let old_base = entry
331                    .get()
332                    .accumulator
333                    .with_result(|result| weight_fn(entry.key(), result));
334                let old = match impact {
335                    ImpactType::Penalty => -old_base,
336                    ImpactType::Reward => old_base,
337                };
338                let group = entry.get_mut();
339                let retraction = group.accumulator.accumulate(value);
340                group.count += 1;
341                let new_base = entry
342                    .get()
343                    .accumulator
344                    .with_result(|result| weight_fn(entry.key(), result));
345                let new_score = match impact {
346                    ImpactType::Penalty => -new_base,
347                    ImpactType::Reward => new_base,
348                };
349                self.entity_retractions.insert(entity_index, retraction);
350                (old, new_score)
351            }
352            std::collections::hash_map::Entry::Vacant(entry) => {
353                let mut entry = entry.insert_entry(GroupState {
354                    accumulator: self.collector.create_accumulator(),
355                    count: 0,
356                });
357                let group = entry.get_mut();
358                let retraction = group.accumulator.accumulate(value);
359                group.count += 1;
360                let new_base = entry
361                    .get()
362                    .accumulator
363                    .with_result(|result| weight_fn(entry.key(), result));
364                let new_score = match impact {
365                    ImpactType::Penalty => -new_base,
366                    ImpactType::Reward => new_base,
367                };
368                self.entity_retractions.insert(entity_index, retraction);
369                (Sc::zero(), new_score)
370            }
371        };
372
373        // Track entity -> group mapping and accumulator token for correct retraction.
374        self.entity_groups.insert(entity_index, entity_key);
375
376        // Return delta (both scores computed fresh, no cloning)
377        new_score - old
378    }
379
380    fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
381        // Find which group this entity belonged to
382        let Some(key) = self.entity_groups.remove(&entity_index) else {
383            return Sc::zero();
384        };
385
386        // Use cached retraction token (entity may have been mutated since insert)
387        let Some(retraction) = self.entity_retractions.remove(&entity_index) else {
388            return Sc::zero();
389        };
390        let impact = self.impact_type;
391
392        let weight_fn = &self.weight_fn;
393        let std::collections::hash_map::Entry::Occupied(mut entry) = self.groups.entry(key) else {
394            return Sc::zero();
395        };
396
397        let old_base = entry
398            .get()
399            .accumulator
400            .with_result(|result| weight_fn(entry.key(), result));
401        let old = match impact {
402            ImpactType::Penalty => -old_base,
403            ImpactType::Reward => old_base,
404        };
405
406        let group = entry.get_mut();
407        group.accumulator.retract(retraction);
408        group.count = group.count.saturating_sub(1);
409        let is_empty = group.count == 0;
410        let new_score = if is_empty {
411            entry.remove();
412            Sc::zero()
413        } else {
414            let new_base = entry
415                .get()
416                .accumulator
417                .with_result(|result| weight_fn(entry.key(), result));
418            match impact {
419                ImpactType::Penalty => -new_base,
420                ImpactType::Reward => new_base,
421            }
422        };
423
424        // Return delta (both scores computed fresh, no cloning)
425        new_score - old
426    }
427}
428
429impl<S, A, K, E, Fi, KF, C, W, Sc> std::fmt::Debug
430    for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
431where
432    C: UniCollector<A>,
433    Sc: Score,
434{
435    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
436        f.debug_struct("GroupedUniConstraint")
437            .field("name", &self.constraint_ref.name)
438            .field("impact_type", &self.impact_type)
439            .field("groups", &self.groups.len())
440            .finish()
441    }
442}