Skip to main content

solverforge_scoring/constraint/projected/
bi.rs

1use std::collections::{HashMap, HashSet};
2use std::hash::Hash;
3use std::marker::PhantomData;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::filter::{BiFilter, UniFilter};
10use crate::stream::{ProjectedRowCoordinate, ProjectedRowOwner, ProjectedSource};
11
12struct ProjectedJoinRow<Out> {
13    output: Out,
14    coordinate: ProjectedRowCoordinate,
15}
16
17pub struct ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
18where
19    Src: ProjectedSource<S, Out>,
20    Sc: Score,
21{
22    constraint_ref: ConstraintRef,
23    impact_type: ImpactType,
24    source: Src,
25    filter: F,
26    key_fn: KF,
27    pair_filter: PF,
28    weight: W,
29    is_hard: bool,
30    source_state: Option<Src::State>,
31    rows: Vec<Option<ProjectedJoinRow<Out>>>,
32    free_row_ids: Vec<usize>,
33    rows_by_owner: HashMap<ProjectedRowOwner, Vec<usize>>,
34    row_ids_by_coordinate: HashMap<ProjectedRowCoordinate, usize>,
35    rows_by_key: HashMap<K, Vec<usize>>,
36    _phantom: PhantomData<(fn() -> S, fn() -> Out, fn() -> Sc)>,
37}
38
39impl<S, Out, K, Src, F, KF, PF, W, Sc> ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
40where
41    S: Send + Sync + 'static,
42    Out: Send + Sync + 'static,
43    K: Eq + Hash + Send + Sync + 'static,
44    Src: ProjectedSource<S, Out>,
45    F: UniFilter<S, Out>,
46    KF: Fn(&Out) -> K + Send + Sync,
47    PF: BiFilter<S, Out, Out>,
48    W: Fn(&Out, &Out) -> Sc + Send + Sync,
49    Sc: Score + 'static,
50{
51    #[allow(clippy::too_many_arguments)]
52    pub fn new(
53        constraint_ref: ConstraintRef,
54        impact_type: ImpactType,
55        source: Src,
56        filter: F,
57        key_fn: KF,
58        pair_filter: PF,
59        weight: W,
60        is_hard: bool,
61    ) -> Self {
62        Self {
63            constraint_ref,
64            impact_type,
65            source,
66            filter,
67            key_fn,
68            pair_filter,
69            weight,
70            is_hard,
71            source_state: None,
72            rows: Vec::new(),
73            free_row_ids: Vec::new(),
74            rows_by_owner: HashMap::new(),
75            row_ids_by_coordinate: HashMap::new(),
76            rows_by_key: HashMap::new(),
77            _phantom: PhantomData,
78        }
79    }
80
81    fn compute_score(&self, left: &Out, right: &Out) -> Sc {
82        let base = (self.weight)(left, right);
83        match self.impact_type {
84            ImpactType::Penalty => -base,
85            ImpactType::Reward => base,
86        }
87    }
88
89    fn score_outputs(
90        &self,
91        solution: &S,
92        left: &Out,
93        right: &Out,
94        left_idx: usize,
95        right_idx: usize,
96    ) -> Sc {
97        if !self
98            .pair_filter
99            .test(solution, left, right, left_idx, right_idx)
100        {
101            return Sc::zero();
102        }
103        self.compute_score(left, right)
104    }
105
106    fn filter_index(coordinate: ProjectedRowCoordinate) -> usize {
107        coordinate.primary_owner.entity_index
108    }
109
110    fn score_retained_rows(
111        &self,
112        solution: &S,
113        first: &ProjectedJoinRow<Out>,
114        second: &ProjectedJoinRow<Out>,
115    ) -> Sc {
116        let (left, right) = if first.coordinate <= second.coordinate {
117            (first, second)
118        } else {
119            (second, first)
120        };
121        self.score_outputs(
122            solution,
123            &left.output,
124            &right.output,
125            Self::filter_index(left.coordinate),
126            Self::filter_index(right.coordinate),
127        )
128    }
129
130    fn score_candidate_row(
131        &self,
132        solution: &S,
133        candidate_output: &Out,
134        candidate_coordinate: ProjectedRowCoordinate,
135        other: &ProjectedJoinRow<Out>,
136    ) -> Sc {
137        let (left, right, left_idx, right_idx) = if candidate_coordinate <= other.coordinate {
138            (
139                candidate_output,
140                &other.output,
141                Self::filter_index(candidate_coordinate),
142                Self::filter_index(other.coordinate),
143            )
144        } else {
145            (
146                &other.output,
147                candidate_output,
148                Self::filter_index(other.coordinate),
149                Self::filter_index(candidate_coordinate),
150            )
151        };
152        self.score_outputs(solution, left, right, left_idx, right_idx)
153    }
154
155    fn score_pair(&self, solution: &S, first_id: usize, second_id: usize) -> Sc {
156        let Some(first) = self.rows.get(first_id).and_then(Option::as_ref) else {
157            return Sc::zero();
158        };
159        let Some(second) = self.rows.get(second_id).and_then(Option::as_ref) else {
160            return Sc::zero();
161        };
162        self.score_retained_rows(solution, first, second)
163    }
164
165    fn ensure_source_state(&mut self, solution: &S) {
166        if self.source_state.is_none() {
167            self.source_state = Some(self.source.build_state(solution));
168        }
169    }
170
171    fn index_row_owners(&mut self, coordinate: ProjectedRowCoordinate, row_id: usize) {
172        coordinate.for_each_owner(|owner| {
173            self.rows_by_owner.entry(owner).or_default().push(row_id);
174        });
175    }
176
177    fn unindex_row_owners(&mut self, coordinate: ProjectedRowCoordinate, row_id: usize) {
178        coordinate.for_each_owner(|owner| {
179            let mut remove_bucket = false;
180            if let Some(ids) = self.rows_by_owner.get_mut(&owner) {
181                ids.retain(|candidate| *candidate != row_id);
182                remove_bucket = ids.is_empty();
183            }
184            if remove_bucket {
185                self.rows_by_owner.remove(&owner);
186            }
187        });
188    }
189
190    fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
191        if self.row_ids_by_coordinate.contains_key(&coordinate) {
192            return Sc::zero();
193        }
194        let key = (self.key_fn)(&output);
195        let mut total = Sc::zero();
196        if let Some(existing) = self.rows_by_key.get(&key) {
197            for &other_id in existing {
198                if let Some(other) = self.rows.get(other_id).and_then(Option::as_ref) {
199                    total = total + self.score_candidate_row(solution, &output, coordinate, other);
200                }
201            }
202        }
203        let row = Some(ProjectedJoinRow { output, coordinate });
204        let row_id = if let Some(row_id) = self.free_row_ids.pop() {
205            debug_assert!(self.rows[row_id].is_none());
206            self.rows[row_id] = row;
207            row_id
208        } else {
209            let row_id = self.rows.len();
210            self.rows.push(row);
211            row_id
212        };
213        self.row_ids_by_coordinate.insert(coordinate, row_id);
214        self.index_row_owners(coordinate, row_id);
215        self.rows_by_key.entry(key).or_default().push(row_id);
216        total
217    }
218
219    fn retract_row(&mut self, solution: &S, row_id: usize) -> Sc {
220        let Some((key, coordinate)) = self
221            .rows
222            .get(row_id)
223            .and_then(Option::as_ref)
224            .map(|row| ((self.key_fn)(&row.output), row.coordinate))
225        else {
226            return Sc::zero();
227        };
228        let mut total = Sc::zero();
229        if let Some(candidates) = self.rows_by_key.get(&key) {
230            for &other_id in candidates {
231                if other_id == row_id {
232                    continue;
233                }
234                total = total - self.score_pair(solution, row_id, other_id);
235            }
236        }
237
238        if let Some(ids) = self.rows_by_key.get_mut(&key) {
239            ids.retain(|&id| id != row_id);
240            if ids.is_empty() {
241                self.rows_by_key.remove(&key);
242            }
243        }
244        self.row_ids_by_coordinate.remove(&coordinate);
245        self.unindex_row_owners(coordinate, row_id);
246        self.rows[row_id] = None;
247        self.free_row_ids.push(row_id);
248        total
249    }
250
251    fn evaluate_rows(&self, solution: &S) -> Vec<ProjectedJoinRow<Out>> {
252        let state = self.source.build_state(solution);
253        let mut rows = Vec::new();
254        self.source
255            .collect_all(solution, &state, |coordinate, output| {
256                if self.filter.test(solution, &output) {
257                    rows.push(ProjectedJoinRow { output, coordinate });
258                }
259            });
260        rows
261    }
262
263    fn score_evaluation_pair(
264        &self,
265        solution: &S,
266        first: &ProjectedJoinRow<Out>,
267        second: &ProjectedJoinRow<Out>,
268    ) -> Sc {
269        if (self.key_fn)(&first.output) == (self.key_fn)(&second.output) {
270            let (left, right) = if first.coordinate <= second.coordinate {
271                (first, second)
272            } else {
273                (second, first)
274            };
275            self.score_outputs(
276                solution,
277                &left.output,
278                &right.output,
279                Self::filter_index(left.coordinate),
280                Self::filter_index(right.coordinate),
281            )
282        } else {
283            Sc::zero()
284        }
285    }
286
287    fn evaluation_pair_matches(
288        &self,
289        solution: &S,
290        first: &ProjectedJoinRow<Out>,
291        second: &ProjectedJoinRow<Out>,
292    ) -> bool {
293        if (self.key_fn)(&first.output) != (self.key_fn)(&second.output) {
294            return false;
295        }
296        let (left, right) = if first.coordinate <= second.coordinate {
297            (first, second)
298        } else {
299            (second, first)
300        };
301        self.pair_filter.test(
302            solution,
303            &left.output,
304            &right.output,
305            Self::filter_index(left.coordinate),
306            Self::filter_index(right.coordinate),
307        )
308    }
309
310    fn localized_owners(
311        &self,
312        descriptor_index: usize,
313        entity_index: usize,
314    ) -> Vec<ProjectedRowOwner> {
315        let mut owners = Vec::new();
316        for slot in 0..self.source.source_count() {
317            if self
318                .source
319                .change_source(slot)
320                .assert_localizes(descriptor_index, &self.constraint_ref.name)
321            {
322                owners.push(ProjectedRowOwner {
323                    source_slot: slot,
324                    entity_index,
325                });
326            }
327        }
328        owners
329    }
330
331    fn row_ids_for_owners(&self, owners: &[ProjectedRowOwner]) -> Vec<usize> {
332        let mut seen = HashSet::new();
333        let mut row_ids = Vec::new();
334        for owner in owners {
335            let Some(ids) = self.rows_by_owner.get(owner) else {
336                continue;
337            };
338            for &row_id in ids {
339                if seen.insert(row_id) {
340                    row_ids.push(row_id);
341                }
342            }
343        }
344        row_ids
345    }
346
347    #[cfg(test)]
348    pub(crate) fn debug_row_storage_len(&self) -> usize {
349        self.rows.len()
350    }
351
352    #[cfg(test)]
353    pub(crate) fn debug_free_row_count(&self) -> usize {
354        self.free_row_ids.len()
355    }
356}
357
358impl<S, Out, K, Src, F, KF, PF, W, Sc> IncrementalConstraint<S, Sc>
359    for ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
360where
361    S: Send + Sync + 'static,
362    Out: Send + Sync + 'static,
363    K: Eq + Hash + Send + Sync + 'static,
364    Src: ProjectedSource<S, Out>,
365    F: UniFilter<S, Out>,
366    KF: Fn(&Out) -> K + Send + Sync,
367    PF: BiFilter<S, Out, Out>,
368    W: Fn(&Out, &Out) -> Sc + Send + Sync,
369    Sc: Score + 'static,
370{
371    fn evaluate(&self, solution: &S) -> Sc {
372        let rows = self.evaluate_rows(solution);
373
374        let mut total = Sc::zero();
375        for left_index in 0..rows.len() {
376            for right_index in (left_index + 1)..rows.len() {
377                total = total
378                    + self.score_evaluation_pair(solution, &rows[left_index], &rows[right_index]);
379            }
380        }
381        total
382    }
383
384    fn match_count(&self, solution: &S) -> usize {
385        let rows = self.evaluate_rows(solution);
386
387        let mut count = 0;
388        for left_index in 0..rows.len() {
389            for right_index in (left_index + 1)..rows.len() {
390                if self.evaluation_pair_matches(solution, &rows[left_index], &rows[right_index]) {
391                    count += 1;
392                }
393            }
394        }
395        count
396    }
397
398    fn initialize(&mut self, solution: &S) -> Sc {
399        self.reset();
400        let state = self.source.build_state(solution);
401        let mut rows = Vec::new();
402        self.source
403            .collect_all(solution, &state, |coordinate, output| {
404                if self.filter.test(solution, &output) {
405                    rows.push((coordinate, output));
406                }
407            });
408        self.source_state = Some(state);
409
410        rows.into_iter()
411            .fold(Sc::zero(), |total, (coordinate, output)| {
412                total + self.insert_row(solution, coordinate, output)
413            })
414    }
415
416    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
417        let owners = self.localized_owners(descriptor_index, entity_index);
418        self.ensure_source_state(solution);
419        {
420            let state = self.source_state.as_mut().expect("projected source state");
421            for owner in &owners {
422                self.source.insert_entity_state(
423                    solution,
424                    state,
425                    owner.source_slot,
426                    owner.entity_index,
427                );
428            }
429        }
430        let mut rows = Vec::new();
431        let state = self.source_state.as_ref().expect("projected source state");
432        for owner in &owners {
433            self.source.collect_entity(
434                solution,
435                state,
436                owner.source_slot,
437                owner.entity_index,
438                |coordinate, output| {
439                    if self.filter.test(solution, &output) {
440                        rows.push((coordinate, output));
441                    }
442                },
443            );
444        }
445        let mut total = Sc::zero();
446        for (coordinate, output) in rows {
447            total = total + self.insert_row(solution, coordinate, output);
448        }
449        total
450    }
451
452    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
453        let owners = self.localized_owners(descriptor_index, entity_index);
454        let mut total = Sc::zero();
455        for row_id in self.row_ids_for_owners(&owners) {
456            total = total + self.retract_row(solution, row_id);
457        }
458        if let Some(state) = self.source_state.as_mut() {
459            for owner in &owners {
460                self.source.retract_entity_state(
461                    solution,
462                    state,
463                    owner.source_slot,
464                    owner.entity_index,
465                );
466            }
467        }
468        total
469    }
470
471    fn reset(&mut self) {
472        self.source_state = None;
473        self.rows.clear();
474        self.free_row_ids.clear();
475        self.rows_by_owner.clear();
476        self.row_ids_by_coordinate.clear();
477        self.rows_by_key.clear();
478    }
479
480    fn name(&self) -> &str {
481        &self.constraint_ref.name
482    }
483
484    fn constraint_ref(&self) -> &ConstraintRef {
485        &self.constraint_ref
486    }
487
488    fn is_hard(&self) -> bool {
489        self.is_hard
490    }
491}