1use std::hash::Hash;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use super::collector::UniCollector;
13use crate::constraint::complemented::ComplementedGroupConstraint;
14
15pub struct ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
83where
84 Sc: Score,
85{
86 extractor_a: EA,
87 extractor_b: EB,
88 key_a: KA,
89 key_b: KB,
90 collector: C,
91 default_fn: D,
92 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
93}
94
95impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
96 ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
97where
98 S: Send + Sync + 'static,
99 A: Clone + Send + Sync + 'static,
100 B: Clone + Send + Sync + 'static,
101 K: Clone + Eq + Hash + Send + Sync + 'static,
102 EA: Fn(&S) -> &[A] + Send + Sync,
103 EB: Fn(&S) -> &[B] + Send + Sync,
104 KA: Fn(&A) -> Option<K> + Send + Sync,
105 KB: Fn(&B) -> K + Send + Sync,
106 C: UniCollector<A> + Send + Sync + 'static,
107 C::Accumulator: Send + Sync,
108 C::Result: Clone + Send + Sync,
109 D: Fn(&B) -> C::Result + Send + Sync,
110 Sc: Score + 'static,
111{
112 pub(crate) fn new(
114 extractor_a: EA,
115 extractor_b: EB,
116 key_a: KA,
117 key_b: KB,
118 collector: C,
119 default_fn: D,
120 ) -> Self {
121 Self {
122 extractor_a,
123 extractor_b,
124 key_a,
125 key_b,
126 collector,
127 default_fn,
128 _phantom: PhantomData,
129 }
130 }
131
132 pub fn penalize_with<W>(
134 self,
135 weight_fn: W,
136 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
137 where
138 W: Fn(&C::Result) -> Sc + Send + Sync,
139 {
140 ComplementedConstraintBuilder {
141 extractor_a: self.extractor_a,
142 extractor_b: self.extractor_b,
143 key_a: self.key_a,
144 key_b: self.key_b,
145 collector: self.collector,
146 default_fn: self.default_fn,
147 impact_type: ImpactType::Penalty,
148 weight_fn,
149 is_hard: false,
150 _phantom: PhantomData,
151 }
152 }
153
154 pub fn penalize_hard_with<W>(
156 self,
157 weight_fn: W,
158 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
159 where
160 W: Fn(&C::Result) -> Sc + Send + Sync,
161 {
162 ComplementedConstraintBuilder {
163 extractor_a: self.extractor_a,
164 extractor_b: self.extractor_b,
165 key_a: self.key_a,
166 key_b: self.key_b,
167 collector: self.collector,
168 default_fn: self.default_fn,
169 impact_type: ImpactType::Penalty,
170 weight_fn,
171 is_hard: true,
172 _phantom: PhantomData,
173 }
174 }
175
176 pub fn reward_with<W>(
178 self,
179 weight_fn: W,
180 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
181 where
182 W: Fn(&C::Result) -> Sc + Send + Sync,
183 {
184 ComplementedConstraintBuilder {
185 extractor_a: self.extractor_a,
186 extractor_b: self.extractor_b,
187 key_a: self.key_a,
188 key_b: self.key_b,
189 collector: self.collector,
190 default_fn: self.default_fn,
191 impact_type: ImpactType::Reward,
192 weight_fn,
193 is_hard: false,
194 _phantom: PhantomData,
195 }
196 }
197
198 pub fn reward_hard_with<W>(
200 self,
201 weight_fn: W,
202 ) -> ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
203 where
204 W: Fn(&C::Result) -> Sc + Send + Sync,
205 {
206 ComplementedConstraintBuilder {
207 extractor_a: self.extractor_a,
208 extractor_b: self.extractor_b,
209 key_a: self.key_a,
210 key_b: self.key_b,
211 collector: self.collector,
212 default_fn: self.default_fn,
213 impact_type: ImpactType::Reward,
214 weight_fn,
215 is_hard: true,
216 _phantom: PhantomData,
217 }
218 }
219}
220
221impl<S, A, B, K, EA, EB, KA, KB, C, D, Sc: Score> std::fmt::Debug
222 for ComplementedConstraintStream<S, A, B, K, EA, EB, KA, KB, C, D, Sc>
223{
224 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225 f.debug_struct("ComplementedConstraintStream").finish()
226 }
227}
228
229pub struct ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
231where
232 Sc: Score,
233{
234 extractor_a: EA,
235 extractor_b: EB,
236 key_a: KA,
237 key_b: KB,
238 collector: C,
239 default_fn: D,
240 impact_type: ImpactType,
241 weight_fn: W,
242 is_hard: bool,
243 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
244}
245
246impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
247 ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
248where
249 S: Send + Sync + 'static,
250 A: Clone + Send + Sync + 'static,
251 B: Clone + Send + Sync + 'static,
252 K: Clone + Eq + Hash + Send + Sync + 'static,
253 EA: Fn(&S) -> &[A] + Send + Sync,
254 EB: Fn(&S) -> &[B] + Send + Sync,
255 KA: Fn(&A) -> Option<K> + Send + Sync,
256 KB: Fn(&B) -> K + Send + Sync,
257 C: UniCollector<A> + Send + Sync + 'static,
258 C::Accumulator: Send + Sync,
259 C::Result: Clone + Send + Sync,
260 D: Fn(&B) -> C::Result + Send + Sync,
261 W: Fn(&C::Result) -> Sc + Send + Sync,
262 Sc: Score + 'static,
263{
264 pub fn as_constraint(
266 self,
267 name: &str,
268 ) -> ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> {
269 ComplementedGroupConstraint::new(
270 ConstraintRef::new("", name),
271 self.impact_type,
272 self.extractor_a,
273 self.extractor_b,
274 self.key_a,
275 self.key_b,
276 self.collector,
277 self.default_fn,
278 self.weight_fn,
279 self.is_hard,
280 )
281 }
282}
283
284impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc: Score> std::fmt::Debug
285 for ComplementedConstraintBuilder<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
286{
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 f.debug_struct("ComplementedConstraintBuilder")
289 .field("impact_type", &self.impact_type)
290 .finish()
291 }
292}