Skip to main content

solverforge_scoring/constraint/
flattened_bi.rs

1/* O(1) flattened bi-constraint for cross-entity joins.
2
3Pre-indexes C items by key for O(1) lookup on entity changes.
4*/
5
6use std::collections::HashMap;
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::api::constraint_set::IncrementalConstraint;
14/* O(1) flattened bi-constraint.
15
16Given a join between A and B entities by key, this constraint:
171. Expands each B into C items via a flatten function
182. Pre-indexes C items by (join_key, c_key) for O(1) lookup
193. On A entity change, looks up matching C items in O(1) instead of O(|C|)
20
21# Type Parameters
22
23- `S` - Solution type
24- `A` - Entity type A (the planning entity, e.g., Shift)
25- `B` - Entity type B (the joined entity, e.g., Employee)
26- `C` - Flattened item type (e.g., NaiveDate from unavailable dates)
27- `K` - Join key type (e.g., Option<usize> for employee_idx)
28- `CK` - C item key type for indexing (e.g., NaiveDate)
29- `EA` - Extractor for A entities
30- `EB` - Extractor for B entities
31- `KA` - Key extractor for A (join key)
32- `KB` - Key extractor for B (join key)
33- `Flatten` - Function extracting `&[C]` from `&B`
34- `CKeyFn` - Function extracting index key from &C
35- `ALookup` - Function extracting lookup key from &A
36- `F` - Filter on (A, C) pairs
37- `W` - Weight function on (A, C) pairs
38- `Sc` - Score type
39
40# Example
41
42```
43use solverforge_scoring::constraint::flattened_bi::FlattenedBiConstraint;
44use solverforge_scoring::api::constraint_set::IncrementalConstraint;
45use solverforge_core::{ConstraintRef, ImpactType};
46use solverforge_core::score::SoftScore;
47
48#[derive(Clone)]
49struct Employee {
50id: usize,
51unavailable_days: Vec<u32>,
52}
53
54#[derive(Clone)]
55struct Shift {
56employee_id: Option<usize>,
57day: u32,
58}
59
60#[derive(Clone)]
61struct Schedule {
62shifts: Vec<Shift>,
63employees: Vec<Employee>,
64}
65
66let constraint = FlattenedBiConstraint::new(
67ConstraintRef::new("", "Unavailable employee"),
68ImpactType::Penalty,
69|s: &Schedule| s.shifts.as_slice(),
70|s: &Schedule| s.employees.as_slice(),
71|shift: &Shift| shift.employee_id,
72|emp: &Employee| Some(emp.id),
73|emp: &Employee| emp.unavailable_days.as_slice(),
74|day: &u32| *day,           // C → index key
75|shift: &Shift| shift.day,  // A → lookup key
76|_s: &Schedule, shift: &Shift, day: &u32| shift.day == *day,
77|_shift: &Shift, _day: &u32| SoftScore::of(1),
78false,
79);
80
81let schedule = Schedule {
82shifts: vec![
83Shift { employee_id: Some(0), day: 5 },
84Shift { employee_id: Some(0), day: 10 },
85],
86employees: vec![
87Employee { id: 0, unavailable_days: vec![5, 15] },
88],
89};
90
91// Day 5 shift conflicts with employee's unavailable day 5 → O(1) lookup!
92assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-1));
93```
94*/
95pub struct FlattenedBiConstraint<
96    S,
97    A,
98    B,
99    C,
100    K,
101    CK,
102    EA,
103    EB,
104    KA,
105    KB,
106    Flatten,
107    CKeyFn,
108    ALookup,
109    F,
110    W,
111    Sc,
112> where
113    Sc: Score,
114{
115    constraint_ref: ConstraintRef,
116    impact_type: ImpactType,
117    extractor_a: EA,
118    extractor_b: EB,
119    key_a: KA,
120    key_b: KB,
121    flatten: Flatten,
122    c_key_fn: CKeyFn,
123    a_lookup_fn: ALookup,
124    filter: F,
125    weight: W,
126    is_hard: bool,
127    // (join_key, c_key) → list of (b_idx, c_value) for O(1) lookup
128    c_index: HashMap<(K, CK), Vec<(usize, C)>>,
129    // A index → cached score for this entity's matches
130    a_scores: HashMap<usize, Sc>,
131    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B)>,
132}
133
134impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
135    FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
136where
137    S: 'static,
138    A: Clone + 'static,
139    B: Clone + 'static,
140    C: Clone + 'static,
141    K: Eq + Hash + Clone,
142    CK: Eq + Hash + Clone,
143    EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
144    EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
145    KA: Fn(&A) -> K,
146    KB: Fn(&B) -> K,
147    Flatten: Fn(&B) -> &[C],
148    CKeyFn: Fn(&C) -> CK,
149    ALookup: Fn(&A) -> CK,
150    F: Fn(&S, &A, &C) -> bool,
151    W: Fn(&A, &C) -> Sc,
152    Sc: Score,
153{
154    // Creates a new O(1) flattened bi-constraint.
155    #[allow(clippy::too_many_arguments)]
156    pub fn new(
157        constraint_ref: ConstraintRef,
158        impact_type: ImpactType,
159        extractor_a: EA,
160        extractor_b: EB,
161        key_a: KA,
162        key_b: KB,
163        flatten: Flatten,
164        c_key_fn: CKeyFn,
165        a_lookup_fn: ALookup,
166        filter: F,
167        weight: W,
168        is_hard: bool,
169    ) -> Self {
170        Self {
171            constraint_ref,
172            impact_type,
173            extractor_a,
174            extractor_b,
175            key_a,
176            key_b,
177            flatten,
178            c_key_fn,
179            a_lookup_fn,
180            filter,
181            weight,
182            is_hard,
183            c_index: HashMap::new(),
184            a_scores: HashMap::new(),
185            _phantom: PhantomData,
186        }
187    }
188
189    #[inline]
190    fn compute_score(&self, a: &A, c: &C) -> Sc {
191        let base = (self.weight)(a, c);
192        match self.impact_type {
193            ImpactType::Penalty => -base,
194            ImpactType::Reward => base,
195        }
196    }
197
198    // Build C index: (join_key, c_key) → list of (b_idx, c_value)
199    fn build_c_index(&mut self, entities_b: &[B]) {
200        self.c_index.clear();
201        for (b_idx, b) in entities_b.iter().enumerate() {
202            let join_key = (self.key_b)(b);
203            for c in (self.flatten)(b) {
204                let c_key = (self.c_key_fn)(c);
205                self.c_index
206                    .entry((join_key.clone(), c_key))
207                    .or_default()
208                    .push((b_idx, c.clone()));
209            }
210        }
211    }
212
213    // Compute score for entity A using O(1) index lookup.
214    fn compute_a_score(&self, solution: &S, a: &A) -> Sc {
215        let join_key = (self.key_a)(a);
216        let lookup_key = (self.a_lookup_fn)(a);
217
218        // O(1) HashMap lookup instead of O(|C|) iteration!
219        let matches = match self.c_index.get(&(join_key, lookup_key)) {
220            Some(v) => v.as_slice(),
221            None => return Sc::zero(),
222        };
223
224        let mut total = Sc::zero();
225        for (_b_idx, c) in matches {
226            if (self.filter)(solution, a, c) {
227                total = total + self.compute_score(a, c);
228            }
229        }
230        total
231    }
232
233    fn insert_a(&mut self, solution: &S, entities_a: &[A], a_idx: usize) -> Sc {
234        if a_idx >= entities_a.len() {
235            return Sc::zero();
236        }
237
238        let a = &entities_a[a_idx];
239        let score = self.compute_a_score(solution, a);
240
241        if score != Sc::zero() {
242            self.a_scores.insert(a_idx, score);
243        }
244        score
245    }
246
247    fn retract_a(&mut self, a_idx: usize) -> Sc {
248        match self.a_scores.remove(&a_idx) {
249            Some(score) => -score,
250            None => Sc::zero(),
251        }
252    }
253}
254
255impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
256    IncrementalConstraint<S, Sc>
257    for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
258where
259    S: Send + Sync + 'static,
260    A: Clone + Send + Sync + 'static,
261    B: Clone + Send + Sync + 'static,
262    C: Clone + Send + Sync + 'static,
263    K: Eq + Hash + Clone + Send + Sync,
264    CK: Eq + Hash + Clone + Send + Sync,
265    EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
266    EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
267    KA: Fn(&A) -> K + Send + Sync,
268    KB: Fn(&B) -> K + Send + Sync,
269    Flatten: Fn(&B) -> &[C] + Send + Sync,
270    CKeyFn: Fn(&C) -> CK + Send + Sync,
271    ALookup: Fn(&A) -> CK + Send + Sync,
272    F: Fn(&S, &A, &C) -> bool + Send + Sync,
273    W: Fn(&A, &C) -> Sc + Send + Sync,
274    Sc: Score,
275{
276    fn evaluate(&self, solution: &S) -> Sc {
277        let entities_a = self.extractor_a.extract(solution);
278        let entities_b = self.extractor_b.extract(solution);
279        let mut total = Sc::zero();
280
281        // Build temporary index for standalone evaluation
282        let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
283        for (b_idx, b) in entities_b.iter().enumerate() {
284            let join_key = (self.key_b)(b);
285            for c in (self.flatten)(b) {
286                let c_key = (self.c_key_fn)(c);
287                temp_index
288                    .entry((join_key.clone(), c_key))
289                    .or_default()
290                    .push((b_idx, c.clone()));
291            }
292        }
293
294        for a in entities_a {
295            let join_key = (self.key_a)(a);
296            let lookup_key = (self.a_lookup_fn)(a);
297
298            if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
299                for (_b_idx, c) in matches {
300                    if (self.filter)(solution, a, c) {
301                        total = total + self.compute_score(a, c);
302                    }
303                }
304            }
305        }
306
307        total
308    }
309
310    fn match_count(&self, solution: &S) -> usize {
311        let entities_a = self.extractor_a.extract(solution);
312        let entities_b = self.extractor_b.extract(solution);
313        let mut count = 0;
314
315        // Build temporary index
316        let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
317        for (b_idx, b) in entities_b.iter().enumerate() {
318            let join_key = (self.key_b)(b);
319            for c in (self.flatten)(b) {
320                let c_key = (self.c_key_fn)(c);
321                temp_index
322                    .entry((join_key.clone(), c_key))
323                    .or_default()
324                    .push((b_idx, c.clone()));
325            }
326        }
327
328        for a in entities_a {
329            let join_key = (self.key_a)(a);
330            let lookup_key = (self.a_lookup_fn)(a);
331
332            if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
333                for (_b_idx, c) in matches {
334                    if (self.filter)(solution, a, c) {
335                        count += 1;
336                    }
337                }
338            }
339        }
340
341        count
342    }
343
344    fn initialize(&mut self, solution: &S) -> Sc {
345        self.reset();
346
347        let entities_a = self.extractor_a.extract(solution);
348        let entities_b = self.extractor_b.extract(solution);
349
350        // Build C index once: O(B × C)
351        self.build_c_index(entities_b);
352
353        // Insert all A entities: O(A) with O(1) lookups each
354        let mut total = Sc::zero();
355        for a_idx in 0..entities_a.len() {
356            total = total + self.insert_a(solution, entities_a, a_idx);
357        }
358
359        total
360    }
361
362    fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
363        let entities_a = self.extractor_a.extract(solution);
364        self.insert_a(solution, entities_a, entity_index)
365    }
366
367    fn on_retract(&mut self, _solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
368        self.retract_a(entity_index)
369    }
370
371    fn reset(&mut self) {
372        self.c_index.clear();
373        self.a_scores.clear();
374    }
375
376    fn name(&self) -> &str {
377        &self.constraint_ref.name
378    }
379
380    fn is_hard(&self) -> bool {
381        self.is_hard
382    }
383
384    fn constraint_ref(&self) -> ConstraintRef {
385        self.constraint_ref.clone()
386    }
387}
388
389impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc: Score> std::fmt::Debug
390    for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
391{
392    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
393        f.debug_struct("FlattenedBiConstraint")
394            .field("name", &self.constraint_ref.name)
395            .field("impact_type", &self.impact_type)
396            .field("c_index_size", &self.c_index.len())
397            .finish()
398    }
399}