Skip to main content

solverforge_scoring/constraint/
if_exists.rs

1/* Zero-erasure if_exists/if_not_exists uni-constraint.
2
3Filters A entities based on whether a matching B entity exists in another collection.
4The result is still a uni-constraint over A, not a bi-constraint.
5*/
6
7use std::collections::HashSet;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::constraint_set::IncrementalConstraint;
15
16// Whether to include A entities that have or don't have matching B entities.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ExistenceMode {
19    // Include A if at least one matching B exists.
20    Exists,
21    // Include A if no matching B exists.
22    NotExists,
23}
24
25/* Zero-erasure uni-constraint with existence filtering.
26
27Scores A entities based on whether a matching B entity exists (or doesn't exist)
28in another collection. Unlike join, this produces a uni-constraint over A.
29
30# Type Parameters
31
32- `S` - Solution type
33- `A` - Primary entity type (scored)
34- `B` - Secondary entity type (checked for existence)
35- `K` - Join key type
36- `EA` - Extractor for A entities
37- `EB` - Extractor for B entities
38- `KA` - Key extractor for A
39- `KB` - Key extractor for B
40- `FA` - Filter on A entities
41- `W` - Weight function for A entities
42- `Sc` - Score type
43
44# Example
45
46```
47use solverforge_scoring::constraint::if_exists::{IfExistsUniConstraint, ExistenceMode};
48use solverforge_scoring::api::constraint_set::IncrementalConstraint;
49use solverforge_core::{ConstraintRef, ImpactType};
50use solverforge_core::score::SoftScore;
51
52#[derive(Clone)]
53struct Shift { id: usize, employee_idx: Option<usize> }
54
55#[derive(Clone)]
56struct Employee { id: usize, on_vacation: bool }
57
58#[derive(Clone)]
59struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
60
61// Penalize shifts assigned to employees who are on vacation
62let constraint = IfExistsUniConstraint::new(
63ConstraintRef::new("", "Vacation conflict"),
64ImpactType::Penalty,
65ExistenceMode::Exists,
66|s: &Schedule| s.shifts.as_slice(),
67|s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect::<Vec<_>>(),
68|shift: &Shift| shift.employee_idx,
69|emp: &Employee| Some(emp.id),
70|_s: &Schedule, shift: &Shift| shift.employee_idx.is_some(),
71|_shift: &Shift| SoftScore::of(1),
72false,
73);
74
75let schedule = Schedule {
76shifts: vec![
77Shift { id: 0, employee_idx: Some(0) },  // assigned to vacationing emp
78Shift { id: 1, employee_idx: Some(1) },  // assigned to working emp
79Shift { id: 2, employee_idx: None },     // unassigned
80],
81employees: vec![
82Employee { id: 0, on_vacation: true },
83Employee { id: 1, on_vacation: false },
84],
85};
86
87// Only shift 0 matches (assigned to employee on vacation)
88assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-1));
89```
90*/
91pub struct IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
92where
93    Sc: Score,
94{
95    constraint_ref: ConstraintRef,
96    impact_type: ImpactType,
97    mode: ExistenceMode,
98    extractor_a: EA,
99    extractor_b: EB,
100    key_a: KA,
101    key_b: KB,
102    filter_a: FA,
103    weight: W,
104    is_hard: bool,
105    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
106}
107
108impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
109    IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
110where
111    S: 'static,
112    A: Clone + 'static,
113    B: Clone + 'static,
114    K: Eq + Hash + Clone,
115    EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
116    EB: Fn(&S) -> Vec<B>,
117    KA: Fn(&A) -> K,
118    KB: Fn(&B) -> K,
119    FA: Fn(&S, &A) -> bool,
120    W: Fn(&A) -> Sc,
121    Sc: Score,
122{
123    // Creates a new if_exists/if_not_exists constraint.
124    #[allow(clippy::too_many_arguments)]
125    pub fn new(
126        constraint_ref: ConstraintRef,
127        impact_type: ImpactType,
128        mode: ExistenceMode,
129        extractor_a: EA,
130        extractor_b: EB,
131        key_a: KA,
132        key_b: KB,
133        filter_a: FA,
134        weight: W,
135        is_hard: bool,
136    ) -> Self {
137        Self {
138            constraint_ref,
139            impact_type,
140            mode,
141            extractor_a,
142            extractor_b,
143            key_a,
144            key_b,
145            filter_a,
146            weight,
147            is_hard,
148            _phantom: PhantomData,
149        }
150    }
151
152    #[inline]
153    fn compute_score(&self, a: &A) -> Sc {
154        let base = (self.weight)(a);
155        match self.impact_type {
156            ImpactType::Penalty => -base,
157            ImpactType::Reward => base,
158        }
159    }
160
161    fn build_b_keys(&self, solution: &S) -> HashSet<K> {
162        let entities_b = (self.extractor_b)(solution);
163        entities_b.iter().map(|b| (self.key_b)(b)).collect()
164    }
165
166    fn matches_existence(&self, a: &A, b_keys: &HashSet<K>) -> bool {
167        let key = (self.key_a)(a);
168        let exists = b_keys.contains(&key);
169        match self.mode {
170            ExistenceMode::Exists => exists,
171            ExistenceMode::NotExists => !exists,
172        }
173    }
174}
175
176impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc> IncrementalConstraint<S, Sc>
177    for IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
178where
179    S: Send + Sync + 'static,
180    A: Clone + Send + Sync + 'static,
181    B: Clone + Send + Sync + 'static,
182    K: Eq + Hash + Clone + Send + Sync,
183    EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
184    EB: Fn(&S) -> Vec<B> + Send + Sync,
185    KA: Fn(&A) -> K + Send + Sync,
186    KB: Fn(&B) -> K + Send + Sync,
187    FA: Fn(&S, &A) -> bool + Send + Sync,
188    W: Fn(&A) -> Sc + Send + Sync,
189    Sc: Score,
190{
191    fn evaluate(&self, solution: &S) -> Sc {
192        let entities_a = self.extractor_a.extract(solution);
193        let b_keys = self.build_b_keys(solution);
194
195        let mut total = Sc::zero();
196        for a in entities_a {
197            if (self.filter_a)(solution, a) && self.matches_existence(a, &b_keys) {
198                total = total + self.compute_score(a);
199            }
200        }
201        total
202    }
203
204    fn match_count(&self, solution: &S) -> usize {
205        let entities_a = self.extractor_a.extract(solution);
206        let b_keys = self.build_b_keys(solution);
207
208        entities_a
209            .iter()
210            .filter(|a| (self.filter_a)(solution, a) && self.matches_existence(a, &b_keys))
211            .count()
212    }
213
214    fn initialize(&mut self, solution: &S) -> Sc {
215        self.evaluate(solution)
216    }
217
218    fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
219        let entities_a = self.extractor_a.extract(solution);
220        if entity_index >= entities_a.len() {
221            return Sc::zero();
222        }
223
224        let a = &entities_a[entity_index];
225        if !(self.filter_a)(solution, a) {
226            return Sc::zero();
227        }
228
229        let b_keys = self.build_b_keys(solution);
230        if self.matches_existence(a, &b_keys) {
231            self.compute_score(a)
232        } else {
233            Sc::zero()
234        }
235    }
236
237    fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
238        let entities_a = self.extractor_a.extract(solution);
239        if entity_index >= entities_a.len() {
240            return Sc::zero();
241        }
242
243        let a = &entities_a[entity_index];
244        if !(self.filter_a)(solution, a) {
245            return Sc::zero();
246        }
247
248        let b_keys = self.build_b_keys(solution);
249        if self.matches_existence(a, &b_keys) {
250            -self.compute_score(a)
251        } else {
252            Sc::zero()
253        }
254    }
255
256    fn reset(&mut self) {
257        // No cached state to clear - we rebuild b_keys on each evaluation
258    }
259
260    fn name(&self) -> &str {
261        &self.constraint_ref.name
262    }
263
264    fn is_hard(&self) -> bool {
265        self.is_hard
266    }
267
268    fn constraint_ref(&self) -> ConstraintRef {
269        self.constraint_ref.clone()
270    }
271}
272
273impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc: Score> std::fmt::Debug
274    for IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
275{
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        f.debug_struct("IfExistsUniConstraint")
278            .field("name", &self.constraint_ref.name)
279            .field("impact_type", &self.impact_type)
280            .field("mode", &self.mode)
281            .finish()
282    }
283}