1use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collection_extract::CollectionExtract;
14use super::collector::{Accumulator, Collector};
15use super::weighting_support::ConstraintWeight;
16use crate::constraint::complemented::ComplementedGroupConstraint;
17
18pub struct ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, Sc>
87where
88 Sc: Score,
89{
90 extractor_a: EA,
91 extractor_b: EB,
92 key_a: KA,
93 key_b: KB,
94 collector: C,
95 default_fn: D,
96 _phantom: PhantomData<(
97 fn() -> S,
98 fn() -> A,
99 fn() -> B,
100 fn() -> K,
101 fn() -> V,
102 fn() -> R,
103 fn() -> Acc,
104 fn() -> Sc,
105 )>,
106}
107
108impl<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, Sc>
109 ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, Sc>
110where
111 S: Send + Sync + 'static,
112 A: Clone + Send + Sync + 'static,
113 B: Clone + Send + Sync + 'static,
114 K: Clone + Eq + Hash + Send + Sync + 'static,
115 EA: CollectionExtract<S, Item = A>,
116 EB: CollectionExtract<S, Item = B>,
117 KA: Fn(&A) -> Option<K> + Send + Sync,
118 KB: Fn(&B) -> K + Send + Sync,
119 C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
120 V: Send + Sync + 'static,
121 R: Send + Sync + 'static,
122 Acc: Accumulator<V, R> + Send + Sync + 'static,
123 D: Fn(&B) -> R + Send + Sync,
124 Sc: Score + 'static,
125{
126 fn into_weighted_builder<W>(
127 self,
128 impact_type: ImpactType,
129 weight_fn: W,
130 is_hard: bool,
131 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc>
132 where
133 W: Fn(&K, &R) -> Sc + Send + Sync,
134 {
135 ComplementedConstraintBuilder {
136 extractor_a: self.extractor_a,
137 extractor_b: self.extractor_b,
138 key_a: self.key_a,
139 key_b: self.key_b,
140 collector: self.collector,
141 default_fn: self.default_fn,
142 impact_type,
143 weight_fn,
144 is_hard,
145 _phantom: PhantomData,
146 }
147 }
148
149 pub(crate) fn new(
151 extractor_a: EA,
152 extractor_b: EB,
153 key_a: KA,
154 key_b: KB,
155 collector: C,
156 default_fn: D,
157 ) -> Self {
158 Self {
159 extractor_a,
160 extractor_b,
161 key_a,
162 key_b,
163 collector,
164 default_fn,
165 _phantom: PhantomData,
166 }
167 }
168
169 pub fn penalize<W>(
170 self,
171 weight: W,
172 ) -> ComplementedConstraintBuilder<
173 S,
174 A,
175 B,
176 K,
177 EA,
178 EB,
179 KA,
180 KB,
181 C,
182 V,
183 R,
184 Acc,
185 D,
186 impl Fn(&K, &R) -> Sc + Send + Sync,
187 Sc,
188 >
189 where
190 W: for<'w> ConstraintWeight<(&'w K, &'w R), Sc> + Send + Sync,
191 {
192 let is_hard = weight.is_hard();
193 self.into_weighted_builder(
194 ImpactType::Penalty,
195 move |key: &K, result: &R| weight.score((key, result)),
196 is_hard,
197 )
198 }
199
200 pub fn reward<W>(
201 self,
202 weight: W,
203 ) -> ComplementedConstraintBuilder<
204 S,
205 A,
206 B,
207 K,
208 EA,
209 EB,
210 KA,
211 KB,
212 C,
213 V,
214 R,
215 Acc,
216 D,
217 impl Fn(&K, &R) -> Sc + Send + Sync,
218 Sc,
219 >
220 where
221 W: for<'w> ConstraintWeight<(&'w K, &'w R), Sc> + Send + Sync,
222 {
223 let is_hard = weight.is_hard();
224 self.into_weighted_builder(
225 ImpactType::Reward,
226 move |key: &K, result: &R| weight.score((key, result)),
227 is_hard,
228 )
229 }
230}
231
232impl<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, Sc: Score> std::fmt::Debug
233 for ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, Sc>
234{
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 f.debug_struct("ComplementedConstraintStream").finish()
237 }
238}
239
240pub struct ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc>
242where
243 Sc: Score,
244{
245 extractor_a: EA,
246 extractor_b: EB,
247 key_a: KA,
248 key_b: KB,
249 collector: C,
250 default_fn: D,
251 impact_type: ImpactType,
252 weight_fn: W,
253 is_hard: bool,
254 _phantom: PhantomData<(
255 fn() -> S,
256 fn() -> A,
257 fn() -> B,
258 fn() -> K,
259 fn() -> V,
260 fn() -> R,
261 fn() -> Acc,
262 fn() -> Sc,
263 )>,
264}
265
266impl<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc>
267 ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc>
268where
269 S: Send + Sync + 'static,
270 A: Clone + Send + Sync + 'static,
271 B: Clone + Send + Sync + 'static,
272 K: Clone + Eq + Hash + Send + Sync + 'static,
273 EA: CollectionExtract<S, Item = A>,
274 EB: CollectionExtract<S, Item = B>,
275 KA: Fn(&A) -> Option<K> + Send + Sync,
276 KB: Fn(&B) -> K + Send + Sync,
277 C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
278 V: Send + Sync + 'static,
279 R: Send + Sync + 'static,
280 Acc: Accumulator<V, R> + Send + Sync + 'static,
281 D: Fn(&B) -> R + Send + Sync,
282 W: Fn(&K, &R) -> Sc + Send + Sync,
283 Sc: Score + 'static,
284{
285 pub fn named(
286 self,
287 name: &str,
288 ) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc> {
289 ComplementedGroupConstraint::new(
290 ConstraintRef::new("", name),
291 self.impact_type,
292 self.extractor_a,
293 self.extractor_b,
294 self.key_a,
295 self.key_b,
296 self.collector,
297 self.default_fn,
298 self.weight_fn,
299 self.is_hard,
300 )
301 }
302
303 }
305
306impl<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc: Score> std::fmt::Debug
307 for ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, V, R, Acc, D, W, Sc>
308{
309 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310 f.debug_struct("ComplementedConstraintBuilder")
311 .field("impact_type", &self.impact_type)
312 .finish()
313 }
314}