1use std::collections::HashSet;
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::api::constraint_set::IncrementalConstraint;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ExistenceMode {
18 Exists,
20 NotExists,
22}
23
24pub struct IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
90where
91 Sc: Score,
92{
93 constraint_ref: ConstraintRef,
94 impact_type: ImpactType,
95 mode: ExistenceMode,
96 extractor_a: EA,
97 extractor_b: EB,
98 key_a: KA,
99 key_b: KB,
100 filter_a: FA,
101 weight: W,
102 is_hard: bool,
103 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
104}
105
106impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
107 IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
108where
109 S: 'static,
110 A: Clone + 'static,
111 B: Clone + 'static,
112 K: Eq + Hash + Clone,
113 EA: Fn(&S) -> &[A],
114 EB: Fn(&S) -> Vec<B>,
115 KA: Fn(&A) -> K,
116 KB: Fn(&B) -> K,
117 FA: Fn(&S, &A) -> bool,
118 W: Fn(&A) -> Sc,
119 Sc: Score,
120{
121 #[allow(clippy::too_many_arguments)]
123 pub fn new(
124 constraint_ref: ConstraintRef,
125 impact_type: ImpactType,
126 mode: ExistenceMode,
127 extractor_a: EA,
128 extractor_b: EB,
129 key_a: KA,
130 key_b: KB,
131 filter_a: FA,
132 weight: W,
133 is_hard: bool,
134 ) -> Self {
135 Self {
136 constraint_ref,
137 impact_type,
138 mode,
139 extractor_a,
140 extractor_b,
141 key_a,
142 key_b,
143 filter_a,
144 weight,
145 is_hard,
146 _phantom: PhantomData,
147 }
148 }
149
150 #[inline]
151 fn compute_score(&self, a: &A) -> Sc {
152 let base = (self.weight)(a);
153 match self.impact_type {
154 ImpactType::Penalty => -base,
155 ImpactType::Reward => base,
156 }
157 }
158
159 fn build_b_keys(&self, solution: &S) -> HashSet<K> {
160 let entities_b = (self.extractor_b)(solution);
161 entities_b.iter().map(|b| (self.key_b)(b)).collect()
162 }
163
164 fn matches_existence(&self, a: &A, b_keys: &HashSet<K>) -> bool {
165 let key = (self.key_a)(a);
166 let exists = b_keys.contains(&key);
167 match self.mode {
168 ExistenceMode::Exists => exists,
169 ExistenceMode::NotExists => !exists,
170 }
171 }
172}
173
174impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc> IncrementalConstraint<S, Sc>
175 for IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
176where
177 S: Send + Sync + 'static,
178 A: Clone + Send + Sync + 'static,
179 B: Clone + Send + Sync + 'static,
180 K: Eq + Hash + Clone + Send + Sync,
181 EA: Fn(&S) -> &[A] + Send + Sync,
182 EB: Fn(&S) -> Vec<B> + Send + Sync,
183 KA: Fn(&A) -> K + Send + Sync,
184 KB: Fn(&B) -> K + Send + Sync,
185 FA: Fn(&S, &A) -> bool + Send + Sync,
186 W: Fn(&A) -> Sc + Send + Sync,
187 Sc: Score,
188{
189 fn evaluate(&self, solution: &S) -> Sc {
190 let entities_a = (self.extractor_a)(solution);
191 let b_keys = self.build_b_keys(solution);
192
193 let mut total = Sc::zero();
194 for a in entities_a {
195 if (self.filter_a)(solution, a) && self.matches_existence(a, &b_keys) {
196 total = total + self.compute_score(a);
197 }
198 }
199 total
200 }
201
202 fn match_count(&self, solution: &S) -> usize {
203 let entities_a = (self.extractor_a)(solution);
204 let b_keys = self.build_b_keys(solution);
205
206 entities_a
207 .iter()
208 .filter(|a| (self.filter_a)(solution, a) && self.matches_existence(a, &b_keys))
209 .count()
210 }
211
212 fn initialize(&mut self, solution: &S) -> Sc {
213 self.evaluate(solution)
214 }
215
216 fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
217 let entities_a = (self.extractor_a)(solution);
218 if entity_index >= entities_a.len() {
219 return Sc::zero();
220 }
221
222 let a = &entities_a[entity_index];
223 if !(self.filter_a)(solution, a) {
224 return Sc::zero();
225 }
226
227 let b_keys = self.build_b_keys(solution);
228 if self.matches_existence(a, &b_keys) {
229 self.compute_score(a)
230 } else {
231 Sc::zero()
232 }
233 }
234
235 fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
236 let entities_a = (self.extractor_a)(solution);
237 if entity_index >= entities_a.len() {
238 return Sc::zero();
239 }
240
241 let a = &entities_a[entity_index];
242 if !(self.filter_a)(solution, a) {
243 return Sc::zero();
244 }
245
246 let b_keys = self.build_b_keys(solution);
247 if self.matches_existence(a, &b_keys) {
248 -self.compute_score(a)
249 } else {
250 Sc::zero()
251 }
252 }
253
254 fn reset(&mut self) {
255 }
257
258 fn name(&self) -> &str {
259 &self.constraint_ref.name
260 }
261
262 fn is_hard(&self) -> bool {
263 self.is_hard
264 }
265
266 fn constraint_ref(&self) -> ConstraintRef {
267 self.constraint_ref.clone()
268 }
269}
270
271impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc: Score> std::fmt::Debug
272 for IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
273{
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 f.debug_struct("IfExistsUniConstraint")
276 .field("name", &self.constraint_ref.name)
277 .field("impact_type", &self.impact_type)
278 .field("mode", &self.mode)
279 .finish()
280 }
281}