Skip to main content

solverforge_scoring/constraint/projected/
uni.rs

1use std::collections::HashMap;
2use std::marker::PhantomData;
3
4use solverforge_core::score::Score;
5use solverforge_core::{ConstraintRef, ImpactType};
6
7use crate::api::constraint_set::IncrementalConstraint;
8use crate::stream::filter::UniFilter;
9use crate::stream::ProjectedSource;
10
11pub struct ProjectedUniConstraint<S, Out, Src, F, W, Sc>
12where
13    Sc: Score,
14{
15    constraint_ref: ConstraintRef,
16    impact_type: ImpactType,
17    source: Src,
18    filter: F,
19    weight: W,
20    is_hard: bool,
21    entity_contributions: HashMap<(usize, usize), Vec<Sc>>,
22    _phantom: PhantomData<(fn() -> S, fn() -> Out)>,
23}
24
25impl<S, Out, Src, F, W, Sc> ProjectedUniConstraint<S, Out, Src, F, W, Sc>
26where
27    S: Send + Sync + 'static,
28    Out: Clone + Send + Sync + 'static,
29    Src: ProjectedSource<S, Out>,
30    F: UniFilter<S, Out>,
31    W: Fn(&Out) -> Sc + Send + Sync,
32    Sc: Score + 'static,
33{
34    pub fn new(
35        constraint_ref: ConstraintRef,
36        impact_type: ImpactType,
37        source: Src,
38        filter: F,
39        weight: W,
40        is_hard: bool,
41    ) -> Self {
42        Self {
43            constraint_ref,
44            impact_type,
45            source,
46            filter,
47            weight,
48            is_hard,
49            entity_contributions: HashMap::new(),
50            _phantom: PhantomData,
51        }
52    }
53
54    fn compute_score(&self, output: &Out) -> Sc {
55        let base = (self.weight)(output);
56        match self.impact_type {
57            ImpactType::Penalty => -base,
58            ImpactType::Reward => base,
59        }
60    }
61
62    fn insert_entity_outputs(&mut self, solution: &S, slot: usize, entity_index: usize) -> Sc {
63        let mut total = Sc::zero();
64        let mut contributions = Vec::new();
65        let source = &self.source;
66        let filter = &self.filter;
67        let weight = &self.weight;
68        let impact = self.impact_type;
69        source.collect_entity(solution, slot, entity_index, |_, output| {
70            if !filter.test(solution, &output) {
71                return;
72            }
73            let base = weight(&output);
74            let contribution = match impact {
75                ImpactType::Penalty => -base,
76                ImpactType::Reward => base,
77            };
78            total = total + contribution;
79            contributions.push(contribution);
80        });
81        self.entity_contributions
82            .insert((slot, entity_index), contributions);
83        total
84    }
85
86    fn retract_entity_outputs(&mut self, slot: usize, entity_index: usize) -> Sc {
87        self.entity_contributions
88            .remove(&(slot, entity_index))
89            .unwrap_or_default()
90            .into_iter()
91            .fold(Sc::zero(), |total, contribution| total - contribution)
92    }
93
94    fn localized_slots(&self, descriptor_index: usize) -> Vec<usize> {
95        let mut slots = Vec::new();
96        for slot in 0..self.source.source_count() {
97            if self
98                .source
99                .change_source(slot)
100                .assert_localizes(descriptor_index, &self.constraint_ref.name)
101            {
102                slots.push(slot);
103            }
104        }
105        slots
106    }
107}
108
109impl<S, Out, Src, F, W, Sc> IncrementalConstraint<S, Sc>
110    for ProjectedUniConstraint<S, Out, Src, F, W, Sc>
111where
112    S: Send + Sync + 'static,
113    Out: Clone + Send + Sync + 'static,
114    Src: ProjectedSource<S, Out>,
115    F: UniFilter<S, Out>,
116    W: Fn(&Out) -> Sc + Send + Sync,
117    Sc: Score + 'static,
118{
119    fn evaluate(&self, solution: &S) -> Sc {
120        let mut total = Sc::zero();
121        self.source.collect_all(solution, |_, output| {
122            if self.filter.test(solution, &output) {
123                total = total + self.compute_score(&output);
124            }
125        });
126        total
127    }
128
129    fn match_count(&self, solution: &S) -> usize {
130        let mut count = 0;
131        self.source.collect_all(solution, |_, output| {
132            if self.filter.test(solution, &output) {
133                count += 1;
134            }
135        });
136        count
137    }
138
139    fn initialize(&mut self, solution: &S) -> Sc {
140        self.reset();
141        let mut total = Sc::zero();
142        let source = &self.source;
143        let filter = &self.filter;
144        let weight = &self.weight;
145        let impact = self.impact_type;
146        let entity_contributions = &mut self.entity_contributions;
147        source.collect_all(solution, |coordinate, output| {
148            if !filter.test(solution, &output) {
149                return;
150            }
151            let base = weight(&output);
152            let contribution = match impact {
153                ImpactType::Penalty => -base,
154                ImpactType::Reward => base,
155            };
156            let mut contributions = entity_contributions
157                .remove(&(coordinate.source_slot, coordinate.entity_index))
158                .unwrap_or_default();
159            total = total + contribution;
160            contributions.push(contribution);
161            entity_contributions.insert(
162                (coordinate.source_slot, coordinate.entity_index),
163                contributions,
164            );
165        });
166        total
167    }
168
169    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
170        let mut total = Sc::zero();
171        for slot in self.localized_slots(descriptor_index) {
172            total = total + self.insert_entity_outputs(solution, slot, entity_index);
173        }
174        total
175    }
176
177    fn on_retract(&mut self, _solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
178        let mut total = Sc::zero();
179        for slot in self.localized_slots(descriptor_index) {
180            total = total + self.retract_entity_outputs(slot, entity_index);
181        }
182        total
183    }
184
185    fn reset(&mut self) {
186        self.entity_contributions.clear();
187    }
188
189    fn name(&self) -> &str {
190        &self.constraint_ref.name
191    }
192
193    fn is_hard(&self) -> bool {
194        self.is_hard
195    }
196
197    fn constraint_ref(&self) -> ConstraintRef {
198        self.constraint_ref.clone()
199    }
200}