Skip to main content

solverforge_scoring/constraint/projected/
bi.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3use std::marker::PhantomData;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::filter::{BiFilter, UniFilter};
10use crate::stream::{ProjectedRowCoordinate, ProjectedSource};
11
12struct ProjectedJoinRow<Out, K> {
13    key: K,
14    output: Out,
15    order: ProjectedRowCoordinate,
16}
17
18pub struct ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
19where
20    Sc: Score,
21{
22    constraint_ref: ConstraintRef,
23    impact_type: ImpactType,
24    source: Src,
25    filter: F,
26    key_fn: KF,
27    pair_filter: PF,
28    weight: W,
29    is_hard: bool,
30    rows: Vec<Option<ProjectedJoinRow<Out, K>>>,
31    free_row_ids: Vec<usize>,
32    rows_by_entity: HashMap<(usize, usize), Vec<usize>>,
33    rows_by_key: HashMap<K, Vec<usize>>,
34    _phantom: PhantomData<(fn() -> S, fn() -> Out, fn() -> Sc)>,
35}
36
37impl<S, Out, K, Src, F, KF, PF, W, Sc> ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
38where
39    S: Send + Sync + 'static,
40    Out: Clone + Send + Sync + 'static,
41    K: Clone + Eq + Hash + Send + Sync + 'static,
42    Src: ProjectedSource<S, Out>,
43    F: UniFilter<S, Out>,
44    KF: Fn(&Out) -> K + Send + Sync,
45    PF: BiFilter<S, Out, Out>,
46    W: Fn(&Out, &Out) -> Sc + Send + Sync,
47    Sc: Score + 'static,
48{
49    #[allow(clippy::too_many_arguments)]
50    pub fn new(
51        constraint_ref: ConstraintRef,
52        impact_type: ImpactType,
53        source: Src,
54        filter: F,
55        key_fn: KF,
56        pair_filter: PF,
57        weight: W,
58        is_hard: bool,
59    ) -> Self {
60        Self {
61            constraint_ref,
62            impact_type,
63            source,
64            filter,
65            key_fn,
66            pair_filter,
67            weight,
68            is_hard,
69            rows: Vec::new(),
70            free_row_ids: Vec::new(),
71            rows_by_entity: HashMap::new(),
72            rows_by_key: HashMap::new(),
73            _phantom: PhantomData,
74        }
75    }
76
77    fn compute_score(&self, left: &Out, right: &Out) -> Sc {
78        let base = (self.weight)(left, right);
79        match self.impact_type {
80            ImpactType::Penalty => -base,
81            ImpactType::Reward => base,
82        }
83    }
84
85    fn score_ordered_rows(
86        &self,
87        solution: &S,
88        first: &ProjectedJoinRow<Out, K>,
89        second: &ProjectedJoinRow<Out, K>,
90    ) -> Sc {
91        let (left, right) = if first.order <= second.order {
92            (first, second)
93        } else {
94            (second, first)
95        };
96        if !self
97            .pair_filter
98            .test(solution, &left.output, &right.output, 0, 1)
99        {
100            return Sc::zero();
101        }
102        self.compute_score(&left.output, &right.output)
103    }
104
105    fn score_pair(&self, solution: &S, first_id: usize, second_id: usize) -> Sc {
106        let Some(first) = self.rows.get(first_id).and_then(Option::as_ref) else {
107            return Sc::zero();
108        };
109        let Some(second) = self.rows.get(second_id).and_then(Option::as_ref) else {
110            return Sc::zero();
111        };
112        self.score_ordered_rows(solution, first, second)
113    }
114
115    fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
116        let key = (self.key_fn)(&output);
117        let existing = self.rows_by_key.get(&key).cloned().unwrap_or_default();
118        let row = Some(ProjectedJoinRow {
119            key: key.clone(),
120            output,
121            order: coordinate,
122        });
123        let row_id = if let Some(row_id) = self.free_row_ids.pop() {
124            debug_assert!(self.rows[row_id].is_none());
125            self.rows[row_id] = row;
126            row_id
127        } else {
128            let row_id = self.rows.len();
129            self.rows.push(row);
130            row_id
131        };
132        self.rows_by_entity
133            .entry((coordinate.source_slot, coordinate.entity_index))
134            .or_default()
135            .push(row_id);
136
137        let mut total = Sc::zero();
138        for other_id in existing {
139            total = total + self.score_pair(solution, row_id, other_id);
140        }
141        self.rows_by_key.entry(key).or_default().push(row_id);
142        total
143    }
144
145    fn retract_row(&mut self, solution: &S, row_id: usize) -> Sc {
146        let Some(row) = self.rows.get(row_id).and_then(Option::as_ref) else {
147            return Sc::zero();
148        };
149        let key = row.key.clone();
150        let candidates = self.rows_by_key.get(&key).cloned().unwrap_or_default();
151        let mut total = Sc::zero();
152        for other_id in candidates {
153            if other_id == row_id {
154                continue;
155            }
156            total = total - self.score_pair(solution, row_id, other_id);
157        }
158
159        if let Some(ids) = self.rows_by_key.get_mut(&key) {
160            ids.retain(|&id| id != row_id);
161            if ids.is_empty() {
162                self.rows_by_key.remove(&key);
163            }
164        }
165        self.rows[row_id] = None;
166        self.free_row_ids.push(row_id);
167        total
168    }
169
170    fn insert_entity_outputs(&mut self, solution: &S, slot: usize, entity_index: usize) -> Sc {
171        let mut outputs = Vec::new();
172        self.source
173            .collect_entity(solution, slot, entity_index, |coordinate, output| {
174                if self.filter.test(solution, &output) {
175                    outputs.push((coordinate, output));
176                }
177            });
178
179        outputs
180            .into_iter()
181            .fold(Sc::zero(), |total, (coordinate, output)| {
182                total + self.insert_row(solution, coordinate, output)
183            })
184    }
185
186    fn retract_entity_outputs(&mut self, solution: &S, slot: usize, entity_index: usize) -> Sc {
187        let Some(row_ids) = self.rows_by_entity.remove(&(slot, entity_index)) else {
188            return Sc::zero();
189        };
190        row_ids.into_iter().fold(Sc::zero(), |total, row_id| {
191            total + self.retract_row(solution, row_id)
192        })
193    }
194
195    fn evaluate_rows(&self, solution: &S) -> Vec<ProjectedJoinRow<Out, K>> {
196        let mut rows = Vec::new();
197        self.source.collect_all(solution, |coordinate, output| {
198            if self.filter.test(solution, &output) {
199                rows.push(ProjectedJoinRow {
200                    key: (self.key_fn)(&output),
201                    output,
202                    order: coordinate,
203                });
204            }
205        });
206        rows
207    }
208
209    fn score_evaluation_pair(
210        &self,
211        solution: &S,
212        first: &ProjectedJoinRow<Out, K>,
213        second: &ProjectedJoinRow<Out, K>,
214    ) -> Sc {
215        if first.key == second.key {
216            self.score_ordered_rows(solution, first, second)
217        } else {
218            Sc::zero()
219        }
220    }
221
222    fn evaluation_pair_matches(
223        &self,
224        solution: &S,
225        first: &ProjectedJoinRow<Out, K>,
226        second: &ProjectedJoinRow<Out, K>,
227    ) -> bool {
228        if first.key != second.key {
229            return false;
230        }
231        let (left, right) = if first.order <= second.order {
232            (first, second)
233        } else {
234            (second, first)
235        };
236        self.pair_filter
237            .test(solution, &left.output, &right.output, 0, 1)
238    }
239
240    fn localized_slots(&self, descriptor_index: usize) -> Vec<usize> {
241        let mut slots = Vec::new();
242        for slot in 0..self.source.source_count() {
243            if self
244                .source
245                .change_source(slot)
246                .assert_localizes(descriptor_index, &self.constraint_ref.name)
247            {
248                slots.push(slot);
249            }
250        }
251        slots
252    }
253
254    #[cfg(test)]
255    pub(crate) fn debug_row_storage_len(&self) -> usize {
256        self.rows.len()
257    }
258
259    #[cfg(test)]
260    pub(crate) fn debug_free_row_count(&self) -> usize {
261        self.free_row_ids.len()
262    }
263}
264
265impl<S, Out, K, Src, F, KF, PF, W, Sc> IncrementalConstraint<S, Sc>
266    for ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
267where
268    S: Send + Sync + 'static,
269    Out: Clone + Send + Sync + 'static,
270    K: Clone + Eq + Hash + Send + Sync + 'static,
271    Src: ProjectedSource<S, Out>,
272    F: UniFilter<S, Out>,
273    KF: Fn(&Out) -> K + Send + Sync,
274    PF: BiFilter<S, Out, Out>,
275    W: Fn(&Out, &Out) -> Sc + Send + Sync,
276    Sc: Score + 'static,
277{
278    fn evaluate(&self, solution: &S) -> Sc {
279        let rows = self.evaluate_rows(solution);
280
281        let mut total = Sc::zero();
282        for left_index in 0..rows.len() {
283            for right_index in (left_index + 1)..rows.len() {
284                total = total
285                    + self.score_evaluation_pair(solution, &rows[left_index], &rows[right_index]);
286            }
287        }
288        total
289    }
290
291    fn match_count(&self, solution: &S) -> usize {
292        let rows = self.evaluate_rows(solution);
293
294        let mut count = 0;
295        for left_index in 0..rows.len() {
296            for right_index in (left_index + 1)..rows.len() {
297                if self.evaluation_pair_matches(solution, &rows[left_index], &rows[right_index]) {
298                    count += 1;
299                }
300            }
301        }
302        count
303    }
304
305    fn initialize(&mut self, solution: &S) -> Sc {
306        self.reset();
307        let mut rows = Vec::new();
308        self.source.collect_all(solution, |coordinate, output| {
309            if self.filter.test(solution, &output) {
310                rows.push((coordinate, output));
311            }
312        });
313
314        rows.into_iter()
315            .fold(Sc::zero(), |total, (coordinate, output)| {
316                total + self.insert_row(solution, coordinate, output)
317            })
318    }
319
320    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
321        let mut total = Sc::zero();
322        for slot in self.localized_slots(descriptor_index) {
323            total = total + self.insert_entity_outputs(solution, slot, entity_index);
324        }
325        total
326    }
327
328    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
329        let mut total = Sc::zero();
330        for slot in self.localized_slots(descriptor_index) {
331            total = total + self.retract_entity_outputs(solution, slot, entity_index);
332        }
333        total
334    }
335
336    fn reset(&mut self) {
337        self.rows.clear();
338        self.free_row_ids.clear();
339        self.rows_by_entity.clear();
340        self.rows_by_key.clear();
341    }
342
343    fn name(&self) -> &str {
344        &self.constraint_ref.name
345    }
346
347    fn is_hard(&self) -> bool {
348        self.is_hard
349    }
350
351    fn constraint_ref(&self) -> ConstraintRef {
352        self.constraint_ref.clone()
353    }
354}