1use std::hash::Hash;
2use std::marker::PhantomData;
3
4use solverforge_core::score::Score;
5use solverforge_core::{ConstraintRef, ImpactType};
6
7use crate::api::analysis::DetailedConstraintMatch;
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::collection_extract::CollectionExtract;
10use crate::stream::collector::{Accumulator, Collector};
11use crate::stream::filter::UniFilter;
12use crate::stream::ConstraintWeight;
13
14use super::scorer::GroupedTerminalScorer;
15use super::shared_set::SharedGroupedConstraintSet;
16use super::state::GroupedNodeState;
17
18type Inner<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc> = SharedGroupedConstraintSet<
19 S,
20 A,
21 K,
22 E,
23 Fi,
24 KF,
25 C,
26 V,
27 R,
28 Acc,
29 GroupedTerminalScorer<K, R, W, Sc>,
30 Sc,
31>;
32
33pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
34where
35 Acc: Accumulator<V, R>,
36 Sc: Score,
37{
38 is_hard: bool,
39 inner: Inner<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>,
40 _phantom: PhantomData<fn() -> (S, A, V, R, Acc)>,
41}
42
43impl<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
44 GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
45where
46 S: Send + Sync + 'static,
47 A: Send + Sync + 'static,
48 K: Eq + Hash + Send + Sync + 'static,
49 E: CollectionExtract<S, Item = A>,
50 Fi: UniFilter<S, A>,
51 KF: Fn(&A) -> K + Send + Sync,
52 C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
53 V: Send + Sync + 'static,
54 R: Send + Sync + 'static,
55 Acc: Accumulator<V, R> + Send + Sync + 'static,
56 W: Fn(&K, &R) -> Sc + Send + Sync,
57 Sc: Score + 'static,
58{
59 #[allow(clippy::too_many_arguments)]
60 pub fn new(
61 constraint_ref: ConstraintRef,
62 impact_type: ImpactType,
63 extractor: E,
64 filter: Fi,
65 key_fn: KF,
66 collector: C,
67 weight_fn: W,
68 is_hard: bool,
69 ) -> Self {
70 let state = GroupedNodeState::new(extractor, filter, key_fn, collector);
71 let scorer = GroupedTerminalScorer::new(constraint_ref, impact_type, weight_fn, is_hard);
72 Self {
73 is_hard,
74 inner: SharedGroupedConstraintSet::new(state, scorer),
75 _phantom: PhantomData,
76 }
77 }
78
79 pub fn penalize<W2>(
80 self,
81 weight: W2,
82 ) -> super::shared_set::GroupedConstraintSetBuilder<
83 S,
84 A,
85 K,
86 E,
87 Fi,
88 KF,
89 C,
90 V,
91 R,
92 Acc,
93 GroupedTerminalScorer<K, R, W, Sc>,
94 impl Fn(&K, &R) -> Sc + Send + Sync,
95 Sc,
96 >
97 where
98 W2: for<'w> ConstraintWeight<(&'w K, &'w R), Sc> + Send + Sync,
99 {
100 self.inner.penalize(weight)
101 }
102
103 pub fn reward<W2>(
104 self,
105 weight: W2,
106 ) -> super::shared_set::GroupedConstraintSetBuilder<
107 S,
108 A,
109 K,
110 E,
111 Fi,
112 KF,
113 C,
114 V,
115 R,
116 Acc,
117 GroupedTerminalScorer<K, R, W, Sc>,
118 impl Fn(&K, &R) -> Sc + Send + Sync,
119 Sc,
120 >
121 where
122 W2: for<'w> ConstraintWeight<(&'w K, &'w R), Sc> + Send + Sync,
123 {
124 self.inner.reward(weight)
125 }
126}
127
128impl<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc> IncrementalConstraint<S, Sc>
129 for GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
130where
131 S: Send + Sync + 'static,
132 A: Send + Sync + 'static,
133 K: Eq + Hash + Send + Sync + 'static,
134 E: CollectionExtract<S, Item = A>,
135 Fi: UniFilter<S, A>,
136 KF: Fn(&A) -> K + Send + Sync,
137 C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
138 V: Send + Sync + 'static,
139 R: Send + Sync + 'static,
140 Acc: Accumulator<V, R> + Send + Sync + 'static,
141 W: Fn(&K, &R) -> Sc + Send + Sync,
142 Sc: Score + 'static,
143{
144 fn evaluate(&self, solution: &S) -> Sc {
145 crate::api::constraint_set::ConstraintSet::evaluate_all(&self.inner, solution)
146 }
147
148 fn match_count(&self, solution: &S) -> usize {
149 crate::api::constraint_set::ConstraintSet::evaluate_each(&self.inner, solution)
150 .first()
151 .map_or(0, |result| result.match_count)
152 }
153
154 fn initialize(&mut self, solution: &S) -> Sc {
155 crate::api::constraint_set::ConstraintSet::initialize_all(&mut self.inner, solution)
156 }
157
158 fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
159 crate::api::constraint_set::ConstraintSet::on_insert_all(
160 &mut self.inner,
161 solution,
162 entity_index,
163 descriptor_index,
164 )
165 }
166
167 fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
168 crate::api::constraint_set::ConstraintSet::on_retract_all(
169 &mut self.inner,
170 solution,
171 entity_index,
172 descriptor_index,
173 )
174 }
175
176 fn reset(&mut self) {
177 crate::api::constraint_set::ConstraintSet::reset_all(&mut self.inner);
178 }
179
180 fn constraint_ref(&self) -> &ConstraintRef {
181 self.inner.primary_constraint_ref()
182 }
183
184 fn is_hard(&self) -> bool {
185 self.is_hard
186 }
187
188 fn get_matches<'a>(&'a self, _solution: &S) -> Vec<DetailedConstraintMatch<'a, Sc>> {
189 Vec::new()
190 }
191
192 fn weight(&self) -> Sc {
193 Sc::zero()
194 }
195}
196
197impl<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc> std::fmt::Debug
198 for GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
199where
200 Acc: Accumulator<V, R>,
201 Sc: Score,
202{
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 f.debug_struct("GroupedUniConstraint").finish()
205 }
206}