Skip to main content

solverforge_scoring/constraint/complemented/
state.rs

1/* Zero-erasure complemented group constraint.
2
3Evaluates grouped results plus complement entities with default values.
4Provides true incremental scoring by tracking per-key accumulators.
5*/
6
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::stream::collection_extract::ChangeSource;
15use crate::stream::collector::{Accumulator, UniCollector};
16
17type CollectorRetraction<C, A> = <<C as UniCollector<A>>::Accumulator as Accumulator<
18    <C as UniCollector<A>>::Value,
19    <C as UniCollector<A>>::Result,
20>>::Retraction;
21
22/* Zero-erasure constraint for complemented grouped results.
23
24Groups A entities by key, then iterates over B entities (complement source),
25using grouped values where they exist and default values otherwise.
26
27The key function for A returns `Option<K>`, allowing entities to be skipped
28when they don't have a valid key (e.g., unassigned shifts).
29
30# Type Parameters
31
32- `S` - Solution type
33- `A` - Entity type being grouped (e.g., Shift)
34- `B` - Complement entity type (e.g., Employee)
35- `K` - Group key type
36- `EA` - Extractor for A entities
37- `EB` - Extractor for B entities
38- `KA` - Key function for A (returns `Option<K>` to allow skipping)
39- `KB` - Key function for B
40- `C` - Collector type
41- `D` - Default value function
42- `W` - Weight function
43- `Sc` - Score type
44
45# Example
46
47```
48use solverforge_scoring::constraint::complemented::ComplementedGroupConstraint;
49use solverforge_scoring::stream::collector::count;
50use solverforge_scoring::api::constraint_set::IncrementalConstraint;
51use solverforge_core::{ConstraintRef, ImpactType};
52use solverforge_core::score::SoftScore;
53
54#[derive(Clone, Hash, PartialEq, Eq)]
55struct Employee { id: usize }
56
57#[derive(Clone)]
58struct Shift { employee_id: Option<usize> }
59
60#[derive(Clone)]
61struct Schedule {
62employees: Vec<Employee>,
63shifts: Vec<Shift>,
64}
65
66let constraint = ComplementedGroupConstraint::new(
67ConstraintRef::new("", "Shift count"),
68ImpactType::Penalty,
69|s: &Schedule| s.shifts.as_slice(),
70|s: &Schedule| s.employees.as_slice(),
71|shift: &Shift| shift.employee_id,  // Returns Option<usize>
72|emp: &Employee| emp.id,
73count(),
74|_emp: &Employee| 0usize,
75|_employee_id: &usize, count: &usize| SoftScore::of(*count as i64),
76false,
77);
78
79let schedule = Schedule {
80employees: vec![Employee { id: 0 }, Employee { id: 1 }],
81shifts: vec![
82Shift { employee_id: Some(0) },
83Shift { employee_id: Some(0) },
84Shift { employee_id: None },  // Skipped - no key
85],
86};
87
88// Employee 0: 2 shifts, Employee 1: 0 shifts → Total: -2
89// Unassigned shift is skipped
90assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
91```
92*/
93pub struct ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
94where
95    C: UniCollector<A>,
96    Sc: Score,
97{
98    pub(super) constraint_ref: ConstraintRef,
99    pub(super) impact_type: ImpactType,
100    pub(super) extractor_a: EA,
101    pub(super) extractor_b: EB,
102    pub(super) key_a: KA,
103    pub(super) key_b: KB,
104    pub(super) collector: C,
105    pub(super) default_fn: D,
106    pub(super) weight_fn: W,
107    pub(super) is_hard: bool,
108    pub(super) a_source: ChangeSource,
109    pub(super) b_source: ChangeSource,
110    // Group key -> accumulator for incremental scoring
111    pub(super) groups: HashMap<K, C::Accumulator>,
112    // A entity index -> group key (for tracking which group each entity belongs to)
113    pub(super) entity_groups: HashMap<usize, K>,
114    // A entity index -> accumulator retraction token
115    pub(super) entity_retractions: HashMap<usize, CollectorRetraction<C, A>>,
116    // B key -> B entity index (for looking up B entities by key)
117    pub(super) b_by_key: HashMap<K, usize>,
118    // B entity index -> B key (for localized B retraction)
119    pub(super) b_index_to_key: HashMap<usize, K>,
120    pub(super) _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> Sc)>,
121}
122
123impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
124    ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
125where
126    S: 'static,
127    A: Clone + 'static,
128    B: Clone + 'static,
129    K: Clone + Eq + Hash,
130    EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
131    EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
132    KA: Fn(&A) -> Option<K>,
133    KB: Fn(&B) -> K,
134    C: UniCollector<A>,
135    D: Fn(&B) -> C::Result,
136    W: Fn(&K, &C::Result) -> Sc,
137    Sc: Score,
138{
139    // Creates a new complemented group constraint.
140    #[allow(clippy::too_many_arguments)]
141    pub fn new(
142        constraint_ref: ConstraintRef,
143        impact_type: ImpactType,
144        extractor_a: EA,
145        extractor_b: EB,
146        key_a: KA,
147        key_b: KB,
148        collector: C,
149        default_fn: D,
150        weight_fn: W,
151        is_hard: bool,
152    ) -> Self {
153        let a_source = extractor_a.change_source();
154        let b_source = extractor_b.change_source();
155        Self {
156            constraint_ref,
157            impact_type,
158            extractor_a,
159            extractor_b,
160            key_a,
161            key_b,
162            collector,
163            default_fn,
164            weight_fn,
165            is_hard,
166            a_source,
167            b_source,
168            groups: HashMap::new(),
169            entity_groups: HashMap::new(),
170            entity_retractions: HashMap::new(),
171            b_by_key: HashMap::new(),
172            b_index_to_key: HashMap::new(),
173            _phantom: PhantomData,
174        }
175    }
176
177    #[inline]
178    pub(super) fn compute_score(&self, key: &K, result: &C::Result) -> Sc {
179        let base = (self.weight_fn)(key, result);
180        match self.impact_type {
181            ImpactType::Penalty => -base,
182            ImpactType::Reward => base,
183        }
184    }
185
186    // Build grouped results from A entities.
187    pub(super) fn build_groups(&self, entities_a: &[A]) -> HashMap<K, C::Accumulator> {
188        let mut accumulators: HashMap<K, C::Accumulator> = HashMap::new();
189
190        for a in entities_a {
191            // Skip entities with no key (e.g., unassigned shifts)
192            let Some(key) = (self.key_a)(a) else {
193                continue;
194            };
195            let value = self.collector.extract(a);
196            accumulators
197                .entry(key)
198                .or_insert_with(|| self.collector.create_accumulator())
199                .accumulate(value);
200        }
201
202        accumulators
203    }
204}