Skip to main content

solverforge_scoring/constraint/projected/
bi.rs

1use std::collections::{HashMap, HashSet};
2use std::hash::Hash;
3use std::marker::PhantomData;
4
5use solverforge_core::score::Score;
6use solverforge_core::{ConstraintRef, ImpactType};
7
8use crate::api::constraint_set::IncrementalConstraint;
9use crate::stream::filter::{BiFilter, UniFilter};
10use crate::stream::{ProjectedRowCoordinate, ProjectedRowOwner, ProjectedSource};
11
12struct ProjectedJoinRow<Out> {
13    output: Out,
14    coordinate: ProjectedRowCoordinate,
15}
16
17pub struct ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
18where
19    Src: ProjectedSource<S, Out>,
20    Sc: Score,
21{
22    constraint_ref: ConstraintRef,
23    impact_type: ImpactType,
24    source: Src,
25    filter: F,
26    key_fn: KF,
27    pair_filter: PF,
28    weight: W,
29    is_hard: bool,
30    source_state: Option<Src::State>,
31    rows: Vec<Option<ProjectedJoinRow<Out>>>,
32    free_row_ids: Vec<usize>,
33    rows_by_owner: HashMap<ProjectedRowOwner, Vec<usize>>,
34    row_ids_by_coordinate: HashMap<ProjectedRowCoordinate, usize>,
35    rows_by_key: HashMap<K, Vec<usize>>,
36    _phantom: PhantomData<(fn() -> S, fn() -> Out, fn() -> Sc)>,
37}
38
39impl<S, Out, K, Src, F, KF, PF, W, Sc> ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
40where
41    S: Send + Sync + 'static,
42    Out: Send + Sync + 'static,
43    K: Eq + Hash + Send + Sync + 'static,
44    Src: ProjectedSource<S, Out>,
45    F: UniFilter<S, Out>,
46    KF: Fn(&Out) -> K + Send + Sync,
47    PF: BiFilter<S, Out, Out>,
48    W: Fn(&Out, &Out) -> Sc + Send + Sync,
49    Sc: Score + 'static,
50{
51    #[allow(clippy::too_many_arguments)]
52    pub fn new(
53        constraint_ref: ConstraintRef,
54        impact_type: ImpactType,
55        source: Src,
56        filter: F,
57        key_fn: KF,
58        pair_filter: PF,
59        weight: W,
60        is_hard: bool,
61    ) -> Self {
62        Self {
63            constraint_ref,
64            impact_type,
65            source,
66            filter,
67            key_fn,
68            pair_filter,
69            weight,
70            is_hard,
71            source_state: None,
72            rows: Vec::new(),
73            free_row_ids: Vec::new(),
74            rows_by_owner: HashMap::new(),
75            row_ids_by_coordinate: HashMap::new(),
76            rows_by_key: HashMap::new(),
77            _phantom: PhantomData,
78        }
79    }
80
81    fn compute_score(&self, left: &Out, right: &Out) -> Sc {
82        let base = (self.weight)(left, right);
83        match self.impact_type {
84            ImpactType::Penalty => -base,
85            ImpactType::Reward => base,
86        }
87    }
88
89    fn score_ordered_rows(
90        &self,
91        solution: &S,
92        first: &ProjectedJoinRow<Out>,
93        second: &ProjectedJoinRow<Out>,
94    ) -> Sc {
95        let (left, right) = if first.coordinate <= second.coordinate {
96            (first, second)
97        } else {
98            (second, first)
99        };
100        if !self
101            .pair_filter
102            .test(solution, &left.output, &right.output, 0, 1)
103        {
104            return Sc::zero();
105        }
106        self.compute_score(&left.output, &right.output)
107    }
108
109    fn score_candidate_row(
110        &self,
111        solution: &S,
112        candidate_output: &Out,
113        candidate_coordinate: ProjectedRowCoordinate,
114        other: &ProjectedJoinRow<Out>,
115    ) -> Sc {
116        let (left, right) = if candidate_coordinate <= other.coordinate {
117            (candidate_output, &other.output)
118        } else {
119            (&other.output, candidate_output)
120        };
121        if !self.pair_filter.test(solution, left, right, 0, 1) {
122            return Sc::zero();
123        }
124        self.compute_score(left, right)
125    }
126
127    fn score_pair(&self, solution: &S, first_id: usize, second_id: usize) -> Sc {
128        let Some(first) = self.rows.get(first_id).and_then(Option::as_ref) else {
129            return Sc::zero();
130        };
131        let Some(second) = self.rows.get(second_id).and_then(Option::as_ref) else {
132            return Sc::zero();
133        };
134        self.score_ordered_rows(solution, first, second)
135    }
136
137    fn ensure_source_state(&mut self, solution: &S) {
138        if self.source_state.is_none() {
139            self.source_state = Some(self.source.build_state(solution));
140        }
141    }
142
143    fn index_row_owners(&mut self, coordinate: ProjectedRowCoordinate, row_id: usize) {
144        coordinate.for_each_owner(|owner| {
145            self.rows_by_owner.entry(owner).or_default().push(row_id);
146        });
147    }
148
149    fn unindex_row_owners(&mut self, coordinate: ProjectedRowCoordinate, row_id: usize) {
150        coordinate.for_each_owner(|owner| {
151            let mut remove_bucket = false;
152            if let Some(ids) = self.rows_by_owner.get_mut(&owner) {
153                ids.retain(|candidate| *candidate != row_id);
154                remove_bucket = ids.is_empty();
155            }
156            if remove_bucket {
157                self.rows_by_owner.remove(&owner);
158            }
159        });
160    }
161
162    fn insert_row(&mut self, solution: &S, coordinate: ProjectedRowCoordinate, output: Out) -> Sc {
163        if self.row_ids_by_coordinate.contains_key(&coordinate) {
164            return Sc::zero();
165        }
166        let key = (self.key_fn)(&output);
167        let mut total = Sc::zero();
168        if let Some(existing) = self.rows_by_key.get(&key) {
169            for &other_id in existing {
170                if let Some(other) = self.rows.get(other_id).and_then(Option::as_ref) {
171                    total = total + self.score_candidate_row(solution, &output, coordinate, other);
172                }
173            }
174        }
175        let row = Some(ProjectedJoinRow { output, coordinate });
176        let row_id = if let Some(row_id) = self.free_row_ids.pop() {
177            debug_assert!(self.rows[row_id].is_none());
178            self.rows[row_id] = row;
179            row_id
180        } else {
181            let row_id = self.rows.len();
182            self.rows.push(row);
183            row_id
184        };
185        self.row_ids_by_coordinate.insert(coordinate, row_id);
186        self.index_row_owners(coordinate, row_id);
187        self.rows_by_key.entry(key).or_default().push(row_id);
188        total
189    }
190
191    fn retract_row(&mut self, solution: &S, row_id: usize) -> Sc {
192        let Some((key, coordinate)) = self
193            .rows
194            .get(row_id)
195            .and_then(Option::as_ref)
196            .map(|row| ((self.key_fn)(&row.output), row.coordinate))
197        else {
198            return Sc::zero();
199        };
200        let mut total = Sc::zero();
201        if let Some(candidates) = self.rows_by_key.get(&key) {
202            for &other_id in candidates {
203                if other_id == row_id {
204                    continue;
205                }
206                total = total - self.score_pair(solution, row_id, other_id);
207            }
208        }
209
210        if let Some(ids) = self.rows_by_key.get_mut(&key) {
211            ids.retain(|&id| id != row_id);
212            if ids.is_empty() {
213                self.rows_by_key.remove(&key);
214            }
215        }
216        self.row_ids_by_coordinate.remove(&coordinate);
217        self.unindex_row_owners(coordinate, row_id);
218        self.rows[row_id] = None;
219        self.free_row_ids.push(row_id);
220        total
221    }
222
223    fn evaluate_rows(&self, solution: &S) -> Vec<ProjectedJoinRow<Out>> {
224        let state = self.source.build_state(solution);
225        let mut rows = Vec::new();
226        self.source
227            .collect_all(solution, &state, |coordinate, output| {
228                if self.filter.test(solution, &output) {
229                    rows.push(ProjectedJoinRow { output, coordinate });
230                }
231            });
232        rows
233    }
234
235    fn score_evaluation_pair(
236        &self,
237        solution: &S,
238        first: &ProjectedJoinRow<Out>,
239        second: &ProjectedJoinRow<Out>,
240    ) -> Sc {
241        if (self.key_fn)(&first.output) == (self.key_fn)(&second.output) {
242            self.score_ordered_rows(solution, first, second)
243        } else {
244            Sc::zero()
245        }
246    }
247
248    fn evaluation_pair_matches(
249        &self,
250        solution: &S,
251        first: &ProjectedJoinRow<Out>,
252        second: &ProjectedJoinRow<Out>,
253    ) -> bool {
254        if (self.key_fn)(&first.output) != (self.key_fn)(&second.output) {
255            return false;
256        }
257        let (left, right) = if first.coordinate <= second.coordinate {
258            (first, second)
259        } else {
260            (second, first)
261        };
262        self.pair_filter
263            .test(solution, &left.output, &right.output, 0, 1)
264    }
265
266    fn localized_owners(
267        &self,
268        descriptor_index: usize,
269        entity_index: usize,
270    ) -> Vec<ProjectedRowOwner> {
271        let mut owners = Vec::new();
272        for slot in 0..self.source.source_count() {
273            if self
274                .source
275                .change_source(slot)
276                .assert_localizes(descriptor_index, &self.constraint_ref.name)
277            {
278                owners.push(ProjectedRowOwner {
279                    source_slot: slot,
280                    entity_index,
281                });
282            }
283        }
284        owners
285    }
286
287    fn row_ids_for_owners(&self, owners: &[ProjectedRowOwner]) -> Vec<usize> {
288        let mut seen = HashSet::new();
289        let mut row_ids = Vec::new();
290        for owner in owners {
291            let Some(ids) = self.rows_by_owner.get(owner) else {
292                continue;
293            };
294            for &row_id in ids {
295                if seen.insert(row_id) {
296                    row_ids.push(row_id);
297                }
298            }
299        }
300        row_ids
301    }
302
303    #[cfg(test)]
304    pub(crate) fn debug_row_storage_len(&self) -> usize {
305        self.rows.len()
306    }
307
308    #[cfg(test)]
309    pub(crate) fn debug_free_row_count(&self) -> usize {
310        self.free_row_ids.len()
311    }
312}
313
314impl<S, Out, K, Src, F, KF, PF, W, Sc> IncrementalConstraint<S, Sc>
315    for ProjectedBiConstraint<S, Out, K, Src, F, KF, PF, W, Sc>
316where
317    S: Send + Sync + 'static,
318    Out: Send + Sync + 'static,
319    K: Eq + Hash + Send + Sync + 'static,
320    Src: ProjectedSource<S, Out>,
321    F: UniFilter<S, Out>,
322    KF: Fn(&Out) -> K + Send + Sync,
323    PF: BiFilter<S, Out, Out>,
324    W: Fn(&Out, &Out) -> Sc + Send + Sync,
325    Sc: Score + 'static,
326{
327    fn evaluate(&self, solution: &S) -> Sc {
328        let rows = self.evaluate_rows(solution);
329
330        let mut total = Sc::zero();
331        for left_index in 0..rows.len() {
332            for right_index in (left_index + 1)..rows.len() {
333                total = total
334                    + self.score_evaluation_pair(solution, &rows[left_index], &rows[right_index]);
335            }
336        }
337        total
338    }
339
340    fn match_count(&self, solution: &S) -> usize {
341        let rows = self.evaluate_rows(solution);
342
343        let mut count = 0;
344        for left_index in 0..rows.len() {
345            for right_index in (left_index + 1)..rows.len() {
346                if self.evaluation_pair_matches(solution, &rows[left_index], &rows[right_index]) {
347                    count += 1;
348                }
349            }
350        }
351        count
352    }
353
354    fn initialize(&mut self, solution: &S) -> Sc {
355        self.reset();
356        let state = self.source.build_state(solution);
357        let mut rows = Vec::new();
358        self.source
359            .collect_all(solution, &state, |coordinate, output| {
360                if self.filter.test(solution, &output) {
361                    rows.push((coordinate, output));
362                }
363            });
364        self.source_state = Some(state);
365
366        rows.into_iter()
367            .fold(Sc::zero(), |total, (coordinate, output)| {
368                total + self.insert_row(solution, coordinate, output)
369            })
370    }
371
372    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
373        let owners = self.localized_owners(descriptor_index, entity_index);
374        self.ensure_source_state(solution);
375        {
376            let state = self.source_state.as_mut().expect("projected source state");
377            for owner in &owners {
378                self.source.insert_entity_state(
379                    solution,
380                    state,
381                    owner.source_slot,
382                    owner.entity_index,
383                );
384            }
385        }
386        let mut rows = Vec::new();
387        let state = self.source_state.as_ref().expect("projected source state");
388        for owner in &owners {
389            self.source.collect_entity(
390                solution,
391                state,
392                owner.source_slot,
393                owner.entity_index,
394                |coordinate, output| {
395                    if self.filter.test(solution, &output) {
396                        rows.push((coordinate, output));
397                    }
398                },
399            );
400        }
401        let mut total = Sc::zero();
402        for (coordinate, output) in rows {
403            total = total + self.insert_row(solution, coordinate, output);
404        }
405        total
406    }
407
408    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
409        let owners = self.localized_owners(descriptor_index, entity_index);
410        let mut total = Sc::zero();
411        for row_id in self.row_ids_for_owners(&owners) {
412            total = total + self.retract_row(solution, row_id);
413        }
414        if let Some(state) = self.source_state.as_mut() {
415            for owner in &owners {
416                self.source.retract_entity_state(
417                    solution,
418                    state,
419                    owner.source_slot,
420                    owner.entity_index,
421                );
422            }
423        }
424        total
425    }
426
427    fn reset(&mut self) {
428        self.source_state = None;
429        self.rows.clear();
430        self.free_row_ids.clear();
431        self.rows_by_owner.clear();
432        self.row_ids_by_coordinate.clear();
433        self.rows_by_key.clear();
434    }
435
436    fn name(&self) -> &str {
437        &self.constraint_ref.name
438    }
439
440    fn constraint_ref(&self) -> &ConstraintRef {
441        &self.constraint_ref
442    }
443
444    fn is_hard(&self) -> bool {
445        self.is_hard
446    }
447}