solverforge_scoring/constraint/
incremental.rs1use std::fmt::Debug;
6use std::marker::PhantomData;
7
8use solverforge_core::score::Score;
9use solverforge_core::{ConstraintRef, ImpactType};
10
11use crate::api::analysis::{ConstraintJustification, DetailedConstraintMatch, EntityRef};
12use crate::api::constraint_set::IncrementalConstraint;
13
14pub struct IncrementalUniConstraint<S, A, E, F, W, Sc>
18where
19 Sc: Score,
20{
21 constraint_ref: ConstraintRef,
22 impact_type: ImpactType,
23 extractor: E,
24 filter: F,
25 weight: W,
26 is_hard: bool,
27 expected_descriptor: Option<usize>,
28 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
29}
30
31impl<S, A, E, F, W, Sc> IncrementalUniConstraint<S, A, E, F, W, Sc>
32where
33 S: Send + Sync + 'static,
34 A: Clone + Send + Sync + 'static,
35 E: Fn(&S) -> &[A] + Send + Sync,
36 F: Fn(&S, &A) -> bool + Send + Sync,
37 W: Fn(&A) -> Sc + Send + Sync,
38 Sc: Score,
39{
40 pub fn new(
42 constraint_ref: ConstraintRef,
43 impact_type: ImpactType,
44 extractor: E,
45 filter: F,
46 weight: W,
47 is_hard: bool,
48 ) -> Self {
49 Self {
50 constraint_ref,
51 impact_type,
52 extractor,
53 filter,
54 weight,
55 is_hard,
56 expected_descriptor: None,
57 _phantom: PhantomData,
58 }
59 }
60
61 pub fn with_descriptor(mut self, descriptor_index: usize) -> Self {
62 self.expected_descriptor = Some(descriptor_index);
63 self
64 }
65
66 #[inline]
67 fn matches(&self, solution: &S, entity: &A) -> bool {
68 (self.filter)(solution, entity)
69 }
70
71 #[inline]
72 fn compute_delta(&self, entity: &A) -> Sc {
73 let base = (self.weight)(entity);
74 match self.impact_type {
75 ImpactType::Penalty => -base,
76 ImpactType::Reward => base,
77 }
78 }
79
80 #[inline]
81 fn reverse_delta(&self, entity: &A) -> Sc {
82 let base = (self.weight)(entity);
83 match self.impact_type {
84 ImpactType::Penalty => base,
85 ImpactType::Reward => -base,
86 }
87 }
88}
89
90impl<S, A, E, F, W, Sc> IncrementalConstraint<S, Sc> for IncrementalUniConstraint<S, A, E, F, W, Sc>
91where
92 S: Send + Sync + 'static,
93 A: Clone + Debug + Send + Sync + 'static,
94 E: Fn(&S) -> &[A] + Send + Sync,
95 F: Fn(&S, &A) -> bool + Send + Sync,
96 W: Fn(&A) -> Sc + Send + Sync,
97 Sc: Score,
98{
99 fn evaluate(&self, solution: &S) -> Sc {
100 let entities = (self.extractor)(solution);
101 let mut total = Sc::zero();
102 for entity in entities {
103 if self.matches(solution, entity) {
104 total = total + self.compute_delta(entity);
105 }
106 }
107 total
108 }
109
110 fn match_count(&self, solution: &S) -> usize {
111 let entities = (self.extractor)(solution);
112 entities
113 .iter()
114 .filter(|e| self.matches(solution, e))
115 .count()
116 }
117
118 fn initialize(&mut self, solution: &S) -> Sc {
119 self.evaluate(solution)
120 }
121
122 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
123 if let Some(expected) = self.expected_descriptor {
124 if descriptor_index != expected {
125 return Sc::zero();
126 }
127 }
128 let entities = (self.extractor)(solution);
129 if entity_index >= entities.len() {
130 return Sc::zero();
131 }
132 let entity = &entities[entity_index];
133 if self.matches(solution, entity) {
134 self.compute_delta(entity)
135 } else {
136 Sc::zero()
137 }
138 }
139
140 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
141 if let Some(expected) = self.expected_descriptor {
142 if descriptor_index != expected {
143 return Sc::zero();
144 }
145 }
146 let entities = (self.extractor)(solution);
147 if entity_index >= entities.len() {
148 return Sc::zero();
149 }
150 let entity = &entities[entity_index];
151 if self.matches(solution, entity) {
152 self.reverse_delta(entity)
153 } else {
154 Sc::zero()
155 }
156 }
157
158 fn reset(&mut self) {
159 }
161
162 fn name(&self) -> &str {
163 &self.constraint_ref.name
164 }
165
166 fn is_hard(&self) -> bool {
167 self.is_hard
168 }
169
170 fn constraint_ref(&self) -> ConstraintRef {
171 self.constraint_ref.clone()
172 }
173
174 fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
175 let entities = (self.extractor)(solution);
176 let cref = self.constraint_ref.clone();
177 entities
178 .iter()
179 .filter(|e| self.matches(solution, e))
180 .map(|entity| {
181 let entity_ref = EntityRef::new(entity);
182 let justification = ConstraintJustification::new(vec![entity_ref]);
183 DetailedConstraintMatch::new(
184 cref.clone(),
185 self.compute_delta(entity),
186 justification,
187 )
188 })
189 .collect()
190 }
191
192 fn weight(&self) -> Sc {
193 Sc::zero()
196 }
197}
198
199impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for IncrementalUniConstraint<S, A, E, F, W, Sc> {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 f.debug_struct("IncrementalUniConstraint")
202 .field("name", &self.constraint_ref.name)
203 .field("impact_type", &self.impact_type)
204 .finish()
205 }
206}