1use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collection_extract::CollectionExtract;
14use super::collector::UniCollector;
15use crate::constraint::complemented::ComplementedGroupConstraint;
16
17pub struct ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
86where
87 Sc: Score,
88{
89 extractor_a: EA,
90 extractor_b: EB,
91 key_a: KA,
92 key_b: KB,
93 collector: C,
94 default_fn: D,
95 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
96}
97
98impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
99 ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
100where
101 S: Send + Sync + 'static,
102 A: Clone + Send + Sync + 'static,
103 B: Clone + Send + Sync + 'static,
104 K: Clone + Eq + Hash + Send + Sync + 'static,
105 EA: CollectionExtract<S, Item = A>,
106 EB: CollectionExtract<S, Item = B>,
107 KA: Fn(&A) -> Option<K> + Send + Sync,
108 KB: Fn(&B) -> K + Send + Sync,
109 C: UniCollector<A> + Send + Sync + 'static,
110 C::Accumulator: Send + Sync,
111 C::Result: Clone + Send + Sync,
112 D: Fn(&B) -> C::Result + Send + Sync,
113 Sc: Score + 'static,
114{
115 pub(crate) fn new(
117 extractor_a: EA,
118 extractor_b: EB,
119 key_a: KA,
120 key_b: KB,
121 collector: C,
122 default_fn: D,
123 ) -> Self {
124 Self {
125 extractor_a,
126 extractor_b,
127 key_a,
128 key_b,
129 collector,
130 default_fn,
131 _phantom: PhantomData,
132 }
133 }
134
135 pub fn penalize_with<W>(
137 self,
138 weight_fn: W,
139 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
140 where
141 W: Fn(&C::Result) -> Sc + Send + Sync,
142 {
143 ComplementedConstraintBuilder {
144 extractor_a: self.extractor_a,
145 extractor_b: self.extractor_b,
146 key_a: self.key_a,
147 key_b: self.key_b,
148 collector: self.collector,
149 default_fn: self.default_fn,
150 impact_type: ImpactType::Penalty,
151 weight_fn,
152 is_hard: false,
153 _phantom: PhantomData,
154 }
155 }
156
157 pub fn penalize_hard_with<W>(
159 self,
160 weight_fn: W,
161 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
162 where
163 W: Fn(&C::Result) -> Sc + Send + Sync,
164 {
165 ComplementedConstraintBuilder {
166 extractor_a: self.extractor_a,
167 extractor_b: self.extractor_b,
168 key_a: self.key_a,
169 key_b: self.key_b,
170 collector: self.collector,
171 default_fn: self.default_fn,
172 impact_type: ImpactType::Penalty,
173 weight_fn,
174 is_hard: true,
175 _phantom: PhantomData,
176 }
177 }
178
179 pub fn reward_with<W>(
181 self,
182 weight_fn: W,
183 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
184 where
185 W: Fn(&C::Result) -> Sc + Send + Sync,
186 {
187 ComplementedConstraintBuilder {
188 extractor_a: self.extractor_a,
189 extractor_b: self.extractor_b,
190 key_a: self.key_a,
191 key_b: self.key_b,
192 collector: self.collector,
193 default_fn: self.default_fn,
194 impact_type: ImpactType::Reward,
195 weight_fn,
196 is_hard: false,
197 _phantom: PhantomData,
198 }
199 }
200
201 pub fn reward_hard_with<W>(
203 self,
204 weight_fn: W,
205 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
206 where
207 W: Fn(&C::Result) -> Sc + Send + Sync,
208 {
209 ComplementedConstraintBuilder {
210 extractor_a: self.extractor_a,
211 extractor_b: self.extractor_b,
212 key_a: self.key_a,
213 key_b: self.key_b,
214 collector: self.collector,
215 default_fn: self.default_fn,
216 impact_type: ImpactType::Reward,
217 weight_fn,
218 is_hard: true,
219 _phantom: PhantomData,
220 }
221 }
222
223 pub fn penalize_hard(
225 self,
226 ) -> ComplementedConstraintBuilder<
227 S,
228 A,
229 B,
230 K,
231 EA,
232 EB,
233 KA,
234 KB,
235 C,
236 D,
237 impl Fn(&C::Result) -> Sc + Send + Sync,
238 Sc,
239 >
240 where
241 Sc: Copy,
242 {
243 let w = Sc::one_hard();
244 self.penalize_hard_with(move |_: &C::Result| w)
245 }
246
247 pub fn penalize_soft(
249 self,
250 ) -> ComplementedConstraintBuilder<
251 S,
252 A,
253 B,
254 K,
255 EA,
256 EB,
257 KA,
258 KB,
259 C,
260 D,
261 impl Fn(&C::Result) -> Sc + Send + Sync,
262 Sc,
263 >
264 where
265 Sc: Copy,
266 {
267 let w = Sc::one_soft();
268 self.penalize_with(move |_: &C::Result| w)
269 }
270
271 pub fn reward_hard(
273 self,
274 ) -> ComplementedConstraintBuilder<
275 S,
276 A,
277 B,
278 K,
279 EA,
280 EB,
281 KA,
282 KB,
283 C,
284 D,
285 impl Fn(&C::Result) -> Sc + Send + Sync,
286 Sc,
287 >
288 where
289 Sc: Copy,
290 {
291 let w = Sc::one_hard();
292 self.reward_hard_with(move |_: &C::Result| w)
293 }
294
295 pub fn reward_soft(
297 self,
298 ) -> ComplementedConstraintBuilder<
299 S,
300 A,
301 B,
302 K,
303 EA,
304 EB,
305 KA,
306 KB,
307 C,
308 D,
309 impl Fn(&C::Result) -> Sc + Send + Sync,
310 Sc,
311 >
312 where
313 Sc: Copy,
314 {
315 let w = Sc::one_soft();
316 self.reward_with(move |_: &C::Result| w)
317 }
318}
319
320impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc: Score> std::fmt::Debug
321 for ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
322{
323 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324 f.debug_struct("ComplementedConstraintStream").finish()
325 }
326}
327
328pub struct ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
330where
331 Sc: Score,
332{
333 extractor_a: EA,
334 extractor_b: EB,
335 key_a: KA,
336 key_b: KB,
337 collector: C,
338 default_fn: D,
339 impact_type: ImpactType,
340 weight_fn: W,
341 is_hard: bool,
342 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
343}
344
345impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
346 ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
347where
348 S: Send + Sync + 'static,
349 A: Clone + Send + Sync + 'static,
350 B: Clone + Send + Sync + 'static,
351 K: Clone + Eq + Hash + Send + Sync + 'static,
352 EA: CollectionExtract<S, Item = A>,
353 EB: CollectionExtract<S, Item = B>,
354 KA: Fn(&A) -> Option<K> + Send + Sync,
355 KB: Fn(&B) -> K + Send + Sync,
356 C: UniCollector<A> + Send + Sync + 'static,
357 C::Accumulator: Send + Sync,
358 C::Result: Clone + Send + Sync,
359 D: Fn(&B) -> C::Result + Send + Sync,
360 W: Fn(&C::Result) -> Sc + Send + Sync,
361 Sc: Score + 'static,
362{
363 pub fn named(
364 self,
365 name: &str,
366 ) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> {
367 ComplementedGroupConstraint::new(
368 ConstraintRef::new("", name),
369 self.impact_type,
370 self.extractor_a,
371 self.extractor_b,
372 self.key_a,
373 self.key_b,
374 self.collector,
375 self.default_fn,
376 self.weight_fn,
377 self.is_hard,
378 )
379 }
380
381 }
383
384impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc: Score> std::fmt::Debug
385 for ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
386{
387 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388 f.debug_struct("ComplementedConstraintBuilder")
389 .field("impact_type", &self.impact_type)
390 .finish()
391 }
392}