solverforge_scoring/stream/
balance_stream.rs1use std::hash::Hash;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use super::filter::UniFilter;
13use crate::constraint::balance::BalanceConstraint;
14
15pub struct BalanceConstraintStream<S, A, K, E, F, KF, Sc>
62where
63 Sc: Score,
64{
65 extractor: E,
66 filter: F,
67 key_fn: KF,
68 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
69}
70
71impl<S, A, K, E, F, KF, Sc> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
72where
73 S: Send + Sync + 'static,
74 A: Clone + Send + Sync + 'static,
75 K: Clone + Eq + Hash + Send + Sync + 'static,
76 E: Fn(&S) -> &[A] + Send + Sync,
77 F: UniFilter<S, A>,
78 KF: Fn(&A) -> Option<K> + Send + Sync,
79 Sc: Score + 'static,
80{
81 pub(crate) fn new(extractor: E, filter: F, key_fn: KF) -> Self {
83 Self {
84 extractor,
85 filter,
86 key_fn,
87 _phantom: PhantomData,
88 }
89 }
90
91 pub fn penalize(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
95 let is_hard = base_score
96 .to_level_numbers()
97 .first()
98 .map(|&h| h != 0)
99 .unwrap_or(false);
100 BalanceConstraintBuilder {
101 extractor: self.extractor,
102 filter: self.filter,
103 key_fn: self.key_fn,
104 impact_type: ImpactType::Penalty,
105 base_score,
106 is_hard,
107 _phantom: PhantomData,
108 }
109 }
110
111 pub fn reward(self, base_score: Sc) -> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc> {
115 let is_hard = base_score
116 .to_level_numbers()
117 .first()
118 .map(|&h| h != 0)
119 .unwrap_or(false);
120 BalanceConstraintBuilder {
121 extractor: self.extractor,
122 filter: self.filter,
123 key_fn: self.key_fn,
124 impact_type: ImpactType::Reward,
125 base_score,
126 is_hard,
127 _phantom: PhantomData,
128 }
129 }
130}
131
132impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
133 for BalanceConstraintStream<S, A, K, E, F, KF, Sc>
134{
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 f.debug_struct("BalanceConstraintStream").finish()
137 }
138}
139
140pub struct BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
142where
143 Sc: Score,
144{
145 extractor: E,
146 filter: F,
147 key_fn: KF,
148 impact_type: ImpactType,
149 base_score: Sc,
150 is_hard: bool,
151 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K)>,
152}
153
154impl<S, A, K, E, F, KF, Sc> BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
155where
156 S: Send + Sync + 'static,
157 A: Clone + Send + Sync + 'static,
158 K: Clone + Eq + Hash + Send + Sync + 'static,
159 E: Fn(&S) -> &[A] + Send + Sync,
160 F: UniFilter<S, A>,
161 KF: Fn(&A) -> Option<K> + Send + Sync,
162 Sc: Score + 'static,
163{
164 pub fn as_constraint(self, name: &str) -> BalanceConstraint<S, A, K, E, F, KF, Sc> {
166 BalanceConstraint::new(
167 ConstraintRef::new("", name),
168 self.impact_type,
169 self.extractor,
170 self.filter,
171 self.key_fn,
172 self.base_score,
173 self.is_hard,
174 )
175 }
176}
177
178impl<S, A, K, E, F, KF, Sc: Score> std::fmt::Debug
179 for BalanceConstraintBuilder<S, A, K, E, F, KF, Sc>
180{
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 f.debug_struct("BalanceConstraintBuilder")
183 .field("impact_type", &self.impact_type)
184 .finish()
185 }
186}