1use std::hash::Hash;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use super::collector::UniCollector;
13use crate::constraint::complemented::ComplementedGroupConstraint;
14
15pub struct ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
83where
84 Sc: Score,
85{
86 extractor_a: EA,
87 extractor_b: EB,
88 key_a: KA,
89 key_b: KB,
90 collector: C,
91 default_fn: D,
92 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
93}
94
95impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
96 ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
97where
98 S: Send + Sync + 'static,
99 A: Clone + Send + Sync + 'static,
100 B: Clone + Send + Sync + 'static,
101 K: Clone + Eq + Hash + Send + Sync + 'static,
102 EA: Fn(&S) -> &[A] + Send + Sync,
103 EB: Fn(&S) -> &[B] + Send + Sync,
104 KA: Fn(&A) -> Option<K> + Send + Sync,
105 KB: Fn(&B) -> K + Send + Sync,
106 C: UniCollector<A> + Send + Sync + 'static,
107 C::Accumulator: Send + Sync,
108 C::Result: Clone + Send + Sync,
109 D: Fn(&B) -> C::Result + Send + Sync,
110 Sc: Score + 'static,
111{
112 pub(crate) fn new(
114 extractor_a: EA,
115 extractor_b: EB,
116 key_a: KA,
117 key_b: KB,
118 collector: C,
119 default_fn: D,
120 ) -> Self {
121 Self {
122 extractor_a,
123 extractor_b,
124 key_a,
125 key_b,
126 collector,
127 default_fn,
128 _phantom: PhantomData,
129 }
130 }
131
132 pub fn penalize_with<W>(
134 self,
135 weight_fn: W,
136 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
137 where
138 W: Fn(&C::Result) -> Sc + Send + Sync,
139 {
140 ComplementedConstraintBuilder {
141 extractor_a: self.extractor_a,
142 extractor_b: self.extractor_b,
143 key_a: self.key_a,
144 key_b: self.key_b,
145 collector: self.collector,
146 default_fn: self.default_fn,
147 impact_type: ImpactType::Penalty,
148 weight_fn,
149 is_hard: false,
150 _phantom: PhantomData,
151 }
152 }
153
154 pub fn penalize_hard_with<W>(
156 self,
157 weight_fn: W,
158 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
159 where
160 W: Fn(&C::Result) -> Sc + Send + Sync,
161 {
162 ComplementedConstraintBuilder {
163 extractor_a: self.extractor_a,
164 extractor_b: self.extractor_b,
165 key_a: self.key_a,
166 key_b: self.key_b,
167 collector: self.collector,
168 default_fn: self.default_fn,
169 impact_type: ImpactType::Penalty,
170 weight_fn,
171 is_hard: true,
172 _phantom: PhantomData,
173 }
174 }
175
176 pub fn reward_with<W>(
178 self,
179 weight_fn: W,
180 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
181 where
182 W: Fn(&C::Result) -> Sc + Send + Sync,
183 {
184 ComplementedConstraintBuilder {
185 extractor_a: self.extractor_a,
186 extractor_b: self.extractor_b,
187 key_a: self.key_a,
188 key_b: self.key_b,
189 collector: self.collector,
190 default_fn: self.default_fn,
191 impact_type: ImpactType::Reward,
192 weight_fn,
193 is_hard: false,
194 _phantom: PhantomData,
195 }
196 }
197
198 pub fn reward_hard_with<W>(
200 self,
201 weight_fn: W,
202 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
203 where
204 W: Fn(&C::Result) -> Sc + Send + Sync,
205 {
206 ComplementedConstraintBuilder {
207 extractor_a: self.extractor_a,
208 extractor_b: self.extractor_b,
209 key_a: self.key_a,
210 key_b: self.key_b,
211 collector: self.collector,
212 default_fn: self.default_fn,
213 impact_type: ImpactType::Reward,
214 weight_fn,
215 is_hard: true,
216 _phantom: PhantomData,
217 }
218 }
219
220 pub fn penalize_hard(
222 self,
223 ) -> ComplementedConstraintBuilder<
224 S,
225 A,
226 B,
227 K,
228 EA,
229 EB,
230 KA,
231 KB,
232 C,
233 D,
234 impl Fn(&C::Result) -> Sc + Send + Sync,
235 Sc,
236 >
237 where
238 Sc: Copy,
239 {
240 let w = Sc::one_hard();
241 self.penalize_hard_with(move |_: &C::Result| w)
242 }
243
244 pub fn penalize_soft(
246 self,
247 ) -> ComplementedConstraintBuilder<
248 S,
249 A,
250 B,
251 K,
252 EA,
253 EB,
254 KA,
255 KB,
256 C,
257 D,
258 impl Fn(&C::Result) -> Sc + Send + Sync,
259 Sc,
260 >
261 where
262 Sc: Copy,
263 {
264 let w = Sc::one_soft();
265 self.penalize_with(move |_: &C::Result| w)
266 }
267
268 pub fn reward_hard(
270 self,
271 ) -> ComplementedConstraintBuilder<
272 S,
273 A,
274 B,
275 K,
276 EA,
277 EB,
278 KA,
279 KB,
280 C,
281 D,
282 impl Fn(&C::Result) -> Sc + Send + Sync,
283 Sc,
284 >
285 where
286 Sc: Copy,
287 {
288 let w = Sc::one_hard();
289 self.reward_hard_with(move |_: &C::Result| w)
290 }
291
292 pub fn reward_soft(
294 self,
295 ) -> ComplementedConstraintBuilder<
296 S,
297 A,
298 B,
299 K,
300 EA,
301 EB,
302 KA,
303 KB,
304 C,
305 D,
306 impl Fn(&C::Result) -> Sc + Send + Sync,
307 Sc,
308 >
309 where
310 Sc: Copy,
311 {
312 let w = Sc::one_soft();
313 self.reward_with(move |_: &C::Result| w)
314 }
315}
316
317impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc: Score> std::fmt::Debug
318 for ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
319{
320 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321 f.debug_struct("ComplementedConstraintStream").finish()
322 }
323}
324
325pub struct ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
327where
328 Sc: Score,
329{
330 extractor_a: EA,
331 extractor_b: EB,
332 key_a: KA,
333 key_b: KB,
334 collector: C,
335 default_fn: D,
336 impact_type: ImpactType,
337 weight_fn: W,
338 is_hard: bool,
339 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
340}
341
342impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
343 ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
344where
345 S: Send + Sync + 'static,
346 A: Clone + Send + Sync + 'static,
347 B: Clone + Send + Sync + 'static,
348 K: Clone + Eq + Hash + Send + Sync + 'static,
349 EA: Fn(&S) -> &[A] + Send + Sync,
350 EB: Fn(&S) -> &[B] + Send + Sync,
351 KA: Fn(&A) -> Option<K> + Send + Sync,
352 KB: Fn(&B) -> K + Send + Sync,
353 C: UniCollector<A> + Send + Sync + 'static,
354 C::Accumulator: Send + Sync,
355 C::Result: Clone + Send + Sync,
356 D: Fn(&B) -> C::Result + Send + Sync,
357 W: Fn(&C::Result) -> Sc + Send + Sync,
358 Sc: Score + 'static,
359{
360 pub fn named(
361 self,
362 name: &str,
363 ) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> {
364 ComplementedGroupConstraint::new(
365 ConstraintRef::new("", name),
366 self.impact_type,
367 self.extractor_a,
368 self.extractor_b,
369 self.key_a,
370 self.key_b,
371 self.collector,
372 self.default_fn,
373 self.weight_fn,
374 self.is_hard,
375 )
376 }
377
378 }
380
381impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc: Score> std::fmt::Debug
382 for ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
383{
384 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385 f.debug_struct("ComplementedConstraintBuilder")
386 .field("impact_type", &self.impact_type)
387 .finish()
388 }
389}