solverforge_scoring/constraint/
incremental.rs1use std::fmt::Debug;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use crate::api::analysis::{ConstraintJustification, DetailedConstraintMatch, EntityRef};
13use crate::api::constraint_set::IncrementalConstraint;
14
15pub struct IncrementalUniConstraint<S, A, E, F, W, Sc>
20where
21 Sc: Score,
22{
23 constraint_ref: ConstraintRef,
24 impact_type: ImpactType,
25 extractor: E,
26 filter: F,
27 weight: W,
28 is_hard: bool,
29 expected_descriptor: Option<usize>,
30 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
31}
32
33impl<S, A, E, F, W, Sc> IncrementalUniConstraint<S, A, E, F, W, Sc>
34where
35 S: Send + Sync + 'static,
36 A: Clone + Send + Sync + 'static,
37 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
38 F: Fn(&S, &A) -> bool + Send + Sync,
39 W: Fn(&A) -> Sc + Send + Sync,
40 Sc: Score,
41{
42 pub fn new(
44 constraint_ref: ConstraintRef,
45 impact_type: ImpactType,
46 extractor: E,
47 filter: F,
48 weight: W,
49 is_hard: bool,
50 ) -> Self {
51 Self {
52 constraint_ref,
53 impact_type,
54 extractor,
55 filter,
56 weight,
57 is_hard,
58 expected_descriptor: None,
59 _phantom: PhantomData,
60 }
61 }
62
63 pub fn with_descriptor(mut self, descriptor_index: usize) -> Self {
64 self.expected_descriptor = Some(descriptor_index);
65 self
66 }
67
68 #[inline]
69 fn matches(&self, solution: &S, entity: &A) -> bool {
70 (self.filter)(solution, entity)
71 }
72
73 #[inline]
74 fn compute_delta(&self, entity: &A) -> Sc {
75 let base = (self.weight)(entity);
76 match self.impact_type {
77 ImpactType::Penalty => -base,
78 ImpactType::Reward => base,
79 }
80 }
81
82 #[inline]
83 fn reverse_delta(&self, entity: &A) -> Sc {
84 let base = (self.weight)(entity);
85 match self.impact_type {
86 ImpactType::Penalty => base,
87 ImpactType::Reward => -base,
88 }
89 }
90}
91
92impl<S, A, E, F, W, Sc> IncrementalConstraint<S, Sc> for IncrementalUniConstraint<S, A, E, F, W, Sc>
93where
94 S: Send + Sync + 'static,
95 A: Clone + Debug + Send + Sync + 'static,
96 E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
97 F: Fn(&S, &A) -> bool + Send + Sync,
98 W: Fn(&A) -> Sc + Send + Sync,
99 Sc: Score,
100{
101 fn evaluate(&self, solution: &S) -> Sc {
102 let entities = self.extractor.extract(solution);
103 let mut total = Sc::zero();
104 for entity in entities {
105 if self.matches(solution, entity) {
106 total = total + self.compute_delta(entity);
107 }
108 }
109 total
110 }
111
112 fn match_count(&self, solution: &S) -> usize {
113 let entities = self.extractor.extract(solution);
114 entities
115 .iter()
116 .filter(|e| self.matches(solution, e))
117 .count()
118 }
119
120 fn initialize(&mut self, solution: &S) -> Sc {
121 self.evaluate(solution)
122 }
123
124 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
125 if let Some(expected) = self.expected_descriptor {
126 if descriptor_index != expected {
127 return Sc::zero();
128 }
129 }
130 let entities = self.extractor.extract(solution);
131 if entity_index >= entities.len() {
132 return Sc::zero();
133 }
134 let entity = &entities[entity_index];
135 if self.matches(solution, entity) {
136 self.compute_delta(entity)
137 } else {
138 Sc::zero()
139 }
140 }
141
142 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
143 if let Some(expected) = self.expected_descriptor {
144 if descriptor_index != expected {
145 return Sc::zero();
146 }
147 }
148 let entities = self.extractor.extract(solution);
149 if entity_index >= entities.len() {
150 return Sc::zero();
151 }
152 let entity = &entities[entity_index];
153 if self.matches(solution, entity) {
154 self.reverse_delta(entity)
155 } else {
156 Sc::zero()
157 }
158 }
159
160 fn reset(&mut self) {
161 }
163
164 fn name(&self) -> &str {
165 &self.constraint_ref.name
166 }
167
168 fn is_hard(&self) -> bool {
169 self.is_hard
170 }
171
172 fn constraint_ref(&self) -> ConstraintRef {
173 self.constraint_ref.clone()
174 }
175
176 fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
177 let entities = self.extractor.extract(solution);
178 let cref = self.constraint_ref.clone();
179 entities
180 .iter()
181 .filter(|e| self.matches(solution, e))
182 .map(|entity| {
183 let entity_ref = EntityRef::new(entity);
184 let justification = ConstraintJustification::new(vec![entity_ref]);
185 DetailedConstraintMatch::new(
186 cref.clone(),
187 self.compute_delta(entity),
188 justification,
189 )
190 })
191 .collect()
192 }
193
194 fn weight(&self) -> Sc {
195 Sc::zero()
198 }
199}
200
201impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for IncrementalUniConstraint<S, A, E, F, W, Sc> {
202 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203 f.debug_struct("IncrementalUniConstraint")
204 .field("name", &self.constraint_ref.name)
205 .field("impact_type", &self.impact_type)
206 .finish()
207 }
208}