Skip to main content

solverforge_scoring/constraint/projected/
grouped.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3use std::marker::PhantomData;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::collector::{Accumulator, UniCollector};
10use crate::stream::filter::UniFilter;
11use crate::stream::ProjectedSource;
12
13pub struct ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
14where
15    C: UniCollector<Out>,
16    Sc: Score,
17{
18    constraint_ref: ConstraintRef,
19    impact_type: ImpactType,
20    source: Src,
21    filter: F,
22    key_fn: KF,
23    collector: C,
24    weight_fn: W,
25    is_hard: bool,
26    groups: HashMap<K, C::Accumulator>,
27    group_counts: HashMap<K, usize>,
28    entity_values: HashMap<(usize, usize), Vec<(K, C::Value)>>,
29    _phantom: PhantomData<(fn() -> S, fn() -> Out, fn() -> Sc)>,
30}
31
32impl<S, Out, K, Src, F, KF, C, W, Sc> ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
33where
34    S: Send + Sync + 'static,
35    Out: Clone + Send + Sync + 'static,
36    K: Clone + Eq + Hash + Send + Sync + 'static,
37    Src: ProjectedSource<S, Out>,
38    F: UniFilter<S, Out>,
39    KF: Fn(&Out) -> K + Send + Sync,
40    C: UniCollector<Out> + Send + Sync + 'static,
41    C::Accumulator: Send + Sync,
42    C::Result: Send + Sync,
43    C::Value: Clone + Send + Sync,
44    W: Fn(&C::Result) -> Sc + Send + Sync,
45    Sc: Score + 'static,
46{
47    #[allow(clippy::too_many_arguments)]
48    pub fn new(
49        constraint_ref: ConstraintRef,
50        impact_type: ImpactType,
51        source: Src,
52        filter: F,
53        key_fn: KF,
54        collector: C,
55        weight_fn: W,
56        is_hard: bool,
57    ) -> Self {
58        Self {
59            constraint_ref,
60            impact_type,
61            source,
62            filter,
63            key_fn,
64            collector,
65            weight_fn,
66            is_hard,
67            groups: HashMap::new(),
68            group_counts: HashMap::new(),
69            entity_values: HashMap::new(),
70            _phantom: PhantomData,
71        }
72    }
73
74    fn compute_score(&self, result: &C::Result) -> Sc {
75        let base = (self.weight_fn)(result);
76        match self.impact_type {
77            ImpactType::Penalty => -base,
78            ImpactType::Reward => base,
79        }
80    }
81
82    fn retract_output(&mut self, key: &K, value: &C::Value) -> Sc {
83        let Some(acc) = self.groups.get_mut(key) else {
84            return Sc::zero();
85        };
86        let impact = self.impact_type;
87        let old_base = (self.weight_fn)(&acc.finish());
88        let old = match impact {
89            ImpactType::Penalty => -old_base,
90            ImpactType::Reward => old_base,
91        };
92
93        let is_empty = {
94            let count = self.group_counts.entry(key.clone()).or_insert(0);
95            *count = count.saturating_sub(1);
96            *count == 0
97        };
98        if is_empty {
99            self.group_counts.remove(key);
100        }
101
102        acc.retract(value);
103        let new_score = if is_empty {
104            self.groups.remove(key);
105            Sc::zero()
106        } else {
107            let new_base = (self.weight_fn)(&acc.finish());
108            match impact {
109                ImpactType::Penalty => -new_base,
110                ImpactType::Reward => new_base,
111            }
112        };
113
114        new_score - old
115    }
116
117    fn insert_entity_outputs(&mut self, solution: &S, slot: usize, entity_index: usize) -> Sc {
118        let mut total = Sc::zero();
119        let mut cached = Vec::new();
120        let source = &self.source;
121        let filter = &self.filter;
122        let key_fn = &self.key_fn;
123        let collector = &self.collector;
124        let weight_fn = &self.weight_fn;
125        let impact = self.impact_type;
126        let groups = &mut self.groups;
127        let group_counts = &mut self.group_counts;
128        source.collect_entity(solution, slot, entity_index, |_, output| {
129            if !filter.test(solution, &output) {
130                return;
131            }
132            let key = key_fn(&output);
133            let value = collector.extract(&output);
134            let is_new = !groups.contains_key(&key);
135            let acc = groups
136                .entry(key.clone())
137                .or_insert_with(|| collector.create_accumulator());
138            let old = if is_new {
139                Sc::zero()
140            } else {
141                let old_base = weight_fn(&acc.finish());
142                match impact {
143                    ImpactType::Penalty => -old_base,
144                    ImpactType::Reward => old_base,
145                }
146            };
147            acc.accumulate(&value);
148            let new_base = weight_fn(&acc.finish());
149            let new_score = match impact {
150                ImpactType::Penalty => -new_base,
151                ImpactType::Reward => new_base,
152            };
153            *group_counts.entry(key.clone()).or_insert(0) += 1;
154            cached.push((key, value));
155            total = total + (new_score - old);
156        });
157        self.entity_values.insert((slot, entity_index), cached);
158        total
159    }
160
161    fn retract_entity_outputs(&mut self, slot: usize, entity_index: usize) -> Sc {
162        let Some(cached) = self.entity_values.remove(&(slot, entity_index)) else {
163            return Sc::zero();
164        };
165        let mut total = Sc::zero();
166        for (key, value) in cached {
167            total = total + self.retract_output(&key, &value);
168        }
169        total
170    }
171
172    fn localized_slots(&self, descriptor_index: usize) -> Vec<usize> {
173        let mut slots = Vec::new();
174        for slot in 0..self.source.source_count() {
175            if self
176                .source
177                .change_source(slot)
178                .assert_localizes(descriptor_index, &self.constraint_ref.name)
179            {
180                slots.push(slot);
181            }
182        }
183        slots
184    }
185}
186
187impl<S, Out, K, Src, F, KF, C, W, Sc> IncrementalConstraint<S, Sc>
188    for ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
189where
190    S: Send + Sync + 'static,
191    Out: Clone + Send + Sync + 'static,
192    K: Clone + Eq + Hash + Send + Sync + 'static,
193    Src: ProjectedSource<S, Out>,
194    F: UniFilter<S, Out>,
195    KF: Fn(&Out) -> K + Send + Sync,
196    C: UniCollector<Out> + Send + Sync + 'static,
197    C::Accumulator: Send + Sync,
198    C::Result: Send + Sync,
199    C::Value: Clone + Send + Sync,
200    W: Fn(&C::Result) -> Sc + Send + Sync,
201    Sc: Score + 'static,
202{
203    fn evaluate(&self, solution: &S) -> Sc {
204        let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
205        self.source.collect_all(solution, |_, output| {
206            if !self.filter.test(solution, &output) {
207                return;
208            }
209            let key = (self.key_fn)(&output);
210            let value = self.collector.extract(&output);
211            groups
212                .entry(key)
213                .or_insert_with(|| self.collector.create_accumulator())
214                .accumulate(&value);
215        });
216        groups.values().fold(Sc::zero(), |total, acc| {
217            total + self.compute_score(&acc.finish())
218        })
219    }
220
221    fn match_count(&self, solution: &S) -> usize {
222        let mut keys = HashMap::<K, ()>::new();
223        self.source.collect_all(solution, |_, output| {
224            if self.filter.test(solution, &output) {
225                keys.insert((self.key_fn)(&output), ());
226            }
227        });
228        keys.len()
229    }
230
231    fn initialize(&mut self, solution: &S) -> Sc {
232        self.reset();
233        let mut total = Sc::zero();
234        let source = &self.source;
235        let filter = &self.filter;
236        let key_fn = &self.key_fn;
237        let collector = &self.collector;
238        let weight_fn = &self.weight_fn;
239        let impact = self.impact_type;
240        let groups = &mut self.groups;
241        let group_counts = &mut self.group_counts;
242        let entity_values = &mut self.entity_values;
243        source.collect_all(solution, |coordinate, output| {
244            if !filter.test(solution, &output) {
245                return;
246            }
247            let key = key_fn(&output);
248            let value = collector.extract(&output);
249            let is_new = !groups.contains_key(&key);
250            let acc = groups
251                .entry(key.clone())
252                .or_insert_with(|| collector.create_accumulator());
253            let old = if is_new {
254                Sc::zero()
255            } else {
256                let old_base = weight_fn(&acc.finish());
257                match impact {
258                    ImpactType::Penalty => -old_base,
259                    ImpactType::Reward => old_base,
260                }
261            };
262            acc.accumulate(&value);
263            let new_base = weight_fn(&acc.finish());
264            let new_score = match impact {
265                ImpactType::Penalty => -new_base,
266                ImpactType::Reward => new_base,
267            };
268            *group_counts.entry(key.clone()).or_insert(0) += 1;
269            entity_values
270                .entry((coordinate.source_slot, coordinate.entity_index))
271                .or_default()
272                .push((key, value));
273            total = total + (new_score - old);
274        });
275        total
276    }
277
278    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
279        let mut total = Sc::zero();
280        for slot in self.localized_slots(descriptor_index) {
281            total = total + self.insert_entity_outputs(solution, slot, entity_index);
282        }
283        total
284    }
285
286    fn on_retract(&mut self, _solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
287        let mut total = Sc::zero();
288        for slot in self.localized_slots(descriptor_index) {
289            total = total + self.retract_entity_outputs(slot, entity_index);
290        }
291        total
292    }
293
294    fn reset(&mut self) {
295        self.groups.clear();
296        self.group_counts.clear();
297        self.entity_values.clear();
298    }
299
300    fn name(&self) -> &str {
301        &self.constraint_ref.name
302    }
303
304    fn is_hard(&self) -> bool {
305        self.is_hard
306    }
307
308    fn constraint_ref(&self) -> ConstraintRef {
309        self.constraint_ref.clone()
310    }
311}