Skip to main content

solverforge_scoring/constraint/
grouped.rs

1// Zero-erasure grouped constraint for group-by operations.
2//
3// Provides incremental scoring for constraints that group entities and
4// apply collectors to compute aggregate scores.
5// All type information is preserved at compile time - no Arc, no dyn.
6
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::constraint_set::IncrementalConstraint;
15use crate::stream::collector::{Accumulator, UniCollector};
16
17// Zero-erasure constraint that groups entities by key and scores based on collector results.
18//
19// This enables incremental scoring for group-by operations:
20// - Tracks which entities belong to which group
21// - Maintains collector state per group
22// - Computes score deltas when entities are added/removed
23//
24// All type parameters are concrete - no trait objects, no Arc allocations.
25//
26// # Type Parameters
27//
28// - `S` - Solution type
29// - `A` - Entity type
30// - `K` - Group key type
31// - `E` - Extractor function for entities
32// - `KF` - Key function
33// - `C` - Collector type
34// - `W` - Weight function
35// - `Sc` - Score type
36//
37// # Example
38//
39// ```
40// use solverforge_scoring::constraint::grouped::GroupedUniConstraint;
41// use solverforge_scoring::stream::collector::count;
42// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
43// use solverforge_core::{ConstraintRef, ImpactType};
44// use solverforge_core::score::SimpleScore;
45//
46// #[derive(Clone, Hash, PartialEq, Eq)]
47// struct Shift { employee_id: usize }
48//
49// #[derive(Clone)]
50// struct Solution { shifts: Vec<Shift> }
51//
52// // Penalize based on squared workload per employee
53// let constraint = GroupedUniConstraint::new(
54//     ConstraintRef::new("", "Balanced workload"),
55//     ImpactType::Penalty,
56//     |s: &Solution| &s.shifts,
57//     |shift: &Shift| shift.employee_id,
58//     count::<Shift>(),
59//     |count: &usize| SimpleScore::of((*count * *count) as i64),
60//     false,
61// );
62//
63// let solution = Solution {
64//     shifts: vec![
65//         Shift { employee_id: 1 },
66//         Shift { employee_id: 1 },
67//         Shift { employee_id: 1 },
68//         Shift { employee_id: 2 },
69//     ],
70// };
71//
72// // Employee 1: 3 shifts -> 9 penalty
73// // Employee 2: 1 shift -> 1 penalty
74// // Total: -10
75// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-10));
76// ```
77pub struct GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
78where
79    C: UniCollector<A>,
80    Sc: Score,
81{
82    constraint_ref: ConstraintRef,
83    impact_type: ImpactType,
84    extractor: E,
85    key_fn: KF,
86    collector: C,
87    weight_fn: W,
88    is_hard: bool,
89    expected_descriptor: Option<usize>,
90    // Group key -> accumulator (scores computed on-the-fly, no cloning)
91    groups: HashMap<K, C::Accumulator>,
92    // Entity index -> group key (for tracking which group an entity belongs to)
93    entity_groups: HashMap<usize, K>,
94    // Entity index -> extracted value (for correct retraction after entity mutation)
95    entity_values: HashMap<usize, C::Value>,
96    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
97}
98
99impl<S, A, K, E, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
100where
101    S: Send + Sync + 'static,
102    A: Clone + Send + Sync + 'static,
103    K: Clone + Eq + Hash + Send + Sync + 'static,
104    E: Fn(&S) -> &[A] + Send + Sync,
105    KF: Fn(&A) -> K + Send + Sync,
106    C: UniCollector<A> + Send + Sync + 'static,
107    C::Accumulator: Send + Sync,
108    C::Result: Send + Sync,
109    W: Fn(&C::Result) -> Sc + Send + Sync,
110    Sc: Score + 'static,
111{
112    // Creates a new zero-erasure grouped constraint.
113    //
114    // # Arguments
115    //
116    // * `constraint_ref` - Identifier for this constraint
117    // * `impact_type` - Whether to penalize or reward
118    // * `extractor` - Function to get entity slice from solution
119    // * `key_fn` - Function to extract group key from entity
120    // * `collector` - Collector to aggregate entities per group
121    // * `weight_fn` - Function to compute score from collector result
122    // * `is_hard` - Whether this is a hard constraint
123    pub fn new(
124        constraint_ref: ConstraintRef,
125        impact_type: ImpactType,
126        extractor: E,
127        key_fn: KF,
128        collector: C,
129        weight_fn: W,
130        is_hard: bool,
131    ) -> Self {
132        Self {
133            constraint_ref,
134            impact_type,
135            extractor,
136            key_fn,
137            collector,
138            weight_fn,
139            is_hard,
140            expected_descriptor: None,
141            groups: HashMap::new(),
142            entity_groups: HashMap::new(),
143            entity_values: HashMap::new(),
144            _phantom: PhantomData,
145        }
146    }
147
148    pub fn with_descriptor(mut self, descriptor_index: usize) -> Self {
149        self.expected_descriptor = Some(descriptor_index);
150        self
151    }
152
153    // Computes the score contribution for a group's result.
154    fn compute_score(&self, result: &C::Result) -> Sc {
155        let base = (self.weight_fn)(result);
156        match self.impact_type {
157            ImpactType::Penalty => -base,
158            ImpactType::Reward => base,
159        }
160    }
161}
162
163impl<S, A, K, E, KF, C, W, Sc> IncrementalConstraint<S, Sc>
164    for GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
165where
166    S: Send + Sync + 'static,
167    A: Clone + Send + Sync + 'static,
168    K: Clone + Eq + Hash + Send + Sync + 'static,
169    E: Fn(&S) -> &[A] + Send + Sync,
170    KF: Fn(&A) -> K + Send + Sync,
171    C: UniCollector<A> + Send + Sync + 'static,
172    C::Accumulator: Send + Sync,
173    C::Result: Send + Sync,
174    C::Value: Send + Sync,
175    W: Fn(&C::Result) -> Sc + Send + Sync,
176    Sc: Score + 'static,
177{
178    fn evaluate(&self, solution: &S) -> Sc {
179        let entities = (self.extractor)(solution);
180
181        // Group entities by key
182        let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
183
184        for entity in entities {
185            let key = (self.key_fn)(entity);
186            let value = self.collector.extract(entity);
187            let acc = groups
188                .entry(key)
189                .or_insert_with(|| self.collector.create_accumulator());
190            acc.accumulate(&value);
191        }
192
193        // Sum scores for all groups
194        let mut total = Sc::zero();
195        for acc in groups.values() {
196            let result = acc.finish();
197            total = total + self.compute_score(&result);
198        }
199
200        total
201    }
202
203    fn match_count(&self, solution: &S) -> usize {
204        let entities = (self.extractor)(solution);
205
206        // Count unique groups
207        let mut groups: HashMap<K, ()> = HashMap::new();
208        for entity in entities {
209            let key = (self.key_fn)(entity);
210            groups.insert(key, ());
211        }
212
213        groups.len()
214    }
215
216    fn initialize(&mut self, solution: &S) -> Sc {
217        self.reset();
218
219        let entities = (self.extractor)(solution);
220        let mut total = Sc::zero();
221
222        for (idx, entity) in entities.iter().enumerate() {
223            total = total + self.insert_entity(entities, idx, entity);
224        }
225
226        total
227    }
228
229    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
230        if let Some(expected) = self.expected_descriptor {
231            if descriptor_index != expected {
232                return Sc::zero();
233            }
234        }
235        let entities = (self.extractor)(solution);
236        if entity_index >= entities.len() {
237            return Sc::zero();
238        }
239
240        let entity = &entities[entity_index];
241        self.insert_entity(entities, entity_index, entity)
242    }
243
244    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
245        if let Some(expected) = self.expected_descriptor {
246            if descriptor_index != expected {
247                return Sc::zero();
248            }
249        }
250        let entities = (self.extractor)(solution);
251        self.retract_entity(entities, entity_index)
252    }
253
254    fn reset(&mut self) {
255        self.groups.clear();
256        self.entity_groups.clear();
257        self.entity_values.clear();
258    }
259
260    fn name(&self) -> &str {
261        &self.constraint_ref.name
262    }
263
264    fn is_hard(&self) -> bool {
265        self.is_hard
266    }
267
268    fn constraint_ref(&self) -> ConstraintRef {
269        self.constraint_ref.clone()
270    }
271}
272
273impl<S, A, K, E, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
274where
275    S: Send + Sync + 'static,
276    A: Clone + Send + Sync + 'static,
277    K: Clone + Eq + Hash + Send + Sync + 'static,
278    E: Fn(&S) -> &[A] + Send + Sync,
279    KF: Fn(&A) -> K + Send + Sync,
280    C: UniCollector<A> + Send + Sync + 'static,
281    C::Accumulator: Send + Sync,
282    C::Result: Send + Sync,
283    C::Value: Send + Sync,
284    W: Fn(&C::Result) -> Sc + Send + Sync,
285    Sc: Score + 'static,
286{
287    fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
288        let key = (self.key_fn)(entity);
289        let value = self.collector.extract(entity);
290        let impact = self.impact_type;
291
292        // Get or create group accumulator
293        let acc = self
294            .groups
295            .entry(key.clone())
296            .or_insert_with(|| self.collector.create_accumulator());
297
298        // Compute old score from current state (inlined to avoid borrow conflict)
299        let old_base = (self.weight_fn)(&acc.finish());
300        let old = match impact {
301            ImpactType::Penalty => -old_base,
302            ImpactType::Reward => old_base,
303        };
304
305        // Accumulate and compute new score
306        acc.accumulate(&value);
307        let new_base = (self.weight_fn)(&acc.finish());
308        let new_score = match impact {
309            ImpactType::Penalty => -new_base,
310            ImpactType::Reward => new_base,
311        };
312
313        // Track entity -> group mapping and cache value for correct retraction
314        self.entity_groups.insert(entity_index, key);
315        self.entity_values.insert(entity_index, value);
316
317        // Return delta (both scores computed fresh, no cloning)
318        new_score - old
319    }
320
321    fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
322        // Find which group this entity belonged to
323        let Some(key) = self.entity_groups.remove(&entity_index) else {
324            return Sc::zero();
325        };
326
327        // Use cached value (entity may have been mutated since insert)
328        let Some(value) = self.entity_values.remove(&entity_index) else {
329            return Sc::zero();
330        };
331        let impact = self.impact_type;
332
333        // Get the group accumulator
334        let Some(acc) = self.groups.get_mut(&key) else {
335            return Sc::zero();
336        };
337
338        // Compute old score from current state (inlined to avoid borrow conflict)
339        let old_base = (self.weight_fn)(&acc.finish());
340        let old = match impact {
341            ImpactType::Penalty => -old_base,
342            ImpactType::Reward => old_base,
343        };
344
345        // Retract and compute new score
346        acc.retract(&value);
347        let new_base = (self.weight_fn)(&acc.finish());
348        let new_score = match impact {
349            ImpactType::Penalty => -new_base,
350            ImpactType::Reward => new_base,
351        };
352
353        // Return delta (both scores computed fresh, no cloning)
354        new_score - old
355    }
356}
357
358impl<S, A, K, E, KF, C, W, Sc> std::fmt::Debug for GroupedUniConstraint<S, A, K, E, KF, C, W, Sc>
359where
360    C: UniCollector<A>,
361    Sc: Score,
362{
363    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364        f.debug_struct("GroupedUniConstraint")
365            .field("name", &self.constraint_ref.name)
366            .field("impact_type", &self.impact_type)
367            .field("groups", &self.groups.len())
368            .finish()
369    }
370}