solverforge_scoring/constraint/projected/
uni.rs1use std::collections::HashMap;
2use std::marker::PhantomData;
3
4use solverforge_core::score::Score;
5use solverforge_core::{ConstraintRef, ImpactType};
6
7use crate::api::constraint_set::IncrementalConstraint;
8use crate::stream::filter::UniFilter;
9use crate::stream::ProjectedSource;
10
11pub struct ProjectedUniConstraint<S, Out, Src, F, W, Sc>
12where
13 Sc: Score,
14{
15 constraint_ref: ConstraintRef,
16 impact_type: ImpactType,
17 source: Src,
18 filter: F,
19 weight: W,
20 is_hard: bool,
21 entity_contributions: HashMap<(usize, usize), Vec<Sc>>,
22 _phantom: PhantomData<(fn() -> S, fn() -> Out)>,
23}
24
25impl<S, Out, Src, F, W, Sc> ProjectedUniConstraint<S, Out, Src, F, W, Sc>
26where
27 S: Send + Sync + 'static,
28 Out: Clone + Send + Sync + 'static,
29 Src: ProjectedSource<S, Out>,
30 F: UniFilter<S, Out>,
31 W: Fn(&Out) -> Sc + Send + Sync,
32 Sc: Score + 'static,
33{
34 pub fn new(
35 constraint_ref: ConstraintRef,
36 impact_type: ImpactType,
37 source: Src,
38 filter: F,
39 weight: W,
40 is_hard: bool,
41 ) -> Self {
42 Self {
43 constraint_ref,
44 impact_type,
45 source,
46 filter,
47 weight,
48 is_hard,
49 entity_contributions: HashMap::new(),
50 _phantom: PhantomData,
51 }
52 }
53
54 fn compute_score(&self, output: &Out) -> Sc {
55 let base = (self.weight)(output);
56 match self.impact_type {
57 ImpactType::Penalty => -base,
58 ImpactType::Reward => base,
59 }
60 }
61
62 fn insert_entity_outputs(&mut self, solution: &S, slot: usize, entity_index: usize) -> Sc {
63 let mut total = Sc::zero();
64 let mut contributions = Vec::new();
65 let source = &self.source;
66 let filter = &self.filter;
67 let weight = &self.weight;
68 let impact = self.impact_type;
69 source.collect_entity(solution, slot, entity_index, |_, output| {
70 if !filter.test(solution, &output) {
71 return;
72 }
73 let base = weight(&output);
74 let contribution = match impact {
75 ImpactType::Penalty => -base,
76 ImpactType::Reward => base,
77 };
78 total = total + contribution;
79 contributions.push(contribution);
80 });
81 self.entity_contributions
82 .insert((slot, entity_index), contributions);
83 total
84 }
85
86 fn retract_entity_outputs(&mut self, slot: usize, entity_index: usize) -> Sc {
87 self.entity_contributions
88 .remove(&(slot, entity_index))
89 .unwrap_or_default()
90 .into_iter()
91 .fold(Sc::zero(), |total, contribution| total - contribution)
92 }
93
94 fn localized_slots(&self, descriptor_index: usize) -> Vec<usize> {
95 let mut slots = Vec::new();
96 for slot in 0..self.source.source_count() {
97 if self
98 .source
99 .change_source(slot)
100 .assert_localizes(descriptor_index, &self.constraint_ref.name)
101 {
102 slots.push(slot);
103 }
104 }
105 slots
106 }
107}
108
109impl<S, Out, Src, F, W, Sc> IncrementalConstraint<S, Sc>
110 for ProjectedUniConstraint<S, Out, Src, F, W, Sc>
111where
112 S: Send + Sync + 'static,
113 Out: Clone + Send + Sync + 'static,
114 Src: ProjectedSource<S, Out>,
115 F: UniFilter<S, Out>,
116 W: Fn(&Out) -> Sc + Send + Sync,
117 Sc: Score + 'static,
118{
119 fn evaluate(&self, solution: &S) -> Sc {
120 let mut total = Sc::zero();
121 self.source.collect_all(solution, |_, output| {
122 if self.filter.test(solution, &output) {
123 total = total + self.compute_score(&output);
124 }
125 });
126 total
127 }
128
129 fn match_count(&self, solution: &S) -> usize {
130 let mut count = 0;
131 self.source.collect_all(solution, |_, output| {
132 if self.filter.test(solution, &output) {
133 count += 1;
134 }
135 });
136 count
137 }
138
139 fn initialize(&mut self, solution: &S) -> Sc {
140 self.reset();
141 let mut total = Sc::zero();
142 let source = &self.source;
143 let filter = &self.filter;
144 let weight = &self.weight;
145 let impact = self.impact_type;
146 let entity_contributions = &mut self.entity_contributions;
147 source.collect_all(solution, |coordinate, output| {
148 if !filter.test(solution, &output) {
149 return;
150 }
151 let base = weight(&output);
152 let contribution = match impact {
153 ImpactType::Penalty => -base,
154 ImpactType::Reward => base,
155 };
156 let mut contributions = entity_contributions
157 .remove(&(coordinate.source_slot, coordinate.entity_index))
158 .unwrap_or_default();
159 total = total + contribution;
160 contributions.push(contribution);
161 entity_contributions.insert(
162 (coordinate.source_slot, coordinate.entity_index),
163 contributions,
164 );
165 });
166 total
167 }
168
169 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
170 let mut total = Sc::zero();
171 for slot in self.localized_slots(descriptor_index) {
172 total = total + self.insert_entity_outputs(solution, slot, entity_index);
173 }
174 total
175 }
176
177 fn on_retract(&mut self, _solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
178 let mut total = Sc::zero();
179 for slot in self.localized_slots(descriptor_index) {
180 total = total + self.retract_entity_outputs(slot, entity_index);
181 }
182 total
183 }
184
185 fn reset(&mut self) {
186 self.entity_contributions.clear();
187 }
188
189 fn name(&self) -> &str {
190 &self.constraint_ref.name
191 }
192
193 fn is_hard(&self) -> bool {
194 self.is_hard
195 }
196
197 fn constraint_ref(&self) -> ConstraintRef {
198 self.constraint_ref.clone()
199 }
200}