solverforge_scoring/constraint/
if_exists.rs

1//! Zero-erasure if_exists/if_not_exists uni-constraint.
2//!
3//! Filters A entities based on whether a matching B entity exists in another collection.
4//! The result is still a uni-constraint over A, not a bi-constraint.
5
6use std::collections::HashSet;
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::api::constraint_set::IncrementalConstraint;
14
15/// Whether to include A entities that have or don't have matching B entities.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ExistenceMode {
18    /// Include A if at least one matching B exists.
19    Exists,
20    /// Include A if no matching B exists.
21    NotExists,
22}
23
24/// Zero-erasure uni-constraint with existence filtering.
25///
26/// Scores A entities based on whether a matching B entity exists (or doesn't exist)
27/// in another collection. Unlike join, this produces a uni-constraint over A.
28///
29/// # Type Parameters
30///
31/// - `S` - Solution type
32/// - `A` - Primary entity type (scored)
33/// - `B` - Secondary entity type (checked for existence)
34/// - `K` - Join key type
35/// - `EA` - Extractor for A entities
36/// - `EB` - Extractor for B entities
37/// - `KA` - Key extractor for A
38/// - `KB` - Key extractor for B
39/// - `FA` - Filter on A entities
40/// - `W` - Weight function for A entities
41/// - `Sc` - Score type
42///
43/// # Example
44///
45/// ```
46/// use solverforge_scoring::constraint::if_exists::{IfExistsUniConstraint, ExistenceMode};
47/// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
48/// use solverforge_core::{ConstraintRef, ImpactType};
49/// use solverforge_core::score::SimpleScore;
50///
51/// #[derive(Clone)]
52/// struct Shift { id: usize, employee_idx: Option<usize> }
53///
54/// #[derive(Clone)]
55/// struct Employee { id: usize, on_vacation: bool }
56///
57/// #[derive(Clone)]
58/// struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
59///
60/// // Penalize shifts assigned to employees who are on vacation
61/// let constraint = IfExistsUniConstraint::new(
62///     ConstraintRef::new("", "Vacation conflict"),
63///     ImpactType::Penalty,
64///     ExistenceMode::Exists,
65///     |s: &Schedule| s.shifts.as_slice(),
66///     |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect::<Vec<_>>(),
67///     |shift: &Shift| shift.employee_idx,
68///     |emp: &Employee| Some(emp.id),
69///     |shift: &Shift| shift.employee_idx.is_some(),
70///     |_shift: &Shift| SimpleScore::of(1),
71///     false,
72/// );
73///
74/// let schedule = Schedule {
75///     shifts: vec![
76///         Shift { id: 0, employee_idx: Some(0) },  // assigned to vacationing emp
77///         Shift { id: 1, employee_idx: Some(1) },  // assigned to working emp
78///         Shift { id: 2, employee_idx: None },     // unassigned
79///     ],
80///     employees: vec![
81///         Employee { id: 0, on_vacation: true },
82///         Employee { id: 1, on_vacation: false },
83///     ],
84/// };
85///
86/// // Only shift 0 matches (assigned to employee on vacation)
87/// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
88/// ```
89pub struct IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
90where
91    Sc: Score,
92{
93    constraint_ref: ConstraintRef,
94    impact_type: ImpactType,
95    mode: ExistenceMode,
96    extractor_a: EA,
97    extractor_b: EB,
98    key_a: KA,
99    key_b: KB,
100    filter_a: FA,
101    weight: W,
102    is_hard: bool,
103    _phantom: PhantomData<(S, A, B, K, Sc)>,
104}
105
106impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
107    IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
108where
109    S: 'static,
110    A: Clone + 'static,
111    B: Clone + 'static,
112    K: Eq + Hash + Clone,
113    EA: Fn(&S) -> &[A],
114    EB: Fn(&S) -> Vec<B>,
115    KA: Fn(&A) -> K,
116    KB: Fn(&B) -> K,
117    FA: Fn(&A) -> bool,
118    W: Fn(&A) -> Sc,
119    Sc: Score,
120{
121    /// Creates a new if_exists/if_not_exists constraint.
122    #[allow(clippy::too_many_arguments)]
123    pub fn new(
124        constraint_ref: ConstraintRef,
125        impact_type: ImpactType,
126        mode: ExistenceMode,
127        extractor_a: EA,
128        extractor_b: EB,
129        key_a: KA,
130        key_b: KB,
131        filter_a: FA,
132        weight: W,
133        is_hard: bool,
134    ) -> Self {
135        Self {
136            constraint_ref,
137            impact_type,
138            mode,
139            extractor_a,
140            extractor_b,
141            key_a,
142            key_b,
143            filter_a,
144            weight,
145            is_hard,
146            _phantom: PhantomData,
147        }
148    }
149
150    #[inline]
151    fn compute_score(&self, a: &A) -> Sc {
152        let base = (self.weight)(a);
153        match self.impact_type {
154            ImpactType::Penalty => -base,
155            ImpactType::Reward => base,
156        }
157    }
158
159    fn build_b_keys(&self, solution: &S) -> HashSet<K> {
160        let entities_b = (self.extractor_b)(solution);
161        entities_b.iter().map(|b| (self.key_b)(b)).collect()
162    }
163
164    fn matches_existence(&self, a: &A, b_keys: &HashSet<K>) -> bool {
165        let key = (self.key_a)(a);
166        let exists = b_keys.contains(&key);
167        match self.mode {
168            ExistenceMode::Exists => exists,
169            ExistenceMode::NotExists => !exists,
170        }
171    }
172}
173
174impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc> IncrementalConstraint<S, Sc>
175    for IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
176where
177    S: Send + Sync + 'static,
178    A: Clone + Send + Sync + 'static,
179    B: Clone + Send + Sync + 'static,
180    K: Eq + Hash + Clone + Send + Sync,
181    EA: Fn(&S) -> &[A] + Send + Sync,
182    EB: Fn(&S) -> Vec<B> + Send + Sync,
183    KA: Fn(&A) -> K + Send + Sync,
184    KB: Fn(&B) -> K + Send + Sync,
185    FA: Fn(&A) -> bool + Send + Sync,
186    W: Fn(&A) -> Sc + Send + Sync,
187    Sc: Score,
188{
189    fn evaluate(&self, solution: &S) -> Sc {
190        let entities_a = (self.extractor_a)(solution);
191        let b_keys = self.build_b_keys(solution);
192
193        let mut total = Sc::zero();
194        for a in entities_a {
195            if (self.filter_a)(a) && self.matches_existence(a, &b_keys) {
196                total = total + self.compute_score(a);
197            }
198        }
199        total
200    }
201
202    fn match_count(&self, solution: &S) -> usize {
203        let entities_a = (self.extractor_a)(solution);
204        let b_keys = self.build_b_keys(solution);
205
206        entities_a
207            .iter()
208            .filter(|a| (self.filter_a)(a) && self.matches_existence(a, &b_keys))
209            .count()
210    }
211
212    fn initialize(&mut self, solution: &S) -> Sc {
213        self.evaluate(solution)
214    }
215
216    fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
217        let entities_a = (self.extractor_a)(solution);
218        if entity_index >= entities_a.len() {
219            return Sc::zero();
220        }
221
222        let a = &entities_a[entity_index];
223        if !(self.filter_a)(a) {
224            return Sc::zero();
225        }
226
227        let b_keys = self.build_b_keys(solution);
228        if self.matches_existence(a, &b_keys) {
229            self.compute_score(a)
230        } else {
231            Sc::zero()
232        }
233    }
234
235    fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
236        let entities_a = (self.extractor_a)(solution);
237        if entity_index >= entities_a.len() {
238            return Sc::zero();
239        }
240
241        let a = &entities_a[entity_index];
242        if !(self.filter_a)(a) {
243            return Sc::zero();
244        }
245
246        let b_keys = self.build_b_keys(solution);
247        if self.matches_existence(a, &b_keys) {
248            -self.compute_score(a)
249        } else {
250            Sc::zero()
251        }
252    }
253
254    fn reset(&mut self) {
255        // No cached state to clear - we rebuild b_keys on each evaluation
256    }
257
258    fn name(&self) -> &str {
259        &self.constraint_ref.name
260    }
261
262    fn is_hard(&self) -> bool {
263        self.is_hard
264    }
265
266    fn constraint_ref(&self) -> ConstraintRef {
267        self.constraint_ref.clone()
268    }
269}
270
271impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc: Score> std::fmt::Debug
272    for IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
273{
274    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275        f.debug_struct("IfExistsUniConstraint")
276            .field("name", &self.constraint_ref.name)
277            .field("impact_type", &self.impact_type)
278            .field("mode", &self.mode)
279            .finish()
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use solverforge_core::score::SimpleScore;
287
288    #[derive(Clone)]
289    struct Task {
290        _id: usize,
291        assignee: Option<usize>,
292    }
293
294    #[derive(Clone)]
295    struct Worker {
296        id: usize,
297        available: bool,
298    }
299
300    #[derive(Clone)]
301    struct Schedule {
302        tasks: Vec<Task>,
303        workers: Vec<Worker>,
304    }
305
306    #[test]
307    fn test_if_exists_penalizes_assigned_to_unavailable() {
308        // Penalize tasks assigned to unavailable workers
309        let constraint = IfExistsUniConstraint::new(
310            ConstraintRef::new("", "Unavailable worker"),
311            ImpactType::Penalty,
312            ExistenceMode::Exists,
313            |s: &Schedule| s.tasks.as_slice(),
314            |s: &Schedule| s.workers.iter().filter(|w| !w.available).cloned().collect(),
315            |t: &Task| t.assignee,
316            |w: &Worker| Some(w.id),
317            |t: &Task| t.assignee.is_some(),
318            |_t: &Task| SimpleScore::of(1),
319            false,
320        );
321
322        let schedule = Schedule {
323            tasks: vec![
324                Task {
325                    _id: 0,
326                    assignee: Some(0),
327                }, // assigned to unavailable
328                Task {
329                    _id: 1,
330                    assignee: Some(1),
331                }, // assigned to available
332                Task {
333                    _id: 2,
334                    assignee: None,
335                }, // unassigned
336            ],
337            workers: vec![
338                Worker {
339                    id: 0,
340                    available: false,
341                },
342                Worker {
343                    id: 1,
344                    available: true,
345                },
346            ],
347        };
348
349        // Only task 0 matches
350        assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
351        assert_eq!(constraint.match_count(&schedule), 1);
352    }
353
354    #[test]
355    fn test_if_not_exists_penalizes_unassigned() {
356        // Penalize tasks not assigned to any available worker
357        let constraint = IfExistsUniConstraint::new(
358            ConstraintRef::new("", "No available worker"),
359            ImpactType::Penalty,
360            ExistenceMode::NotExists,
361            |s: &Schedule| s.tasks.as_slice(),
362            |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
363            |t: &Task| t.assignee,
364            |w: &Worker| Some(w.id),
365            |t: &Task| t.assignee.is_some(),
366            |_t: &Task| SimpleScore::of(1),
367            false,
368        );
369
370        let schedule = Schedule {
371            tasks: vec![
372                Task {
373                    _id: 0,
374                    assignee: Some(0),
375                }, // assigned to unavailable - no match in available
376                Task {
377                    _id: 1,
378                    assignee: Some(1),
379                }, // assigned to available
380                Task {
381                    _id: 2,
382                    assignee: None,
383                }, // unassigned - filtered out by filter_a
384            ],
385            workers: vec![
386                Worker {
387                    id: 0,
388                    available: false,
389                },
390                Worker {
391                    id: 1,
392                    available: true,
393                },
394            ],
395        };
396
397        // Task 0 is assigned but worker 0 is not available
398        assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
399        assert_eq!(constraint.match_count(&schedule), 1);
400    }
401}