solverforge_scoring/constraint/
incremental.rs1use std::fmt::Debug;
6use std::marker::PhantomData;
7
8use solverforge_core::score::Score;
9use solverforge_core::{ConstraintRef, ImpactType};
10
11use crate::api::analysis::{ConstraintJustification, DetailedConstraintMatch, EntityRef};
12use crate::api::constraint_set::IncrementalConstraint;
13
14pub struct IncrementalUniConstraint<S, A, E, F, W, Sc>
18where
19 Sc: Score,
20{
21 constraint_ref: ConstraintRef,
22 impact_type: ImpactType,
23 extractor: E,
24 filter: F,
25 weight: W,
26 is_hard: bool,
27 _phantom: PhantomData<(S, A, Sc)>,
28}
29
30impl<S, A, E, F, W, Sc> IncrementalUniConstraint<S, A, E, F, W, Sc>
31where
32 S: Send + Sync + 'static,
33 A: Clone + Send + Sync + 'static,
34 E: Fn(&S) -> &[A] + Send + Sync,
35 F: Fn(&A) -> bool + Send + Sync,
36 W: Fn(&A) -> Sc + Send + Sync,
37 Sc: Score,
38{
39 pub fn new(
41 constraint_ref: ConstraintRef,
42 impact_type: ImpactType,
43 extractor: E,
44 filter: F,
45 weight: W,
46 is_hard: bool,
47 ) -> Self {
48 Self {
49 constraint_ref,
50 impact_type,
51 extractor,
52 filter,
53 weight,
54 is_hard,
55 _phantom: PhantomData,
56 }
57 }
58
59 #[inline]
60 fn matches(&self, entity: &A) -> bool {
61 (self.filter)(entity)
62 }
63
64 #[inline]
65 fn compute_delta(&self, entity: &A) -> Sc {
66 let base = (self.weight)(entity);
67 match self.impact_type {
68 ImpactType::Penalty => -base,
69 ImpactType::Reward => base,
70 }
71 }
72
73 #[inline]
74 fn reverse_delta(&self, entity: &A) -> Sc {
75 let base = (self.weight)(entity);
76 match self.impact_type {
77 ImpactType::Penalty => base,
78 ImpactType::Reward => -base,
79 }
80 }
81}
82
83impl<S, A, E, F, W, Sc> IncrementalConstraint<S, Sc> for IncrementalUniConstraint<S, A, E, F, W, Sc>
84where
85 S: Send + Sync + 'static,
86 A: Clone + Debug + Send + Sync + 'static,
87 E: Fn(&S) -> &[A] + Send + Sync,
88 F: Fn(&A) -> bool + Send + Sync,
89 W: Fn(&A) -> Sc + Send + Sync,
90 Sc: Score,
91{
92 fn evaluate(&self, solution: &S) -> Sc {
93 let entities = (self.extractor)(solution);
94 let mut total = Sc::zero();
95 for entity in entities {
96 if self.matches(entity) {
97 total = total + self.compute_delta(entity);
98 }
99 }
100 total
101 }
102
103 fn match_count(&self, solution: &S) -> usize {
104 let entities = (self.extractor)(solution);
105 entities.iter().filter(|e| self.matches(e)).count()
106 }
107
108 fn initialize(&mut self, solution: &S) -> Sc {
109 self.evaluate(solution)
110 }
111
112 fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
113 let entities = (self.extractor)(solution);
114 if entity_index >= entities.len() {
115 return Sc::zero();
116 }
117 let entity = &entities[entity_index];
118 if self.matches(entity) {
119 self.compute_delta(entity)
120 } else {
121 Sc::zero()
122 }
123 }
124
125 fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
126 let entities = (self.extractor)(solution);
127 if entity_index >= entities.len() {
128 return Sc::zero();
129 }
130 let entity = &entities[entity_index];
131 if self.matches(entity) {
132 self.reverse_delta(entity)
133 } else {
134 Sc::zero()
135 }
136 }
137
138 fn reset(&mut self) {
139 }
141
142 fn name(&self) -> &str {
143 &self.constraint_ref.name
144 }
145
146 fn is_hard(&self) -> bool {
147 self.is_hard
148 }
149
150 fn constraint_ref(&self) -> ConstraintRef {
151 self.constraint_ref.clone()
152 }
153
154 fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
155 let entities = (self.extractor)(solution);
156 let cref = self.constraint_ref.clone();
157 entities
158 .iter()
159 .filter(|e| self.matches(e))
160 .map(|entity| {
161 let entity_ref = EntityRef::new(entity);
162 let justification = ConstraintJustification::new(vec![entity_ref]);
163 DetailedConstraintMatch::new(
164 cref.clone(),
165 self.compute_delta(entity),
166 justification,
167 )
168 })
169 .collect()
170 }
171
172 fn weight(&self) -> Sc {
173 Sc::zero()
176 }
177}
178
179impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for IncrementalUniConstraint<S, A, E, F, W, Sc> {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 f.debug_struct("IncrementalUniConstraint")
182 .field("name", &self.constraint_ref.name)
183 .field("impact_type", &self.impact_type)
184 .finish()
185 }
186}