1use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collection_extract::CollectionExtract;
14use super::filter::UniFilter;
15use super::weighting_support::fixed_weight_is_hard;
16use crate::constraint::balance::BalanceConstraint;
17
18pub struct BalanceConstraintStream<S, A, K, E, F, KF, Sc>
66where
67 Sc: Score,
68{
69 extractor: E,
70 filter: F,
71 key_fn: KF,
72 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
73}
74
75impl<S, A, K, E, F, KF, Sc> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
76where
77 S: Send + Sync + 'static,
78 A: Clone + Send + Sync + 'static,
79 K: Clone + Eq + Hash + Send + Sync + 'static,
80 E: CollectionExtract<S, Item = A>,
81 F: UniFilter<S, A>,
82 KF: Fn(&A) -> Option<K> + Send + Sync,
83 Sc: Score + 'static,
84{
85 fn into_builder(
86 self,
87 impact_type: ImpactType,
88 base_score: Sc,
89 ) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
90 BalanceConstraintBuilder {
91 extractor: self.extractor,
92 filter: self.filter,
93 key_fn: self.key_fn,
94 impact_type,
95 base_score,
96 is_hard: fixed_weight_is_hard(base_score),
97 _phantom: PhantomData,
98 }
99 }
100
101 pub(crate) fn new(extractor: E, filter: F, key_fn: KF) -> Self {
103 Self {
104 extractor,
105 filter,
106 key_fn,
107 _phantom: PhantomData,
108 }
109 }
110
111 pub fn penalize(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
116 self.into_builder(ImpactType::Penalty, base_score)
117 }
118
119 pub fn penalize_hard(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
121 where
122 Sc: Copy,
123 {
124 self.penalize(Sc::one_hard())
125 }
126
127 pub fn penalize_soft(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
129 where
130 Sc: Copy,
131 {
132 self.penalize(Sc::one_soft())
133 }
134
135 pub fn reward_hard(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
137 where
138 Sc: Copy,
139 {
140 self.reward(Sc::one_hard())
141 }
142
143 pub fn reward_soft(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
145 where
146 Sc: Copy,
147 {
148 self.reward(Sc::one_soft())
149 }
150
151 pub fn reward(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
156 self.into_builder(ImpactType::Reward, base_score)
157 }
158}
159
160impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
161 for BalanceConstraintStream<S, A, K, E, F, KF, Sc>
162{
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.debug_struct("BalanceConstraintStream").finish()
165 }
166}
167
168pub struct BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
170where
171 Sc: Score,
172{
173 extractor: E,
174 filter: F,
175 key_fn: KF,
176 impact_type: ImpactType,
177 base_score: Sc,
178 is_hard: bool,
179 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K)>,
180}
181
182impl<S, A, K, E, F, KF, Sc> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
183where
184 S: Send + Sync + 'static,
185 A: Clone + Send + Sync + 'static,
186 K: Clone + Eq + Hash + Send + Sync + 'static,
187 E: CollectionExtract<S, Item = A>,
188 F: UniFilter<S, A>,
189 KF: Fn(&A) -> Option<K> + Send + Sync,
190 Sc: Score + 'static,
191{
192 pub fn named(self, name: &str) -> BalanceConstraint<S, A, K, E, F, KF, Sc> {
193 BalanceConstraint::new(
194 ConstraintRef::new("", name),
195 self.impact_type,
196 self.extractor,
197 self.filter,
198 self.key_fn,
199 self.base_score,
200 self.is_hard,
201 )
202 }
203
204 }
206
207impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
208 for BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
209{
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 f.debug_struct("BalanceConstraintBuilder")
212 .field("impact_type", &self.impact_type)
213 .finish()
214 }
215}