Skip to main content

solverforge_scoring/constraint/
flattened_bi.rs

1// O(1) flattened bi-constraint for cross-entity joins.
2//
3// Pre-indexes C items by key for O(1) lookup on entity changes.
4
5use std::collections::HashMap;
6use std::hash::Hash;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use crate::api::constraint_set::IncrementalConstraint;
13
14// O(1) flattened bi-constraint.
15//
16// Given a join between A and B entities by key, this constraint:
17// 1. Expands each B into C items via a flatten function
18// 2. Pre-indexes C items by (join_key, c_key) for O(1) lookup
19// 3. On A entity change, looks up matching C items in O(1) instead of O(|C|)
20//
21// # Type Parameters
22//
23// - `S` - Solution type
24// - `A` - Entity type A (the planning entity, e.g., Shift)
25// - `B` - Entity type B (the joined entity, e.g., Employee)
26// - `C` - Flattened item type (e.g., NaiveDate from unavailable dates)
27// - `K` - Join key type (e.g., Option<usize> for employee_idx)
28// - `CK` - C item key type for indexing (e.g., NaiveDate)
29// - `EA` - Extractor for A entities
30// - `EB` - Extractor for B entities
31// - `KA` - Key extractor for A (join key)
32// - `KB` - Key extractor for B (join key)
33// - `Flatten` - Function extracting &[C] from &B
34// - `CKeyFn` - Function extracting index key from &C
35// - `ALookup` - Function extracting lookup key from &A
36// - `F` - Filter on (A, C) pairs
37// - `W` - Weight function on (A, C) pairs
38// - `Sc` - Score type
39//
40// # Example
41//
42// ```
43// use solverforge_scoring::constraint::flattened_bi::FlattenedBiConstraint;
44// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
45// use solverforge_core::{ConstraintRef, ImpactType};
46// use solverforge_core::score::SoftScore;
47//
48// #[derive(Clone)]
49// struct Employee {
50//     id: usize,
51//     unavailable_days: Vec<u32>,
52// }
53//
54// #[derive(Clone)]
55// struct Shift {
56//     employee_id: Option<usize>,
57//     day: u32,
58// }
59//
60// #[derive(Clone)]
61// struct Schedule {
62//     shifts: Vec<Shift>,
63//     employees: Vec<Employee>,
64// }
65//
66// let constraint = FlattenedBiConstraint::new(
67//     ConstraintRef::new("", "Unavailable employee"),
68//     ImpactType::Penalty,
69//     |s: &Schedule| s.shifts.as_slice(),
70//     |s: &Schedule| s.employees.as_slice(),
71//     |shift: &Shift| shift.employee_id,
72//     |emp: &Employee| Some(emp.id),
73//     |emp: &Employee| emp.unavailable_days.as_slice(),
74//     |day: &u32| *day,           // C → index key
75//     |shift: &Shift| shift.day,  // A → lookup key
76//     |_s: &Schedule, shift: &Shift, day: &u32| shift.day == *day,
77//     |_shift: &Shift, _day: &u32| SoftScore::of(1),
78//     false,
79// );
80//
81// let schedule = Schedule {
82//     shifts: vec![
83//         Shift { employee_id: Some(0), day: 5 },
84//         Shift { employee_id: Some(0), day: 10 },
85//     ],
86//     employees: vec![
87//         Employee { id: 0, unavailable_days: vec![5, 15] },
88//     ],
89// };
90//
91// // Day 5 shift conflicts with employee's unavailable day 5 → O(1) lookup!
92// assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-1));
93// ```
94pub struct FlattenedBiConstraint<
95    S,
96    A,
97    B,
98    C,
99    K,
100    CK,
101    EA,
102    EB,
103    KA,
104    KB,
105    Flatten,
106    CKeyFn,
107    ALookup,
108    F,
109    W,
110    Sc,
111> where
112    Sc: Score,
113{
114    constraint_ref: ConstraintRef,
115    impact_type: ImpactType,
116    extractor_a: EA,
117    extractor_b: EB,
118    key_a: KA,
119    key_b: KB,
120    flatten: Flatten,
121    c_key_fn: CKeyFn,
122    a_lookup_fn: ALookup,
123    filter: F,
124    weight: W,
125    is_hard: bool,
126    // (join_key, c_key) → list of (b_idx, c_value) for O(1) lookup
127    c_index: HashMap<(K, CK), Vec<(usize, C)>>,
128    // A index → cached score for this entity's matches
129    a_scores: HashMap<usize, Sc>,
130    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B)>,
131}
132
133impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
134    FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
135where
136    S: 'static,
137    A: Clone + 'static,
138    B: Clone + 'static,
139    C: Clone + 'static,
140    K: Eq + Hash + Clone,
141    CK: Eq + Hash + Clone,
142    EA: Fn(&S) -> &[A],
143    EB: Fn(&S) -> &[B],
144    KA: Fn(&A) -> K,
145    KB: Fn(&B) -> K,
146    Flatten: Fn(&B) -> &[C],
147    CKeyFn: Fn(&C) -> CK,
148    ALookup: Fn(&A) -> CK,
149    F: Fn(&S, &A, &C) -> bool,
150    W: Fn(&A, &C) -> Sc,
151    Sc: Score,
152{
153    // Creates a new O(1) flattened bi-constraint.
154    #[allow(clippy::too_many_arguments)]
155    pub fn new(
156        constraint_ref: ConstraintRef,
157        impact_type: ImpactType,
158        extractor_a: EA,
159        extractor_b: EB,
160        key_a: KA,
161        key_b: KB,
162        flatten: Flatten,
163        c_key_fn: CKeyFn,
164        a_lookup_fn: ALookup,
165        filter: F,
166        weight: W,
167        is_hard: bool,
168    ) -> Self {
169        Self {
170            constraint_ref,
171            impact_type,
172            extractor_a,
173            extractor_b,
174            key_a,
175            key_b,
176            flatten,
177            c_key_fn,
178            a_lookup_fn,
179            filter,
180            weight,
181            is_hard,
182            c_index: HashMap::new(),
183            a_scores: HashMap::new(),
184            _phantom: PhantomData,
185        }
186    }
187
188    #[inline]
189    fn compute_score(&self, a: &A, c: &C) -> Sc {
190        let base = (self.weight)(a, c);
191        match self.impact_type {
192            ImpactType::Penalty => -base,
193            ImpactType::Reward => base,
194        }
195    }
196
197    // Build C index: (join_key, c_key) → list of (b_idx, c_value)
198    fn build_c_index(&mut self, entities_b: &[B]) {
199        self.c_index.clear();
200        for (b_idx, b) in entities_b.iter().enumerate() {
201            let join_key = (self.key_b)(b);
202            for c in (self.flatten)(b) {
203                let c_key = (self.c_key_fn)(c);
204                self.c_index
205                    .entry((join_key.clone(), c_key))
206                    .or_default()
207                    .push((b_idx, c.clone()));
208            }
209        }
210    }
211
212    // Compute score for entity A using O(1) index lookup.
213    fn compute_a_score(&self, solution: &S, a: &A) -> Sc {
214        let join_key = (self.key_a)(a);
215        let lookup_key = (self.a_lookup_fn)(a);
216
217        // O(1) HashMap lookup instead of O(|C|) iteration!
218        let matches = match self.c_index.get(&(join_key, lookup_key)) {
219            Some(v) => v.as_slice(),
220            None => return Sc::zero(),
221        };
222
223        let mut total = Sc::zero();
224        for (_b_idx, c) in matches {
225            if (self.filter)(solution, a, c) {
226                total = total + self.compute_score(a, c);
227            }
228        }
229        total
230    }
231
232    fn insert_a(&mut self, solution: &S, entities_a: &[A], a_idx: usize) -> Sc {
233        if a_idx >= entities_a.len() {
234            return Sc::zero();
235        }
236
237        let a = &entities_a[a_idx];
238        let score = self.compute_a_score(solution, a);
239
240        if score != Sc::zero() {
241            self.a_scores.insert(a_idx, score);
242        }
243        score
244    }
245
246    fn retract_a(&mut self, a_idx: usize) -> Sc {
247        match self.a_scores.remove(&a_idx) {
248            Some(score) => -score,
249            None => Sc::zero(),
250        }
251    }
252}
253
254impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
255    IncrementalConstraint<S, Sc>
256    for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
257where
258    S: Send + Sync + 'static,
259    A: Clone + Send + Sync + 'static,
260    B: Clone + Send + Sync + 'static,
261    C: Clone + Send + Sync + 'static,
262    K: Eq + Hash + Clone + Send + Sync,
263    CK: Eq + Hash + Clone + Send + Sync,
264    EA: Fn(&S) -> &[A] + Send + Sync,
265    EB: Fn(&S) -> &[B] + Send + Sync,
266    KA: Fn(&A) -> K + Send + Sync,
267    KB: Fn(&B) -> K + Send + Sync,
268    Flatten: Fn(&B) -> &[C] + Send + Sync,
269    CKeyFn: Fn(&C) -> CK + Send + Sync,
270    ALookup: Fn(&A) -> CK + Send + Sync,
271    F: Fn(&S, &A, &C) -> bool + Send + Sync,
272    W: Fn(&A, &C) -> Sc + Send + Sync,
273    Sc: Score,
274{
275    fn evaluate(&self, solution: &S) -> Sc {
276        let entities_a = (self.extractor_a)(solution);
277        let entities_b = (self.extractor_b)(solution);
278        let mut total = Sc::zero();
279
280        // Build temporary index for standalone evaluation
281        let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
282        for (b_idx, b) in entities_b.iter().enumerate() {
283            let join_key = (self.key_b)(b);
284            for c in (self.flatten)(b) {
285                let c_key = (self.c_key_fn)(c);
286                temp_index
287                    .entry((join_key.clone(), c_key))
288                    .or_default()
289                    .push((b_idx, c.clone()));
290            }
291        }
292
293        for a in entities_a {
294            let join_key = (self.key_a)(a);
295            let lookup_key = (self.a_lookup_fn)(a);
296
297            if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
298                for (_b_idx, c) in matches {
299                    if (self.filter)(solution, a, c) {
300                        total = total + self.compute_score(a, c);
301                    }
302                }
303            }
304        }
305
306        total
307    }
308
309    fn match_count(&self, solution: &S) -> usize {
310        let entities_a = (self.extractor_a)(solution);
311        let entities_b = (self.extractor_b)(solution);
312        let mut count = 0;
313
314        // Build temporary index
315        let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
316        for (b_idx, b) in entities_b.iter().enumerate() {
317            let join_key = (self.key_b)(b);
318            for c in (self.flatten)(b) {
319                let c_key = (self.c_key_fn)(c);
320                temp_index
321                    .entry((join_key.clone(), c_key))
322                    .or_default()
323                    .push((b_idx, c.clone()));
324            }
325        }
326
327        for a in entities_a {
328            let join_key = (self.key_a)(a);
329            let lookup_key = (self.a_lookup_fn)(a);
330
331            if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
332                for (_b_idx, c) in matches {
333                    if (self.filter)(solution, a, c) {
334                        count += 1;
335                    }
336                }
337            }
338        }
339
340        count
341    }
342
343    fn initialize(&mut self, solution: &S) -> Sc {
344        self.reset();
345
346        let entities_a = (self.extractor_a)(solution);
347        let entities_b = (self.extractor_b)(solution);
348
349        // Build C index once: O(B × C)
350        self.build_c_index(entities_b);
351
352        // Insert all A entities: O(A) with O(1) lookups each
353        let mut total = Sc::zero();
354        for a_idx in 0..entities_a.len() {
355            total = total + self.insert_a(solution, entities_a, a_idx);
356        }
357
358        total
359    }
360
361    fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
362        let entities_a = (self.extractor_a)(solution);
363        self.insert_a(solution, entities_a, entity_index)
364    }
365
366    fn on_retract(&mut self, _solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
367        self.retract_a(entity_index)
368    }
369
370    fn reset(&mut self) {
371        self.c_index.clear();
372        self.a_scores.clear();
373    }
374
375    fn name(&self) -> &str {
376        &self.constraint_ref.name
377    }
378
379    fn is_hard(&self) -> bool {
380        self.is_hard
381    }
382
383    fn constraint_ref(&self) -> ConstraintRef {
384        self.constraint_ref.clone()
385    }
386}
387
388impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc: Score> std::fmt::Debug
389    for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
390{
391    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392        f.debug_struct("FlattenedBiConstraint")
393            .field("name", &self.constraint_ref.name)
394            .field("impact_type", &self.impact_type)
395            .field("c_index_size", &self.c_index.len())
396            .finish()
397    }
398}