Skip to main content

solverforge_scoring/constraint/projected/
grouped.rs

1use std::collections::{hash_map::Entry, HashMap, HashSet};
2use std::hash::Hash;
3use std::marker::PhantomData;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::collector::{Accumulator, Collector};
10use crate::stream::filter::UniFilter;
11use crate::stream::{ProjectedRowCoordinate, ProjectedRowOwner, ProjectedSource};
12
13struct GroupState<Acc> {
14    accumulator: Acc,
15    count: usize,
16}
17
18type CollectorRetraction<Acc, V, R> = <Acc as Accumulator<V, R>>::Retraction;
19
20pub struct ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, V, R, Acc, W, Sc>
21where
22    Src: ProjectedSource<S, Out>,
23    Acc: Accumulator<V, R>,
24    Sc: Score,
25{
26    constraint_ref: ConstraintRef,
27    impact_type: ImpactType,
28    source: Src,
29    filter: F,
30    key_fn: KF,
31    collector: C,
32    weight_fn: W,
33    is_hard: bool,
34    source_state: Option<Src::State>,
35    groups: HashMap<K, GroupState<Acc>>,
36    row_outputs: HashMap<ProjectedRowCoordinate, Out>,
37    row_retractions: HashMap<ProjectedRowCoordinate, CollectorRetraction<Acc, V, R>>,
38    rows_by_owner: HashMap<ProjectedRowOwner, Vec<ProjectedRowCoordinate>>,
39    _phantom: PhantomData<(
40        fn() -> S,
41        fn() -> Out,
42        fn() -> V,
43        fn() -> R,
44        fn() -> Acc,
45        fn() -> Sc,
46    )>,
47}
48
49impl<S, Out, K, Src, F, KF, C, V, R, Acc, W, Sc>
50    ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, V, R, Acc, W, Sc>
51where
52    S: Send + Sync + 'static,
53    Out: Send + Sync + 'static,
54    K: Eq + Hash + Send + Sync + 'static,
55    Src: ProjectedSource<S, Out>,
56    F: UniFilter<S, Out>,
57    KF: Fn(&Out) -> K + Send + Sync,
58    C: for<'i> Collector<&'i Out, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
59    V: Send + Sync,
60    R: Send + Sync,
61    Acc: Accumulator<V, R> + Send + Sync,
62    W: Fn(&K, &R) -> Sc + Send + Sync,
63    Sc: Score + 'static,
64{
65    #[allow(clippy::too_many_arguments)]
66    pub fn new(
67        constraint_ref: ConstraintRef,
68        impact_type: ImpactType,
69        source: Src,
70        filter: F,
71        key_fn: KF,
72        collector: C,
73        weight_fn: W,
74        is_hard: bool,
75    ) -> Self {
76        Self {
77            constraint_ref,
78            impact_type,
79            source,
80            filter,
81            key_fn,
82            collector,
83            weight_fn,
84            is_hard,
85            source_state: None,
86            groups: HashMap::new(),
87            row_outputs: HashMap::new(),
88            row_retractions: HashMap::new(),
89            rows_by_owner: HashMap::new(),
90            _phantom: PhantomData,
91        }
92    }
93
94    fn compute_score(&self, key: &K, result: &R) -> Sc {
95        let base = (self.weight_fn)(key, result);
96        match self.impact_type {
97            ImpactType::Penalty => -base,
98            ImpactType::Reward => base,
99        }
100    }
101
102    fn ensure_source_state(&mut self, solution: &S) {
103        if self.source_state.is_none() {
104            self.source_state = Some(self.source.build_state(solution));
105        }
106    }
107
108    fn index_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
109        coordinate.for_each_owner(|owner| {
110            self.rows_by_owner
111                .entry(owner)
112                .or_default()
113                .push(coordinate);
114        });
115    }
116
117    fn unindex_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
118        coordinate.for_each_owner(|owner| {
119            let mut remove_bucket = false;
120            if let Some(rows) = self.rows_by_owner.get_mut(&owner) {
121                rows.retain(|candidate| *candidate != coordinate);
122                remove_bucket = rows.is_empty();
123            }
124            if remove_bucket {
125                self.rows_by_owner.remove(&owner);
126            }
127        });
128    }
129
130    fn insert_value(&mut self, key: K, value: V) -> (Sc, CollectorRetraction<Acc, V, R>) {
131        let impact = self.impact_type;
132        let weight_fn = &self.weight_fn;
133        match self.groups.entry(key) {
134            Entry::Occupied(mut entry) => {
135                let old_base = entry
136                    .get()
137                    .accumulator
138                    .with_result(|result| weight_fn(entry.key(), result));
139                let old = match impact {
140                    ImpactType::Penalty => -old_base,
141                    ImpactType::Reward => old_base,
142                };
143                let group = entry.get_mut();
144                let retraction = group.accumulator.accumulate(value);
145                group.count += 1;
146                let new_base = entry
147                    .get()
148                    .accumulator
149                    .with_result(|result| weight_fn(entry.key(), result));
150                let new_score = match impact {
151                    ImpactType::Penalty => -new_base,
152                    ImpactType::Reward => new_base,
153                };
154                (new_score - old, retraction)
155            }
156            Entry::Vacant(entry) => {
157                let mut entry = entry.insert_entry(GroupState {
158                    accumulator: self.collector.create_accumulator(),
159                    count: 0,
160                });
161                let group = entry.get_mut();
162                let retraction = group.accumulator.accumulate(value);
163                group.count += 1;
164                let new_base = entry
165                    .get()
166                    .accumulator
167                    .with_result(|result| weight_fn(entry.key(), result));
168                let score = match impact {
169                    ImpactType::Penalty => -new_base,
170                    ImpactType::Reward => new_base,
171                };
172                (score, retraction)
173            }
174        }
175    }
176
177    fn retract_value(&mut self, key: K, retraction: CollectorRetraction<Acc, V, R>) -> Sc {
178        let impact = self.impact_type;
179        let weight_fn = &self.weight_fn;
180        let Entry::Occupied(mut entry) = self.groups.entry(key) else {
181            return Sc::zero();
182        };
183        let old_base = entry
184            .get()
185            .accumulator
186            .with_result(|result| weight_fn(entry.key(), result));
187        let old = match impact {
188            ImpactType::Penalty => -old_base,
189            ImpactType::Reward => old_base,
190        };
191        let group = entry.get_mut();
192        group.accumulator.retract(retraction);
193        group.count = group.count.saturating_sub(1);
194        let new_score = if group.count == 0 {
195            entry.remove();
196            Sc::zero()
197        } else {
198            let new_base = entry
199                .get()
200                .accumulator
201                .with_result(|result| weight_fn(entry.key(), result));
202            match impact {
203                ImpactType::Penalty => -new_base,
204                ImpactType::Reward => new_base,
205            }
206        };
207
208        new_score - old
209    }
210
211    fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
212        if self.row_outputs.contains_key(&coordinate) || !self.filter.test(solution, &output) {
213            return Sc::zero();
214        }
215        let key = (self.key_fn)(&output);
216        let value = self.collector.extract(&output);
217        let (delta, retraction) = self.insert_value(key, value);
218        self.row_outputs.insert(coordinate, output);
219        self.row_retractions.insert(coordinate, retraction);
220        self.index_coordinate(coordinate);
221        delta
222    }
223
224    fn retract_row(&mut self, coordinate: ProjectedRowCoordinate) -> Sc {
225        let Some(output) = self.row_outputs.remove(&coordinate) else {
226            return Sc::zero();
227        };
228        self.unindex_coordinate(coordinate);
229        let key = (self.key_fn)(&output);
230        let Some(retraction) = self.row_retractions.remove(&coordinate) else {
231            return Sc::zero();
232        };
233        self.retract_value(key, retraction)
234    }
235
236    fn localized_owners(
237        &self,
238        descriptor_index: usize,
239        entity_index: usize,
240    ) -> Vec<ProjectedRowOwner> {
241        let mut owners = Vec::new();
242        for slot in 0..self.source.source_count() {
243            if self
244                .source
245                .change_source(slot)
246                .assert_localizes(descriptor_index, &self.constraint_ref.name)
247            {
248                owners.push(ProjectedRowOwner {
249                    source_slot: slot,
250                    entity_index,
251                });
252            }
253        }
254        owners
255    }
256
257    fn coordinates_for_owners(&self, owners: &[ProjectedRowOwner]) -> Vec<ProjectedRowCoordinate> {
258        let mut seen = HashSet::new();
259        let mut coordinates = Vec::new();
260        for owner in owners {
261            let Some(rows) = self.rows_by_owner.get(owner) else {
262                continue;
263            };
264            for &coordinate in rows {
265                if seen.insert(coordinate) {
266                    coordinates.push(coordinate);
267                }
268            }
269        }
270        coordinates
271    }
272}
273
274impl<S, Out, K, Src, F, KF, C, V, R, Acc, W, Sc> IncrementalConstraint<S, Sc>
275    for ProjectedGroupedConstraint<S, Out, K, Src, F, KF, C, V, R, Acc, W, Sc>
276where
277    S: Send + Sync + 'static,
278    Out: Send + Sync + 'static,
279    K: Eq + Hash + Send + Sync + 'static,
280    Src: ProjectedSource<S, Out>,
281    F: UniFilter<S, Out>,
282    KF: Fn(&Out) -> K + Send + Sync,
283    C: for<'i> Collector<&'i Out, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
284    V: Send + Sync,
285    R: Send + Sync,
286    Acc: Accumulator<V, R> + Send + Sync,
287    W: Fn(&K, &R) -> Sc + Send + Sync,
288    Sc: Score + 'static,
289{
290    fn evaluate(&self, solution: &S) -> Sc {
291        let state = self.source.build_state(solution);
292        let mut groups: HashMap<K, Acc> = HashMap::new();
293        self.source.collect_all(solution, &state, |_, output| {
294            if !self.filter.test(solution, &output) {
295                return;
296            }
297            let key = (self.key_fn)(&output);
298            let value = self.collector.extract(&output);
299            groups
300                .entry(key)
301                .or_insert_with(|| self.collector.create_accumulator())
302                .accumulate(value);
303        });
304        groups.iter().fold(Sc::zero(), |total, (key, acc)| {
305            total + acc.with_result(|result| self.compute_score(key, result))
306        })
307    }
308
309    fn match_count(&self, solution: &S) -> usize {
310        let state = self.source.build_state(solution);
311        let mut keys = HashMap::<K, ()>::new();
312        self.source.collect_all(solution, &state, |_, output| {
313            if self.filter.test(solution, &output) {
314                keys.insert((self.key_fn)(&output), ());
315            }
316        });
317        keys.len()
318    }
319
320    fn initialize(&mut self, solution: &S) -> Sc {
321        self.reset();
322        let state = self.source.build_state(solution);
323        let mut total = Sc::zero();
324        let mut rows = Vec::new();
325        self.source
326            .collect_all(solution, &state, |coordinate, output| {
327                rows.push((coordinate, output));
328            });
329        self.source_state = Some(state);
330        for (coordinate, output) in rows {
331            total = total + self.insert_row(solution, coordinate, output);
332        }
333        total
334    }
335
336    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
337        let owners = self.localized_owners(descriptor_index, entity_index);
338        self.ensure_source_state(solution);
339        {
340            let state = self.source_state.as_mut().expect("projected source state");
341            for owner in &owners {
342                self.source.insert_entity_state(
343                    solution,
344                    state,
345                    owner.source_slot,
346                    owner.entity_index,
347                );
348            }
349        }
350        let mut rows = Vec::new();
351        let state = self.source_state.as_ref().expect("projected source state");
352        for owner in &owners {
353            self.source.collect_entity(
354                solution,
355                state,
356                owner.source_slot,
357                owner.entity_index,
358                |coordinate, output| rows.push((coordinate, output)),
359            );
360        }
361        let mut total = Sc::zero();
362        for (coordinate, output) in rows {
363            total = total + self.insert_row(solution, coordinate, output);
364        }
365        total
366    }
367
368    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
369        let owners = self.localized_owners(descriptor_index, entity_index);
370        let mut total = Sc::zero();
371        for coordinate in self.coordinates_for_owners(&owners) {
372            total = total + self.retract_row(coordinate);
373        }
374        if let Some(state) = self.source_state.as_mut() {
375            for owner in &owners {
376                self.source.retract_entity_state(
377                    solution,
378                    state,
379                    owner.source_slot,
380                    owner.entity_index,
381                );
382            }
383        }
384        total
385    }
386
387    fn reset(&mut self) {
388        self.source_state = None;
389        self.groups.clear();
390        self.row_outputs.clear();
391        self.row_retractions.clear();
392        self.rows_by_owner.clear();
393    }
394
395    fn name(&self) -> &str {
396        &self.constraint_ref.name
397    }
398
399    fn constraint_ref(&self) -> &ConstraintRef {
400        &self.constraint_ref
401    }
402
403    fn is_hard(&self) -> bool {
404        self.is_hard
405    }
406}