1use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collection_extract::CollectionExtract;
14use super::collector::UniCollector;
15use super::weighting_support::ConstraintWeight;
16use crate::constraint::complemented::ComplementedGroupConstraint;
17
18pub struct ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
87where
88 Sc: Score,
89{
90 extractor_a: EA,
91 extractor_b: EB,
92 key_a: KA,
93 key_b: KB,
94 collector: C,
95 default_fn: D,
96 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
97}
98
99impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
100 ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
101where
102 S: Send + Sync + 'static,
103 A: Clone + Send + Sync + 'static,
104 B: Clone + Send + Sync + 'static,
105 K: Clone + Eq + Hash + Send + Sync + 'static,
106 EA: CollectionExtract<S, Item = A>,
107 EB: CollectionExtract<S, Item = B>,
108 KA: Fn(&A) -> Option<K> + Send + Sync,
109 KB: Fn(&B) -> K + Send + Sync,
110 C: UniCollector<A> + Send + Sync + 'static,
111 C::Accumulator: Send + Sync,
112 C::Result: Send + Sync,
113 D: Fn(&B) -> C::Result + Send + Sync,
114 Sc: Score + 'static,
115{
116 fn into_weighted_builder<W>(
117 self,
118 impact_type: ImpactType,
119 weight_fn: W,
120 is_hard: bool,
121 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
122 where
123 W: Fn(&K, &C::Result) -> Sc + Send + Sync,
124 {
125 ComplementedConstraintBuilder {
126 extractor_a: self.extractor_a,
127 extractor_b: self.extractor_b,
128 key_a: self.key_a,
129 key_b: self.key_b,
130 collector: self.collector,
131 default_fn: self.default_fn,
132 impact_type,
133 weight_fn,
134 is_hard,
135 _phantom: PhantomData,
136 }
137 }
138
139 pub(crate) fn new(
141 extractor_a: EA,
142 extractor_b: EB,
143 key_a: KA,
144 key_b: KB,
145 collector: C,
146 default_fn: D,
147 ) -> Self {
148 Self {
149 extractor_a,
150 extractor_b,
151 key_a,
152 key_b,
153 collector,
154 default_fn,
155 _phantom: PhantomData,
156 }
157 }
158
159 pub fn penalize<W>(
160 self,
161 weight: W,
162 ) -> ComplementedConstraintBuilder<
163 S,
164 A,
165 B,
166 K,
167 EA,
168 EB,
169 KA,
170 KB,
171 C,
172 D,
173 impl Fn(&K, &C::Result) -> Sc + Send + Sync,
174 Sc,
175 >
176 where
177 W: for<'w> ConstraintWeight<(&'w K, &'w C::Result), Sc> + Send + Sync,
178 {
179 let is_hard = weight.is_hard();
180 self.into_weighted_builder(
181 ImpactType::Penalty,
182 move |key: &K, result: &C::Result| weight.score((key, result)),
183 is_hard,
184 )
185 }
186
187 pub fn reward<W>(
188 self,
189 weight: W,
190 ) -> ComplementedConstraintBuilder<
191 S,
192 A,
193 B,
194 K,
195 EA,
196 EB,
197 KA,
198 KB,
199 C,
200 D,
201 impl Fn(&K, &C::Result) -> Sc + Send + Sync,
202 Sc,
203 >
204 where
205 W: for<'w> ConstraintWeight<(&'w K, &'w C::Result), Sc> + Send + Sync,
206 {
207 let is_hard = weight.is_hard();
208 self.into_weighted_builder(
209 ImpactType::Reward,
210 move |key: &K, result: &C::Result| weight.score((key, result)),
211 is_hard,
212 )
213 }
214}
215
216impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc: Score> std::fmt::Debug
217 for ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
218{
219 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220 f.debug_struct("ComplementedConstraintStream").finish()
221 }
222}
223
224pub struct ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
226where
227 Sc: Score,
228{
229 extractor_a: EA,
230 extractor_b: EB,
231 key_a: KA,
232 key_b: KB,
233 collector: C,
234 default_fn: D,
235 impact_type: ImpactType,
236 weight_fn: W,
237 is_hard: bool,
238 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
239}
240
241impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
242 ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
243where
244 S: Send + Sync + 'static,
245 A: Clone + Send + Sync + 'static,
246 B: Clone + Send + Sync + 'static,
247 K: Clone + Eq + Hash + Send + Sync + 'static,
248 EA: CollectionExtract<S, Item = A>,
249 EB: CollectionExtract<S, Item = B>,
250 KA: Fn(&A) -> Option<K> + Send + Sync,
251 KB: Fn(&B) -> K + Send + Sync,
252 C: UniCollector<A> + Send + Sync + 'static,
253 C::Accumulator: Send + Sync,
254 C::Result: Send + Sync,
255 D: Fn(&B) -> C::Result + Send + Sync,
256 W: Fn(&K, &C::Result) -> Sc + Send + Sync,
257 Sc: Score + 'static,
258{
259 pub fn named(
260 self,
261 name: &str,
262 ) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> {
263 ComplementedGroupConstraint::new(
264 ConstraintRef::new("", name),
265 self.impact_type,
266 self.extractor_a,
267 self.extractor_b,
268 self.key_a,
269 self.key_b,
270 self.collector,
271 self.default_fn,
272 self.weight_fn,
273 self.is_hard,
274 )
275 }
276
277 }
279
280impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc: Score> std::fmt::Debug
281 for ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
282{
283 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284 f.debug_struct("ComplementedConstraintBuilder")
285 .field("impact_type", &self.impact_type)
286 .finish()
287 }
288}