1use std::hash::Hash;
2use std::marker::PhantomData;
3
4use solverforge_core::score::Score;
5use solverforge_core::{ConstraintRef, ImpactType};
6
7use crate::constraint::exists::{IncrementalExistsConstraint, SelfFlatten};
8use crate::stream::collection_extract::{FlattenExtract, TrackedCollectionExtract};
9use crate::stream::filter::UniFilter;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ExistenceMode {
13 Exists,
14 NotExists,
15}
16
17pub struct ExistsConstraintStream<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, Sc>
18where
19 Sc: Score,
20{
21 pub(super) mode: ExistenceMode,
22 pub(super) extractor_a: EA,
23 pub(super) extractor_parent: EP,
24 pub(super) key_a: KA,
25 pub(super) key_b: KB,
26 pub(super) filter_a: FA,
27 pub(super) filter_parent: FP,
28 pub(super) flatten: Flatten,
29 pub(super) _phantom: PhantomData<(
30 fn() -> S,
31 fn() -> A,
32 fn() -> P,
33 fn() -> B,
34 fn() -> K,
35 fn() -> Sc,
36 )>,
37}
38
39impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, Sc>
40 ExistsConstraintStream<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, Sc>
41where
42 S: Send + Sync + 'static,
43 A: Clone + Send + Sync + 'static,
44 P: Clone + Send + Sync + 'static,
45 B: Clone + Send + Sync + 'static,
46 K: Eq + Hash + Clone + Send + Sync,
47 EA: TrackedCollectionExtract<S, Item = A>,
48 EP: TrackedCollectionExtract<S, Item = P>,
49 KA: Fn(&A) -> K + Send + Sync,
50 KB: Fn(&B) -> K + Send + Sync,
51 FA: UniFilter<S, A>,
52 FP: UniFilter<S, P>,
53 Flatten: FlattenExtract<P, Item = B>,
54 Sc: Score + 'static,
55{
56 pub fn new(
57 mode: ExistenceMode,
58 extractor_a: EA,
59 extractor_parent: EP,
60 keys: (KA, KB),
61 filter_a: FA,
62 filter_parent: FP,
63 flatten: Flatten,
64 ) -> Self {
65 let (key_a, key_b) = keys;
66 Self {
67 mode,
68 extractor_a,
69 extractor_parent,
70 key_a,
71 key_b,
72 filter_a,
73 filter_parent,
74 flatten,
75 _phantom: PhantomData,
76 }
77 }
78
79 pub fn penalize(
80 self,
81 weight: Sc,
82 ) -> ExistsConstraintBuilder<
83 S,
84 A,
85 P,
86 B,
87 K,
88 EA,
89 EP,
90 KA,
91 KB,
92 FA,
93 FP,
94 Flatten,
95 impl Fn(&A) -> Sc + Send + Sync,
96 Sc,
97 >
98 where
99 Sc: Copy,
100 {
101 let is_hard = weight
102 .to_level_numbers()
103 .first()
104 .map(|&h| h != 0)
105 .unwrap_or(false);
106 ExistsConstraintBuilder {
107 mode: self.mode,
108 extractor_a: self.extractor_a,
109 extractor_parent: self.extractor_parent,
110 key_a: self.key_a,
111 key_b: self.key_b,
112 filter_a: self.filter_a,
113 filter_parent: self.filter_parent,
114 flatten: self.flatten,
115 impact_type: ImpactType::Penalty,
116 weight: move |_: &A| weight,
117 is_hard,
118 _phantom: PhantomData,
119 }
120 }
121
122 pub fn penalize_with<W>(
123 self,
124 weight_fn: W,
125 ) -> ExistsConstraintBuilder<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
126 where
127 W: Fn(&A) -> Sc + Send + Sync,
128 {
129 ExistsConstraintBuilder {
130 mode: self.mode,
131 extractor_a: self.extractor_a,
132 extractor_parent: self.extractor_parent,
133 key_a: self.key_a,
134 key_b: self.key_b,
135 filter_a: self.filter_a,
136 filter_parent: self.filter_parent,
137 flatten: self.flatten,
138 impact_type: ImpactType::Penalty,
139 weight: weight_fn,
140 is_hard: false,
141 _phantom: PhantomData,
142 }
143 }
144
145 pub fn penalize_hard_with<W>(
146 self,
147 weight_fn: W,
148 ) -> ExistsConstraintBuilder<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
149 where
150 W: Fn(&A) -> Sc + Send + Sync,
151 {
152 ExistsConstraintBuilder {
153 mode: self.mode,
154 extractor_a: self.extractor_a,
155 extractor_parent: self.extractor_parent,
156 key_a: self.key_a,
157 key_b: self.key_b,
158 filter_a: self.filter_a,
159 filter_parent: self.filter_parent,
160 flatten: self.flatten,
161 impact_type: ImpactType::Penalty,
162 weight: weight_fn,
163 is_hard: true,
164 _phantom: PhantomData,
165 }
166 }
167
168 pub fn penalize_hard(
169 self,
170 ) -> ExistsConstraintBuilder<
171 S,
172 A,
173 P,
174 B,
175 K,
176 EA,
177 EP,
178 KA,
179 KB,
180 FA,
181 FP,
182 Flatten,
183 impl Fn(&A) -> Sc + Send + Sync,
184 Sc,
185 >
186 where
187 Sc: Copy,
188 {
189 self.penalize(Sc::one_hard())
190 }
191
192 pub fn penalize_soft(
193 self,
194 ) -> ExistsConstraintBuilder<
195 S,
196 A,
197 P,
198 B,
199 K,
200 EA,
201 EP,
202 KA,
203 KB,
204 FA,
205 FP,
206 Flatten,
207 impl Fn(&A) -> Sc + Send + Sync,
208 Sc,
209 >
210 where
211 Sc: Copy,
212 {
213 self.penalize(Sc::one_soft())
214 }
215
216 pub fn reward(
217 self,
218 weight: Sc,
219 ) -> ExistsConstraintBuilder<
220 S,
221 A,
222 P,
223 B,
224 K,
225 EA,
226 EP,
227 KA,
228 KB,
229 FA,
230 FP,
231 Flatten,
232 impl Fn(&A) -> Sc + Send + Sync,
233 Sc,
234 >
235 where
236 Sc: Copy,
237 {
238 let is_hard = weight
239 .to_level_numbers()
240 .first()
241 .map(|&h| h != 0)
242 .unwrap_or(false);
243 ExistsConstraintBuilder {
244 mode: self.mode,
245 extractor_a: self.extractor_a,
246 extractor_parent: self.extractor_parent,
247 key_a: self.key_a,
248 key_b: self.key_b,
249 filter_a: self.filter_a,
250 filter_parent: self.filter_parent,
251 flatten: self.flatten,
252 impact_type: ImpactType::Reward,
253 weight: move |_: &A| weight,
254 is_hard,
255 _phantom: PhantomData,
256 }
257 }
258
259 pub fn reward_with<W>(
260 self,
261 weight_fn: W,
262 ) -> ExistsConstraintBuilder<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
263 where
264 W: Fn(&A) -> Sc + Send + Sync,
265 {
266 ExistsConstraintBuilder {
267 mode: self.mode,
268 extractor_a: self.extractor_a,
269 extractor_parent: self.extractor_parent,
270 key_a: self.key_a,
271 key_b: self.key_b,
272 filter_a: self.filter_a,
273 filter_parent: self.filter_parent,
274 flatten: self.flatten,
275 impact_type: ImpactType::Reward,
276 weight: weight_fn,
277 is_hard: false,
278 _phantom: PhantomData,
279 }
280 }
281
282 pub fn reward_hard_with<W>(
283 self,
284 weight_fn: W,
285 ) -> ExistsConstraintBuilder<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
286 where
287 W: Fn(&A) -> Sc + Send + Sync,
288 {
289 ExistsConstraintBuilder {
290 mode: self.mode,
291 extractor_a: self.extractor_a,
292 extractor_parent: self.extractor_parent,
293 key_a: self.key_a,
294 key_b: self.key_b,
295 filter_a: self.filter_a,
296 filter_parent: self.filter_parent,
297 flatten: self.flatten,
298 impact_type: ImpactType::Reward,
299 weight: weight_fn,
300 is_hard: true,
301 _phantom: PhantomData,
302 }
303 }
304
305 pub fn reward_hard(
306 self,
307 ) -> ExistsConstraintBuilder<
308 S,
309 A,
310 P,
311 B,
312 K,
313 EA,
314 EP,
315 KA,
316 KB,
317 FA,
318 FP,
319 Flatten,
320 impl Fn(&A) -> Sc + Send + Sync,
321 Sc,
322 >
323 where
324 Sc: Copy,
325 {
326 self.reward(Sc::one_hard())
327 }
328
329 pub fn reward_soft(
330 self,
331 ) -> ExistsConstraintBuilder<
332 S,
333 A,
334 P,
335 B,
336 K,
337 EA,
338 EP,
339 KA,
340 KB,
341 FA,
342 FP,
343 Flatten,
344 impl Fn(&A) -> Sc + Send + Sync,
345 Sc,
346 >
347 where
348 Sc: Copy,
349 {
350 self.reward(Sc::one_soft())
351 }
352}
353
354impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, Sc: Score> std::fmt::Debug
355 for ExistsConstraintStream<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, Sc>
356{
357 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358 f.debug_struct("ExistsConstraintStream")
359 .field("mode", &self.mode)
360 .finish()
361 }
362}
363
364pub struct ExistsConstraintBuilder<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
365where
366 Sc: Score,
367{
368 mode: ExistenceMode,
369 extractor_a: EA,
370 extractor_parent: EP,
371 key_a: KA,
372 key_b: KB,
373 filter_a: FA,
374 filter_parent: FP,
375 flatten: Flatten,
376 impact_type: ImpactType,
377 weight: W,
378 is_hard: bool,
379 _phantom: PhantomData<(
380 fn() -> S,
381 fn() -> A,
382 fn() -> P,
383 fn() -> B,
384 fn() -> K,
385 fn() -> Sc,
386 )>,
387}
388
389impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
390 ExistsConstraintBuilder<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
391where
392 S: Send + Sync + 'static,
393 A: Clone + Send + Sync + 'static,
394 P: Clone + Send + Sync + 'static,
395 B: Clone + Send + Sync + 'static,
396 K: Eq + Hash + Clone + Send + Sync,
397 EA: TrackedCollectionExtract<S, Item = A>,
398 EP: TrackedCollectionExtract<S, Item = P>,
399 KA: Fn(&A) -> K + Send + Sync,
400 KB: Fn(&B) -> K + Send + Sync,
401 FA: UniFilter<S, A> + Send + Sync,
402 FP: UniFilter<S, P> + Send + Sync,
403 Flatten: FlattenExtract<P, Item = B> + Send + Sync,
404 W: Fn(&A) -> Sc + Send + Sync,
405 Sc: Score + 'static,
406{
407 pub fn named(
408 self,
409 name: &str,
410 ) -> IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc> {
411 IncrementalExistsConstraint::new(
412 ConstraintRef::new("", name),
413 self.impact_type,
414 self.mode,
415 self.extractor_a,
416 self.extractor_parent,
417 self.key_a,
418 self.key_b,
419 self.filter_a,
420 self.filter_parent,
421 self.flatten,
422 self.weight,
423 self.is_hard,
424 )
425 }
426}
427
428impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc: Score> std::fmt::Debug
429 for ExistsConstraintBuilder<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
430{
431 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432 f.debug_struct("ExistsConstraintBuilder")
433 .field("mode", &self.mode)
434 .field("impact_type", &self.impact_type)
435 .finish()
436 }
437}
438
439pub(crate) type DirectExistenceStream<S, A, B, K, EA, EP, KA, KB, FA, FP, Sc> =
440 ExistsConstraintStream<S, A, B, B, K, EA, EP, KA, KB, FA, FP, SelfFlatten, Sc>;