1use std::collections::HashSet;
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::api::constraint_set::IncrementalConstraint;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ExistenceMode {
18 Exists,
20 NotExists,
22}
23
24pub struct IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
90where
91 Sc: Score,
92{
93 constraint_ref: ConstraintRef,
94 impact_type: ImpactType,
95 mode: ExistenceMode,
96 extractor_a: EA,
97 extractor_b: EB,
98 key_a: KA,
99 key_b: KB,
100 filter_a: FA,
101 weight: W,
102 is_hard: bool,
103 _phantom: PhantomData<(S, A, B, K, Sc)>,
104}
105
106impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
107 IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
108where
109 S: 'static,
110 A: Clone + 'static,
111 B: Clone + 'static,
112 K: Eq + Hash + Clone,
113 EA: Fn(&S) -> &[A],
114 EB: Fn(&S) -> Vec<B>,
115 KA: Fn(&A) -> K,
116 KB: Fn(&B) -> K,
117 FA: Fn(&S, &A) -> bool,
118 W: Fn(&A) -> Sc,
119 Sc: Score,
120{
121 #[allow(clippy::too_many_arguments)]
123 pub fn new(
124 constraint_ref: ConstraintRef,
125 impact_type: ImpactType,
126 mode: ExistenceMode,
127 extractor_a: EA,
128 extractor_b: EB,
129 key_a: KA,
130 key_b: KB,
131 filter_a: FA,
132 weight: W,
133 is_hard: bool,
134 ) -> Self {
135 Self {
136 constraint_ref,
137 impact_type,
138 mode,
139 extractor_a,
140 extractor_b,
141 key_a,
142 key_b,
143 filter_a,
144 weight,
145 is_hard,
146 _phantom: PhantomData,
147 }
148 }
149
150 #[inline]
151 fn compute_score(&self, a: &A) -> Sc {
152 let base = (self.weight)(a);
153 match self.impact_type {
154 ImpactType::Penalty => -base,
155 ImpactType::Reward => base,
156 }
157 }
158
159 fn build_b_keys(&self, solution: &S) -> HashSet<K> {
160 let entities_b = (self.extractor_b)(solution);
161 entities_b.iter().map(|b| (self.key_b)(b)).collect()
162 }
163
164 fn matches_existence(&self, a: &A, b_keys: &HashSet<K>) -> bool {
165 let key = (self.key_a)(a);
166 let exists = b_keys.contains(&key);
167 match self.mode {
168 ExistenceMode::Exists => exists,
169 ExistenceMode::NotExists => !exists,
170 }
171 }
172}
173
174impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc> IncrementalConstraint<S, Sc>
175 for IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
176where
177 S: Send + Sync + 'static,
178 A: Clone + Send + Sync + 'static,
179 B: Clone + Send + Sync + 'static,
180 K: Eq + Hash + Clone + Send + Sync,
181 EA: Fn(&S) -> &[A] + Send + Sync,
182 EB: Fn(&S) -> Vec<B> + Send + Sync,
183 KA: Fn(&A) -> K + Send + Sync,
184 KB: Fn(&B) -> K + Send + Sync,
185 FA: Fn(&S, &A) -> bool + Send + Sync,
186 W: Fn(&A) -> Sc + Send + Sync,
187 Sc: Score,
188{
189 fn evaluate(&self, solution: &S) -> Sc {
190 let entities_a = (self.extractor_a)(solution);
191 let b_keys = self.build_b_keys(solution);
192
193 let mut total = Sc::zero();
194 for a in entities_a {
195 if (self.filter_a)(solution, a) && self.matches_existence(a, &b_keys) {
196 total = total + self.compute_score(a);
197 }
198 }
199 total
200 }
201
202 fn match_count(&self, solution: &S) -> usize {
203 let entities_a = (self.extractor_a)(solution);
204 let b_keys = self.build_b_keys(solution);
205
206 entities_a
207 .iter()
208 .filter(|a| (self.filter_a)(solution, a) && self.matches_existence(a, &b_keys))
209 .count()
210 }
211
212 fn initialize(&mut self, solution: &S) -> Sc {
213 self.evaluate(solution)
214 }
215
216 fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
217 let entities_a = (self.extractor_a)(solution);
218 if entity_index >= entities_a.len() {
219 return Sc::zero();
220 }
221
222 let a = &entities_a[entity_index];
223 if !(self.filter_a)(solution, a) {
224 return Sc::zero();
225 }
226
227 let b_keys = self.build_b_keys(solution);
228 if self.matches_existence(a, &b_keys) {
229 self.compute_score(a)
230 } else {
231 Sc::zero()
232 }
233 }
234
235 fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
236 let entities_a = (self.extractor_a)(solution);
237 if entity_index >= entities_a.len() {
238 return Sc::zero();
239 }
240
241 let a = &entities_a[entity_index];
242 if !(self.filter_a)(solution, a) {
243 return Sc::zero();
244 }
245
246 let b_keys = self.build_b_keys(solution);
247 if self.matches_existence(a, &b_keys) {
248 -self.compute_score(a)
249 } else {
250 Sc::zero()
251 }
252 }
253
254 fn reset(&mut self) {
255 }
257
258 fn name(&self) -> &str {
259 &self.constraint_ref.name
260 }
261
262 fn is_hard(&self) -> bool {
263 self.is_hard
264 }
265
266 fn constraint_ref(&self) -> ConstraintRef {
267 self.constraint_ref.clone()
268 }
269}
270
271impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc: Score> std::fmt::Debug
272 for IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
273{
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 f.debug_struct("IfExistsUniConstraint")
276 .field("name", &self.constraint_ref.name)
277 .field("impact_type", &self.impact_type)
278 .field("mode", &self.mode)
279 .finish()
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use solverforge_core::score::SimpleScore;
287
288 #[derive(Clone)]
289 struct Task {
290 _id: usize,
291 assignee: Option<usize>,
292 }
293
294 #[derive(Clone)]
295 struct Worker {
296 id: usize,
297 available: bool,
298 }
299
300 #[derive(Clone)]
301 struct Schedule {
302 tasks: Vec<Task>,
303 workers: Vec<Worker>,
304 }
305
306 #[test]
307 fn test_if_exists_penalizes_assigned_to_unavailable() {
308 let constraint = IfExistsUniConstraint::new(
310 ConstraintRef::new("", "Unavailable worker"),
311 ImpactType::Penalty,
312 ExistenceMode::Exists,
313 |s: &Schedule| s.tasks.as_slice(),
314 |s: &Schedule| s.workers.iter().filter(|w| !w.available).cloned().collect(),
315 |t: &Task| t.assignee,
316 |w: &Worker| Some(w.id),
317 |_s: &Schedule, t: &Task| t.assignee.is_some(),
318 |_t: &Task| SimpleScore::of(1),
319 false,
320 );
321
322 let schedule = Schedule {
323 tasks: vec![
324 Task {
325 _id: 0,
326 assignee: Some(0),
327 }, Task {
329 _id: 1,
330 assignee: Some(1),
331 }, Task {
333 _id: 2,
334 assignee: None,
335 }, ],
337 workers: vec![
338 Worker {
339 id: 0,
340 available: false,
341 },
342 Worker {
343 id: 1,
344 available: true,
345 },
346 ],
347 };
348
349 assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
351 assert_eq!(constraint.match_count(&schedule), 1);
352 }
353
354 #[test]
355 fn test_if_not_exists_penalizes_unassigned() {
356 let constraint = IfExistsUniConstraint::new(
358 ConstraintRef::new("", "No available worker"),
359 ImpactType::Penalty,
360 ExistenceMode::NotExists,
361 |s: &Schedule| s.tasks.as_slice(),
362 |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
363 |t: &Task| t.assignee,
364 |w: &Worker| Some(w.id),
365 |_s: &Schedule, t: &Task| t.assignee.is_some(),
366 |_t: &Task| SimpleScore::of(1),
367 false,
368 );
369
370 let schedule = Schedule {
371 tasks: vec![
372 Task {
373 _id: 0,
374 assignee: Some(0),
375 }, Task {
377 _id: 1,
378 assignee: Some(1),
379 }, Task {
381 _id: 2,
382 assignee: None,
383 }, ],
385 workers: vec![
386 Worker {
387 id: 0,
388 available: false,
389 },
390 Worker {
391 id: 1,
392 available: true,
393 },
394 ],
395 };
396
397 assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
399 assert_eq!(constraint.match_count(&schedule), 1);
400 }
401}