Skip to main content

solverforge_scoring/constraint/projected/
grouped.rs

1use std::collections::{hash_map::Entry, HashMap, HashSet};
2use std::hash::Hash;
3use std::marker::PhantomData;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::collector::{Accumulator, UniCollector};
10use crate::stream::filter::UniFilter;
11use crate::stream::{ProjectedRowCoordinate, ProjectedRowOwner, ProjectedSource};
12
13struct GroupState<Acc> {
14    accumulator: Acc,
15    count: usize,
16}
17
18type CollectorRetraction<C, Out> = <<C as UniCollector<Out>>::Accumulator as Accumulator<
19    <C as UniCollector<Out>>::Value,
20    <C as UniCollector<Out>>::Result,
21>>::Retraction;
22
23pub struct ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
24where
25    Src: ProjectedSource<S, Out>,
26    C: UniCollector<Out>,
27    Sc: Score,
28{
29    constraint_ref: ConstraintRef,
30    impact_type: ImpactType,
31    source: Src,
32    filter: F,
33    key_fn: KF,
34    collector: C,
35    weight_fn: W,
36    is_hard: bool,
37    source_state: Option<Src::State>,
38    groups: HashMap<K, GroupState<C::Accumulator>>,
39    row_outputs: HashMap<ProjectedRowCoordinate, Out>,
40    row_retractions: HashMap<ProjectedRowCoordinate, CollectorRetraction<C, Out>>,
41    rows_by_owner: HashMap<ProjectedRowOwner, Vec<ProjectedRowCoordinate>>,
42    _phantom: PhantomData<(fn() -> S, fn() -> Out, fn() -> Sc)>,
43}
44
45impl<S, Out, K, Src, F, KF, C, W, Sc> ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
46where
47    S: Send + Sync + 'static,
48    Out: Send + Sync + 'static,
49    K: Eq + Hash + Send + Sync + 'static,
50    Src: ProjectedSource<S, Out>,
51    F: UniFilter<S, Out>,
52    KF: Fn(&Out) -> K + Send + Sync,
53    C: UniCollector<Out> + Send + Sync + 'static,
54    C::Accumulator: Send + Sync,
55    C::Result: Send + Sync,
56    C::Value: Send + Sync,
57    W: Fn(&K, &C::Result) -> Sc + Send + Sync,
58    Sc: Score + 'static,
59{
60    #[allow(clippy::too_many_arguments)]
61    pub fn new(
62        constraint_ref: ConstraintRef,
63        impact_type: ImpactType,
64        source: Src,
65        filter: F,
66        key_fn: KF,
67        collector: C,
68        weight_fn: W,
69        is_hard: bool,
70    ) -> Self {
71        Self {
72            constraint_ref,
73            impact_type,
74            source,
75            filter,
76            key_fn,
77            collector,
78            weight_fn,
79            is_hard,
80            source_state: None,
81            groups: HashMap::new(),
82            row_outputs: HashMap::new(),
83            row_retractions: HashMap::new(),
84            rows_by_owner: HashMap::new(),
85            _phantom: PhantomData,
86        }
87    }
88
89    fn compute_score(&self, key: &K, result: &C::Result) -> Sc {
90        let base = (self.weight_fn)(key, result);
91        match self.impact_type {
92            ImpactType::Penalty => -base,
93            ImpactType::Reward => base,
94        }
95    }
96
97    fn ensure_source_state(&mut self, solution: &S) {
98        if self.source_state.is_none() {
99            self.source_state = Some(self.source.build_state(solution));
100        }
101    }
102
103    fn index_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
104        coordinate.for_each_owner(|owner| {
105            self.rows_by_owner
106                .entry(owner)
107                .or_default()
108                .push(coordinate);
109        });
110    }
111
112    fn unindex_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
113        coordinate.for_each_owner(|owner| {
114            let mut remove_bucket = false;
115            if let Some(rows) = self.rows_by_owner.get_mut(&owner) {
116                rows.retain(|candidate| *candidate != coordinate);
117                remove_bucket = rows.is_empty();
118            }
119            if remove_bucket {
120                self.rows_by_owner.remove(&owner);
121            }
122        });
123    }
124
125    fn insert_value(&mut self, key: K, value: C::Value) -> (Sc, CollectorRetraction<C, Out>) {
126        let impact = self.impact_type;
127        let weight_fn = &self.weight_fn;
128        match self.groups.entry(key) {
129            Entry::Occupied(mut entry) => {
130                let old_base = entry
131                    .get()
132                    .accumulator
133                    .with_result(|result| weight_fn(entry.key(), result));
134                let old = match impact {
135                    ImpactType::Penalty => -old_base,
136                    ImpactType::Reward => old_base,
137                };
138                let group = entry.get_mut();
139                let retraction = group.accumulator.accumulate(value);
140                group.count += 1;
141                let new_base = entry
142                    .get()
143                    .accumulator
144                    .with_result(|result| weight_fn(entry.key(), result));
145                let new_score = match impact {
146                    ImpactType::Penalty => -new_base,
147                    ImpactType::Reward => new_base,
148                };
149                (new_score - old, retraction)
150            }
151            Entry::Vacant(entry) => {
152                let mut entry = entry.insert_entry(GroupState {
153                    accumulator: self.collector.create_accumulator(),
154                    count: 0,
155                });
156                let group = entry.get_mut();
157                let retraction = group.accumulator.accumulate(value);
158                group.count += 1;
159                let new_base = entry
160                    .get()
161                    .accumulator
162                    .with_result(|result| weight_fn(entry.key(), result));
163                let score = match impact {
164                    ImpactType::Penalty => -new_base,
165                    ImpactType::Reward => new_base,
166                };
167                (score, retraction)
168            }
169        }
170    }
171
172    fn retract_value(&mut self, key: K, retraction: CollectorRetraction<C, Out>) -> Sc {
173        let impact = self.impact_type;
174        let weight_fn = &self.weight_fn;
175        let Entry::Occupied(mut entry) = self.groups.entry(key) else {
176            return Sc::zero();
177        };
178        let old_base = entry
179            .get()
180            .accumulator
181            .with_result(|result| weight_fn(entry.key(), result));
182        let old = match impact {
183            ImpactType::Penalty => -old_base,
184            ImpactType::Reward => old_base,
185        };
186        let group = entry.get_mut();
187        group.accumulator.retract(retraction);
188        group.count = group.count.saturating_sub(1);
189        let new_score = if group.count == 0 {
190            entry.remove();
191            Sc::zero()
192        } else {
193            let new_base = entry
194                .get()
195                .accumulator
196                .with_result(|result| weight_fn(entry.key(), result));
197            match impact {
198                ImpactType::Penalty => -new_base,
199                ImpactType::Reward => new_base,
200            }
201        };
202
203        new_score - old
204    }
205
206    fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
207        if self.row_outputs.contains_key(&coordinate) || !self.filter.test(solution, &output) {
208            return Sc::zero();
209        }
210        let key = (self.key_fn)(&output);
211        let value = self.collector.extract(&output);
212        let (delta, retraction) = self.insert_value(key, value);
213        self.row_outputs.insert(coordinate, output);
214        self.row_retractions.insert(coordinate, retraction);
215        self.index_coordinate(coordinate);
216        delta
217    }
218
219    fn retract_row(&mut self, coordinate: ProjectedRowCoordinate) -> Sc {
220        let Some(output) = self.row_outputs.remove(&coordinate) else {
221            return Sc::zero();
222        };
223        self.unindex_coordinate(coordinate);
224        let key = (self.key_fn)(&output);
225        let Some(retraction) = self.row_retractions.remove(&coordinate) else {
226            return Sc::zero();
227        };
228        self.retract_value(key, retraction)
229    }
230
231    fn localized_owners(
232        &self,
233        descriptor_index: usize,
234        entity_index: usize,
235    ) -> Vec<ProjectedRowOwner> {
236        let mut owners = Vec::new();
237        for slot in 0..self.source.source_count() {
238            if self
239                .source
240                .change_source(slot)
241                .assert_localizes(descriptor_index, &self.constraint_ref.name)
242            {
243                owners.push(ProjectedRowOwner {
244                    source_slot: slot,
245                    entity_index,
246                });
247            }
248        }
249        owners
250    }
251
252    fn coordinates_for_owners(&self, owners: &[ProjectedRowOwner]) -> Vec<ProjectedRowCoordinate> {
253        let mut seen = HashSet::new();
254        let mut coordinates = Vec::new();
255        for owner in owners {
256            let Some(rows) = self.rows_by_owner.get(owner) else {
257                continue;
258            };
259            for &coordinate in rows {
260                if seen.insert(coordinate) {
261                    coordinates.push(coordinate);
262                }
263            }
264        }
265        coordinates
266    }
267}
268
269impl<S, Out, K, Src, F, KF, C, W, Sc> IncrementalConstraint<S, Sc>
270    for ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, W, Sc>
271where
272    S: Send + Sync + 'static,
273    Out: Send + Sync + 'static,
274    K: Eq + Hash + Send + Sync + 'static,
275    Src: ProjectedSource<S, Out>,
276    F: UniFilter<S, Out>,
277    KF: Fn(&Out) -> K + Send + Sync,
278    C: UniCollector<Out> + Send + Sync + 'static,
279    C::Accumulator: Send + Sync,
280    C::Result: Send + Sync,
281    C::Value: Send + Sync,
282    W: Fn(&K, &C::Result) -> Sc + Send + Sync,
283    Sc: Score + 'static,
284{
285    fn evaluate(&self, solution: &S) -> Sc {
286        let state = self.source.build_state(solution);
287        let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
288        self.source.collect_all(solution, &state, |_, output| {
289            if !self.filter.test(solution, &output) {
290                return;
291            }
292            let key = (self.key_fn)(&output);
293            let value = self.collector.extract(&output);
294            groups
295                .entry(key)
296                .or_insert_with(|| self.collector.create_accumulator())
297                .accumulate(value);
298        });
299        groups.iter().fold(Sc::zero(), |total, (key, acc)| {
300            total + acc.with_result(|result| self.compute_score(key, result))
301        })
302    }
303
304    fn match_count(&self, solution: &S) -> usize {
305        let state = self.source.build_state(solution);
306        let mut keys = HashMap::<K, ()>::new();
307        self.source.collect_all(solution, &state, |_, output| {
308            if self.filter.test(solution, &output) {
309                keys.insert((self.key_fn)(&output), ());
310            }
311        });
312        keys.len()
313    }
314
315    fn initialize(&mut self, solution: &S) -> Sc {
316        self.reset();
317        let state = self.source.build_state(solution);
318        let mut total = Sc::zero();
319        let mut rows = Vec::new();
320        self.source
321            .collect_all(solution, &state, |coordinate, output| {
322                rows.push((coordinate, output));
323            });
324        self.source_state = Some(state);
325        for (coordinate, output) in rows {
326            total = total + self.insert_row(solution, coordinate, output);
327        }
328        total
329    }
330
331    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
332        let owners = self.localized_owners(descriptor_index, entity_index);
333        self.ensure_source_state(solution);
334        {
335            let state = self.source_state.as_mut().expect("projected source state");
336            for owner in &owners {
337                self.source.insert_entity_state(
338                    solution,
339                    state,
340                    owner.source_slot,
341                    owner.entity_index,
342                );
343            }
344        }
345        let mut rows = Vec::new();
346        let state = self.source_state.as_ref().expect("projected source state");
347        for owner in &owners {
348            self.source.collect_entity(
349                solution,
350                state,
351                owner.source_slot,
352                owner.entity_index,
353                |coordinate, output| rows.push((coordinate, output)),
354            );
355        }
356        let mut total = Sc::zero();
357        for (coordinate, output) in rows {
358            total = total + self.insert_row(solution, coordinate, output);
359        }
360        total
361    }
362
363    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
364        let owners = self.localized_owners(descriptor_index, entity_index);
365        let mut total = Sc::zero();
366        for coordinate in self.coordinates_for_owners(&owners) {
367            total = total + self.retract_row(coordinate);
368        }
369        if let Some(state) = self.source_state.as_mut() {
370            for owner in &owners {
371                self.source.retract_entity_state(
372                    solution,
373                    state,
374                    owner.source_slot,
375                    owner.entity_index,
376                );
377            }
378        }
379        total
380    }
381
382    fn reset(&mut self) {
383        self.source_state = None;
384        self.groups.clear();
385        self.row_outputs.clear();
386        self.row_retractions.clear();
387        self.rows_by_owner.clear();
388    }
389
390    fn name(&self) -> &str {
391        &self.constraint_ref.name
392    }
393
394    fn constraint_ref(&self) -> &ConstraintRef {
395        &self.constraint_ref
396    }
397
398    fn is_hard(&self) -> bool {
399        self.is_hard
400    }
401}