solverforge_scoring/constraint/
flattened_bi.rs

1//! O(1) flattened bi-constraint for cross-entity joins.
2//!
3//! Pre-indexes C items by key for O(1) lookup on entity changes.
4
5use std::collections::HashMap;
6use std::hash::Hash;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use crate::api::constraint_set::IncrementalConstraint;
13
14/// O(1) flattened bi-constraint.
15///
16/// Given a join between A and B entities by key, this constraint:
17/// 1. Expands each B into C items via a flatten function
18/// 2. Pre-indexes C items by (join_key, c_key) for O(1) lookup
19/// 3. On A entity change, looks up matching C items in O(1) instead of O(|C|)
20///
21/// # Type Parameters
22///
23/// - `S` - Solution type
24/// - `A` - Entity type A (the planning entity, e.g., Shift)
25/// - `B` - Entity type B (the joined entity, e.g., Employee)
26/// - `C` - Flattened item type (e.g., NaiveDate from unavailable dates)
27/// - `K` - Join key type (e.g., Option<usize> for employee_idx)
28/// - `CK` - C item key type for indexing (e.g., NaiveDate)
29/// - `EA` - Extractor for A entities
30/// - `EB` - Extractor for B entities
31/// - `KA` - Key extractor for A (join key)
32/// - `KB` - Key extractor for B (join key)
33/// - `Flatten` - Function extracting &[C] from &B
34/// - `CKeyFn` - Function extracting index key from &C
35/// - `ALookup` - Function extracting lookup key from &A
36/// - `F` - Filter on (A, C) pairs
37/// - `W` - Weight function on (A, C) pairs
38/// - `Sc` - Score type
39///
40/// # Example
41///
42/// ```
43/// use solverforge_scoring::constraint::flattened_bi::FlattenedBiConstraint;
44/// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
45/// use solverforge_core::{ConstraintRef, ImpactType};
46/// use solverforge_core::score::SimpleScore;
47///
48/// #[derive(Clone)]
49/// struct Employee {
50///     id: usize,
51///     unavailable_days: Vec<u32>,
52/// }
53///
54/// #[derive(Clone)]
55/// struct Shift {
56///     employee_id: Option<usize>,
57///     day: u32,
58/// }
59///
60/// #[derive(Clone)]
61/// struct Schedule {
62///     shifts: Vec<Shift>,
63///     employees: Vec<Employee>,
64/// }
65///
66/// let constraint = FlattenedBiConstraint::new(
67///     ConstraintRef::new("", "Unavailable employee"),
68///     ImpactType::Penalty,
69///     |s: &Schedule| s.shifts.as_slice(),
70///     |s: &Schedule| s.employees.as_slice(),
71///     |shift: &Shift| shift.employee_id,
72///     |emp: &Employee| Some(emp.id),
73///     |emp: &Employee| emp.unavailable_days.as_slice(),
74///     |day: &u32| *day,           // C → index key
75///     |shift: &Shift| shift.day,  // A → lookup key
76///     |_s: &Schedule, shift: &Shift, day: &u32| shift.day == *day,
77///     |_shift: &Shift, _day: &u32| SimpleScore::of(1),
78///     false,
79/// );
80///
81/// let schedule = Schedule {
82///     shifts: vec![
83///         Shift { employee_id: Some(0), day: 5 },
84///         Shift { employee_id: Some(0), day: 10 },
85///     ],
86///     employees: vec![
87///         Employee { id: 0, unavailable_days: vec![5, 15] },
88///     ],
89/// };
90///
91/// // Day 5 shift conflicts with employee's unavailable day 5 → O(1) lookup!
92/// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
93/// ```
94pub struct FlattenedBiConstraint<
95    S,
96    A,
97    B,
98    C,
99    K,
100    CK,
101    EA,
102    EB,
103    KA,
104    KB,
105    Flatten,
106    CKeyFn,
107    ALookup,
108    F,
109    W,
110    Sc,
111> where
112    Sc: Score,
113{
114    constraint_ref: ConstraintRef,
115    impact_type: ImpactType,
116    extractor_a: EA,
117    extractor_b: EB,
118    key_a: KA,
119    key_b: KB,
120    flatten: Flatten,
121    c_key_fn: CKeyFn,
122    a_lookup_fn: ALookup,
123    filter: F,
124    weight: W,
125    is_hard: bool,
126    /// (join_key, c_key) → list of (b_idx, c_value) for O(1) lookup
127    c_index: HashMap<(K, CK), Vec<(usize, C)>>,
128    /// A index → cached score for this entity's matches
129    a_scores: HashMap<usize, Sc>,
130    _phantom: PhantomData<(S, A, B)>,
131}
132
133impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
134    FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
135where
136    S: 'static,
137    A: Clone + 'static,
138    B: Clone + 'static,
139    C: Clone + 'static,
140    K: Eq + Hash + Clone,
141    CK: Eq + Hash + Clone,
142    EA: Fn(&S) -> &[A],
143    EB: Fn(&S) -> &[B],
144    KA: Fn(&A) -> K,
145    KB: Fn(&B) -> K,
146    Flatten: Fn(&B) -> &[C],
147    CKeyFn: Fn(&C) -> CK,
148    ALookup: Fn(&A) -> CK,
149    F: Fn(&S, &A, &C) -> bool,
150    W: Fn(&A, &C) -> Sc,
151    Sc: Score,
152{
153    /// Creates a new O(1) flattened bi-constraint.
154    #[allow(clippy::too_many_arguments)]
155    pub fn new(
156        constraint_ref: ConstraintRef,
157        impact_type: ImpactType,
158        extractor_a: EA,
159        extractor_b: EB,
160        key_a: KA,
161        key_b: KB,
162        flatten: Flatten,
163        c_key_fn: CKeyFn,
164        a_lookup_fn: ALookup,
165        filter: F,
166        weight: W,
167        is_hard: bool,
168    ) -> Self {
169        Self {
170            constraint_ref,
171            impact_type,
172            extractor_a,
173            extractor_b,
174            key_a,
175            key_b,
176            flatten,
177            c_key_fn,
178            a_lookup_fn,
179            filter,
180            weight,
181            is_hard,
182            c_index: HashMap::new(),
183            a_scores: HashMap::new(),
184            _phantom: PhantomData,
185        }
186    }
187
188    #[inline]
189    fn compute_score(&self, a: &A, c: &C) -> Sc {
190        let base = (self.weight)(a, c);
191        match self.impact_type {
192            ImpactType::Penalty => -base,
193            ImpactType::Reward => base,
194        }
195    }
196
197    /// Build C index: (join_key, c_key) → list of (b_idx, c_value)
198    fn build_c_index(&mut self, entities_b: &[B]) {
199        self.c_index.clear();
200        for (b_idx, b) in entities_b.iter().enumerate() {
201            let join_key = (self.key_b)(b);
202            for c in (self.flatten)(b) {
203                let c_key = (self.c_key_fn)(c);
204                self.c_index
205                    .entry((join_key.clone(), c_key))
206                    .or_default()
207                    .push((b_idx, c.clone()));
208            }
209        }
210    }
211
212    /// Compute score for entity A using O(1) index lookup.
213    fn compute_a_score(&self, solution: &S, a: &A) -> Sc {
214        let join_key = (self.key_a)(a);
215        let lookup_key = (self.a_lookup_fn)(a);
216
217        // O(1) HashMap lookup instead of O(|C|) iteration!
218        let matches = match self.c_index.get(&(join_key, lookup_key)) {
219            Some(v) => v.as_slice(),
220            None => return Sc::zero(),
221        };
222
223        let mut total = Sc::zero();
224        for (_b_idx, c) in matches {
225            if (self.filter)(solution, a, c) {
226                total = total + self.compute_score(a, c);
227            }
228        }
229        total
230    }
231
232    fn insert_a(&mut self, solution: &S, entities_a: &[A], a_idx: usize) -> Sc {
233        if a_idx >= entities_a.len() {
234            return Sc::zero();
235        }
236
237        let a = &entities_a[a_idx];
238        let score = self.compute_a_score(solution, a);
239
240        if score != Sc::zero() {
241            self.a_scores.insert(a_idx, score);
242        }
243        score
244    }
245
246    fn retract_a(&mut self, a_idx: usize) -> Sc {
247        match self.a_scores.remove(&a_idx) {
248            Some(score) => -score,
249            None => Sc::zero(),
250        }
251    }
252}
253
254impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
255    IncrementalConstraint<S, Sc>
256    for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
257where
258    S: Send + Sync + 'static,
259    A: Clone + Send + Sync + 'static,
260    B: Clone + Send + Sync + 'static,
261    C: Clone + Send + Sync + 'static,
262    K: Eq + Hash + Clone + Send + Sync,
263    CK: Eq + Hash + Clone + Send + Sync,
264    EA: Fn(&S) -> &[A] + Send + Sync,
265    EB: Fn(&S) -> &[B] + Send + Sync,
266    KA: Fn(&A) -> K + Send + Sync,
267    KB: Fn(&B) -> K + Send + Sync,
268    Flatten: Fn(&B) -> &[C] + Send + Sync,
269    CKeyFn: Fn(&C) -> CK + Send + Sync,
270    ALookup: Fn(&A) -> CK + Send + Sync,
271    F: Fn(&S, &A, &C) -> bool + Send + Sync,
272    W: Fn(&A, &C) -> Sc + Send + Sync,
273    Sc: Score,
274{
275    fn evaluate(&self, solution: &S) -> Sc {
276        let entities_a = (self.extractor_a)(solution);
277        let entities_b = (self.extractor_b)(solution);
278        let mut total = Sc::zero();
279
280        // Build temporary index for standalone evaluation
281        let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
282        for (b_idx, b) in entities_b.iter().enumerate() {
283            let join_key = (self.key_b)(b);
284            for c in (self.flatten)(b) {
285                let c_key = (self.c_key_fn)(c);
286                temp_index
287                    .entry((join_key.clone(), c_key))
288                    .or_default()
289                    .push((b_idx, c.clone()));
290            }
291        }
292
293        for a in entities_a {
294            let join_key = (self.key_a)(a);
295            let lookup_key = (self.a_lookup_fn)(a);
296
297            if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
298                for (_b_idx, c) in matches {
299                    if (self.filter)(solution, a, c) {
300                        total = total + self.compute_score(a, c);
301                    }
302                }
303            }
304        }
305
306        total
307    }
308
309    fn match_count(&self, solution: &S) -> usize {
310        let entities_a = (self.extractor_a)(solution);
311        let entities_b = (self.extractor_b)(solution);
312        let mut count = 0;
313
314        // Build temporary index
315        let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
316        for (b_idx, b) in entities_b.iter().enumerate() {
317            let join_key = (self.key_b)(b);
318            for c in (self.flatten)(b) {
319                let c_key = (self.c_key_fn)(c);
320                temp_index
321                    .entry((join_key.clone(), c_key))
322                    .or_default()
323                    .push((b_idx, c.clone()));
324            }
325        }
326
327        for a in entities_a {
328            let join_key = (self.key_a)(a);
329            let lookup_key = (self.a_lookup_fn)(a);
330
331            if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
332                for (_b_idx, c) in matches {
333                    if (self.filter)(solution, a, c) {
334                        count += 1;
335                    }
336                }
337            }
338        }
339
340        count
341    }
342
343    fn initialize(&mut self, solution: &S) -> Sc {
344        self.reset();
345
346        let entities_a = (self.extractor_a)(solution);
347        let entities_b = (self.extractor_b)(solution);
348
349        // Build C index once: O(B × C)
350        self.build_c_index(entities_b);
351
352        // Insert all A entities: O(A) with O(1) lookups each
353        let mut total = Sc::zero();
354        for a_idx in 0..entities_a.len() {
355            total = total + self.insert_a(solution, entities_a, a_idx);
356        }
357
358        total
359    }
360
361    fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
362        let entities_a = (self.extractor_a)(solution);
363        self.insert_a(solution, entities_a, entity_index)
364    }
365
366    fn on_retract(&mut self, _solution: &S, entity_index: usize) -> Sc {
367        self.retract_a(entity_index)
368    }
369
370    fn reset(&mut self) {
371        self.c_index.clear();
372        self.a_scores.clear();
373    }
374
375    fn name(&self) -> &str {
376        &self.constraint_ref.name
377    }
378
379    fn is_hard(&self) -> bool {
380        self.is_hard
381    }
382
383    fn constraint_ref(&self) -> ConstraintRef {
384        self.constraint_ref.clone()
385    }
386}
387
388impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc: Score> std::fmt::Debug
389    for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
390{
391    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392        f.debug_struct("FlattenedBiConstraint")
393            .field("name", &self.constraint_ref.name)
394            .field("impact_type", &self.impact_type)
395            .field("c_index_size", &self.c_index.len())
396            .finish()
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use solverforge_core::score::SimpleScore;
404
405    #[derive(Clone)]
406    struct Employee {
407        id: usize,
408        unavailable_days: Vec<u32>,
409    }
410
411    #[derive(Clone)]
412    struct Shift {
413        employee_id: Option<usize>,
414        day: u32,
415    }
416
417    #[derive(Clone)]
418    struct Schedule {
419        shifts: Vec<Shift>,
420        employees: Vec<Employee>,
421    }
422
423    fn create_test_constraint() -> FlattenedBiConstraint<
424        Schedule,
425        Shift,
426        Employee,
427        u32,
428        Option<usize>,
429        u32,
430        impl Fn(&Schedule) -> &[Shift],
431        impl Fn(&Schedule) -> &[Employee],
432        impl Fn(&Shift) -> Option<usize>,
433        impl Fn(&Employee) -> Option<usize>,
434        impl Fn(&Employee) -> &[u32],
435        impl Fn(&u32) -> u32,
436        impl Fn(&Shift) -> u32,
437        impl Fn(&Schedule, &Shift, &u32) -> bool,
438        impl Fn(&Shift, &u32) -> SimpleScore,
439        SimpleScore,
440    > {
441        FlattenedBiConstraint::new(
442            ConstraintRef::new("", "Unavailable employee"),
443            ImpactType::Penalty,
444            |s: &Schedule| s.shifts.as_slice(),
445            |s: &Schedule| s.employees.as_slice(),
446            |shift: &Shift| shift.employee_id,
447            |emp: &Employee| Some(emp.id),
448            |emp: &Employee| emp.unavailable_days.as_slice(),
449            |day: &u32| *day,
450            |shift: &Shift| shift.day,
451            |_s: &Schedule, shift: &Shift, day: &u32| {
452                shift.employee_id.is_some() && shift.day == *day
453            },
454            |_shift: &Shift, _day: &u32| SimpleScore::of(1),
455            false,
456        )
457    }
458
459    #[test]
460    fn test_evaluate_single_match() {
461        let constraint = create_test_constraint();
462        let schedule = Schedule {
463            shifts: vec![
464                Shift {
465                    employee_id: Some(0),
466                    day: 5,
467                },
468                Shift {
469                    employee_id: Some(0),
470                    day: 10,
471                },
472            ],
473            employees: vec![Employee {
474                id: 0,
475                unavailable_days: vec![5, 15],
476            }],
477        };
478
479        // Day 5 shift conflicts with employee's unavailable day 5
480        assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
481    }
482
483    #[test]
484    fn test_evaluate_no_match() {
485        let constraint = create_test_constraint();
486        let schedule = Schedule {
487            shifts: vec![Shift {
488                employee_id: Some(0),
489                day: 10,
490            }],
491            employees: vec![Employee {
492                id: 0,
493                unavailable_days: vec![5, 15],
494            }],
495        };
496
497        // Day 10 doesn't conflict
498        assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(0));
499    }
500
501    #[test]
502    fn test_incremental() {
503        let mut constraint = create_test_constraint();
504        let schedule = Schedule {
505            shifts: vec![
506                Shift {
507                    employee_id: Some(0),
508                    day: 5,
509                }, // Conflicts
510                Shift {
511                    employee_id: Some(0),
512                    day: 10,
513                }, // No conflict
514            ],
515            employees: vec![Employee {
516                id: 0,
517                unavailable_days: vec![5, 15],
518            }],
519        };
520
521        // Initialize
522        let initial = constraint.initialize(&schedule);
523        assert_eq!(initial, SimpleScore::of(-1));
524
525        // Retract conflicting shift
526        let delta = constraint.on_retract(&schedule, 0);
527        assert_eq!(delta, SimpleScore::of(1)); // Removing penalty
528
529        // Re-insert it
530        let delta = constraint.on_insert(&schedule, 0);
531        assert_eq!(delta, SimpleScore::of(-1)); // Adding penalty back
532    }
533
534    #[test]
535    fn test_unassigned_shift() {
536        let constraint = create_test_constraint();
537        let schedule = Schedule {
538            shifts: vec![Shift {
539                employee_id: None, // Unassigned
540                day: 5,
541            }],
542            employees: vec![Employee {
543                id: 0,
544                unavailable_days: vec![5],
545            }],
546        };
547
548        // Unassigned shift doesn't match
549        assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(0));
550    }
551}