solverforge_scoring/constraint/
balance.rs1use std::collections::HashMap;
11use std::hash::Hash;
12use std::marker::PhantomData;
13
14use solverforge_core::score::Score;
15use solverforge_core::{ConstraintRef, ImpactType};
16
17use crate::api::constraint_set::IncrementalConstraint;
18use crate::stream::collection_extract::ChangeSource;
19use crate::stream::filter::UniFilter;
20
21pub struct BalanceConstraint<S, A, K, E, F, KF, Sc>
85where
86 Sc: Score,
87{
88 constraint_ref: ConstraintRef,
89 impact_type: ImpactType,
90 extractor: E,
91 filter: F,
92 key_fn: KF,
93 change_source: ChangeSource,
94 base_score: Sc,
96 is_hard: bool,
97 counts: HashMap<K, i64>,
99 entity_keys: HashMap<usize, K>,
101 group_count: i64,
104 total_count: i64,
106 sum_squared: i64,
108 _phantom: PhantomData<(fn() -> S, fn() -> A)>,
109}
110
111impl<S, A, K, E, F, KF, Sc> BalanceConstraint<S, A, K, E, F, KF, Sc>
112where
113 S: Send + Sync + 'static,
114 A: Clone + Send + Sync + 'static,
115 K: Clone + Eq + Hash + Send + Sync + 'static,
116 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
117 F: UniFilter<S, A>,
118 KF: Fn(&A) -> Option<K> + Send + Sync,
119 Sc: Score + 'static,
120{
121 pub fn new(
134 constraint_ref: ConstraintRef,
135 impact_type: ImpactType,
136 extractor: E,
137 filter: F,
138 key_fn: KF,
139 base_score: Sc,
140 is_hard: bool,
141 ) -> Self {
142 let change_source = extractor.change_source();
143 Self {
144 constraint_ref,
145 impact_type,
146 extractor,
147 filter,
148 key_fn,
149 change_source,
150 base_score,
151 is_hard,
152 counts: HashMap::new(),
153 entity_keys: HashMap::new(),
154 group_count: 0,
155 total_count: 0,
156 sum_squared: 0,
157 _phantom: PhantomData,
158 }
159 }
160
161 fn compute_std_dev(&self) -> f64 {
163 if self.group_count == 0 {
164 return 0.0;
165 }
166 let n = self.group_count as f64;
167 let mean = self.total_count as f64 / n;
168 let variance = (self.sum_squared as f64 / n) - (mean * mean);
169 if variance <= 0.0 {
170 return 0.0;
171 }
172 variance.sqrt()
173 }
174
175 fn compute_score(&self) -> Sc {
177 let std_dev = self.compute_std_dev();
178 let base = self.base_score.multiply(std_dev);
179 match self.impact_type {
180 ImpactType::Penalty => -base,
181 ImpactType::Reward => base,
182 }
183 }
184
185 fn compute_std_dev_from_counts(counts: &HashMap<K, i64>) -> f64 {
187 if counts.is_empty() {
188 return 0.0;
189 }
190 let n = counts.len() as f64;
191 let total: i64 = counts.values().sum();
192 let sum_sq: i64 = counts.values().map(|&c| c * c).sum();
193 let mean = total as f64 / n;
194 let variance = (sum_sq as f64 / n) - (mean * mean);
195 if variance > 0.0 {
196 variance.sqrt()
197 } else {
198 0.0
199 }
200 }
201}
202
203impl<S, A, K, E, F, KF, Sc> IncrementalConstraint<S, Sc>
204 for BalanceConstraint<S, A, K, E, F, KF, Sc>
205where
206 S: Send + Sync + 'static,
207 A: Clone + Send + Sync + 'static,
208 K: Clone + Eq + Hash + Send + Sync + 'static,
209 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
210 F: UniFilter<S, A>,
211 KF: Fn(&A) -> Option<K> + Send + Sync,
212 Sc: Score + 'static,
213{
214 fn evaluate(&self, solution: &S) -> Sc {
215 let entities = self.extractor.extract(solution);
216
217 let mut counts: HashMap<K, i64> = HashMap::new();
219 for entity in entities {
220 if !self.filter.test(solution, entity) {
221 continue;
222 }
223 if let Some(key) = (self.key_fn)(entity) {
224 *counts.entry(key).or_insert(0) += 1;
225 }
226 }
227
228 if counts.is_empty() {
229 return Sc::zero();
230 }
231
232 let std_dev = Self::compute_std_dev_from_counts(&counts);
233 let base = self.base_score.multiply(std_dev);
234 match self.impact_type {
235 ImpactType::Penalty => -base,
236 ImpactType::Reward => base,
237 }
238 }
239
240 fn match_count(&self, solution: &S) -> usize {
241 let entities = self.extractor.extract(solution);
242
243 let mut counts: HashMap<K, i64> = HashMap::new();
245 for entity in entities {
246 if !self.filter.test(solution, entity) {
247 continue;
248 }
249 if let Some(key) = (self.key_fn)(entity) {
250 *counts.entry(key).or_insert(0) += 1;
251 }
252 }
253
254 if counts.is_empty() {
255 return 0;
256 }
257
258 let total: i64 = counts.values().sum();
259 let mean = total as f64 / counts.len() as f64;
260
261 counts
263 .values()
264 .filter(|&&c| (c as f64 - mean).abs() > 0.5)
265 .count()
266 }
267
268 fn initialize(&mut self, solution: &S) -> Sc {
269 self.reset();
270
271 let entities = self.extractor.extract(solution);
272
273 for (idx, entity) in entities.iter().enumerate() {
274 if !self.filter.test(solution, entity) {
275 continue;
276 }
277 if let Some(key) = (self.key_fn)(entity) {
278 let old_count = *self.counts.get(&key).unwrap_or(&0);
279 let new_count = old_count + 1;
280 self.counts.insert(key.clone(), new_count);
281 self.entity_keys.insert(idx, key);
282
283 if old_count == 0 {
284 self.group_count += 1;
285 }
286 self.total_count += 1;
287 self.sum_squared += new_count * new_count - old_count * old_count;
288 }
289 }
290
291 self.compute_score()
292 }
293
294 fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
295 if !self
296 .change_source
297 .assert_localizes(_descriptor_index, &self.constraint_ref.name)
298 {
299 return Sc::zero();
300 }
301 let entities = self.extractor.extract(solution);
302 if entity_index >= entities.len() {
303 return Sc::zero();
304 }
305
306 let entity = &entities[entity_index];
307 if !self.filter.test(solution, entity) {
308 return Sc::zero();
309 }
310
311 let Some(key) = (self.key_fn)(entity) else {
312 return Sc::zero();
313 };
314
315 let old_score = self.compute_score();
316
317 let old_count = *self.counts.get(&key).unwrap_or(&0);
318 let new_count = old_count + 1;
319 self.counts.insert(key.clone(), new_count);
320 self.entity_keys.insert(entity_index, key);
321
322 if old_count == 0 {
323 self.group_count += 1;
324 }
325 self.total_count += 1;
326 self.sum_squared += new_count * new_count - old_count * old_count;
327
328 let new_score = self.compute_score();
329 new_score - old_score
330 }
331
332 fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
333 if !self
334 .change_source
335 .assert_localizes(_descriptor_index, &self.constraint_ref.name)
336 {
337 return Sc::zero();
338 }
339 let entities = self.extractor.extract(solution);
340 if entity_index >= entities.len() {
341 return Sc::zero();
342 }
343
344 let Some(key) = self.entity_keys.remove(&entity_index) else {
346 return Sc::zero();
347 };
348
349 let old_score = self.compute_score();
350
351 let old_count = *self.counts.get(&key).unwrap_or(&0);
352 if old_count == 0 {
353 return Sc::zero();
354 }
355
356 let new_count = old_count - 1;
357 if new_count == 0 {
358 self.counts.remove(&key);
359 self.group_count -= 1;
360 } else {
361 self.counts.insert(key, new_count);
362 }
363 self.total_count -= 1;
364 self.sum_squared += new_count * new_count - old_count * old_count;
365
366 let new_score = self.compute_score();
367 new_score - old_score
368 }
369
370 fn reset(&mut self) {
371 self.counts.clear();
372 self.entity_keys.clear();
373 self.group_count = 0;
374 self.total_count = 0;
375 self.sum_squared = 0;
376 }
377
378 fn name(&self) -> &str {
379 &self.constraint_ref.name
380 }
381
382 fn is_hard(&self) -> bool {
383 self.is_hard
384 }
385
386 fn constraint_ref(&self) -> &ConstraintRef {
387 &self.constraint_ref
388 }
389}
390
391impl<S, A, K, E, F, KF, Sc> std::fmt::Debug for BalanceConstraint<S, A, K, E, F, KF, Sc>
392where
393 Sc: Score,
394{
395 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396 f.debug_struct("BalanceConstraint")
397 .field("name", &self.constraint_ref.name)
398 .field("impact_type", &self.impact_type)
399 .field("groups", &self.counts.len())
400 .finish()
401 }
402}