Skip to main content

solverforge_scoring/constraint/grouped/
terminal.rs

1use std::hash::Hash;
2use std::marker::PhantomData;
3
4use solverforge_core::score::Score;
5use solverforge_core::{ConstraintRef, ImpactType};
6
7use crate::api::analysis::DetailedConstraintMatch;
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::collection_extract::CollectionExtract;
10use crate::stream::collector::{Accumulator, Collector};
11use crate::stream::filter::UniFilter;
12use crate::stream::ConstraintWeight;
13
14use super::scorer::GroupedTerminalScorer;
15use super::shared_set::SharedGroupedConstraintSet;
16use super::state::GroupedNodeState;
17
18type Inner<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc> = SharedGroupedConstraintSet<
19    S,
20    A,
21    K,
22    E,
23    Fi,
24    KF,
25    C,
26    V,
27    R,
28    Acc,
29    GroupedTerminalScorer<K, R, W, Sc>,
30    Sc,
31>;
32
33pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
34where
35    Acc: Accumulator<V, R>,
36    Sc: Score,
37{
38    is_hard: bool,
39    inner: Inner<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>,
40    _phantom: PhantomData<fn() -> (S, A, V, R, Acc)>,
41}
42
43impl<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
44    GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
45where
46    S: Send + Sync + 'static,
47    A: Send + Sync + 'static,
48    K: Eq + Hash + Send + Sync + 'static,
49    E: CollectionExtract<S, Item = A>,
50    Fi: UniFilter<S, A>,
51    KF: Fn(&A) -> K + Send + Sync,
52    C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
53    V: Send + Sync + 'static,
54    R: Send + Sync + 'static,
55    Acc: Accumulator<V, R> + Send + Sync + 'static,
56    W: Fn(&K, &R) -> Sc + Send + Sync,
57    Sc: Score + 'static,
58{
59    #[allow(clippy::too_many_arguments)]
60    pub fn new(
61        constraint_ref: ConstraintRef,
62        impact_type: ImpactType,
63        extractor: E,
64        filter: Fi,
65        key_fn: KF,
66        collector: C,
67        weight_fn: W,
68        is_hard: bool,
69    ) -> Self {
70        let state = GroupedNodeState::new(extractor, filter, key_fn, collector);
71        let scorer = GroupedTerminalScorer::new(constraint_ref, impact_type, weight_fn, is_hard);
72        Self {
73            is_hard,
74            inner: SharedGroupedConstraintSet::new(state, scorer),
75            _phantom: PhantomData,
76        }
77    }
78
79    pub fn penalize<W2>(
80        self,
81        weight: W2,
82    ) -> super::shared_set::GroupedConstraintSetBuilder<
83        S,
84        A,
85        K,
86        E,
87        Fi,
88        KF,
89        C,
90        V,
91        R,
92        Acc,
93        GroupedTerminalScorer<K, R, W, Sc>,
94        impl Fn(&K, &R) -> Sc + Send + Sync,
95        Sc,
96    >
97    where
98        W2: for<'w> ConstraintWeight<(&'w K, &'w R), Sc> + Send + Sync,
99    {
100        self.inner.penalize(weight)
101    }
102
103    pub fn reward<W2>(
104        self,
105        weight: W2,
106    ) -> super::shared_set::GroupedConstraintSetBuilder<
107        S,
108        A,
109        K,
110        E,
111        Fi,
112        KF,
113        C,
114        V,
115        R,
116        Acc,
117        GroupedTerminalScorer<K, R, W, Sc>,
118        impl Fn(&K, &R) -> Sc + Send + Sync,
119        Sc,
120    >
121    where
122        W2: for<'w> ConstraintWeight<(&'w K, &'w R), Sc> + Send + Sync,
123    {
124        self.inner.reward(weight)
125    }
126}
127
128impl<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc> IncrementalConstraint<S, Sc>
129    for GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
130where
131    S: Send + Sync + 'static,
132    A: Send + Sync + 'static,
133    K: Eq + Hash + Send + Sync + 'static,
134    E: CollectionExtract<S, Item = A>,
135    Fi: UniFilter<S, A>,
136    KF: Fn(&A) -> K + Send + Sync,
137    C: for<'i> Collector<&'i A, Value = V, Result = R, Accumulator = Acc> + Send + Sync + 'static,
138    V: Send + Sync + 'static,
139    R: Send + Sync + 'static,
140    Acc: Accumulator<V, R> + Send + Sync + 'static,
141    W: Fn(&K, &R) -> Sc + Send + Sync,
142    Sc: Score + 'static,
143{
144    fn evaluate(&self, solution: &S) -> Sc {
145        crate::api::constraint_set::ConstraintSet::evaluate_all(&self.inner, solution)
146    }
147
148    fn match_count(&self, solution: &S) -> usize {
149        crate::api::constraint_set::ConstraintSet::evaluate_each(&self.inner, solution)
150            .first()
151            .map_or(0, |result| result.match_count)
152    }
153
154    fn initialize(&mut self, solution: &S) -> Sc {
155        crate::api::constraint_set::ConstraintSet::initialize_all(&mut self.inner, solution)
156    }
157
158    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
159        crate::api::constraint_set::ConstraintSet::on_insert_all(
160            &mut self.inner,
161            solution,
162            entity_index,
163            descriptor_index,
164        )
165    }
166
167    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
168        crate::api::constraint_set::ConstraintSet::on_retract_all(
169            &mut self.inner,
170            solution,
171            entity_index,
172            descriptor_index,
173        )
174    }
175
176    fn reset(&mut self) {
177        crate::api::constraint_set::ConstraintSet::reset_all(&mut self.inner);
178    }
179
180    fn constraint_ref(&self) -> &ConstraintRef {
181        self.inner.primary_constraint_ref()
182    }
183
184    fn is_hard(&self) -> bool {
185        self.is_hard
186    }
187
188    fn get_matches<'a>(&'a self, _solution: &S) -> Vec<DetailedConstraintMatch<'a, Sc>> {
189        Vec::new()
190    }
191
192    fn weight(&self) -> Sc {
193        Sc::zero()
194    }
195}
196
197impl<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc> std::fmt::Debug
198    for GroupedUniConstraint<S, A, K, E, Fi, KF, C, V, R, Acc, W, Sc>
199where
200    Acc: Accumulator<V, R>,
201    Sc: Score,
202{
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        f.debug_struct("GroupedUniConstraint").finish()
205    }
206}