Skip to main content

solverforge_scoring/constraint/
if_exists.rs

1// Zero-erasure if_exists/if_not_exists uni-constraint.
2//
3// Filters A entities based on whether a matching B entity exists in another collection.
4// The result is still a uni-constraint over A, not a bi-constraint.
5
6use std::collections::HashSet;
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::api::constraint_set::IncrementalConstraint;
14
15// Whether to include A entities that have or don't have matching B entities.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ExistenceMode {
18    // Include A if at least one matching B exists.
19    Exists,
20    // Include A if no matching B exists.
21    NotExists,
22}
23
24// Zero-erasure uni-constraint with existence filtering.
25//
26// Scores A entities based on whether a matching B entity exists (or doesn't exist)
27// in another collection. Unlike join, this produces a uni-constraint over A.
28//
29// # Type Parameters
30//
31// - `S` - Solution type
32// - `A` - Primary entity type (scored)
33// - `B` - Secondary entity type (checked for existence)
34// - `K` - Join key type
35// - `EA` - Extractor for A entities
36// - `EB` - Extractor for B entities
37// - `KA` - Key extractor for A
38// - `KB` - Key extractor for B
39// - `FA` - Filter on A entities
40// - `W` - Weight function for A entities
41// - `Sc` - Score type
42//
43// # Example
44//
45// ```
46// use solverforge_scoring::constraint::if_exists::{IfExistsUniConstraint, ExistenceMode};
47// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
48// use solverforge_core::{ConstraintRef, ImpactType};
49// use solverforge_core::score::SoftScore;
50//
51// #[derive(Clone)]
52// struct Shift { id: usize, employee_idx: Option<usize> }
53//
54// #[derive(Clone)]
55// struct Employee { id: usize, on_vacation: bool }
56//
57// #[derive(Clone)]
58// struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
59//
60// // Penalize shifts assigned to employees who are on vacation
61// let constraint = IfExistsUniConstraint::new(
62//     ConstraintRef::new("", "Vacation conflict"),
63//     ImpactType::Penalty,
64//     ExistenceMode::Exists,
65//     |s: &Schedule| s.shifts.as_slice(),
66//     |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect::<Vec<_>>(),
67//     |shift: &Shift| shift.employee_idx,
68//     |emp: &Employee| Some(emp.id),
69//     |_s: &Schedule, shift: &Shift| shift.employee_idx.is_some(),
70//     |_shift: &Shift| SoftScore::of(1),
71//     false,
72// );
73//
74// let schedule = Schedule {
75//     shifts: vec![
76//         Shift { id: 0, employee_idx: Some(0) },  // assigned to vacationing emp
77//         Shift { id: 1, employee_idx: Some(1) },  // assigned to working emp
78//         Shift { id: 2, employee_idx: None },     // unassigned
79//     ],
80//     employees: vec![
81//         Employee { id: 0, on_vacation: true },
82//         Employee { id: 1, on_vacation: false },
83//     ],
84// };
85//
86// // Only shift 0 matches (assigned to employee on vacation)
87// assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-1));
88// ```
89pub struct IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
90where
91    Sc: Score,
92{
93    constraint_ref: ConstraintRef,
94    impact_type: ImpactType,
95    mode: ExistenceMode,
96    extractor_a: EA,
97    extractor_b: EB,
98    key_a: KA,
99    key_b: KB,
100    filter_a: FA,
101    weight: W,
102    is_hard: bool,
103    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
104}
105
106impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
107    IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
108where
109    S: 'static,
110    A: Clone + 'static,
111    B: Clone + 'static,
112    K: Eq + Hash + Clone,
113    EA: Fn(&S) -> &[A],
114    EB: Fn(&S) -> Vec<B>,
115    KA: Fn(&A) -> K,
116    KB: Fn(&B) -> K,
117    FA: Fn(&S, &A) -> bool,
118    W: Fn(&A) -> Sc,
119    Sc: Score,
120{
121    // Creates a new if_exists/if_not_exists constraint.
122    #[allow(clippy::too_many_arguments)]
123    pub fn new(
124        constraint_ref: ConstraintRef,
125        impact_type: ImpactType,
126        mode: ExistenceMode,
127        extractor_a: EA,
128        extractor_b: EB,
129        key_a: KA,
130        key_b: KB,
131        filter_a: FA,
132        weight: W,
133        is_hard: bool,
134    ) -> Self {
135        Self {
136            constraint_ref,
137            impact_type,
138            mode,
139            extractor_a,
140            extractor_b,
141            key_a,
142            key_b,
143            filter_a,
144            weight,
145            is_hard,
146            _phantom: PhantomData,
147        }
148    }
149
150    #[inline]
151    fn compute_score(&self, a: &A) -> Sc {
152        let base = (self.weight)(a);
153        match self.impact_type {
154            ImpactType::Penalty => -base,
155            ImpactType::Reward => base,
156        }
157    }
158
159    fn build_b_keys(&self, solution: &S) -> HashSet<K> {
160        let entities_b = (self.extractor_b)(solution);
161        entities_b.iter().map(|b| (self.key_b)(b)).collect()
162    }
163
164    fn matches_existence(&self, a: &A, b_keys: &HashSet<K>) -> bool {
165        let key = (self.key_a)(a);
166        let exists = b_keys.contains(&key);
167        match self.mode {
168            ExistenceMode::Exists => exists,
169            ExistenceMode::NotExists => !exists,
170        }
171    }
172}
173
174impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc> IncrementalConstraint<S, Sc>
175    for IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
176where
177    S: Send + Sync + 'static,
178    A: Clone + Send + Sync + 'static,
179    B: Clone + Send + Sync + 'static,
180    K: Eq + Hash + Clone + Send + Sync,
181    EA: Fn(&S) -> &[A] + Send + Sync,
182    EB: Fn(&S) -> Vec<B> + Send + Sync,
183    KA: Fn(&A) -> K + Send + Sync,
184    KB: Fn(&B) -> K + Send + Sync,
185    FA: Fn(&S, &A) -> bool + Send + Sync,
186    W: Fn(&A) -> Sc + Send + Sync,
187    Sc: Score,
188{
189    fn evaluate(&self, solution: &S) -> Sc {
190        let entities_a = (self.extractor_a)(solution);
191        let b_keys = self.build_b_keys(solution);
192
193        let mut total = Sc::zero();
194        for a in entities_a {
195            if (self.filter_a)(solution, a) && self.matches_existence(a, &b_keys) {
196                total = total + self.compute_score(a);
197            }
198        }
199        total
200    }
201
202    fn match_count(&self, solution: &S) -> usize {
203        let entities_a = (self.extractor_a)(solution);
204        let b_keys = self.build_b_keys(solution);
205
206        entities_a
207            .iter()
208            .filter(|a| (self.filter_a)(solution, a) && self.matches_existence(a, &b_keys))
209            .count()
210    }
211
212    fn initialize(&mut self, solution: &S) -> Sc {
213        self.evaluate(solution)
214    }
215
216    fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
217        let entities_a = (self.extractor_a)(solution);
218        if entity_index >= entities_a.len() {
219            return Sc::zero();
220        }
221
222        let a = &entities_a[entity_index];
223        if !(self.filter_a)(solution, a) {
224            return Sc::zero();
225        }
226
227        let b_keys = self.build_b_keys(solution);
228        if self.matches_existence(a, &b_keys) {
229            self.compute_score(a)
230        } else {
231            Sc::zero()
232        }
233    }
234
235    fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
236        let entities_a = (self.extractor_a)(solution);
237        if entity_index >= entities_a.len() {
238            return Sc::zero();
239        }
240
241        let a = &entities_a[entity_index];
242        if !(self.filter_a)(solution, a) {
243            return Sc::zero();
244        }
245
246        let b_keys = self.build_b_keys(solution);
247        if self.matches_existence(a, &b_keys) {
248            -self.compute_score(a)
249        } else {
250            Sc::zero()
251        }
252    }
253
254    fn reset(&mut self) {
255        // No cached state to clear - we rebuild b_keys on each evaluation
256    }
257
258    fn name(&self) -> &str {
259        &self.constraint_ref.name
260    }
261
262    fn is_hard(&self) -> bool {
263        self.is_hard
264    }
265
266    fn constraint_ref(&self) -> ConstraintRef {
267        self.constraint_ref.clone()
268    }
269}
270
271impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc: Score> std::fmt::Debug
272    for IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
273{
274    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275        f.debug_struct("IfExistsUniConstraint")
276            .field("name", &self.constraint_ref.name)
277            .field("impact_type", &self.impact_type)
278            .field("mode", &self.mode)
279            .finish()
280    }
281}