Skip to main content

solverforge_scoring/constraint/projected/
grouped.rs

1use std::collections::{hash_map::Entry, HashMap, HashSet};
2use std::hash::Hash;
3use std::marker::PhantomData;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::collector::{Accumulator, UniCollector};
10use crate::stream::filter::UniFilter;
11use crate::stream::{ProjectedRowCoordinate, ProjectedRowOwner, ProjectedSource};
12
13struct GroupState<Acc> {
14    accumulator: Acc,
15    count: usize,
16}
17
18pub struct ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
19where
20    Src: ProjectedSource<S, Out>,
21    C: UniCollector<Out>,
22    Sc: Score,
23{
24    constraint_ref: ConstraintRef,
25    impact_type: ImpactType,
26    source: Src,
27    filter: F,
28    key_fn: KF,
29    collector: C,
30    weight_fn: W,
31    is_hard: bool,
32    source_state: Option<Src::State>,
33    groups: HashMap<K, GroupState<C::Accumulator>>,
34    row_outputs: HashMap<ProjectedRowCoordinate, Out>,
35    rows_by_owner: HashMap<ProjectedRowOwner, Vec<ProjectedRowCoordinate>>,
36    _phantom: PhantomData<(fn() -> S, fn() -> Out, fn() -> Sc)>,
37}
38
39impl<S, Out, K, Src, F, KF, C, W, Sc> ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
40where
41    S: Send + Sync + 'static,
42    Out: Send + Sync + 'static,
43    K: Eq + Hash + Send + Sync + 'static,
44    Src: ProjectedSource<S, Out>,
45    F: UniFilter<S, Out>,
46    KF: Fn(&Out) -> K + Send + Sync,
47    C: UniCollector<Out> + Send + Sync + 'static,
48    C::Accumulator: Send + Sync,
49    C::Result: Send + Sync,
50    C::Value: Send + Sync,
51    W: Fn(&K, &C::Result) -> Sc + Send + Sync,
52    Sc: Score + 'static,
53{
54    #[allow(clippy::too_many_arguments)]
55    pub fn new(
56        constraint_ref: ConstraintRef,
57        impact_type: ImpactType,
58        source: Src,
59        filter: F,
60        key_fn: KF,
61        collector: C,
62        weight_fn: W,
63        is_hard: bool,
64    ) -> Self {
65        Self {
66            constraint_ref,
67            impact_type,
68            source,
69            filter,
70            key_fn,
71            collector,
72            weight_fn,
73            is_hard,
74            source_state: None,
75            groups: HashMap::new(),
76            row_outputs: HashMap::new(),
77            rows_by_owner: HashMap::new(),
78            _phantom: PhantomData,
79        }
80    }
81
82    fn compute_score(&self, key: &K, result: &C::Result) -> Sc {
83        let base = (self.weight_fn)(key, result);
84        match self.impact_type {
85            ImpactType::Penalty => -base,
86            ImpactType::Reward => base,
87        }
88    }
89
90    fn ensure_source_state(&mut self, solution: &S) {
91        if self.source_state.is_none() {
92            self.source_state = Some(self.source.build_state(solution));
93        }
94    }
95
96    fn index_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
97        coordinate.for_each_owner(|owner| {
98            self.rows_by_owner
99                .entry(owner)
100                .or_default()
101                .push(coordinate);
102        });
103    }
104
105    fn unindex_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
106        coordinate.for_each_owner(|owner| {
107            let mut remove_bucket = false;
108            if let Some(rows) = self.rows_by_owner.get_mut(&owner) {
109                rows.retain(|candidate| *candidate != coordinate);
110                remove_bucket = rows.is_empty();
111            }
112            if remove_bucket {
113                self.rows_by_owner.remove(&owner);
114            }
115        });
116    }
117
118    fn insert_value(&mut self, key: K, value: &C::Value) -> Sc {
119        let impact = self.impact_type;
120        let weight_fn = &self.weight_fn;
121        match self.groups.entry(key) {
122            Entry::Occupied(mut entry) => {
123                let old_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
124                let old = match impact {
125                    ImpactType::Penalty => -old_base,
126                    ImpactType::Reward => old_base,
127                };
128                let group = entry.get_mut();
129                group.accumulator.accumulate(value);
130                group.count += 1;
131                let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
132                let new_score = match impact {
133                    ImpactType::Penalty => -new_base,
134                    ImpactType::Reward => new_base,
135                };
136                new_score - old
137            }
138            Entry::Vacant(entry) => {
139                let mut entry = entry.insert_entry(GroupState {
140                    accumulator: self.collector.create_accumulator(),
141                    count: 0,
142                });
143                let group = entry.get_mut();
144                group.accumulator.accumulate(value);
145                group.count += 1;
146                let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
147                match impact {
148                    ImpactType::Penalty => -new_base,
149                    ImpactType::Reward => new_base,
150                }
151            }
152        }
153    }
154
155    fn retract_value(&mut self, key: K, value: &C::Value) -> Sc {
156        let impact = self.impact_type;
157        let weight_fn = &self.weight_fn;
158        let Entry::Occupied(mut entry) = self.groups.entry(key) else {
159            return Sc::zero();
160        };
161        let old_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
162        let old = match impact {
163            ImpactType::Penalty => -old_base,
164            ImpactType::Reward => old_base,
165        };
166        let group = entry.get_mut();
167        group.accumulator.retract(value);
168        group.count = group.count.saturating_sub(1);
169        let new_score = if group.count == 0 {
170            entry.remove();
171            Sc::zero()
172        } else {
173            let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
174            match impact {
175                ImpactType::Penalty => -new_base,
176                ImpactType::Reward => new_base,
177            }
178        };
179
180        new_score - old
181    }
182
183    fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
184        if self.row_outputs.contains_key(&coordinate) || !self.filter.test(solution, &output) {
185            return Sc::zero();
186        }
187        let key = (self.key_fn)(&output);
188        let value = self.collector.extract(&output);
189        let delta = self.insert_value(key, &value);
190        self.row_outputs.insert(coordinate, output);
191        self.index_coordinate(coordinate);
192        delta
193    }
194
195    fn retract_row(&mut self, coordinate: ProjectedRowCoordinate) -> Sc {
196        let Some(output) = self.row_outputs.remove(&coordinate) else {
197            return Sc::zero();
198        };
199        self.unindex_coordinate(coordinate);
200        let key = (self.key_fn)(&output);
201        let value = self.collector.extract(&output);
202        self.retract_value(key, &value)
203    }
204
205    fn localized_owners(
206        &self,
207        descriptor_index: usize,
208        entity_index: usize,
209    ) -> Vec<ProjectedRowOwner> {
210        let mut owners = Vec::new();
211        for slot in 0..self.source.source_count() {
212            if self
213                .source
214                .change_source(slot)
215                .assert_localizes(descriptor_index, &self.constraint_ref.name)
216            {
217                owners.push(ProjectedRowOwner {
218                    source_slot: slot,
219                    entity_index,
220                });
221            }
222        }
223        owners
224    }
225
226    fn coordinates_for_owners(&self, owners: &[ProjectedRowOwner]) -> Vec<ProjectedRowCoordinate> {
227        let mut seen = HashSet::new();
228        let mut coordinates = Vec::new();
229        for owner in owners {
230            let Some(rows) = self.rows_by_owner.get(owner) else {
231                continue;
232            };
233            for &coordinate in rows {
234                if seen.insert(coordinate) {
235                    coordinates.push(coordinate);
236                }
237            }
238        }
239        coordinates
240    }
241}
242
243impl<S, Out, K, Src, F, KF, C, W, Sc> IncrementalConstraint<S, Sc>
244    for ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
245where
246    S: Send + Sync + 'static,
247    Out: Send + Sync + 'static,
248    K: Eq + Hash + Send + Sync + 'static,
249    Src: ProjectedSource<S, Out>,
250    F: UniFilter<S, Out>,
251    KF: Fn(&Out) -> K + Send + Sync,
252    C: UniCollector<Out> + Send + Sync + 'static,
253    C::Accumulator: Send + Sync,
254    C::Result: Send + Sync,
255    C::Value: Send + Sync,
256    W: Fn(&K, &C::Result) -> Sc + Send + Sync,
257    Sc: Score + 'static,
258{
259    fn evaluate(&self, solution: &S) -> Sc {
260        let state = self.source.build_state(solution);
261        let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
262        self.source.collect_all(solution, &state, |_, output| {
263            if !self.filter.test(solution, &output) {
264                return;
265            }
266            let key = (self.key_fn)(&output);
267            let value = self.collector.extract(&output);
268            groups
269                .entry(key)
270                .or_insert_with(|| self.collector.create_accumulator())
271                .accumulate(&value);
272        });
273        groups.iter().fold(Sc::zero(), |total, (key, acc)| {
274            total + self.compute_score(key, &acc.finish())
275        })
276    }
277
278    fn match_count(&self, solution: &S) -> usize {
279        let state = self.source.build_state(solution);
280        let mut keys = HashMap::<K, ()>::new();
281        self.source.collect_all(solution, &state, |_, output| {
282            if self.filter.test(solution, &output) {
283                keys.insert((self.key_fn)(&output), ());
284            }
285        });
286        keys.len()
287    }
288
289    fn initialize(&mut self, solution: &S) -> Sc {
290        self.reset();
291        let state = self.source.build_state(solution);
292        let mut total = Sc::zero();
293        let mut rows = Vec::new();
294        self.source
295            .collect_all(solution, &state, |coordinate, output| {
296                rows.push((coordinate, output));
297            });
298        self.source_state = Some(state);
299        for (coordinate, output) in rows {
300            total = total + self.insert_row(solution, coordinate, output);
301        }
302        total
303    }
304
305    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
306        let owners = self.localized_owners(descriptor_index, entity_index);
307        self.ensure_source_state(solution);
308        {
309            let state = self.source_state.as_mut().expect("projected source state");
310            for owner in &owners {
311                self.source.insert_entity_state(
312                    solution,
313                    state,
314                    owner.source_slot,
315                    owner.entity_index,
316                );
317            }
318        }
319        let mut rows = Vec::new();
320        let state = self.source_state.as_ref().expect("projected source state");
321        for owner in &owners {
322            self.source.collect_entity(
323                solution,
324                state,
325                owner.source_slot,
326                owner.entity_index,
327                |coordinate, output| rows.push((coordinate, output)),
328            );
329        }
330        let mut total = Sc::zero();
331        for (coordinate, output) in rows {
332            total = total + self.insert_row(solution, coordinate, output);
333        }
334        total
335    }
336
337    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
338        let owners = self.localized_owners(descriptor_index, entity_index);
339        let mut total = Sc::zero();
340        for coordinate in self.coordinates_for_owners(&owners) {
341            total = total + self.retract_row(coordinate);
342        }
343        if let Some(state) = self.source_state.as_mut() {
344            for owner in &owners {
345                self.source.retract_entity_state(
346                    solution,
347                    state,
348                    owner.source_slot,
349                    owner.entity_index,
350                );
351            }
352        }
353        total
354    }
355
356    fn reset(&mut self) {
357        self.source_state = None;
358        self.groups.clear();
359        self.row_outputs.clear();
360        self.rows_by_owner.clear();
361    }
362
363    fn name(&self) -> &str {
364        &self.constraint_ref.name
365    }
366
367    fn constraint_ref(&self) -> &ConstraintRef {
368        &self.constraint_ref
369    }
370
371    fn is_hard(&self) -> bool {
372        self.is_hard
373    }
374}