1use std::hash::Hash;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use super::filter::UniFilter;
13use crate::constraint::balance::BalanceConstraint;
14
15pub struct BalanceConstraintStream<S, A, K, E, F, KF, Sc>
62where
63 Sc: Score,
64{
65 extractor: E,
66 filter: F,
67 key_fn: KF,
68 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
69}
70
71impl<S, A, K, E, F, KF, Sc> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
72where
73 S: Send + Sync + 'static,
74 A: Clone + Send + Sync + 'static,
75 K: Clone + Eq + Hash + Send + Sync + 'static,
76 E: Fn(&S) -> &[A] + Send + Sync,
77 F: UniFilter<S, A>,
78 KF: Fn(&A) -> Option<K> + Send + Sync,
79 Sc: Score + 'static,
80{
81 pub(crate) fn new(extractor: E, filter: F, key_fn: KF) -> Self {
83 Self {
84 extractor,
85 filter,
86 key_fn,
87 _phantom: PhantomData,
88 }
89 }
90
91 pub fn penalize(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
95 let is_hard = base_score
96 .to_level_numbers()
97 .first()
98 .map(|&h| h != 0)
99 .unwrap_or(false);
100 BalanceConstraintBuilder {
101 extractor: self.extractor,
102 filter: self.filter,
103 key_fn: self.key_fn,
104 impact_type: ImpactType::Penalty,
105 base_score,
106 is_hard,
107 _phantom: PhantomData,
108 }
109 }
110
111 pub fn penalize_hard(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
113 where
114 Sc: Copy,
115 {
116 self.penalize(Sc::one_hard())
117 }
118
119 pub fn penalize_soft(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
121 where
122 Sc: Copy,
123 {
124 self.penalize(Sc::one_soft())
125 }
126
127 pub fn reward_hard(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
129 where
130 Sc: Copy,
131 {
132 self.reward(Sc::one_hard())
133 }
134
135 pub fn reward_soft(self) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
137 where
138 Sc: Copy,
139 {
140 self.reward(Sc::one_soft())
141 }
142
143 pub fn reward(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
147 let is_hard = base_score
148 .to_level_numbers()
149 .first()
150 .map(|&h| h != 0)
151 .unwrap_or(false);
152 BalanceConstraintBuilder {
153 extractor: self.extractor,
154 filter: self.filter,
155 key_fn: self.key_fn,
156 impact_type: ImpactType::Reward,
157 base_score,
158 is_hard,
159 _phantom: PhantomData,
160 }
161 }
162}
163
164impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
165 for BalanceConstraintStream<S, A, K, E, F, KF, Sc>
166{
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 f.debug_struct("BalanceConstraintStream").finish()
169 }
170}
171
172pub struct BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
174where
175 Sc: Score,
176{
177 extractor: E,
178 filter: F,
179 key_fn: KF,
180 impact_type: ImpactType,
181 base_score: Sc,
182 is_hard: bool,
183 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K)>,
184}
185
186impl<S, A, K, E, F, KF, Sc> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
187where
188 S: Send + Sync + 'static,
189 A: Clone + Send + Sync + 'static,
190 K: Clone + Eq + Hash + Send + Sync + 'static,
191 E: Fn(&S) -> &[A] + Send + Sync,
192 F: UniFilter<S, A>,
193 KF: Fn(&A) -> Option<K> + Send + Sync,
194 Sc: Score + 'static,
195{
196 pub fn named(self, name: &str) -> BalanceConstraint<S, A, K, E, F, KF, Sc> {
198 self.as_constraint(name)
199 }
200
201 pub fn as_constraint(self, name: &str) -> BalanceConstraint<S, A, K, E, F, KF, Sc> {
203 BalanceConstraint::new(
204 ConstraintRef::new("", name),
205 self.impact_type,
206 self.extractor,
207 self.filter,
208 self.key_fn,
209 self.base_score,
210 self.is_hard,
211 )
212 }
213}
214
215impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
216 for BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
217{
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 f.debug_struct("BalanceConstraintBuilder")
220 .field("impact_type", &self.impact_type)
221 .finish()
222 }
223}