Skip to main content

solverforge_scoring/constraint/
grouped.rs

1// Zero-erasure grouped constraint for group-by operations.
2//
3// Provides incremental scoring for constraints that group entities and
4// apply collectors to compute aggregate scores.
5// All type information is preserved at compile time - no Arc, no dyn.
6
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::constraint_set::IncrementalConstraint;
15use crate::stream::collector::{Accumulator, UniCollector};
16use crate::stream::filter::UniFilter;
17
18// Zero-erasure constraint that groups entities by key and scores based on collector results.
19//
20// This enables incremental scoring for group-by operations:
21// - Tracks which entities belong to which group
22// - Maintains collector state per group
23// - Computes score deltas when entities are added/removed
24//
25// All type parameters are concrete - no trait objects, no Arc allocations.
26//
27// # Type Parameters
28//
29// - `S` - Solution type
30// - `A` - Entity type
31// - `K` - Group key type
32// - `E` - Extractor function for entities
33// - `Fi` - Filter type (applied before grouping)
34// - `KF` - Key function
35// - `C` - Collector type
36// - `W` - Weight function
37// - `Sc` - Score type
38//
39// # Example
40//
41// ```
42// use solverforge_scoring::constraint::grouped::GroupedUniConstraint;
43// use solverforge_scoring::stream::collector::count;
44// use solverforge_scoring::stream::filter::TrueFilter;
45// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
46// use solverforge_core::{ConstraintRef, ImpactType};
47// use solverforge_core::score::SoftScore;
48//
49// #[derive(Clone, Hash, PartialEq, Eq)]
50// struct Shift { employee_id: usize }
51//
52// #[derive(Clone)]
53// struct Solution { shifts: Vec<Shift> }
54//
55// // Penalize based on squared workload per employee
56// let constraint = GroupedUniConstraint::new(
57//     ConstraintRef::new("", "Balanced workload"),
58//     ImpactType::Penalty,
59//     |s: &Solution| &s.shifts,
60//     TrueFilter,
61//     |shift: &Shift| shift.employee_id,
62//     count::<Shift>(),
63//     |count: &usize| SoftScore::of((*count * *count) as i64),
64//     false,
65// );
66//
67// let solution = Solution {
68//     shifts: vec![
69//         Shift { employee_id: 1 },
70//         Shift { employee_id: 1 },
71//         Shift { employee_id: 1 },
72//         Shift { employee_id: 2 },
73//     ],
74// };
75//
76// // Employee 1: 3 shifts -> 9 penalty
77// // Employee 2: 1 shift -> 1 penalty
78// // Total: -10
79// assert_eq!(constraint.evaluate(&solution), SoftScore::of(-10));
80// ```
81pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
82where
83    C: UniCollector<A>,
84    Sc: Score,
85{
86    constraint_ref: ConstraintRef,
87    impact_type: ImpactType,
88    extractor: E,
89    filter: Fi,
90    key_fn: KF,
91    collector: C,
92    weight_fn: W,
93    is_hard: bool,
94    expected_descriptor: Option<usize>,
95    // Group key -> accumulator (scores computed on-the-fly, no cloning)
96    groups: HashMap<K, C::Accumulator>,
97    // Group key -> number of entities in the group (for empty-group detection)
98    group_counts: HashMap<K, usize>,
99    // Entity index -> group key (for tracking which group an entity belongs to)
100    entity_groups: HashMap<usize, K>,
101    // Entity index -> extracted value (for correct retraction after entity mutation)
102    entity_values: HashMap<usize, C::Value>,
103    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
104}
105
106impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
107where
108    S: Send + Sync + 'static,
109    A: Clone + Send + Sync + 'static,
110    K: Clone + Eq + Hash + Send + Sync + 'static,
111    E: Fn(&S) -> &[A] + Send + Sync,
112    Fi: UniFilter<S, A>,
113    KF: Fn(&A) -> K + Send + Sync,
114    C: UniCollector<A> + Send + Sync + 'static,
115    C::Accumulator: Send + Sync,
116    C::Result: Send + Sync,
117    W: Fn(&C::Result) -> Sc + Send + Sync,
118    Sc: Score + 'static,
119{
120    // Creates a new zero-erasure grouped constraint.
121    //
122    // # Arguments
123    //
124    // * `constraint_ref` - Identifier for this constraint
125    // * `impact_type` - Whether to penalize or reward
126    // * `extractor` - Function to get entity slice from solution
127    // * `filter` - Filter applied to entities before grouping
128    // * `key_fn` - Function to extract group key from entity
129    // * `collector` - Collector to aggregate entities per group
130    // * `weight_fn` - Function to compute score from collector result
131    // * `is_hard` - Whether this is a hard constraint
132    #[allow(clippy::too_many_arguments)]
133    pub fn new(
134        constraint_ref: ConstraintRef,
135        impact_type: ImpactType,
136        extractor: E,
137        filter: Fi,
138        key_fn: KF,
139        collector: C,
140        weight_fn: W,
141        is_hard: bool,
142    ) -> Self {
143        Self {
144            constraint_ref,
145            impact_type,
146            extractor,
147            filter,
148            key_fn,
149            collector,
150            weight_fn,
151            is_hard,
152            expected_descriptor: None,
153            groups: HashMap::new(),
154            group_counts: HashMap::new(),
155            entity_groups: HashMap::new(),
156            entity_values: HashMap::new(),
157            _phantom: PhantomData,
158        }
159    }
160
161    pub fn with_descriptor(mut self, descriptor_index: usize) -> Self {
162        self.expected_descriptor = Some(descriptor_index);
163        self
164    }
165
166    // Computes the score contribution for a group's result.
167    fn compute_score(&self, result: &C::Result) -> Sc {
168        let base = (self.weight_fn)(result);
169        match self.impact_type {
170            ImpactType::Penalty => -base,
171            ImpactType::Reward => base,
172        }
173    }
174}
175
176impl<S, A, K, E, Fi, KF, C, W, Sc> IncrementalConstraint<S, Sc>
177    for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
178where
179    S: Send + Sync + 'static,
180    A: Clone + Send + Sync + 'static,
181    K: Clone + Eq + Hash + Send + Sync + 'static,
182    E: Fn(&S) -> &[A] + Send + Sync,
183    Fi: UniFilter<S, A>,
184    KF: Fn(&A) -> K + Send + Sync,
185    C: UniCollector<A> + Send + Sync + 'static,
186    C::Accumulator: Send + Sync,
187    C::Result: Send + Sync,
188    C::Value: Send + Sync,
189    W: Fn(&C::Result) -> Sc + Send + Sync,
190    Sc: Score + 'static,
191{
192    fn evaluate(&self, solution: &S) -> Sc {
193        let entities = (self.extractor)(solution);
194
195        // Group entities by key, applying filter
196        let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
197
198        for entity in entities {
199            if !self.filter.test(solution, entity) {
200                continue;
201            }
202            let key = (self.key_fn)(entity);
203            let value = self.collector.extract(entity);
204            let acc = groups
205                .entry(key)
206                .or_insert_with(|| self.collector.create_accumulator());
207            acc.accumulate(&value);
208        }
209
210        // Sum scores for all groups
211        let mut total = Sc::zero();
212        for acc in groups.values() {
213            let result = acc.finish();
214            total = total + self.compute_score(&result);
215        }
216
217        total
218    }
219
220    fn match_count(&self, solution: &S) -> usize {
221        let entities = (self.extractor)(solution);
222
223        // Count unique groups (filtered)
224        let mut groups: HashMap<K, ()> = HashMap::new();
225        for entity in entities {
226            if !self.filter.test(solution, entity) {
227                continue;
228            }
229            let key = (self.key_fn)(entity);
230            groups.insert(key, ());
231        }
232
233        groups.len()
234    }
235
236    fn initialize(&mut self, solution: &S) -> Sc {
237        self.reset();
238
239        let entities = (self.extractor)(solution);
240        let mut total = Sc::zero();
241
242        for (idx, entity) in entities.iter().enumerate() {
243            if !self.filter.test(solution, entity) {
244                continue;
245            }
246            total = total + self.insert_entity(entities, idx, entity);
247        }
248
249        total
250    }
251
252    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
253        if let Some(expected) = self.expected_descriptor {
254            if descriptor_index != expected {
255                return Sc::zero();
256            }
257        }
258        let entities = (self.extractor)(solution);
259        if entity_index >= entities.len() {
260            return Sc::zero();
261        }
262
263        let entity = &entities[entity_index];
264        if !self.filter.test(solution, entity) {
265            return Sc::zero();
266        }
267        self.insert_entity(entities, entity_index, entity)
268    }
269
270    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
271        if let Some(expected) = self.expected_descriptor {
272            if descriptor_index != expected {
273                return Sc::zero();
274            }
275        }
276        let entities = (self.extractor)(solution);
277        self.retract_entity(entities, entity_index)
278    }
279
280    fn reset(&mut self) {
281        self.groups.clear();
282        self.group_counts.clear();
283        self.entity_groups.clear();
284        self.entity_values.clear();
285    }
286
287    fn name(&self) -> &str {
288        &self.constraint_ref.name
289    }
290
291    fn is_hard(&self) -> bool {
292        self.is_hard
293    }
294
295    fn constraint_ref(&self) -> ConstraintRef {
296        self.constraint_ref.clone()
297    }
298}
299
300impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
301where
302    S: Send + Sync + 'static,
303    A: Clone + Send + Sync + 'static,
304    K: Clone + Eq + Hash + Send + Sync + 'static,
305    E: Fn(&S) -> &[A] + Send + Sync,
306    Fi: UniFilter<S, A>,
307    KF: Fn(&A) -> K + Send + Sync,
308    C: UniCollector<A> + Send + Sync + 'static,
309    C::Accumulator: Send + Sync,
310    C::Result: Send + Sync,
311    C::Value: Send + Sync,
312    W: Fn(&C::Result) -> Sc + Send + Sync,
313    Sc: Score + 'static,
314{
315    fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
316        let key = (self.key_fn)(entity);
317        let value = self.collector.extract(entity);
318        let impact = self.impact_type;
319
320        // Get or create group accumulator
321        let is_new = !self.groups.contains_key(&key);
322        let acc = self
323            .groups
324            .entry(key.clone())
325            .or_insert_with(|| self.collector.create_accumulator());
326
327        // Old score is zero for new groups (they didn't exist before, contributing nothing)
328        let old = if is_new {
329            Sc::zero()
330        } else {
331            let old_base = (self.weight_fn)(&acc.finish());
332            match impact {
333                ImpactType::Penalty => -old_base,
334                ImpactType::Reward => old_base,
335            }
336        };
337
338        // Accumulate and compute new score
339        acc.accumulate(&value);
340        let new_base = (self.weight_fn)(&acc.finish());
341        let new_score = match impact {
342            ImpactType::Penalty => -new_base,
343            ImpactType::Reward => new_base,
344        };
345
346        // Track entity -> group mapping and cache value for correct retraction
347        self.entity_groups.insert(entity_index, key.clone());
348        self.entity_values.insert(entity_index, value);
349        *self.group_counts.entry(key).or_insert(0) += 1;
350
351        // Return delta (both scores computed fresh, no cloning)
352        new_score - old
353    }
354
355    fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
356        // Find which group this entity belonged to
357        let Some(key) = self.entity_groups.remove(&entity_index) else {
358            return Sc::zero();
359        };
360
361        // Use cached value (entity may have been mutated since insert)
362        let Some(value) = self.entity_values.remove(&entity_index) else {
363            return Sc::zero();
364        };
365        let impact = self.impact_type;
366
367        // Get the group accumulator
368        let Some(acc) = self.groups.get_mut(&key) else {
369            return Sc::zero();
370        };
371
372        // Compute old score from current state (inlined to avoid borrow conflict)
373        let old_base = (self.weight_fn)(&acc.finish());
374        let old = match impact {
375            ImpactType::Penalty => -old_base,
376            ImpactType::Reward => old_base,
377        };
378
379        // Decrement group count; remove group if now empty
380        let is_empty = {
381            let cnt = self.group_counts.entry(key.clone()).or_insert(0);
382            *cnt = cnt.saturating_sub(1);
383            *cnt == 0
384        };
385        if is_empty {
386            self.group_counts.remove(&key);
387        }
388
389        // Retract and compute new score
390        acc.retract(&value);
391        let new_score = if is_empty {
392            // Group is now empty; remove it and treat its contribution as zero
393            self.groups.remove(&key);
394            Sc::zero()
395        } else {
396            let new_base = (self.weight_fn)(&acc.finish());
397            match impact {
398                ImpactType::Penalty => -new_base,
399                ImpactType::Reward => new_base,
400            }
401        };
402
403        // Return delta (both scores computed fresh, no cloning)
404        new_score - old
405    }
406}
407
408impl<S, A, K, E, Fi, KF, C, W, Sc> std::fmt::Debug
409    for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
410where
411    C: UniCollector<A>,
412    Sc: Score,
413{
414    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415        f.debug_struct("GroupedUniConstraint")
416            .field("name", &self.constraint_ref.name)
417            .field("impact_type", &self.impact_type)
418            .field("groups", &self.groups.len())
419            .finish()
420    }
421}