Skip to main content

solverforge_scoring/constraint/complemented/
state.rs

1/* Zero-erasure complemented group constraint.
2
3Evaluates grouped results plus complement entities with default values.
4Provides true incremental scoring by tracking per-key accumulators.
5*/
6
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::stream::collection_extract::ChangeSource;
15use crate::stream::collector::{Accumulator, Collector};
16
17type CollectorRetraction<Acc, V, R> = <Acc as Accumulator<V, R>>::Retraction;
18
19pub(super) struct GroupState<Acc> {
20    pub(super) accumulator: Acc,
21    pub(super) count: usize,
22}
23
24/* Zero-erasure constraint for complemented grouped results.
25
26Groups A entities by key, then iterates over B entities (complement source),
27using grouped values where they exist and default values otherwise.
28
29The key function for A returns `Option<K>`, allowing entities to be skipped
30when they don't have a valid key (e.g., unassigned shifts).
31
32# Type Parameters
33
34- `S` - Solution type
35- `A` - Entity type being grouped (e.g., Shift)
36- `B` - Complement entity type (e.g., Employee)
37- `K` - Group key type
38- `EA` - Extractor for A entities
39- `EB` - Extractor for B entities
40- `KA` - Key function for A (returns `Option<K>` to allow skipping)
41- `KB` - Key function for B
42- `C` - Collector type
43- `D` - Default value function
44- `W` - Weight function
45- `Sc` - Score type
46
47# Example
48
49```
50use solverforge_scoring::constraint::complemented::ComplementedGroupConstraint;
51use solverforge_scoring::stream::collector::count;
52use solverforge_scoring::api::constraint_set::IncrementalConstraint;
53use solverforge_core::{ConstraintRef, ImpactType};
54use solverforge_core::score::SoftScore;
55
56#[derive(Clone, Hash, PartialEq, Eq)]
57struct Employee { id: usize }
58
59#[derive(Clone)]
60struct Shift { employee_id: Option<usize> }
61
62#[derive(Clone)]
63struct Schedule {
64employees: Vec<Employee>,
65shifts: Vec<Shift>,
66}
67
68let constraint = ComplementedGroupConstraint::new(
69ConstraintRef::new("", "Shift count"),
70ImpactType::Penalty,
71|s: &Schedule| s.shifts.as_slice(),
72|s: &Schedule| s.employees.as_slice(),
73|shift: &Shift| shift.employee_id,  // Returns Option<usize>
74|emp: &Employee| emp.id,
75count(),
76|_emp: &Employee| 0usize,
77|_employee_id: &usize, count: &usize| SoftScore::of(*count as i64),
78false,
79);
80
81let schedule = Schedule {
82employees: vec![Employee { id: 0 }, Employee { id: 1 }],
83shifts: vec![
84Shift { employee_id: Some(0) },
85Shift { employee_id: Some(0) },
86Shift { employee_id: None },  // Skipped - no key
87],
88};
89
90// Employee 0: 2 shifts, Employee 1: 0 shifts → Total: -2
91// Unassigned shift is skipped
92assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
93```
94*/
95pub struct ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc>
96where
97    Acc: Accumulator<V, R>,
98    Sc: Score,
99{
100    pub(super) constraint_ref: ConstraintRef,
101    pub(super) impact_type: ImpactType,
102    pub(super) extractor_a: EA,
103    pub(super) extractor_b: EB,
104    pub(super) key_a: KA,
105    pub(super) key_b: KB,
106    pub(super) collector: C,
107    pub(super) default_fn: D,
108    pub(super) weight_fn: W,
109    pub(super) is_hard: bool,
110    pub(super) a_source: ChangeSource,
111    pub(super) b_source: ChangeSource,
112    // Group key -> accumulator for incremental scoring
113    pub(super) groups: HashMap<K, GroupState<Acc>>,
114    // A entity index -> group key (for tracking which group each entity belongs to)
115    pub(super) entity_groups: HashMap<usize, K>,
116    // A entity index -> accumulator retraction token
117    pub(super) entity_retractions: HashMap<usize, CollectorRetraction<Acc, V, R>>,
118    // B key -> B entity indexes (for looking up all complement rows by key)
119    pub(super) b_by_key: HashMap<K, Vec<usize>>,
120    // B entity index -> B key (for localized B retraction)
121    pub(super) b_index_to_key: HashMap<usize, K>,
122    pub(super) _phantom: PhantomData<(
123        fn() -> S,
124        fn() -> A,
125        fn() -> B,
126        fn() -> V,
127        fn() -> R,
128        fn() -> Acc,
129        fn() -> Sc,
130    )>,
131}
132
133impl<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc>
134    ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc>
135where
136    S: 'static,
137    A: Clone + 'static,
138    B: Clone + 'static,
139    K: Clone + Eq + Hash,
140    EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
141    EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
142    KA: Fn(&A) -> Option<K>,
143    KB: Fn(&B) -> K,
144    C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc>,
145    Acc: Accumulator<V, R>,
146    D: Fn(&B) -> R,
147    W: Fn(&K, &R) -> Sc,
148    Sc: Score,
149{
150    // Creates a new complemented group constraint.
151    #[allow(clippy::too_many_arguments)]
152    pub fn new(
153        constraint_ref: ConstraintRef,
154        impact_type: ImpactType,
155        extractor_a: EA,
156        extractor_b: EB,
157        key_a: KA,
158        key_b: KB,
159        collector: C,
160        default_fn: D,
161        weight_fn: W,
162        is_hard: bool,
163    ) -> Self {
164        let a_source = extractor_a.change_source();
165        let b_source = extractor_b.change_source();
166        Self {
167            constraint_ref,
168            impact_type,
169            extractor_a,
170            extractor_b,
171            key_a,
172            key_b,
173            collector,
174            default_fn,
175            weight_fn,
176            is_hard,
177            a_source,
178            b_source,
179            groups: HashMap::new(),
180            entity_groups: HashMap::new(),
181            entity_retractions: HashMap::new(),
182            b_by_key: HashMap::new(),
183            b_index_to_key: HashMap::new(),
184            _phantom: PhantomData,
185        }
186    }
187
188    #[inline]
189    pub(super) fn compute_score(&self, key: &K, result: &R) -> Sc {
190        let base = (self.weight_fn)(key, result);
191        match self.impact_type {
192            ImpactType::Penalty => -base,
193            ImpactType::Reward => base,
194        }
195    }
196
197    // Build grouped results from A entities.
198    pub(super) fn build_groups(&self, entities_a: &[A]) -> HashMap<K, Acc> {
199        let mut accumulators: HashMap<K, Acc> = HashMap::new();
200
201        for a in entities_a {
202            // Skip entities with no key (e.g., unassigned shifts)
203            let Some(key) = (self.key_a)(a) else {
204                continue;
205            };
206            let value = self.collector.extract(a);
207            accumulators
208                .entry(key)
209                .or_insert_with(|| self.collector.create_accumulator())
210                .accumulate(value);
211        }
212
213        accumulators
214    }
215}