Skip to main content

solverforge_scoring/stream/
bi_stream.rs

1/* Zero-erasure bi-constraint stream for self-join patterns.
2
3A `BiConstraintStream` operates on pairs of entities from the same
4collection (self-join), such as comparing Shifts to each other.
5All type information is preserved at compile time - no Arc, no dyn.
6
7# Example
8
9```
10use solverforge_scoring::stream::ConstraintFactory;
11use solverforge_scoring::stream::joiner::equal;
12use solverforge_scoring::api::constraint_set::IncrementalConstraint;
13use solverforge_core::score::SoftScore;
14
15#[derive(Clone, Debug, Hash, PartialEq, Eq)]
16struct Task { team: u32 }
17
18#[derive(Clone)]
19struct Solution { tasks: Vec<Task> }
20
21// Penalize when two tasks are on the same team
22let constraint = ConstraintFactory::<Solution, SoftScore>::new()
23.for_each(|s: &Solution| s.tasks.as_slice())
24.join(equal(|t: &Task| t.team))
25.penalize(SoftScore::of(1))
26.named("Team conflict");
27
28let solution = Solution {
29tasks: vec![
30Task { team: 1 },
31Task { team: 1 },
32Task { team: 2 },
33],
34};
35
36// One pair on team 1: (0, 1) = -1 penalty
37assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1));
38```
39*/
40
41use std::hash::Hash;
42
43use solverforge_core::score::Score;
44
45use crate::constraint::IncrementalBiConstraint;
46
47use super::filter::{BiFilter, FnTriFilter, TriFilter};
48use super::joiner::Joiner;
49use super::tri_stream::TriConstraintStream;
50
51super::arity_stream_macros::impl_arity_stream!(
52    bi,
53    BiConstraintStream,
54    BiConstraintBuilder,
55    IncrementalBiConstraint
56);
57
58// join method - transitions to TriConstraintStream
59impl<S, A, K, E, KE, F, Sc> BiConstraintStream<S, A, K, E, KE, F, Sc>
60where
61    S: Send + Sync + 'static,
62    A: Clone + Hash + PartialEq + Send + Sync + 'static,
63    K: Eq + Hash + Clone + Send + Sync,
64    E: super::collection_extract::CollectionExtract<S, Item = A>,
65    KE: Fn(&S, &A, usize) -> K + Send + Sync,
66    F: BiFilter<S, A, A>,
67    Sc: Score + 'static,
68{
69    /* Joins this stream with a third element to create triples.
70
71    # Example
72
73    ```
74    use solverforge_scoring::stream::ConstraintFactory;
75    use solverforge_scoring::stream::joiner::equal;
76    use solverforge_scoring::api::constraint_set::IncrementalConstraint;
77    use solverforge_core::score::SoftScore;
78
79    #[derive(Clone, Debug, Hash, PartialEq, Eq)]
80    struct Task { team: u32 }
81
82    #[derive(Clone)]
83    struct Solution { tasks: Vec<Task> }
84
85    // Penalize when three tasks are on the same team
86    let constraint = ConstraintFactory::<Solution, SoftScore>::new()
87    .for_each(|s: &Solution| s.tasks.as_slice())
88    .join(equal(|t: &Task| t.team))
89    .join(equal(|t: &Task| t.team))
90    .penalize(SoftScore::of(1))
91    .named("Team clustering");
92
93    let solution = Solution {
94    tasks: vec![
95    Task { team: 1 },
96    Task { team: 1 },
97    Task { team: 1 },
98    Task { team: 2 },
99    ],
100    };
101
102    // One triple on team 1: (0, 1, 2) = -1 penalty
103    assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1));
104    ```
105    */
106    pub fn join<J>(
107        self,
108        joiner: J,
109    ) -> TriConstraintStream<S, A, K, E, KE, impl TriFilter<S, A, A, A>, Sc>
110    where
111        J: Joiner<A, A> + 'static,
112        F: 'static,
113    {
114        let filter = self.filter;
115        let combined_filter =
116            move |s: &S, a: &A, b: &A, c: &A| filter.test(s, a, b, 0, 0) && joiner.matches(a, c);
117
118        TriConstraintStream::new_self_join_with_filter(
119            self.extractor,
120            self.key_extractor,
121            FnTriFilter::new(combined_filter),
122        )
123    }
124}
125
126// Additional doctests for individual methods
127
128#[cfg(doctest)]
129mod doctests {
130    /* # Filter method
131
132    ```
133    use solverforge_scoring::stream::ConstraintFactory;
134    use solverforge_scoring::stream::joiner::equal;
135    use solverforge_scoring::api::constraint_set::IncrementalConstraint;
136    use solverforge_core::score::SoftScore;
137
138    #[derive(Clone, Debug, Hash, PartialEq, Eq)]
139    struct Item { group: u32, value: i32 }
140
141    #[derive(Clone)]
142    struct Solution { items: Vec<Item> }
143
144    let constraint = ConstraintFactory::<Solution, SoftScore>::new()
145    .for_each(|s: &Solution| s.items.as_slice())
146    .join(equal(|i: &Item| i.group))
147    .filter(|a: &Item, b: &Item| a.value + b.value > 10)
148    .penalize(SoftScore::of(1))
149    .named("High sum pairs");
150
151    let solution = Solution {
152    items: vec![
153    Item { group: 1, value: 6 },
154    Item { group: 1, value: 7 },
155    ],
156    };
157
158    // 6+7=13 > 10, matches
159    assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1));
160    ```
161
162    # Penalize method
163
164    ```
165    use solverforge_scoring::stream::ConstraintFactory;
166    use solverforge_scoring::stream::joiner::equal;
167    use solverforge_scoring::api::constraint_set::IncrementalConstraint;
168    use solverforge_core::score::SoftScore;
169
170    #[derive(Clone, Debug, Hash, PartialEq, Eq)]
171    struct Task { priority: u32 }
172
173    #[derive(Clone)]
174    struct Solution { tasks: Vec<Task> }
175
176    let constraint = ConstraintFactory::<Solution, SoftScore>::new()
177    .for_each(|s: &Solution| s.tasks.as_slice())
178    .join(equal(|t: &Task| t.priority))
179    .penalize(SoftScore::of(5))
180    .named("Pair priority conflict");
181
182    let solution = Solution {
183    tasks: vec![
184    Task { priority: 1 },
185    Task { priority: 1 },
186    ],
187    };
188
189    // One pair = -5
190    assert_eq!(constraint.evaluate(&solution), SoftScore::of(-5));
191    ```
192
193    # Penalize with dynamic weight
194
195    ```
196    use solverforge_scoring::stream::ConstraintFactory;
197    use solverforge_scoring::stream::joiner::equal;
198    use solverforge_scoring::api::constraint_set::IncrementalConstraint;
199    use solverforge_core::score::SoftScore;
200
201    #[derive(Clone, Debug, Hash, PartialEq, Eq)]
202    struct Task { team: u32, cost: i64 }
203
204    #[derive(Clone)]
205    struct Solution { tasks: Vec<Task> }
206
207    let constraint = ConstraintFactory::<Solution, SoftScore>::new()
208    .for_each(|s: &Solution| s.tasks.as_slice())
209    .join(equal(|t: &Task| t.team))
210    .penalize_with(|a: &Task, b: &Task| {
211    SoftScore::of(a.cost + b.cost)
212    })
213    .named("Team cost");
214
215    let solution = Solution {
216    tasks: vec![
217    Task { team: 1, cost: 2 },
218    Task { team: 1, cost: 3 },
219    ],
220    };
221
222    // Penalty: 2+3 = -5
223    assert_eq!(constraint.evaluate(&solution), SoftScore::of(-5));
224    ```
225
226    # Reward method
227
228    ```
229    use solverforge_scoring::stream::ConstraintFactory;
230    use solverforge_scoring::stream::joiner::equal;
231    use solverforge_scoring::api::constraint_set::IncrementalConstraint;
232    use solverforge_core::score::SoftScore;
233
234    #[derive(Clone, Debug, Hash, PartialEq, Eq)]
235    struct Person { team: u32 }
236
237    #[derive(Clone)]
238    struct Solution { people: Vec<Person> }
239
240    let constraint = ConstraintFactory::<Solution, SoftScore>::new()
241    .for_each(|s: &Solution| s.people.as_slice())
242    .join(equal(|p: &Person| p.team))
243    .reward(SoftScore::of(10))
244    .named("Team synergy");
245
246    let solution = Solution {
247    people: vec![
248    Person { team: 1 },
249    Person { team: 1 },
250    ],
251    };
252
253    // One pair = +10
254    assert_eq!(constraint.evaluate(&solution), SoftScore::of(10));
255    ```
256
257    # named method
258
259    ```
260    use solverforge_scoring::stream::ConstraintFactory;
261    use solverforge_scoring::stream::joiner::equal;
262    use solverforge_scoring::api::constraint_set::IncrementalConstraint;
263    use solverforge_core::score::SoftScore;
264
265    #[derive(Clone, Debug, Hash, PartialEq, Eq)]
266    struct Item { id: usize }
267
268    #[derive(Clone)]
269    struct Solution { items: Vec<Item> }
270
271    let constraint = ConstraintFactory::<Solution, SoftScore>::new()
272    .for_each(|s: &Solution| s.items.as_slice())
273    .join(equal(|i: &Item| i.id))
274    .penalize(SoftScore::of(1))
275    .named("Pair items");
276
277    assert_eq!(constraint.name(), "Pair items");
278    ```
279    */
280}