1use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collection_extract::CollectionExtract;
14use super::collector::UniCollector;
15use crate::constraint::complemented::ComplementedGroupConstraint;
16
17pub struct ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
86where
87 Sc: Score,
88{
89 extractor_a: EA,
90 extractor_b: EB,
91 key_a: KA,
92 key_b: KB,
93 collector: C,
94 default_fn: D,
95 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
96}
97
98impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
99 ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
100where
101 S: Send + Sync + 'static,
102 A: Clone + Send + Sync + 'static,
103 B: Clone + Send + Sync + 'static,
104 K: Clone + Eq + Hash + Send + Sync + 'static,
105 EA: CollectionExtract<S, Item = A>,
106 EB: CollectionExtract<S, Item = B>,
107 KA: Fn(&A) -> Option<K> + Send + Sync,
108 KB: Fn(&B) -> K + Send + Sync,
109 C: UniCollector<A> + Send + Sync + 'static,
110 C::Accumulator: Send + Sync,
111 C::Result: Clone + Send + Sync,
112 D: Fn(&B) -> C::Result + Send + Sync,
113 Sc: Score + 'static,
114{
115 fn into_weighted_builder<W>(
116 self,
117 impact_type: ImpactType,
118 weight_fn: W,
119 is_hard: bool,
120 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
121 where
122 W: Fn(&C::Result) -> Sc + Send + Sync,
123 {
124 ComplementedConstraintBuilder {
125 extractor_a: self.extractor_a,
126 extractor_b: self.extractor_b,
127 key_a: self.key_a,
128 key_b: self.key_b,
129 collector: self.collector,
130 default_fn: self.default_fn,
131 impact_type,
132 weight_fn,
133 is_hard,
134 _phantom: PhantomData,
135 }
136 }
137
138 pub(crate) fn new(
140 extractor_a: EA,
141 extractor_b: EB,
142 key_a: KA,
143 key_b: KB,
144 collector: C,
145 default_fn: D,
146 ) -> Self {
147 Self {
148 extractor_a,
149 extractor_b,
150 key_a,
151 key_b,
152 collector,
153 default_fn,
154 _phantom: PhantomData,
155 }
156 }
157
158 pub fn penalize_with<W>(
160 self,
161 weight_fn: W,
162 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
163 where
164 W: Fn(&C::Result) -> Sc + Send + Sync,
165 {
166 self.into_weighted_builder(ImpactType::Penalty, weight_fn, false)
167 }
168
169 pub fn penalize_hard_with<W>(
171 self,
172 weight_fn: W,
173 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
174 where
175 W: Fn(&C::Result) -> Sc + Send + Sync,
176 {
177 self.into_weighted_builder(ImpactType::Penalty, weight_fn, true)
178 }
179
180 pub fn reward_with<W>(
182 self,
183 weight_fn: W,
184 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
185 where
186 W: Fn(&C::Result) -> Sc + Send + Sync,
187 {
188 self.into_weighted_builder(ImpactType::Reward, weight_fn, false)
189 }
190
191 pub fn reward_hard_with<W>(
193 self,
194 weight_fn: W,
195 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
196 where
197 W: Fn(&C::Result) -> Sc + Send + Sync,
198 {
199 self.into_weighted_builder(ImpactType::Reward, weight_fn, true)
200 }
201
202 pub fn penalize_hard(
204 self,
205 ) -> ComplementedConstraintBuilder<
206 S,
207 A,
208 B,
209 K,
210 EA,
211 EB,
212 KA,
213 KB,
214 C,
215 D,
216 impl Fn(&C::Result) -> Sc + Send + Sync,
217 Sc,
218 >
219 where
220 Sc: Copy,
221 {
222 let w = Sc::one_hard();
223 self.penalize_hard_with(move |_: &C::Result| w)
224 }
225
226 pub fn penalize_soft(
228 self,
229 ) -> ComplementedConstraintBuilder<
230 S,
231 A,
232 B,
233 K,
234 EA,
235 EB,
236 KA,
237 KB,
238 C,
239 D,
240 impl Fn(&C::Result) -> Sc + Send + Sync,
241 Sc,
242 >
243 where
244 Sc: Copy,
245 {
246 let w = Sc::one_soft();
247 self.penalize_with(move |_: &C::Result| w)
248 }
249
250 pub fn reward_hard(
252 self,
253 ) -> ComplementedConstraintBuilder<
254 S,
255 A,
256 B,
257 K,
258 EA,
259 EB,
260 KA,
261 KB,
262 C,
263 D,
264 impl Fn(&C::Result) -> Sc + Send + Sync,
265 Sc,
266 >
267 where
268 Sc: Copy,
269 {
270 let w = Sc::one_hard();
271 self.reward_hard_with(move |_: &C::Result| w)
272 }
273
274 pub fn reward_soft(
276 self,
277 ) -> ComplementedConstraintBuilder<
278 S,
279 A,
280 B,
281 K,
282 EA,
283 EB,
284 KA,
285 KB,
286 C,
287 D,
288 impl Fn(&C::Result) -> Sc + Send + Sync,
289 Sc,
290 >
291 where
292 Sc: Copy,
293 {
294 let w = Sc::one_soft();
295 self.reward_with(move |_: &C::Result| w)
296 }
297}
298
299impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc: Score> std::fmt::Debug
300 for ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
301{
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 f.debug_struct("ComplementedConstraintStream").finish()
304 }
305}
306
307pub struct ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
309where
310 Sc: Score,
311{
312 extractor_a: EA,
313 extractor_b: EB,
314 key_a: KA,
315 key_b: KB,
316 collector: C,
317 default_fn: D,
318 impact_type: ImpactType,
319 weight_fn: W,
320 is_hard: bool,
321 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
322}
323
324impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
325 ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
326where
327 S: Send + Sync + 'static,
328 A: Clone + Send + Sync + 'static,
329 B: Clone + Send + Sync + 'static,
330 K: Clone + Eq + Hash + Send + Sync + 'static,
331 EA: CollectionExtract<S, Item = A>,
332 EB: CollectionExtract<S, Item = B>,
333 KA: Fn(&A) -> Option<K> + Send + Sync,
334 KB: Fn(&B) -> K + Send + Sync,
335 C: UniCollector<A> + Send + Sync + 'static,
336 C::Accumulator: Send + Sync,
337 C::Result: Clone + Send + Sync,
338 D: Fn(&B) -> C::Result + Send + Sync,
339 W: Fn(&C::Result) -> Sc + Send + Sync,
340 Sc: Score + 'static,
341{
342 pub fn named(
343 self,
344 name: &str,
345 ) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> {
346 ComplementedGroupConstraint::new(
347 ConstraintRef::new("", name),
348 self.impact_type,
349 self.extractor_a,
350 self.extractor_b,
351 self.key_a,
352 self.key_b,
353 self.collector,
354 self.default_fn,
355 self.weight_fn,
356 self.is_hard,
357 )
358 }
359
360 }
362
363impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc: Score> std::fmt::Debug
364 for ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
365{
366 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 f.debug_struct("ComplementedConstraintBuilder")
368 .field("impact_type", &self.impact_type)
369 .finish()
370 }
371}