Skip to main content

solverforge_scoring/constraint/
flattened_bi.rs

1/* O(1) flattened bi-constraint for cross-entity joins.
2
3Pre-indexes C items by key for O(1) lookup on entity changes.
4*/
5
6use std::collections::HashMap;
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::api::constraint_set::IncrementalConstraint;
14
15/* O(1) flattened bi-constraint.
16
17Given a join between A and B entities by key, this constraint:
181. Expands each B into C items via a flatten function
192. Pre-indexes C items by (join_key, c_key) for O(1) lookup
203. On A entity change, looks up matching C items in O(1) instead of O(|C|)
21
22# Type Parameters
23
24- `S` - Solution type
25- `A` - Entity type A (the planning entity, e.g., Shift)
26- `B` - Entity type B (the joined entity, e.g., Employee)
27- `C` - Flattened item type (e.g., NaiveDate from unavailable dates)
28- `K` - Join key type (e.g., Option<usize> for employee_idx)
29- `CK` - C item key type for indexing (e.g., NaiveDate)
30- `EA` - Extractor for A entities
31- `EB` - Extractor for B entities
32- `KA` - Key extractor for A (join key)
33- `KB` - Key extractor for B (join key)
34- `Flatten` - Function extracting &[C] from &B
35- `CKeyFn` - Function extracting index key from &C
36- `ALookup` - Function extracting lookup key from &A
37- `F` - Filter on (A, C) pairs
38- `W` - Weight function on (A, C) pairs
39- `Sc` - Score type
40
41# Example
42
43```
44use solverforge_scoring::constraint::flattened_bi::FlattenedBiConstraint;
45use solverforge_scoring::api::constraint_set::IncrementalConstraint;
46use solverforge_core::{ConstraintRef, ImpactType};
47use solverforge_core::score::SoftScore;
48
49#[derive(Clone)]
50struct Employee {
51id: usize,
52unavailable_days: Vec<u32>,
53}
54
55#[derive(Clone)]
56struct Shift {
57employee_id: Option<usize>,
58day: u32,
59}
60
61#[derive(Clone)]
62struct Schedule {
63shifts: Vec<Shift>,
64employees: Vec<Employee>,
65}
66
67let constraint = FlattenedBiConstraint::new(
68ConstraintRef::new("", "Unavailable employee"),
69ImpactType::Penalty,
70|s: &Schedule| s.shifts.as_slice(),
71|s: &Schedule| s.employees.as_slice(),
72|shift: &Shift| shift.employee_id,
73|emp: &Employee| Some(emp.id),
74|emp: &Employee| emp.unavailable_days.as_slice(),
75|day: &u32| *day,           // C → index key
76|shift: &Shift| shift.day,  // A → lookup key
77|_s: &Schedule, shift: &Shift, day: &u32| shift.day == *day,
78|_shift: &Shift, _day: &u32| SoftScore::of(1),
79false,
80);
81
82let schedule = Schedule {
83shifts: vec![
84Shift { employee_id: Some(0), day: 5 },
85Shift { employee_id: Some(0), day: 10 },
86],
87employees: vec![
88Employee { id: 0, unavailable_days: vec![5, 15] },
89],
90};
91
92// Day 5 shift conflicts with employee's unavailable day 5 → O(1) lookup!
93assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-1));
94```
95*/
96pub struct FlattenedBiConstraint<
97    S,
98    A,
99    B,
100    C,
101    K,
102    CK,
103    EA,
104    EB,
105    KA,
106    KB,
107    Flatten,
108    CKeyFn,
109    ALookup,
110    F,
111    W,
112    Sc,
113> where
114    Sc: Score,
115{
116    constraint_ref: ConstraintRef,
117    impact_type: ImpactType,
118    extractor_a: EA,
119    extractor_b: EB,
120    key_a: KA,
121    key_b: KB,
122    flatten: Flatten,
123    c_key_fn: CKeyFn,
124    a_lookup_fn: ALookup,
125    filter: F,
126    weight: W,
127    is_hard: bool,
128    // (join_key, c_key) → list of (b_idx, c_value) for O(1) lookup
129    c_index: HashMap<(K, CK), Vec<(usize, C)>>,
130    // A index → cached score for this entity's matches
131    a_scores: HashMap<usize, Sc>,
132    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B)>,
133}
134
135impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
136    FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
137where
138    S: 'static,
139    A: Clone + 'static,
140    B: Clone + 'static,
141    C: Clone + 'static,
142    K: Eq + Hash + Clone,
143    CK: Eq + Hash + Clone,
144    EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
145    EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
146    KA: Fn(&A) -> K,
147    KB: Fn(&B) -> K,
148    Flatten: Fn(&B) -> &[C],
149    CKeyFn: Fn(&C) -> CK,
150    ALookup: Fn(&A) -> CK,
151    F: Fn(&S, &A, &C) -> bool,
152    W: Fn(&A, &C) -> Sc,
153    Sc: Score,
154{
155    // Creates a new O(1) flattened bi-constraint.
156    #[allow(clippy::too_many_arguments)]
157    pub fn new(
158        constraint_ref: ConstraintRef,
159        impact_type: ImpactType,
160        extractor_a: EA,
161        extractor_b: EB,
162        key_a: KA,
163        key_b: KB,
164        flatten: Flatten,
165        c_key_fn: CKeyFn,
166        a_lookup_fn: ALookup,
167        filter: F,
168        weight: W,
169        is_hard: bool,
170    ) -> Self {
171        Self {
172            constraint_ref,
173            impact_type,
174            extractor_a,
175            extractor_b,
176            key_a,
177            key_b,
178            flatten,
179            c_key_fn,
180            a_lookup_fn,
181            filter,
182            weight,
183            is_hard,
184            c_index: HashMap::new(),
185            a_scores: HashMap::new(),
186            _phantom: PhantomData,
187        }
188    }
189
190    #[inline]
191    fn compute_score(&self, a: &A, c: &C) -> Sc {
192        let base = (self.weight)(a, c);
193        match self.impact_type {
194            ImpactType::Penalty => -base,
195            ImpactType::Reward => base,
196        }
197    }
198
199    // Build C index: (join_key, c_key) → list of (b_idx, c_value)
200    fn build_c_index(&mut self, entities_b: &[B]) {
201        self.c_index.clear();
202        for (b_idx, b) in entities_b.iter().enumerate() {
203            let join_key = (self.key_b)(b);
204            for c in (self.flatten)(b) {
205                let c_key = (self.c_key_fn)(c);
206                self.c_index
207                    .entry((join_key.clone(), c_key))
208                    .or_default()
209                    .push((b_idx, c.clone()));
210            }
211        }
212    }
213
214    // Compute score for entity A using O(1) index lookup.
215    fn compute_a_score(&self, solution: &S, a: &A) -> Sc {
216        let join_key = (self.key_a)(a);
217        let lookup_key = (self.a_lookup_fn)(a);
218
219        // O(1) HashMap lookup instead of O(|C|) iteration!
220        let matches = match self.c_index.get(&(join_key, lookup_key)) {
221            Some(v) => v.as_slice(),
222            None => return Sc::zero(),
223        };
224
225        let mut total = Sc::zero();
226        for (_b_idx, c) in matches {
227            if (self.filter)(solution, a, c) {
228                total = total + self.compute_score(a, c);
229            }
230        }
231        total
232    }
233
234    fn insert_a(&mut self, solution: &S, entities_a: &[A], a_idx: usize) -> Sc {
235        if a_idx >= entities_a.len() {
236            return Sc::zero();
237        }
238
239        let a = &entities_a[a_idx];
240        let score = self.compute_a_score(solution, a);
241
242        if score != Sc::zero() {
243            self.a_scores.insert(a_idx, score);
244        }
245        score
246    }
247
248    fn retract_a(&mut self, a_idx: usize) -> Sc {
249        match self.a_scores.remove(&a_idx) {
250            Some(score) => -score,
251            None => Sc::zero(),
252        }
253    }
254}
255
256impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
257    IncrementalConstraint<S, Sc>
258    for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
259where
260    S: Send + Sync + 'static,
261    A: Clone + Send + Sync + 'static,
262    B: Clone + Send + Sync + 'static,
263    C: Clone + Send + Sync + 'static,
264    K: Eq + Hash + Clone + Send + Sync,
265    CK: Eq + Hash + Clone + Send + Sync,
266    EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
267    EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
268    KA: Fn(&A) -> K + Send + Sync,
269    KB: Fn(&B) -> K + Send + Sync,
270    Flatten: Fn(&B) -> &[C] + Send + Sync,
271    CKeyFn: Fn(&C) -> CK + Send + Sync,
272    ALookup: Fn(&A) -> CK + Send + Sync,
273    F: Fn(&S, &A, &C) -> bool + Send + Sync,
274    W: Fn(&A, &C) -> Sc + Send + Sync,
275    Sc: Score,
276{
277    fn evaluate(&self, solution: &S) -> Sc {
278        let entities_a = self.extractor_a.extract(solution);
279        let entities_b = self.extractor_b.extract(solution);
280        let mut total = Sc::zero();
281
282        // Build temporary index for standalone evaluation
283        let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
284        for (b_idx, b) in entities_b.iter().enumerate() {
285            let join_key = (self.key_b)(b);
286            for c in (self.flatten)(b) {
287                let c_key = (self.c_key_fn)(c);
288                temp_index
289                    .entry((join_key.clone(), c_key))
290                    .or_default()
291                    .push((b_idx, c.clone()));
292            }
293        }
294
295        for a in entities_a {
296            let join_key = (self.key_a)(a);
297            let lookup_key = (self.a_lookup_fn)(a);
298
299            if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
300                for (_b_idx, c) in matches {
301                    if (self.filter)(solution, a, c) {
302                        total = total + self.compute_score(a, c);
303                    }
304                }
305            }
306        }
307
308        total
309    }
310
311    fn match_count(&self, solution: &S) -> usize {
312        let entities_a = self.extractor_a.extract(solution);
313        let entities_b = self.extractor_b.extract(solution);
314        let mut count = 0;
315
316        // Build temporary index
317        let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
318        for (b_idx, b) in entities_b.iter().enumerate() {
319            let join_key = (self.key_b)(b);
320            for c in (self.flatten)(b) {
321                let c_key = (self.c_key_fn)(c);
322                temp_index
323                    .entry((join_key.clone(), c_key))
324                    .or_default()
325                    .push((b_idx, c.clone()));
326            }
327        }
328
329        for a in entities_a {
330            let join_key = (self.key_a)(a);
331            let lookup_key = (self.a_lookup_fn)(a);
332
333            if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
334                for (_b_idx, c) in matches {
335                    if (self.filter)(solution, a, c) {
336                        count += 1;
337                    }
338                }
339            }
340        }
341
342        count
343    }
344
345    fn initialize(&mut self, solution: &S) -> Sc {
346        self.reset();
347
348        let entities_a = self.extractor_a.extract(solution);
349        let entities_b = self.extractor_b.extract(solution);
350
351        // Build C index once: O(B × C)
352        self.build_c_index(entities_b);
353
354        // Insert all A entities: O(A) with O(1) lookups each
355        let mut total = Sc::zero();
356        for a_idx in 0..entities_a.len() {
357            total = total + self.insert_a(solution, entities_a, a_idx);
358        }
359
360        total
361    }
362
363    fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
364        let entities_a = self.extractor_a.extract(solution);
365        self.insert_a(solution, entities_a, entity_index)
366    }
367
368    fn on_retract(&mut self, _solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
369        self.retract_a(entity_index)
370    }
371
372    fn reset(&mut self) {
373        self.c_index.clear();
374        self.a_scores.clear();
375    }
376
377    fn name(&self) -> &str {
378        &self.constraint_ref.name
379    }
380
381    fn is_hard(&self) -> bool {
382        self.is_hard
383    }
384
385    fn constraint_ref(&self) -> ConstraintRef {
386        self.constraint_ref.clone()
387    }
388}
389
390impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc: Score> std::fmt::Debug
391    for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
392{
393    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394        f.debug_struct("FlattenedBiConstraint")
395            .field("name", &self.constraint_ref.name)
396            .field("impact_type", &self.impact_type)
397            .field("c_index_size", &self.c_index.len())
398            .finish()
399    }
400}