Skip to main content

solverforge_scoring/constraint/
grouped.rs

1/* Zero-erasure grouped constraint for group-by operations.
2
3Provides incremental scoring for constraints that group entities and
4apply collectors to compute aggregate scores.
5All type information is preserved at compile time - no Arc, no dyn.
6*/
7
8use std::collections::HashMap;
9use std::hash::Hash;
10use std::marker::PhantomData;
11
12use solverforge_core::score::Score;
13use solverforge_core::{ConstraintRef, ImpactType};
14
15use crate::api::constraint_set::IncrementalConstraint;
16use crate::stream::collector::{Accumulator, Collector};
17use crate::stream::filter::UniFilter;
18
19struct GroupState<Acc> {
20    accumulator: Acc,
21    count: usize,
22}
23
24type CollectorRetraction<Acc, V, R> = <Acc as Accumulator<V, R>>::Retraction;
25
26/* Zero-erasure constraint that groups entities by key and scores based on collector results.
27
28This enables incremental scoring for group-by operations:
29- Tracks which entities belong to which group
30- Maintains collector state per group
31- Computes score deltas when entities are added/removed
32
33All type parameters are concrete - no trait objects, no Arc allocations.
34
35# Type Parameters
36
37- `S` - Solution type
38- `A` - Entity type
39- `K` - Group key type
40- `E` - Extractor function for entities
41- `Fi` - Filter type (applied before grouping)
42- `KF` - Key function
43- `C` - Collector type
44- `W` - Weight function
45- `Sc` - Score type
46
47# Example
48
49```
50use solverforge_scoring::constraint::grouped::GroupedUniConstraint;
51use solverforge_scoring::stream::collector::count;
52use solverforge_scoring::stream::filter::TrueFilter;
53use solverforge_scoring::api::constraint_set::IncrementalConstraint;
54use solverforge_core::{ConstraintRef, ImpactType};
55use solverforge_core::score::SoftScore;
56
57#[derive(Clone, Hash, PartialEq, Eq)]
58struct Shift { employee_id: usize }
59
60#[derive(Clone)]
61struct Solution { shifts: Vec<Shift> }
62
63// Penalize based on squared workload per employee
64let constraint = GroupedUniConstraint::new(
65ConstraintRef::new("", "Balanced workload"),
66ImpactType::Penalty,
67|s: &Solution| &s.shifts,
68TrueFilter,
69|shift: &Shift| shift.employee_id,
70    count(),
71|_employee_id: &usize, count: &usize| SoftScore::of((*count * *count) as i64),
72false,
73);
74
75let solution = Solution {
76shifts: vec![
77Shift { employee_id: 1 },
78Shift { employee_id: 1 },
79Shift { employee_id: 1 },
80Shift { employee_id: 2 },
81],
82};
83
84// Employee 1: 3 shifts -> 9 penalty
85// Employee 2: 1 shift -> 1 penalty
86// Total: -10
87assert_eq!(constraint.evaluate(&solution), SoftScore::of(-10));
88```
89*/
90pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
91where
92    Acc: Accumulator<V, R>,
93    Sc: Score,
94{
95    constraint_ref: ConstraintRef,
96    impact_type: ImpactType,
97    extractor: E,
98    filter: Fi,
99    key_fn: KF,
100    collector: C,
101    weight_fn: W,
102    is_hard: bool,
103    change_source: crate::stream::collection_extract::ChangeSource,
104    // Group key -> accumulator plus count (scores computed on-the-fly, no cloning)
105    groups: HashMap<K, GroupState<Acc>>,
106    // Entity index -> group key (for tracking which group an entity belongs to)
107    entity_groups: HashMap<usize, K>,
108    // Entity index -> accumulator retraction token
109    entity_retractions: HashMap<usize, CollectorRetraction<Acc, V, R>>,
110    _phantom: PhantomData<(
111        fn() -> S,
112        fn() -> A,
113        fn() -> V,
114        fn() -> R,
115        fn() -> Acc,
116        fn() -> Sc,
117    )>,
118}
119
120impl<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
121    GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
122where
123    S: Send + Sync + 'static,
124    A: Clone + Send + Sync + 'static,
125    K: Clone + Eq + Hash + Send + Sync + 'static,
126    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
127    Fi: UniFilter<S, A>,
128    KF: Fn(&A) -> K + Send + Sync,
129    C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
130    V: Send + Sync + 'static,
131    R: Send + Sync + 'static,
132    Acc: Accumulator<V, R> + Send + Sync + 'static,
133    W: Fn(&K, &R) -> Sc + Send + Sync,
134    Sc: Score + 'static,
135{
136    /* Creates a new zero-erasure grouped constraint.
137
138    # Arguments
139
140    * `constraint_ref` - Identifier for this constraint
141    * `impact_type` - Whether to penalize or reward
142    * `extractor` - Function to get entity slice from solution
143    * `filter` - Filter applied to entities before grouping
144    * `key_fn` - Function to extract group key from entity
145    * `collector` - Collector to aggregate entities per group
146    * `weight_fn` - Function to compute score from collector result
147    * `is_hard` - Whether this is a hard constraint
148    */
149    #[allow(clippy::too_many_arguments)]
150    pub fn new(
151        constraint_ref: ConstraintRef,
152        impact_type: ImpactType,
153        extractor: E,
154        filter: Fi,
155        key_fn: KF,
156        collector: C,
157        weight_fn: W,
158        is_hard: bool,
159    ) -> Self {
160        let change_source = extractor.change_source();
161        Self {
162            constraint_ref,
163            impact_type,
164            extractor,
165            filter,
166            key_fn,
167            collector,
168            weight_fn,
169            is_hard,
170            change_source,
171            groups: HashMap::new(),
172            entity_groups: HashMap::new(),
173            entity_retractions: HashMap::new(),
174            _phantom: PhantomData,
175        }
176    }
177
178    // Computes the score contribution for a group's result.
179    fn compute_score(&self, key: &K, result: &R) -> Sc {
180        let base = (self.weight_fn)(key, result);
181        match self.impact_type {
182            ImpactType::Penalty => -base,
183            ImpactType::Reward => base,
184        }
185    }
186}
187
188impl<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc> IncrementalConstraint<S, Sc>
189    for GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
190where
191    S: Send + Sync + 'static,
192    A: Clone + Send + Sync + 'static,
193    K: Clone + Eq + Hash + Send + Sync + 'static,
194    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
195    Fi: UniFilter<S, A>,
196    KF: Fn(&A) -> K + Send + Sync,
197    C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
198    V: Send + Sync + 'static,
199    R: Send + Sync + 'static,
200    Acc: Accumulator<V, R> + Send + Sync + 'static,
201    W: Fn(&K, &R) -> Sc + Send + Sync,
202    Sc: Score + 'static,
203{
204    fn evaluate(&self, solution: &S) -> Sc {
205        let entities = self.extractor.extract(solution);
206
207        // Group entities by key, applying filter
208        let mut groups: HashMap<K, Acc> = HashMap::new();
209
210        for entity in entities {
211            if !self.filter.test(solution, entity) {
212                continue;
213            }
214            let key = (self.key_fn)(entity);
215            let value = self.collector.extract(entity);
216            let acc = groups
217                .entry(key)
218                .or_insert_with(|| self.collector.create_accumulator());
219            acc.accumulate(value);
220        }
221
222        // Sum scores for all groups
223        let mut total = Sc::zero();
224        for (key, acc) in &groups {
225            total = total + acc.with_result(|result| self.compute_score(key, result));
226        }
227
228        total
229    }
230
231    fn match_count(&self, solution: &S) -> usize {
232        let entities = self.extractor.extract(solution);
233
234        // Count unique groups (filtered)
235        let mut groups: HashMap<K, ()> = HashMap::new();
236        for entity in entities {
237            if !self.filter.test(solution, entity) {
238                continue;
239            }
240            let key = (self.key_fn)(entity);
241            groups.insert(key, ());
242        }
243
244        groups.len()
245    }
246
247    fn initialize(&mut self, solution: &S) -> Sc {
248        self.reset();
249
250        let entities = self.extractor.extract(solution);
251        let mut total = Sc::zero();
252
253        for (idx, entity) in entities.iter().enumerate() {
254            if !self.filter.test(solution, entity) {
255                continue;
256            }
257            total = total + self.insert_entity(entities, idx, entity);
258        }
259
260        total
261    }
262
263    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
264        if !self
265            .change_source
266            .assert_localizes(descriptor_index, &self.constraint_ref.name)
267        {
268            return Sc::zero();
269        }
270        let entities = self.extractor.extract(solution);
271        if entity_index >= entities.len() {
272            return Sc::zero();
273        }
274
275        let entity = &entities[entity_index];
276        if !self.filter.test(solution, entity) {
277            return Sc::zero();
278        }
279        self.insert_entity(entities, entity_index, entity)
280    }
281
282    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
283        if !self
284            .change_source
285            .assert_localizes(descriptor_index, &self.constraint_ref.name)
286        {
287            return Sc::zero();
288        }
289        let entities = self.extractor.extract(solution);
290        self.retract_entity(entities, entity_index)
291    }
292
293    fn reset(&mut self) {
294        self.groups.clear();
295        self.entity_groups.clear();
296        self.entity_retractions.clear();
297    }
298
299    fn name(&self) -> &str {
300        &self.constraint_ref.name
301    }
302
303    fn is_hard(&self) -> bool {
304        self.is_hard
305    }
306
307    fn constraint_ref(&self) -> &ConstraintRef {
308        &self.constraint_ref
309    }
310}
311
312impl<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
313    GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
314where
315    S: Send + Sync + 'static,
316    A: Clone + Send + Sync + 'static,
317    K: Clone + Eq + Hash + Send + Sync + 'static,
318    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
319    Fi: UniFilter<S, A>,
320    KF: Fn(&A) -> K + Send + Sync,
321    C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
322    V: Send + Sync + 'static,
323    R: Send + Sync + 'static,
324    Acc: Accumulator<V, R> + Send + Sync + 'static,
325    W: Fn(&K, &R) -> Sc + Send + Sync,
326    Sc: Score + 'static,
327{
328    fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
329        let key = (self.key_fn)(entity);
330        let entity_key = key.clone();
331        let value = self.collector.extract(entity);
332        let impact = self.impact_type;
333
334        let weight_fn = &self.weight_fn;
335        let (old, new_score) = match self.groups.entry(key) {
336            std::collections::hash_map::Entry::Occupied(mut entry) => {
337                let old_base = entry
338                    .get()
339                    .accumulator
340                    .with_result(|result| weight_fn(entry.key(), result));
341                let old = match impact {
342                    ImpactType::Penalty => -old_base,
343                    ImpactType::Reward => old_base,
344                };
345                let group = entry.get_mut();
346                let retraction = group.accumulator.accumulate(value);
347                group.count += 1;
348                let new_base = entry
349                    .get()
350                    .accumulator
351                    .with_result(|result| weight_fn(entry.key(), result));
352                let new_score = match impact {
353                    ImpactType::Penalty => -new_base,
354                    ImpactType::Reward => new_base,
355                };
356                self.entity_retractions.insert(entity_index, retraction);
357                (old, new_score)
358            }
359            std::collections::hash_map::Entry::Vacant(entry) => {
360                let mut entry = entry.insert_entry(GroupState {
361                    accumulator: self.collector.create_accumulator(),
362                    count: 0,
363                });
364                let group = entry.get_mut();
365                let retraction = group.accumulator.accumulate(value);
366                group.count += 1;
367                let new_base = entry
368                    .get()
369                    .accumulator
370                    .with_result(|result| weight_fn(entry.key(), result));
371                let new_score = match impact {
372                    ImpactType::Penalty => -new_base,
373                    ImpactType::Reward => new_base,
374                };
375                self.entity_retractions.insert(entity_index, retraction);
376                (Sc::zero(), new_score)
377            }
378        };
379
380        // Track entity -> group mapping and accumulator token for correct retraction.
381        self.entity_groups.insert(entity_index, entity_key);
382
383        // Return delta (both scores computed fresh, no cloning)
384        new_score - old
385    }
386
387    fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
388        // Find which group this entity belonged to
389        let Some(key) = self.entity_groups.remove(&entity_index) else {
390            return Sc::zero();
391        };
392
393        // Use cached retraction token (entity may have been mutated since insert)
394        let Some(retraction) = self.entity_retractions.remove(&entity_index) else {
395            return Sc::zero();
396        };
397        let impact = self.impact_type;
398
399        let weight_fn = &self.weight_fn;
400        let std::collections::hash_map::Entry::Occupied(mut entry) = self.groups.entry(key) else {
401            return Sc::zero();
402        };
403
404        let old_base = entry
405            .get()
406            .accumulator
407            .with_result(|result| weight_fn(entry.key(), result));
408        let old = match impact {
409            ImpactType::Penalty => -old_base,
410            ImpactType::Reward => old_base,
411        };
412
413        let group = entry.get_mut();
414        group.accumulator.retract(retraction);
415        group.count = group.count.saturating_sub(1);
416        let is_empty = group.count == 0;
417        let new_score = if is_empty {
418            entry.remove();
419            Sc::zero()
420        } else {
421            let new_base = entry
422                .get()
423                .accumulator
424                .with_result(|result| weight_fn(entry.key(), result));
425            match impact {
426                ImpactType::Penalty => -new_base,
427                ImpactType::Reward => new_base,
428            }
429        };
430
431        // Return delta (both scores computed fresh, no cloning)
432        new_score - old
433    }
434}
435
436impl<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc> std::fmt::Debug
437    for GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
438where
439    Acc: Accumulator<V, R>,
440    Sc: Score,
441{
442    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443        f.debug_struct("GroupedUniConstraint")
444            .field("name", &self.constraint_ref.name)
445            .field("impact_type", &self.impact_type)
446            .field("groups", &self.groups.len())
447            .finish()
448    }
449}