1use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collection_extract::CollectionExtract;
14use super::filter::UniFilter;
15use crate::constraint::balance::BalanceConstraint;
16
17pub struct BalanceConstraintStream<S, A, K, E, F, KF, Sc>
65where
66 Sc: Score,
67{
68 extractor: E,
69 filter: F,
70 key_fn: KF,
71 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
72}
73
74impl<S, A, K, E, F, KF, Sc> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
75where
76 S: Send + Sync + 'static,
77 A: Clone + Send + Sync + 'static,
78 K: Clone + Eq + Hash + Send + Sync + 'static,
79 E: CollectionExtract<S, Item = A>,
80 F: UniFilter<S, A>,
81 KF: Fn(&A) -> Option<K> + Send + Sync,
82 Sc: Score + 'static,
83{
84 pub(crate) fn new(extractor: E, filter: F, key_fn: KF) -> Self {
86 Self {
87 extractor,
88 filter,
89 key_fn,
90 _phantom: PhantomData,
91 }
92 }
93
94 pub fn penalize(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
99 let is_hard = base_score
100 .to_level_numbers()
101 .first()
102 .map(|&h| h != 0)
103 .unwrap_or(false);
104 BalanceConstraintBuilder {
105 extractor: self.extractor,
106 filter: self.filter,
107 key_fn: self.key_fn,
108 impact_type: ImpactType::Penalty,
109 base_score,
110 is_hard,
111 _phantom: PhantomData,
112 }
113 }
114
115 pub fn penalize_hard(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
117 where
118 Sc: Copy,
119 {
120 self.penalize(Sc::one_hard())
121 }
122
123 pub fn penalize_soft(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
125 where
126 Sc: Copy,
127 {
128 self.penalize(Sc::one_soft())
129 }
130
131 pub fn reward_hard(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
133 where
134 Sc: Copy,
135 {
136 self.reward(Sc::one_hard())
137 }
138
139 pub fn reward_soft(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
141 where
142 Sc: Copy,
143 {
144 self.reward(Sc::one_soft())
145 }
146
147 pub fn reward(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
152 let is_hard = base_score
153 .to_level_numbers()
154 .first()
155 .map(|&h| h != 0)
156 .unwrap_or(false);
157 BalanceConstraintBuilder {
158 extractor: self.extractor,
159 filter: self.filter,
160 key_fn: self.key_fn,
161 impact_type: ImpactType::Reward,
162 base_score,
163 is_hard,
164 _phantom: PhantomData,
165 }
166 }
167}
168
169impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
170 for BalanceConstraintStream<S, A, K, E, F, KF, Sc>
171{
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 f.debug_struct("BalanceConstraintStream").finish()
174 }
175}
176
177pub struct BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
179where
180 Sc: Score,
181{
182 extractor: E,
183 filter: F,
184 key_fn: KF,
185 impact_type: ImpactType,
186 base_score: Sc,
187 is_hard: bool,
188 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K)>,
189}
190
191impl<S, A, K, E, F, KF, Sc> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
192where
193 S: Send + Sync + 'static,
194 A: Clone + Send + Sync + 'static,
195 K: Clone + Eq + Hash + Send + Sync + 'static,
196 E: CollectionExtract<S, Item = A>,
197 F: UniFilter<S, A>,
198 KF: Fn(&A) -> Option<K> + Send + Sync,
199 Sc: Score + 'static,
200{
201 pub fn named(self, name: &str) -> BalanceConstraint<S, A, K, E, F, KF, Sc> {
202 BalanceConstraint::new(
203 ConstraintRef::new("", name),
204 self.impact_type,
205 self.extractor,
206 self.filter,
207 self.key_fn,
208 self.base_score,
209 self.is_hard,
210 )
211 }
212
213 }
215
216impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
217 for BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
218{
219 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220 f.debug_struct("BalanceConstraintBuilder")
221 .field("impact_type", &self.impact_type)
222 .finish()
223 }
224}