Skip to main content

solverforge_scoring/constraint/
grouped.rs

1/* Zero-erasure grouped constraint for group-by operations.
2
3Provides incremental scoring for constraints that group entities and
4apply collectors to compute aggregate scores.
5All type information is preserved at compile time - no Arc, no dyn.
6*/
7
8use std::collections::HashMap;
9use std::hash::Hash;
10use std::marker::PhantomData;
11
12use solverforge_core::score::Score;
13use solverforge_core::{ConstraintRef, ImpactType};
14
15use crate::api::constraint_set::IncrementalConstraint;
16use crate::stream::collector::{Accumulator, UniCollector};
17use crate::stream::filter::UniFilter;
18
19/* Zero-erasure constraint that groups entities by key and scores based on collector results.
20
21This enables incremental scoring for group-by operations:
22- Tracks which entities belong to which group
23- Maintains collector state per group
24- Computes score deltas when entities are added/removed
25
26All type parameters are concrete - no trait objects, no Arc allocations.
27
28# Type Parameters
29
30- `S` - Solution type
31- `A` - Entity type
32- `K` - Group key type
33- `E` - Extractor function for entities
34- `Fi` - Filter type (applied before grouping)
35- `KF` - Key function
36- `C` - Collector type
37- `W` - Weight function
38- `Sc` - Score type
39
40# Example
41
42```
43use solverforge_scoring::constraint::grouped::GroupedUniConstraint;
44use solverforge_scoring::stream::collector::count;
45use solverforge_scoring::stream::filter::TrueFilter;
46use solverforge_scoring::api::constraint_set::IncrementalConstraint;
47use solverforge_core::{ConstraintRef, ImpactType};
48use solverforge_core::score::SoftScore;
49
50#[derive(Clone, Hash, PartialEq, Eq)]
51struct Shift { employee_id: usize }
52
53#[derive(Clone)]
54struct Solution { shifts: Vec<Shift> }
55
56// Penalize based on squared workload per employee
57let constraint = GroupedUniConstraint::new(
58ConstraintRef::new("", "Balanced workload"),
59ImpactType::Penalty,
60|s: &Solution| &s.shifts,
61TrueFilter,
62|shift: &Shift| shift.employee_id,
63count::<Shift>(),
64|count: &usize| SoftScore::of((*count * *count) as i64),
65false,
66);
67
68let solution = Solution {
69shifts: vec![
70Shift { employee_id: 1 },
71Shift { employee_id: 1 },
72Shift { employee_id: 1 },
73Shift { employee_id: 2 },
74],
75};
76
77// Employee 1: 3 shifts -> 9 penalty
78// Employee 2: 1 shift -> 1 penalty
79// Total: -10
80assert_eq!(constraint.evaluate(&solution), SoftScore::of(-10));
81```
82*/
83pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
84where
85    C: UniCollector<A>,
86    Sc: Score,
87{
88    constraint_ref: ConstraintRef,
89    impact_type: ImpactType,
90    extractor: E,
91    filter: Fi,
92    key_fn: KF,
93    collector: C,
94    weight_fn: W,
95    is_hard: bool,
96    expected_descriptor: Option<usize>,
97    // Group key -> accumulator (scores computed on-the-fly, no cloning)
98    groups: HashMap<K, C::Accumulator>,
99    // Group key -> number of entities in the group (for empty-group detection)
100    group_counts: HashMap<K, usize>,
101    // Entity index -> group key (for tracking which group an entity belongs to)
102    entity_groups: HashMap<usize, K>,
103    // Entity index -> extracted value (for correct retraction after entity mutation)
104    entity_values: HashMap<usize, C::Value>,
105    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
106}
107
108impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
109where
110    S: Send + Sync + 'static,
111    A: Clone + Send + Sync + 'static,
112    K: Clone + Eq + Hash + Send + Sync + 'static,
113    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
114    Fi: UniFilter<S, A>,
115    KF: Fn(&A) -> K + Send + Sync,
116    C: UniCollector<A> + Send + Sync + 'static,
117    C::Accumulator: Send + Sync,
118    C::Result: Send + Sync,
119    W: Fn(&C::Result) -> Sc + Send + Sync,
120    Sc: Score + 'static,
121{
122    /* Creates a new zero-erasure grouped constraint.
123
124    # Arguments
125
126    * `constraint_ref` - Identifier for this constraint
127    * `impact_type` - Whether to penalize or reward
128    * `extractor` - Function to get entity slice from solution
129    * `filter` - Filter applied to entities before grouping
130    * `key_fn` - Function to extract group key from entity
131    * `collector` - Collector to aggregate entities per group
132    * `weight_fn` - Function to compute score from collector result
133    * `is_hard` - Whether this is a hard constraint
134    */
135    #[allow(clippy::too_many_arguments)]
136    pub fn new(
137        constraint_ref: ConstraintRef,
138        impact_type: ImpactType,
139        extractor: E,
140        filter: Fi,
141        key_fn: KF,
142        collector: C,
143        weight_fn: W,
144        is_hard: bool,
145    ) -> Self {
146        Self {
147            constraint_ref,
148            impact_type,
149            extractor,
150            filter,
151            key_fn,
152            collector,
153            weight_fn,
154            is_hard,
155            expected_descriptor: None,
156            groups: HashMap::new(),
157            group_counts: HashMap::new(),
158            entity_groups: HashMap::new(),
159            entity_values: HashMap::new(),
160            _phantom: PhantomData,
161        }
162    }
163
164    pub fn with_descriptor(mut self, descriptor_index: usize) -> Self {
165        self.expected_descriptor = Some(descriptor_index);
166        self
167    }
168
169    // Computes the score contribution for a group's result.
170    fn compute_score(&self, result: &C::Result) -> Sc {
171        let base = (self.weight_fn)(result);
172        match self.impact_type {
173            ImpactType::Penalty => -base,
174            ImpactType::Reward => base,
175        }
176    }
177}
178
179impl<S, A, K, E, Fi, KF, C, W, Sc> IncrementalConstraint<S, Sc>
180    for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
181where
182    S: Send + Sync + 'static,
183    A: Clone + Send + Sync + 'static,
184    K: Clone + Eq + Hash + Send + Sync + 'static,
185    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
186    Fi: UniFilter<S, A>,
187    KF: Fn(&A) -> K + Send + Sync,
188    C: UniCollector<A> + Send + Sync + 'static,
189    C::Accumulator: Send + Sync,
190    C::Result: Send + Sync,
191    C::Value: Send + Sync,
192    W: Fn(&C::Result) -> Sc + Send + Sync,
193    Sc: Score + 'static,
194{
195    fn evaluate(&self, solution: &S) -> Sc {
196        let entities = self.extractor.extract(solution);
197
198        // Group entities by key, applying filter
199        let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
200
201        for entity in entities {
202            if !self.filter.test(solution, entity) {
203                continue;
204            }
205            let key = (self.key_fn)(entity);
206            let value = self.collector.extract(entity);
207            let acc = groups
208                .entry(key)
209                .or_insert_with(|| self.collector.create_accumulator());
210            acc.accumulate(&value);
211        }
212
213        // Sum scores for all groups
214        let mut total = Sc::zero();
215        for acc in groups.values() {
216            let result = acc.finish();
217            total = total + self.compute_score(&result);
218        }
219
220        total
221    }
222
223    fn match_count(&self, solution: &S) -> usize {
224        let entities = self.extractor.extract(solution);
225
226        // Count unique groups (filtered)
227        let mut groups: HashMap<K, ()> = HashMap::new();
228        for entity in entities {
229            if !self.filter.test(solution, entity) {
230                continue;
231            }
232            let key = (self.key_fn)(entity);
233            groups.insert(key, ());
234        }
235
236        groups.len()
237    }
238
239    fn initialize(&mut self, solution: &S) -> Sc {
240        self.reset();
241
242        let entities = self.extractor.extract(solution);
243        let mut total = Sc::zero();
244
245        for (idx, entity) in entities.iter().enumerate() {
246            if !self.filter.test(solution, entity) {
247                continue;
248            }
249            total = total + self.insert_entity(entities, idx, entity);
250        }
251
252        total
253    }
254
255    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
256        if let Some(expected) = self.expected_descriptor {
257            if descriptor_index != expected {
258                return Sc::zero();
259            }
260        }
261        let entities = self.extractor.extract(solution);
262        if entity_index >= entities.len() {
263            return Sc::zero();
264        }
265
266        let entity = &entities[entity_index];
267        if !self.filter.test(solution, entity) {
268            return Sc::zero();
269        }
270        self.insert_entity(entities, entity_index, entity)
271    }
272
273    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
274        if let Some(expected) = self.expected_descriptor {
275            if descriptor_index != expected {
276                return Sc::zero();
277            }
278        }
279        let entities = self.extractor.extract(solution);
280        self.retract_entity(entities, entity_index)
281    }
282
283    fn reset(&mut self) {
284        self.groups.clear();
285        self.group_counts.clear();
286        self.entity_groups.clear();
287        self.entity_values.clear();
288    }
289
290    fn name(&self) -> &str {
291        &self.constraint_ref.name
292    }
293
294    fn is_hard(&self) -> bool {
295        self.is_hard
296    }
297
298    fn constraint_ref(&self) -> ConstraintRef {
299        self.constraint_ref.clone()
300    }
301}
302
303impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
304where
305    S: Send + Sync + 'static,
306    A: Clone + Send + Sync + 'static,
307    K: Clone + Eq + Hash + Send + Sync + 'static,
308    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
309    Fi: UniFilter<S, A>,
310    KF: Fn(&A) -> K + Send + Sync,
311    C: UniCollector<A> + Send + Sync + 'static,
312    C::Accumulator: Send + Sync,
313    C::Result: Send + Sync,
314    C::Value: Send + Sync,
315    W: Fn(&C::Result) -> Sc + Send + Sync,
316    Sc: Score + 'static,
317{
318    fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
319        let key = (self.key_fn)(entity);
320        let value = self.collector.extract(entity);
321        let impact = self.impact_type;
322
323        // Get or create group accumulator
324        let is_new = !self.groups.contains_key(&key);
325        let acc = self
326            .groups
327            .entry(key.clone())
328            .or_insert_with(|| self.collector.create_accumulator());
329
330        // Old score is zero for new groups (they didn't exist before, contributing nothing)
331        let old = if is_new {
332            Sc::zero()
333        } else {
334            let old_base = (self.weight_fn)(&acc.finish());
335            match impact {
336                ImpactType::Penalty => -old_base,
337                ImpactType::Reward => old_base,
338            }
339        };
340
341        // Accumulate and compute new score
342        acc.accumulate(&value);
343        let new_base = (self.weight_fn)(&acc.finish());
344        let new_score = match impact {
345            ImpactType::Penalty => -new_base,
346            ImpactType::Reward => new_base,
347        };
348
349        // Track entity -> group mapping and cache value for correct retraction
350        self.entity_groups.insert(entity_index, key.clone());
351        self.entity_values.insert(entity_index, value);
352        *self.group_counts.entry(key).or_insert(0) += 1;
353
354        // Return delta (both scores computed fresh, no cloning)
355        new_score - old
356    }
357
358    fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
359        // Find which group this entity belonged to
360        let Some(key) = self.entity_groups.remove(&entity_index) else {
361            return Sc::zero();
362        };
363
364        // Use cached value (entity may have been mutated since insert)
365        let Some(value) = self.entity_values.remove(&entity_index) else {
366            return Sc::zero();
367        };
368        let impact = self.impact_type;
369
370        // Get the group accumulator
371        let Some(acc) = self.groups.get_mut(&key) else {
372            return Sc::zero();
373        };
374
375        // Compute old score from current state (inlined to avoid borrow conflict)
376        let old_base = (self.weight_fn)(&acc.finish());
377        let old = match impact {
378            ImpactType::Penalty => -old_base,
379            ImpactType::Reward => old_base,
380        };
381
382        // Decrement group count; remove group if now empty
383        let is_empty = {
384            let cnt = self.group_counts.entry(key.clone()).or_insert(0);
385            *cnt = cnt.saturating_sub(1);
386            *cnt == 0
387        };
388        if is_empty {
389            self.group_counts.remove(&key);
390        }
391
392        // Retract and compute new score
393        acc.retract(&value);
394        let new_score = if is_empty {
395            // Group is now empty; remove it and treat its contribution as zero
396            self.groups.remove(&key);
397            Sc::zero()
398        } else {
399            let new_base = (self.weight_fn)(&acc.finish());
400            match impact {
401                ImpactType::Penalty => -new_base,
402                ImpactType::Reward => new_base,
403            }
404        };
405
406        // Return delta (both scores computed fresh, no cloning)
407        new_score - old
408    }
409}
410
411impl<S, A, K, E, Fi, KF, C, W, Sc> std::fmt::Debug
412    for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
413where
414    C: UniCollector<A>,
415    Sc: Score,
416{
417    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418        f.debug_struct("GroupedUniConstraint")
419            .field("name", &self.constraint_ref.name)
420            .field("impact_type", &self.impact_type)
421            .field("groups", &self.groups.len())
422            .finish()
423    }
424}