solverforge_scoring/constraint/
incremental.rs1use std::fmt::Debug;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use crate::api::analysis::{ConstraintJustification, DetailedConstraintMatch, EntityRef};
13use crate::api::constraint_set::IncrementalConstraint;
14
15pub struct IncrementalUniConstraint<S, A, E, F, W, Sc>
20where
21 Sc: Score,
22{
23 constraint_ref: ConstraintRef,
24 impact_type: ImpactType,
25 extractor: E,
26 filter: F,
27 weight: W,
28 is_hard: bool,
29 change_source: crate::stream::collection_extract::ChangeSource,
30 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
31}
32
33impl<S, A, E, F, W, Sc> IncrementalUniConstraint<S, A, E, F, W, Sc>
34where
35 S: Send + Sync + 'static,
36 A: Clone + Send + Sync + 'static,
37 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
38 F: Fn(&S, &A) -> bool + Send + Sync,
39 W: Fn(&A) -> Sc + Send + Sync,
40 Sc: Score,
41{
42 pub fn new(
44 constraint_ref: ConstraintRef,
45 impact_type: ImpactType,
46 extractor: E,
47 filter: F,
48 weight: W,
49 is_hard: bool,
50 ) -> Self {
51 let change_source = extractor.change_source();
52 Self {
53 constraint_ref,
54 impact_type,
55 extractor,
56 filter,
57 weight,
58 is_hard,
59 change_source,
60 _phantom: PhantomData,
61 }
62 }
63
64 #[inline]
65 fn matches(&self, solution: &S, entity: &A) -> bool {
66 (self.filter)(solution, entity)
67 }
68
69 #[inline]
70 fn compute_delta(&self, entity: &A) -> Sc {
71 let base = (self.weight)(entity);
72 match self.impact_type {
73 ImpactType::Penalty => -base,
74 ImpactType::Reward => base,
75 }
76 }
77
78 #[inline]
79 fn reverse_delta(&self, entity: &A) -> Sc {
80 let base = (self.weight)(entity);
81 match self.impact_type {
82 ImpactType::Penalty => base,
83 ImpactType::Reward => -base,
84 }
85 }
86}
87
88impl<S, A, E, F, W, Sc> IncrementalConstraint<S, Sc> for IncrementalUniConstraint<S, A, E, F, W, Sc>
89where
90 S: Send + Sync + 'static,
91 A: Clone + Debug + Send + Sync + 'static,
92 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
93 F: Fn(&S, &A) -> bool + Send + Sync,
94 W: Fn(&A) -> Sc + Send + Sync,
95 Sc: Score,
96{
97 fn evaluate(&self, solution: &S) -> Sc {
98 let entities = self.extractor.extract(solution);
99 let mut total = Sc::zero();
100 for entity in entities {
101 if self.matches(solution, entity) {
102 total = total + self.compute_delta(entity);
103 }
104 }
105 total
106 }
107
108 fn match_count(&self, solution: &S) -> usize {
109 let entities = self.extractor.extract(solution);
110 entities
111 .iter()
112 .filter(|e| self.matches(solution, e))
113 .count()
114 }
115
116 fn initialize(&mut self, solution: &S) -> Sc {
117 self.evaluate(solution)
118 }
119
120 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
121 if !self
122 .change_source
123 .assert_localizes(descriptor_index, &self.constraint_ref.name)
124 {
125 return Sc::zero();
126 }
127 let entities = self.extractor.extract(solution);
128 if entity_index >= entities.len() {
129 return Sc::zero();
130 }
131 let entity = &entities[entity_index];
132 if self.matches(solution, entity) {
133 self.compute_delta(entity)
134 } else {
135 Sc::zero()
136 }
137 }
138
139 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
140 if !self
141 .change_source
142 .assert_localizes(descriptor_index, &self.constraint_ref.name)
143 {
144 return Sc::zero();
145 }
146 let entities = self.extractor.extract(solution);
147 if entity_index >= entities.len() {
148 return Sc::zero();
149 }
150 let entity = &entities[entity_index];
151 if self.matches(solution, entity) {
152 self.reverse_delta(entity)
153 } else {
154 Sc::zero()
155 }
156 }
157
158 fn reset(&mut self) {
159 }
161
162 fn name(&self) -> &str {
163 &self.constraint_ref.name
164 }
165
166 fn is_hard(&self) -> bool {
167 self.is_hard
168 }
169
170 fn constraint_ref(&self) -> &ConstraintRef {
171 &self.constraint_ref
172 }
173
174 fn get_matches<'a>(&'a self, solution: &S) -> Vec<DetailedConstraintMatch<'a, Sc>> {
175 let entities = self.extractor.extract(solution);
176 let cref = self.constraint_ref();
177 entities
178 .iter()
179 .filter(|e| self.matches(solution, e))
180 .map(|entity| {
181 let entity_ref = EntityRef::new(entity);
182 let justification = ConstraintJustification::new(vec![entity_ref]);
183 DetailedConstraintMatch::new(cref, self.compute_delta(entity), justification)
184 })
185 .collect()
186 }
187
188 fn weight(&self) -> Sc {
189 Sc::zero()
192 }
193}
194
195impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for IncrementalUniConstraint<S, A, E, F, W, Sc> {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 f.debug_struct("IncrementalUniConstraint")
198 .field("name", &self.constraint_ref.name)
199 .field("impact_type", &self.impact_type)
200 .finish()
201 }
202}