solverforge_scoring/constraint/
balance.rs1use std::collections::HashMap;
11use std::hash::Hash;
12use std::marker::PhantomData;
13
14use solverforge_core::score::Score;
15use solverforge_core::{ConstraintRef, ImpactType};
16
17use crate::api::constraint_set::IncrementalConstraint;
18use crate::stream::filter::UniFilter;
19
20pub struct BalanceConstraint<S, A, K, E, F, KF, Sc>
84where
85 Sc: Score,
86{
87 constraint_ref: ConstraintRef,
88 impact_type: ImpactType,
89 extractor: E,
90 filter: F,
91 key_fn: KF,
92 base_score: Sc,
94 is_hard: bool,
95 counts: HashMap<K, i64>,
97 entity_keys: HashMap<usize, K>,
99 group_count: i64,
102 total_count: i64,
104 sum_squared: i64,
106 _phantom: PhantomData<(fn() -> S, fn() -> A)>,
107}
108
109impl<S, A, K, E, F, KF, Sc> BalanceConstraint<S, A, K, E, F, KF, Sc>
110where
111 S: Send + Sync + 'static,
112 A: Clone + Send + Sync + 'static,
113 K: Clone + Eq + Hash + Send + Sync + 'static,
114 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
115 F: UniFilter<S, A>,
116 KF: Fn(&A) -> Option<K> + Send + Sync,
117 Sc: Score + 'static,
118{
119 pub fn new(
132 constraint_ref: ConstraintRef,
133 impact_type: ImpactType,
134 extractor: E,
135 filter: F,
136 key_fn: KF,
137 base_score: Sc,
138 is_hard: bool,
139 ) -> Self {
140 Self {
141 constraint_ref,
142 impact_type,
143 extractor,
144 filter,
145 key_fn,
146 base_score,
147 is_hard,
148 counts: HashMap::new(),
149 entity_keys: HashMap::new(),
150 group_count: 0,
151 total_count: 0,
152 sum_squared: 0,
153 _phantom: PhantomData,
154 }
155 }
156
157 fn compute_std_dev(&self) -> f64 {
159 if self.group_count == 0 {
160 return 0.0;
161 }
162 let n = self.group_count as f64;
163 let mean = self.total_count as f64 / n;
164 let variance = (self.sum_squared as f64 / n) - (mean * mean);
165 if variance <= 0.0 {
166 return 0.0;
167 }
168 variance.sqrt()
169 }
170
171 fn compute_score(&self) -> Sc {
173 let std_dev = self.compute_std_dev();
174 let base = self.base_score.multiply(std_dev);
175 match self.impact_type {
176 ImpactType::Penalty => -base,
177 ImpactType::Reward => base,
178 }
179 }
180
181 fn compute_std_dev_from_counts(counts: &HashMap<K, i64>) -> f64 {
183 if counts.is_empty() {
184 return 0.0;
185 }
186 let n = counts.len() as f64;
187 let total: i64 = counts.values().sum();
188 let sum_sq: i64 = counts.values().map(|&c| c * c).sum();
189 let mean = total as f64 / n;
190 let variance = (sum_sq as f64 / n) - (mean * mean);
191 if variance > 0.0 {
192 variance.sqrt()
193 } else {
194 0.0
195 }
196 }
197}
198
199impl<S, A, K, E, F, KF, Sc> IncrementalConstraint<S, Sc>
200 for BalanceConstraint<S, A, K, E, F, KF, Sc>
201where
202 S: Send + Sync + 'static,
203 A: Clone + Send + Sync + 'static,
204 K: Clone + Eq + Hash + Send + Sync + 'static,
205 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
206 F: UniFilter<S, A>,
207 KF: Fn(&A) -> Option<K> + Send + Sync,
208 Sc: Score + 'static,
209{
210 fn evaluate(&self, solution: &S) -> Sc {
211 let entities = self.extractor.extract(solution);
212
213 let mut counts: HashMap<K, i64> = HashMap::new();
215 for entity in entities {
216 if !self.filter.test(solution, entity) {
217 continue;
218 }
219 if let Some(key) = (self.key_fn)(entity) {
220 *counts.entry(key).or_insert(0) += 1;
221 }
222 }
223
224 if counts.is_empty() {
225 return Sc::zero();
226 }
227
228 let std_dev = Self::compute_std_dev_from_counts(&counts);
229 let base = self.base_score.multiply(std_dev);
230 match self.impact_type {
231 ImpactType::Penalty => -base,
232 ImpactType::Reward => base,
233 }
234 }
235
236 fn match_count(&self, solution: &S) -> usize {
237 let entities = self.extractor.extract(solution);
238
239 let mut counts: HashMap<K, i64> = HashMap::new();
241 for entity in entities {
242 if !self.filter.test(solution, entity) {
243 continue;
244 }
245 if let Some(key) = (self.key_fn)(entity) {
246 *counts.entry(key).or_insert(0) += 1;
247 }
248 }
249
250 if counts.is_empty() {
251 return 0;
252 }
253
254 let total: i64 = counts.values().sum();
255 let mean = total as f64 / counts.len() as f64;
256
257 counts
259 .values()
260 .filter(|&&c| (c as f64 - mean).abs() > 0.5)
261 .count()
262 }
263
264 fn initialize(&mut self, solution: &S) -> Sc {
265 self.reset();
266
267 let entities = self.extractor.extract(solution);
268
269 for (idx, entity) in entities.iter().enumerate() {
270 if !self.filter.test(solution, entity) {
271 continue;
272 }
273 if let Some(key) = (self.key_fn)(entity) {
274 let old_count = *self.counts.get(&key).unwrap_or(&0);
275 let new_count = old_count + 1;
276 self.counts.insert(key.clone(), new_count);
277 self.entity_keys.insert(idx, key);
278
279 if old_count == 0 {
280 self.group_count += 1;
281 }
282 self.total_count += 1;
283 self.sum_squared += new_count * new_count - old_count * old_count;
284 }
285 }
286
287 self.compute_score()
288 }
289
290 fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
291 let entities = self.extractor.extract(solution);
292 if entity_index >= entities.len() {
293 return Sc::zero();
294 }
295
296 let entity = &entities[entity_index];
297 if !self.filter.test(solution, entity) {
298 return Sc::zero();
299 }
300
301 let Some(key) = (self.key_fn)(entity) else {
302 return Sc::zero();
303 };
304
305 let old_score = self.compute_score();
306
307 let old_count = *self.counts.get(&key).unwrap_or(&0);
308 let new_count = old_count + 1;
309 self.counts.insert(key.clone(), new_count);
310 self.entity_keys.insert(entity_index, key);
311
312 if old_count == 0 {
313 self.group_count += 1;
314 }
315 self.total_count += 1;
316 self.sum_squared += new_count * new_count - old_count * old_count;
317
318 let new_score = self.compute_score();
319 new_score - old_score
320 }
321
322 fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
323 let entities = self.extractor.extract(solution);
324 if entity_index >= entities.len() {
325 return Sc::zero();
326 }
327
328 let Some(key) = self.entity_keys.remove(&entity_index) else {
330 return Sc::zero();
331 };
332
333 let old_score = self.compute_score();
334
335 let old_count = *self.counts.get(&key).unwrap_or(&0);
336 if old_count == 0 {
337 return Sc::zero();
338 }
339
340 let new_count = old_count - 1;
341 if new_count == 0 {
342 self.counts.remove(&key);
343 self.group_count -= 1;
344 } else {
345 self.counts.insert(key, new_count);
346 }
347 self.total_count -= 1;
348 self.sum_squared += new_count * new_count - old_count * old_count;
349
350 let new_score = self.compute_score();
351 new_score - old_score
352 }
353
354 fn reset(&mut self) {
355 self.counts.clear();
356 self.entity_keys.clear();
357 self.group_count = 0;
358 self.total_count = 0;
359 self.sum_squared = 0;
360 }
361
362 fn name(&self) -> &str {
363 &self.constraint_ref.name
364 }
365
366 fn is_hard(&self) -> bool {
367 self.is_hard
368 }
369
370 fn constraint_ref(&self) -> ConstraintRef {
371 self.constraint_ref.clone()
372 }
373}
374
375impl<S, A, K, E, F, KF, Sc> std::fmt::Debug for BalanceConstraint<S, A, K, E, F, KF, Sc>
376where
377 Sc: Score,
378{
379 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 f.debug_struct("BalanceConstraint")
381 .field("name", &self.constraint_ref.name)
382 .field("impact_type", &self.impact_type)
383 .field("groups", &self.counts.len())
384 .finish()
385 }
386}