Skip to main content

solverforge_scoring/stream/
existence_stream.rs

1use std::hash::Hash;
2use std::marker::PhantomData;
3
4use solverforge_core::score::Score;
5use solverforge_core::{ConstraintRef, ImpactType};
6
7use crate::constraint::exists::{IncrementalExistsConstraint, SelfFlatten};
8use crate::stream::collection_extract::{FlattenExtract, TrackedCollectionExtract};
9use crate::stream::filter::UniFilter;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ExistenceMode {
13    Exists,
14    NotExists,
15}
16
17pub struct ExistsConstraintStream<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, Sc>
18where
19    Sc: Score,
20{
21    pub(super) mode: ExistenceMode,
22    pub(super) extractor_a: EA,
23    pub(super) extractor_parent: EP,
24    pub(super) key_a: KA,
25    pub(super) key_b: KB,
26    pub(super) filter_a: FA,
27    pub(super) filter_parent: FP,
28    pub(super) flatten: Flatten,
29    pub(super) _phantom: PhantomData<(
30        fn() -> S,
31        fn() -> A,
32        fn() -> P,
33        fn() -> B,
34        fn() -> K,
35        fn() -> Sc,
36    )>,
37}
38
39impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, Sc>
40    ExistsConstraintStream<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, Sc>
41where
42    S: Send + Sync + 'static,
43    A: Clone + Send + Sync + 'static,
44    P: Clone + Send + Sync + 'static,
45    B: Clone + Send + Sync + 'static,
46    K: Eq + Hash + Clone + Send + Sync,
47    EA: TrackedCollectionExtract<S, Item = A>,
48    EP: TrackedCollectionExtract<S, Item = P>,
49    KA: Fn(&A) -> K + Send + Sync,
50    KB: Fn(&B) -> K + Send + Sync,
51    FA: UniFilter<S, A>,
52    FP: UniFilter<S, P>,
53    Flatten: FlattenExtract<P, Item = B>,
54    Sc: Score + 'static,
55{
56    pub fn new(
57        mode: ExistenceMode,
58        extractor_a: EA,
59        extractor_parent: EP,
60        keys: (KA, KB),
61        filter_a: FA,
62        filter_parent: FP,
63        flatten: Flatten,
64    ) -> Self {
65        let (key_a, key_b) = keys;
66        Self {
67            mode,
68            extractor_a,
69            extractor_parent,
70            key_a,
71            key_b,
72            filter_a,
73            filter_parent,
74            flatten,
75            _phantom: PhantomData,
76        }
77    }
78
79    pub fn penalize(
80        self,
81        weight: Sc,
82    ) -> ExistsConstraintBuilder<
83        S,
84        A,
85        P,
86        B,
87        K,
88        EA,
89        EP,
90        KA,
91        KB,
92        FA,
93        FP,
94        Flatten,
95        impl Fn(&A) -> Sc + Send + Sync,
96        Sc,
97    >
98    where
99        Sc: Copy,
100    {
101        let is_hard = weight
102            .to_level_numbers()
103            .first()
104            .map(|&h| h != 0)
105            .unwrap_or(false);
106        ExistsConstraintBuilder {
107            mode: self.mode,
108            extractor_a: self.extractor_a,
109            extractor_parent: self.extractor_parent,
110            key_a: self.key_a,
111            key_b: self.key_b,
112            filter_a: self.filter_a,
113            filter_parent: self.filter_parent,
114            flatten: self.flatten,
115            impact_type: ImpactType::Penalty,
116            weight: move |_: &A| weight,
117            is_hard,
118            _phantom: PhantomData,
119        }
120    }
121
122    pub fn penalize_with<W>(
123        self,
124        weight_fn: W,
125    ) -> ExistsConstraintBuilder<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
126    where
127        W: Fn(&A) -> Sc + Send + Sync,
128    {
129        ExistsConstraintBuilder {
130            mode: self.mode,
131            extractor_a: self.extractor_a,
132            extractor_parent: self.extractor_parent,
133            key_a: self.key_a,
134            key_b: self.key_b,
135            filter_a: self.filter_a,
136            filter_parent: self.filter_parent,
137            flatten: self.flatten,
138            impact_type: ImpactType::Penalty,
139            weight: weight_fn,
140            is_hard: false,
141            _phantom: PhantomData,
142        }
143    }
144
145    pub fn penalize_hard_with<W>(
146        self,
147        weight_fn: W,
148    ) -> ExistsConstraintBuilder<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
149    where
150        W: Fn(&A) -> Sc + Send + Sync,
151    {
152        ExistsConstraintBuilder {
153            mode: self.mode,
154            extractor_a: self.extractor_a,
155            extractor_parent: self.extractor_parent,
156            key_a: self.key_a,
157            key_b: self.key_b,
158            filter_a: self.filter_a,
159            filter_parent: self.filter_parent,
160            flatten: self.flatten,
161            impact_type: ImpactType::Penalty,
162            weight: weight_fn,
163            is_hard: true,
164            _phantom: PhantomData,
165        }
166    }
167
168    pub fn penalize_hard(
169        self,
170    ) -> ExistsConstraintBuilder<
171        S,
172        A,
173        P,
174        B,
175        K,
176        EA,
177        EP,
178        KA,
179        KB,
180        FA,
181        FP,
182        Flatten,
183        impl Fn(&A) -> Sc + Send + Sync,
184        Sc,
185    >
186    where
187        Sc: Copy,
188    {
189        self.penalize(Sc::one_hard())
190    }
191
192    pub fn penalize_soft(
193        self,
194    ) -> ExistsConstraintBuilder<
195        S,
196        A,
197        P,
198        B,
199        K,
200        EA,
201        EP,
202        KA,
203        KB,
204        FA,
205        FP,
206        Flatten,
207        impl Fn(&A) -> Sc + Send + Sync,
208        Sc,
209    >
210    where
211        Sc: Copy,
212    {
213        self.penalize(Sc::one_soft())
214    }
215
216    pub fn reward(
217        self,
218        weight: Sc,
219    ) -> ExistsConstraintBuilder<
220        S,
221        A,
222        P,
223        B,
224        K,
225        EA,
226        EP,
227        KA,
228        KB,
229        FA,
230        FP,
231        Flatten,
232        impl Fn(&A) -> Sc + Send + Sync,
233        Sc,
234    >
235    where
236        Sc: Copy,
237    {
238        let is_hard = weight
239            .to_level_numbers()
240            .first()
241            .map(|&h| h != 0)
242            .unwrap_or(false);
243        ExistsConstraintBuilder {
244            mode: self.mode,
245            extractor_a: self.extractor_a,
246            extractor_parent: self.extractor_parent,
247            key_a: self.key_a,
248            key_b: self.key_b,
249            filter_a: self.filter_a,
250            filter_parent: self.filter_parent,
251            flatten: self.flatten,
252            impact_type: ImpactType::Reward,
253            weight: move |_: &A| weight,
254            is_hard,
255            _phantom: PhantomData,
256        }
257    }
258
259    pub fn reward_with<W>(
260        self,
261        weight_fn: W,
262    ) -> ExistsConstraintBuilder<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
263    where
264        W: Fn(&A) -> Sc + Send + Sync,
265    {
266        ExistsConstraintBuilder {
267            mode: self.mode,
268            extractor_a: self.extractor_a,
269            extractor_parent: self.extractor_parent,
270            key_a: self.key_a,
271            key_b: self.key_b,
272            filter_a: self.filter_a,
273            filter_parent: self.filter_parent,
274            flatten: self.flatten,
275            impact_type: ImpactType::Reward,
276            weight: weight_fn,
277            is_hard: false,
278            _phantom: PhantomData,
279        }
280    }
281
282    pub fn reward_hard_with<W>(
283        self,
284        weight_fn: W,
285    ) -> ExistsConstraintBuilder<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
286    where
287        W: Fn(&A) -> Sc + Send + Sync,
288    {
289        ExistsConstraintBuilder {
290            mode: self.mode,
291            extractor_a: self.extractor_a,
292            extractor_parent: self.extractor_parent,
293            key_a: self.key_a,
294            key_b: self.key_b,
295            filter_a: self.filter_a,
296            filter_parent: self.filter_parent,
297            flatten: self.flatten,
298            impact_type: ImpactType::Reward,
299            weight: weight_fn,
300            is_hard: true,
301            _phantom: PhantomData,
302        }
303    }
304
305    pub fn reward_hard(
306        self,
307    ) -> ExistsConstraintBuilder<
308        S,
309        A,
310        P,
311        B,
312        K,
313        EA,
314        EP,
315        KA,
316        KB,
317        FA,
318        FP,
319        Flatten,
320        impl Fn(&A) -> Sc + Send + Sync,
321        Sc,
322    >
323    where
324        Sc: Copy,
325    {
326        self.reward(Sc::one_hard())
327    }
328
329    pub fn reward_soft(
330        self,
331    ) -> ExistsConstraintBuilder<
332        S,
333        A,
334        P,
335        B,
336        K,
337        EA,
338        EP,
339        KA,
340        KB,
341        FA,
342        FP,
343        Flatten,
344        impl Fn(&A) -> Sc + Send + Sync,
345        Sc,
346    >
347    where
348        Sc: Copy,
349    {
350        self.reward(Sc::one_soft())
351    }
352}
353
354impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, Sc: Score> std::fmt::Debug
355    for ExistsConstraintStream<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, Sc>
356{
357    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358        f.debug_struct("ExistsConstraintStream")
359            .field("mode", &self.mode)
360            .finish()
361    }
362}
363
364pub struct ExistsConstraintBuilder<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
365where
366    Sc: Score,
367{
368    mode: ExistenceMode,
369    extractor_a: EA,
370    extractor_parent: EP,
371    key_a: KA,
372    key_b: KB,
373    filter_a: FA,
374    filter_parent: FP,
375    flatten: Flatten,
376    impact_type: ImpactType,
377    weight: W,
378    is_hard: bool,
379    _phantom: PhantomData<(
380        fn() -> S,
381        fn() -> A,
382        fn() -> P,
383        fn() -> B,
384        fn() -> K,
385        fn() -> Sc,
386    )>,
387}
388
389impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
390    ExistsConstraintBuilder<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
391where
392    S: Send + Sync + 'static,
393    A: Clone + Send + Sync + 'static,
394    P: Clone + Send + Sync + 'static,
395    B: Clone + Send + Sync + 'static,
396    K: Eq + Hash + Clone + Send + Sync,
397    EA: TrackedCollectionExtract<S, Item = A>,
398    EP: TrackedCollectionExtract<S, Item = P>,
399    KA: Fn(&A) -> K + Send + Sync,
400    KB: Fn(&B) -> K + Send + Sync,
401    FA: UniFilter<S, A> + Send + Sync,
402    FP: UniFilter<S, P> + Send + Sync,
403    Flatten: FlattenExtract<P, Item = B> + Send + Sync,
404    W: Fn(&A) -> Sc + Send + Sync,
405    Sc: Score + 'static,
406{
407    pub fn named(
408        self,
409        name: &str,
410    ) -> IncrementalExistsConstraint<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc> {
411        IncrementalExistsConstraint::new(
412            ConstraintRef::new("", name),
413            self.impact_type,
414            self.mode,
415            self.extractor_a,
416            self.extractor_parent,
417            self.key_a,
418            self.key_b,
419            self.filter_a,
420            self.filter_parent,
421            self.flatten,
422            self.weight,
423            self.is_hard,
424        )
425    }
426}
427
428impl<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc: Score> std::fmt::Debug
429    for ExistsConstraintBuilder<S, A, P, B, K, EA, EP, KA, KB, FA, FP, Flatten, W, Sc>
430{
431    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432        f.debug_struct("ExistsConstraintBuilder")
433            .field("mode", &self.mode)
434            .field("impact_type", &self.impact_type)
435            .finish()
436    }
437}
438
439pub(crate) type DirectExistenceStream<S, A, B, K, EA, EP, KA, KB, FA, FP, Sc> =
440    ExistsConstraintStream<S, A, B, B, K, EA, EP, KA, KB, FA, FP, SelfFlatten, Sc>;