1use std::collections::HashSet;
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::api::constraint_set::IncrementalConstraint;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ExistenceMode {
19 Exists,
21 NotExists,
23}
24
25pub struct IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
92where
93 Sc: Score,
94{
95 constraint_ref: ConstraintRef,
96 impact_type: ImpactType,
97 mode: ExistenceMode,
98 extractor_a: EA,
99 extractor_b: EB,
100 key_a: KA,
101 key_b: KB,
102 filter_a: FA,
103 weight: W,
104 is_hard: bool,
105 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
106}
107
108impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
109 IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
110where
111 S: 'static,
112 A: Clone + 'static,
113 B: Clone + 'static,
114 K: Eq + Hash + Clone,
115 EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
116 EB: Fn(&S) -> Vec<B>,
117 KA: Fn(&A) -> K,
118 KB: Fn(&B) -> K,
119 FA: Fn(&S, &A) -> bool,
120 W: Fn(&A) -> Sc,
121 Sc: Score,
122{
123 #[allow(clippy::too_many_arguments)]
125 pub fn new(
126 constraint_ref: ConstraintRef,
127 impact_type: ImpactType,
128 mode: ExistenceMode,
129 extractor_a: EA,
130 extractor_b: EB,
131 key_a: KA,
132 key_b: KB,
133 filter_a: FA,
134 weight: W,
135 is_hard: bool,
136 ) -> Self {
137 Self {
138 constraint_ref,
139 impact_type,
140 mode,
141 extractor_a,
142 extractor_b,
143 key_a,
144 key_b,
145 filter_a,
146 weight,
147 is_hard,
148 _phantom: PhantomData,
149 }
150 }
151
152 #[inline]
153 fn compute_score(&self, a: &A) -> Sc {
154 let base = (self.weight)(a);
155 match self.impact_type {
156 ImpactType::Penalty => -base,
157 ImpactType::Reward => base,
158 }
159 }
160
161 fn build_b_keys(&self, solution: &S) -> HashSet<K> {
162 let entities_b = (self.extractor_b)(solution);
163 entities_b.iter().map(|b| (self.key_b)(b)).collect()
164 }
165
166 fn matches_existence(&self, a: &A, b_keys: &HashSet<K>) -> bool {
167 let key = (self.key_a)(a);
168 let exists = b_keys.contains(&key);
169 match self.mode {
170 ExistenceMode::Exists => exists,
171 ExistenceMode::NotExists => !exists,
172 }
173 }
174}
175
176impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc> IncrementalConstraint<S, Sc>
177 for IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
178where
179 S: Send + Sync + 'static,
180 A: Clone + Send + Sync + 'static,
181 B: Clone + Send + Sync + 'static,
182 K: Eq + Hash + Clone + Send + Sync,
183 EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
184 EB: Fn(&S) -> Vec<B> + Send + Sync,
185 KA: Fn(&A) -> K + Send + Sync,
186 KB: Fn(&B) -> K + Send + Sync,
187 FA: Fn(&S, &A) -> bool + Send + Sync,
188 W: Fn(&A) -> Sc + Send + Sync,
189 Sc: Score,
190{
191 fn evaluate(&self, solution: &S) -> Sc {
192 let entities_a = self.extractor_a.extract(solution);
193 let b_keys = self.build_b_keys(solution);
194
195 let mut total = Sc::zero();
196 for a in entities_a {
197 if (self.filter_a)(solution, a) && self.matches_existence(a, &b_keys) {
198 total = total + self.compute_score(a);
199 }
200 }
201 total
202 }
203
204 fn match_count(&self, solution: &S) -> usize {
205 let entities_a = self.extractor_a.extract(solution);
206 let b_keys = self.build_b_keys(solution);
207
208 entities_a
209 .iter()
210 .filter(|a| (self.filter_a)(solution, a) && self.matches_existence(a, &b_keys))
211 .count()
212 }
213
214 fn initialize(&mut self, solution: &S) -> Sc {
215 self.evaluate(solution)
216 }
217
218 fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
219 let entities_a = self.extractor_a.extract(solution);
220 if entity_index >= entities_a.len() {
221 return Sc::zero();
222 }
223
224 let a = &entities_a[entity_index];
225 if !(self.filter_a)(solution, a) {
226 return Sc::zero();
227 }
228
229 let b_keys = self.build_b_keys(solution);
230 if self.matches_existence(a, &b_keys) {
231 self.compute_score(a)
232 } else {
233 Sc::zero()
234 }
235 }
236
237 fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
238 let entities_a = self.extractor_a.extract(solution);
239 if entity_index >= entities_a.len() {
240 return Sc::zero();
241 }
242
243 let a = &entities_a[entity_index];
244 if !(self.filter_a)(solution, a) {
245 return Sc::zero();
246 }
247
248 let b_keys = self.build_b_keys(solution);
249 if self.matches_existence(a, &b_keys) {
250 -self.compute_score(a)
251 } else {
252 Sc::zero()
253 }
254 }
255
256 fn reset(&mut self) {
257 }
259
260 fn name(&self) -> &str {
261 &self.constraint_ref.name
262 }
263
264 fn is_hard(&self) -> bool {
265 self.is_hard
266 }
267
268 fn constraint_ref(&self) -> ConstraintRef {
269 self.constraint_ref.clone()
270 }
271}
272
273impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc: Score> std::fmt::Debug
274 for IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
275{
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 f.debug_struct("IfExistsUniConstraint")
278 .field("name", &self.constraint_ref.name)
279 .field("impact_type", &self.impact_type)
280 .field("mode", &self.mode)
281 .finish()
282 }
283}