solverforge_scoring/constraint/projected/
uni.rs1use std::collections::{HashMap, HashSet};
2use std::marker::PhantomData;
3
4use solverforge_core::score::Score;
5use solverforge_core::{ConstraintRef, ImpactType};
6
7use crate::api::constraint_set::IncrementalConstraint;
8use crate::stream::filter::UniFilter;
9use crate::stream::{ProjectedRowCoordinate, ProjectedRowOwner, ProjectedSource};
10
11pub struct ProjectedUniConstraint<S, Out, Src, F, W, Sc>
12where
13 Src: ProjectedSource<S, Out>,
14 Sc: Score,
15{
16 constraint_ref: ConstraintRef,
17 impact_type: ImpactType,
18 source: Src,
19 filter: F,
20 weight: W,
21 is_hard: bool,
22 source_state: Option<Src::State>,
23 row_contributions: HashMap<ProjectedRowCoordinate, Sc>,
24 rows_by_owner: HashMap<ProjectedRowOwner, Vec<ProjectedRowCoordinate>>,
25 _phantom: PhantomData<(fn() -> S, fn() -> Out)>,
26}
27
28impl<S, Out, Src, F, W, Sc> ProjectedUniConstraint<S, Out, Src, F, W, Sc>
29where
30 S: Send + Sync + 'static,
31 Out: Send + Sync + 'static,
32 Src: ProjectedSource<S, Out>,
33 F: UniFilter<S, Out>,
34 W: Fn(&Out) -> Sc + Send + Sync,
35 Sc: Score + 'static,
36{
37 pub fn new(
38 constraint_ref: ConstraintRef,
39 impact_type: ImpactType,
40 source: Src,
41 filter: F,
42 weight: W,
43 is_hard: bool,
44 ) -> Self {
45 Self {
46 constraint_ref,
47 impact_type,
48 source,
49 filter,
50 weight,
51 is_hard,
52 source_state: None,
53 row_contributions: HashMap::new(),
54 rows_by_owner: HashMap::new(),
55 _phantom: PhantomData,
56 }
57 }
58
59 fn compute_score(&self, output: &Out) -> Sc {
60 let base = (self.weight)(output);
61 match self.impact_type {
62 ImpactType::Penalty => -base,
63 ImpactType::Reward => base,
64 }
65 }
66
67 fn ensure_source_state(&mut self, solution: &S) {
68 if self.source_state.is_none() {
69 self.source_state = Some(self.source.build_state(solution));
70 }
71 }
72
73 fn index_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
74 coordinate.for_each_owner(|owner| {
75 self.rows_by_owner
76 .entry(owner)
77 .or_default()
78 .push(coordinate);
79 });
80 }
81
82 fn unindex_coordinate(&mut self, coordinate: ProjectedRowCoordinate) {
83 coordinate.for_each_owner(|owner| {
84 let mut remove_bucket = false;
85 if let Some(rows) = self.rows_by_owner.get_mut(&owner) {
86 rows.retain(|candidate| *candidate != coordinate);
87 remove_bucket = rows.is_empty();
88 }
89 if remove_bucket {
90 self.rows_by_owner.remove(&owner);
91 }
92 });
93 }
94
95 fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
96 if self.row_contributions.contains_key(&coordinate) || !self.filter.test(solution, &output)
97 {
98 return Sc::zero();
99 }
100 let contribution = self.compute_score(&output);
101 self.row_contributions.insert(coordinate, contribution);
102 self.index_coordinate(coordinate);
103 contribution
104 }
105
106 fn retract_row(&mut self, coordinate: ProjectedRowCoordinate) -> Sc {
107 let Some(contribution) = self.row_contributions.remove(&coordinate) else {
108 return Sc::zero();
109 };
110 self.unindex_coordinate(coordinate);
111 -contribution
112 }
113
114 fn localized_owners(
115 &self,
116 descriptor_index: usize,
117 entity_index: usize,
118 ) -> Vec<ProjectedRowOwner> {
119 let mut owners = Vec::new();
120 for slot in 0..self.source.source_count() {
121 if self
122 .source
123 .change_source(slot)
124 .assert_localizes(descriptor_index, &self.constraint_ref.name)
125 {
126 owners.push(ProjectedRowOwner {
127 source_slot: slot,
128 entity_index,
129 });
130 }
131 }
132 owners
133 }
134
135 fn coordinates_for_owners(&self, owners: &[ProjectedRowOwner]) -> Vec<ProjectedRowCoordinate> {
136 let mut seen = HashSet::new();
137 let mut coordinates = Vec::new();
138 for owner in owners {
139 let Some(rows) = self.rows_by_owner.get(owner) else {
140 continue;
141 };
142 for &coordinate in rows {
143 if seen.insert(coordinate) {
144 coordinates.push(coordinate);
145 }
146 }
147 }
148 coordinates
149 }
150}
151
152impl<S, Out, Src, F, W, Sc> IncrementalConstraint<S, Sc>
153 for ProjectedUniConstraint<S, Out, Src, F, W, Sc>
154where
155 S: Send + Sync + 'static,
156 Out: Send + Sync + 'static,
157 Src: ProjectedSource<S, Out>,
158 F: UniFilter<S, Out>,
159 W: Fn(&Out) -> Sc + Send + Sync,
160 Sc: Score + 'static,
161{
162 fn evaluate(&self, solution: &S) -> Sc {
163 let state = self.source.build_state(solution);
164 let mut total = Sc::zero();
165 self.source.collect_all(solution, &state, |_, output| {
166 if self.filter.test(solution, &output) {
167 total = total + self.compute_score(&output);
168 }
169 });
170 total
171 }
172
173 fn match_count(&self, solution: &S) -> usize {
174 let state = self.source.build_state(solution);
175 let mut count = 0;
176 self.source.collect_all(solution, &state, |_, output| {
177 if self.filter.test(solution, &output) {
178 count += 1;
179 }
180 });
181 count
182 }
183
184 fn initialize(&mut self, solution: &S) -> Sc {
185 self.reset();
186 let state = self.source.build_state(solution);
187 let mut total = Sc::zero();
188 let mut rows = Vec::new();
189 self.source
190 .collect_all(solution, &state, |coordinate, output| {
191 rows.push((coordinate, output));
192 });
193 self.source_state = Some(state);
194 for (coordinate, output) in rows {
195 total = total + self.insert_row(solution, coordinate, output);
196 }
197 total
198 }
199
200 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
201 let owners = self.localized_owners(descriptor_index, entity_index);
202 self.ensure_source_state(solution);
203 {
204 let state = self.source_state.as_mut().expect("projected source state");
205 for owner in &owners {
206 self.source.insert_entity_state(
207 solution,
208 state,
209 owner.source_slot,
210 owner.entity_index,
211 );
212 }
213 }
214 let mut rows = Vec::new();
215 let state = self.source_state.as_ref().expect("projected source state");
216 for owner in &owners {
217 self.source.collect_entity(
218 solution,
219 state,
220 owner.source_slot,
221 owner.entity_index,
222 |coordinate, output| rows.push((coordinate, output)),
223 );
224 }
225 let mut total = Sc::zero();
226 for (coordinate, output) in rows {
227 total = total + self.insert_row(solution, coordinate, output);
228 }
229 total
230 }
231
232 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
233 let owners = self.localized_owners(descriptor_index, entity_index);
234 let mut total = Sc::zero();
235 for coordinate in self.coordinates_for_owners(&owners) {
236 total = total + self.retract_row(coordinate);
237 }
238 if let Some(state) = self.source_state.as_mut() {
239 for owner in &owners {
240 self.source.retract_entity_state(
241 solution,
242 state,
243 owner.source_slot,
244 owner.entity_index,
245 );
246 }
247 }
248 total
249 }
250
251 fn reset(&mut self) {
252 self.source_state = None;
253 self.row_contributions.clear();
254 self.rows_by_owner.clear();
255 }
256
257 fn name(&self) -> &str {
258 &self.constraint_ref.name
259 }
260
261 fn constraint_ref(&self) -> &ConstraintRef {
262 &self.constraint_ref
263 }
264
265 fn is_hard(&self) -> bool {
266 self.is_hard
267 }
268}