Skip to main content

solverforge_scoring/constraint/projected/
uni.rs

1use std::collections::{HashMap, HashSet};
2use std::marker::PhantomData;
3
4use solverforge_core::score::Score;
5use solverforge_core::{ConstraintRef, ImpactType};
6
7use crate::api::constraint_set::IncrementalConstraint;
8use crate::stream::filter::UniFilter;
9use crate::stream::{ProjectedRowCoordinate, ProjectedRowOwner, ProjectedSource};
10
11pub struct ProjectedUniConstraint<S, Out, Src, F, W, Sc>
12where
13    Src: ProjectedSource<S, Out>,
14    Sc: Score,
15{
16    constraint_ref: ConstraintRef,
17    impact_type: ImpactType,
18    source: Src,
19    filter: F,
20    weight: W,
21    is_hard: bool,
22    source_state: Option<Src::State>,
23    row_contributions: HashMap<ProjectedRowCoordinate, Sc>,
24    rows_by_owner: HashMap<ProjectedRowOwner, Vec<ProjectedRowCoordinate>>,
25    _phantom: PhantomData<(fn() -> S, fn() -> Out)>,
26}
27
28impl<S, Out, Src, F, W, Sc> ProjectedUniConstraint<S, Out, Src, F, W, Sc>
29where
30    S: Send + Sync + 'static,
31    Out: Send + Sync + 'static,
32    Src: ProjectedSource<S, Out>,
33    F: UniFilter<S, Out>,
34    W: Fn(&Out) -> Sc + Send + Sync,
35    Sc: Score + 'static,
36{
37    pub fn new(
38        constraint_ref: ConstraintRef,
39        impact_type: ImpactType,
40        source: Src,
41        filter: F,
42        weight: W,
43        is_hard: bool,
44    ) -> Self {
45        Self {
46            constraint_ref,
47            impact_type,
48            source,
49            filter,
50            weight,
51            is_hard,
52            source_state: None,
53            row_contributions: HashMap::new(),
54            rows_by_owner: HashMap::new(),
55            _phantom: PhantomData,
56        }
57    }
58
59    fn compute_score(&self, output: &Out) -> Sc {
60        let base = (self.weight)(output);
61        match self.impact_type {
62            ImpactType::Penalty => -base,
63            ImpactType::Reward => base,
64        }
65    }
66
67    fn ensure_source_state(&mut self, solution: &S) {
68        if self.source_state.is_none() {
69            self.source_state = Some(self.source.build_state(solution));
70        }
71    }
72
73    fn index_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
74        coordinate.for_each_owner(|owner| {
75            self.rows_by_owner
76                .entry(owner)
77                .or_default()
78                .push(coordinate);
79        });
80    }
81
82    fn unindex_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
83        coordinate.for_each_owner(|owner| {
84            let mut remove_bucket = false;
85            if let Some(rows) = self.rows_by_owner.get_mut(&owner) {
86                rows.retain(|candidate| *candidate != coordinate);
87                remove_bucket = rows.is_empty();
88            }
89            if remove_bucket {
90                self.rows_by_owner.remove(&owner);
91            }
92        });
93    }
94
95    fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
96        if self.row_contributions.contains_key(&coordinate) || !self.filter.test(solution, &output)
97        {
98            return Sc::zero();
99        }
100        let contribution = self.compute_score(&output);
101        self.row_contributions.insert(coordinate, contribution);
102        self.index_coordinate(coordinate);
103        contribution
104    }
105
106    fn retract_row(&mut self, coordinate: ProjectedRowCoordinate) -> Sc {
107        let Some(contribution) = self.row_contributions.remove(&coordinate) else {
108            return Sc::zero();
109        };
110        self.unindex_coordinate(coordinate);
111        -contribution
112    }
113
114    fn localized_owners(
115        &self,
116        descriptor_index: usize,
117        entity_index: usize,
118    ) -> Vec<ProjectedRowOwner> {
119        let mut owners = Vec::new();
120        for slot in 0..self.source.source_count() {
121            if self
122                .source
123                .change_source(slot)
124                .assert_localizes(descriptor_index, &self.constraint_ref.name)
125            {
126                owners.push(ProjectedRowOwner {
127                    source_slot: slot,
128                    entity_index,
129                });
130            }
131        }
132        owners
133    }
134
135    fn coordinates_for_owners(&self, owners: &[ProjectedRowOwner]) -> Vec<ProjectedRowCoordinate> {
136        let mut seen = HashSet::new();
137        let mut coordinates = Vec::new();
138        for owner in owners {
139            let Some(rows) = self.rows_by_owner.get(owner) else {
140                continue;
141            };
142            for &coordinate in rows {
143                if seen.insert(coordinate) {
144                    coordinates.push(coordinate);
145                }
146            }
147        }
148        coordinates
149    }
150}
151
152impl<S, Out, Src, F, W, Sc> IncrementalConstraint<S, Sc>
153    for ProjectedUniConstraint<S, Out, Src, F, W, Sc>
154where
155    S: Send + Sync + 'static,
156    Out: Send + Sync + 'static,
157    Src: ProjectedSource<S, Out>,
158    F: UniFilter<S, Out>,
159    W: Fn(&Out) -> Sc + Send + Sync,
160    Sc: Score + 'static,
161{
162    fn evaluate(&self, solution: &S) -> Sc {
163        let state = self.source.build_state(solution);
164        let mut total = Sc::zero();
165        self.source.collect_all(solution, &state, |_, output| {
166            if self.filter.test(solution, &output) {
167                total = total + self.compute_score(&output);
168            }
169        });
170        total
171    }
172
173    fn match_count(&self, solution: &S) -> usize {
174        let state = self.source.build_state(solution);
175        let mut count = 0;
176        self.source.collect_all(solution, &state, |_, output| {
177            if self.filter.test(solution, &output) {
178                count += 1;
179            }
180        });
181        count
182    }
183
184    fn initialize(&mut self, solution: &S) -> Sc {
185        self.reset();
186        let state = self.source.build_state(solution);
187        let mut total = Sc::zero();
188        let mut rows = Vec::new();
189        self.source
190            .collect_all(solution, &state, |coordinate, output| {
191                rows.push((coordinate, output));
192            });
193        self.source_state = Some(state);
194        for (coordinate, output) in rows {
195            total = total + self.insert_row(solution, coordinate, output);
196        }
197        total
198    }
199
200    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
201        let owners = self.localized_owners(descriptor_index, entity_index);
202        self.ensure_source_state(solution);
203        {
204            let state = self.source_state.as_mut().expect("projected source state");
205            for owner in &owners {
206                self.source.insert_entity_state(
207                    solution,
208                    state,
209                    owner.source_slot,
210                    owner.entity_index,
211                );
212            }
213        }
214        let mut rows = Vec::new();
215        let state = self.source_state.as_ref().expect("projected source state");
216        for owner in &owners {
217            self.source.collect_entity(
218                solution,
219                state,
220                owner.source_slot,
221                owner.entity_index,
222                |coordinate, output| rows.push((coordinate, output)),
223            );
224        }
225        let mut total = Sc::zero();
226        for (coordinate, output) in rows {
227            total = total + self.insert_row(solution, coordinate, output);
228        }
229        total
230    }
231
232    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
233        let owners = self.localized_owners(descriptor_index, entity_index);
234        let mut total = Sc::zero();
235        for coordinate in self.coordinates_for_owners(&owners) {
236            total = total + self.retract_row(coordinate);
237        }
238        if let Some(state) = self.source_state.as_mut() {
239            for owner in &owners {
240                self.source.retract_entity_state(
241                    solution,
242                    state,
243                    owner.source_slot,
244                    owner.entity_index,
245                );
246            }
247        }
248        total
249    }
250
251    fn reset(&mut self) {
252        self.source_state = None;
253        self.row_contributions.clear();
254        self.rows_by_owner.clear();
255    }
256
257    fn name(&self) -> &str {
258        &self.constraint_ref.name
259    }
260
261    fn constraint_ref(&self) -> &ConstraintRef {
262        &self.constraint_ref
263    }
264
265    fn is_hard(&self) -> bool {
266        self.is_hard
267    }
268}