solverforge_scoring/constraint/
incremental.rs1use std::fmt::Debug;
6use std::marker::PhantomData;
7
8use solverforge_core::score::Score;
9use solverforge_core::{ConstraintRef, ImpactType};
10
11use crate::api::analysis::{ConstraintJustification, DetailedConstraintMatch, EntityRef};
12use crate::api::constraint_set::IncrementalConstraint;
13
14pub struct IncrementalUniConstraint<S, A, E, F, W, Sc>
18where
19 Sc: Score,
20{
21 constraint_ref: ConstraintRef,
22 impact_type: ImpactType,
23 extractor: E,
24 filter: F,
25 weight: W,
26 is_hard: bool,
27 _phantom: PhantomData<(S, A, Sc)>,
28}
29
30impl<S, A, E, F, W, Sc> IncrementalUniConstraint<S, A, E, F, W, Sc>
31where
32 S: Send + Sync + 'static,
33 A: Clone + Send + Sync + 'static,
34 E: Fn(&S) -> &[A] + Send + Sync,
35 F: Fn(&S, &A) -> bool + Send + Sync,
36 W: Fn(&A) -> Sc + Send + Sync,
37 Sc: Score,
38{
39 pub fn new(
41 constraint_ref: ConstraintRef,
42 impact_type: ImpactType,
43 extractor: E,
44 filter: F,
45 weight: W,
46 is_hard: bool,
47 ) -> Self {
48 Self {
49 constraint_ref,
50 impact_type,
51 extractor,
52 filter,
53 weight,
54 is_hard,
55 _phantom: PhantomData,
56 }
57 }
58
59 #[inline]
60 fn matches(&self, solution: &S, entity: &A) -> bool {
61 (self.filter)(solution, entity)
62 }
63
64 #[inline]
65 fn compute_delta(&self, entity: &A) -> Sc {
66 let base = (self.weight)(entity);
67 match self.impact_type {
68 ImpactType::Penalty => -base,
69 ImpactType::Reward => base,
70 }
71 }
72
73 #[inline]
74 fn reverse_delta(&self, entity: &A) -> Sc {
75 let base = (self.weight)(entity);
76 match self.impact_type {
77 ImpactType::Penalty => base,
78 ImpactType::Reward => -base,
79 }
80 }
81}
82
83impl<S, A, E, F, W, Sc> IncrementalConstraint<S, Sc> for IncrementalUniConstraint<S, A, E, F, W, Sc>
84where
85 S: Send + Sync + 'static,
86 A: Clone + Debug + Send + Sync + 'static,
87 E: Fn(&S) -> &[A] + Send + Sync,
88 F: Fn(&S, &A) -> bool + Send + Sync,
89 W: Fn(&A) -> Sc + Send + Sync,
90 Sc: Score,
91{
92 fn evaluate(&self, solution: &S) -> Sc {
93 let entities = (self.extractor)(solution);
94 let mut total = Sc::zero();
95 for entity in entities {
96 if self.matches(solution, entity) {
97 total = total + self.compute_delta(entity);
98 }
99 }
100 total
101 }
102
103 fn match_count(&self, solution: &S) -> usize {
104 let entities = (self.extractor)(solution);
105 entities
106 .iter()
107 .filter(|e| self.matches(solution, e))
108 .count()
109 }
110
111 fn initialize(&mut self, solution: &S) -> Sc {
112 self.evaluate(solution)
113 }
114
115 fn on_insert(&mut self, solution: &S, entity_index: usize) -> Sc {
116 let entities = (self.extractor)(solution);
117 if entity_index >= entities.len() {
118 return Sc::zero();
119 }
120 let entity = &entities[entity_index];
121 if self.matches(solution, entity) {
122 self.compute_delta(entity)
123 } else {
124 Sc::zero()
125 }
126 }
127
128 fn on_retract(&mut self, solution: &S, entity_index: usize) -> Sc {
129 let entities = (self.extractor)(solution);
130 if entity_index >= entities.len() {
131 return Sc::zero();
132 }
133 let entity = &entities[entity_index];
134 if self.matches(solution, entity) {
135 self.reverse_delta(entity)
136 } else {
137 Sc::zero()
138 }
139 }
140
141 fn reset(&mut self) {
142 }
144
145 fn name(&self) -> &str {
146 &self.constraint_ref.name
147 }
148
149 fn is_hard(&self) -> bool {
150 self.is_hard
151 }
152
153 fn constraint_ref(&self) -> ConstraintRef {
154 self.constraint_ref.clone()
155 }
156
157 fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
158 let entities = (self.extractor)(solution);
159 let cref = self.constraint_ref.clone();
160 entities
161 .iter()
162 .filter(|e| self.matches(solution, e))
163 .map(|entity| {
164 let entity_ref = EntityRef::new(entity);
165 let justification = ConstraintJustification::new(vec![entity_ref]);
166 DetailedConstraintMatch::new(
167 cref.clone(),
168 self.compute_delta(entity),
169 justification,
170 )
171 })
172 .collect()
173 }
174
175 fn weight(&self) -> Sc {
176 Sc::zero()
179 }
180}
181
182impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for IncrementalUniConstraint<S, A, E, F, W, Sc> {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 f.debug_struct("IncrementalUniConstraint")
185 .field("name", &self.constraint_ref.name)
186 .field("impact_type", &self.impact_type)
187 .finish()
188 }
189}