Skip to main content

solverforge_scoring/stream/
bi_stream.rs

1// Zero-erasure bi-constraint stream for self-join patterns.
2//
3// A `BiConstraintStream` operates on pairs of entities from the same
4// collection (self-join), such as comparing Shifts to each other.
5// All type information is preserved at compile time - no Arc, no dyn.
6//
7// # Example
8//
9// ```
10// use solverforge_scoring::stream::ConstraintFactory;
11// use solverforge_scoring::stream::joiner::equal;
12// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
13// use solverforge_core::score::SoftScore;
14//
15// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
16// struct Task { team: u32 }
17//
18// #[derive(Clone)]
19// struct Solution { tasks: Vec<Task> }
20//
21// // Penalize when two tasks are on the same team
22// let constraint = ConstraintFactory::<Solution, SoftScore>::new()
23//     .for_each(|s: &Solution| s.tasks.as_slice())
24//     .join_self(equal(|t: &Task| t.team))
25//     .penalize(SoftScore::of(1))
26//     .as_constraint("Team conflict");
27//
28// let solution = Solution {
29//     tasks: vec![
30//         Task { team: 1 },
31//         Task { team: 1 },
32//         Task { team: 2 },
33//     ],
34// };
35//
36// // One pair on team 1: (0, 1) = -1 penalty
37// assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1));
38// ```
39
40use std::hash::Hash;
41
42use solverforge_core::score::Score;
43
44use crate::constraint::IncrementalBiConstraint;
45
46use super::filter::{BiFilter, FnTriFilter, TriFilter};
47use super::joiner::Joiner;
48use super::tri_stream::TriConstraintStream;
49
50super::arity_stream_macros::impl_arity_stream!(
51    bi,
52    BiConstraintStream,
53    BiConstraintBuilder,
54    IncrementalBiConstraint
55);
56
57// join_self method - transitions to TriConstraintStream
58impl<S, A, K, E, KE, F, Sc> BiConstraintStream<S, A, K, E, KE, F, Sc>
59where
60    S: Send + Sync + 'static,
61    A: Clone + Hash + PartialEq + Send + Sync + 'static,
62    K: Eq + Hash + Clone + Send + Sync,
63    E: Fn(&S) -> &[A] + Send + Sync,
64    KE: Fn(&S, &A, usize) -> K + Send + Sync,
65    F: BiFilter<S, A, A>,
66    Sc: Score + 'static,
67{
68    // Joins this stream with a third element to create triples.
69    //
70    // # Example
71    //
72    // ```
73    // use solverforge_scoring::stream::ConstraintFactory;
74    // use solverforge_scoring::stream::joiner::equal;
75    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
76    // use solverforge_core::score::SoftScore;
77    //
78    // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
79    // struct Task { team: u32 }
80    //
81    // #[derive(Clone)]
82    // struct Solution { tasks: Vec<Task> }
83    //
84    // // Penalize when three tasks are on the same team
85    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
86    //     .for_each(|s: &Solution| s.tasks.as_slice())
87    //     .join_self(equal(|t: &Task| t.team))
88    //     .join_self(equal(|t: &Task| t.team))
89    //     .penalize(SoftScore::of(1))
90    //     .as_constraint("Team clustering");
91    //
92    // let solution = Solution {
93    //     tasks: vec![
94    //         Task { team: 1 },
95    //         Task { team: 1 },
96    //         Task { team: 1 },
97    //         Task { team: 2 },
98    //     ],
99    // };
100    //
101    // // One triple on team 1: (0, 1, 2) = -1 penalty
102    // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1));
103    // ```
104    pub fn join_self<J>(
105        self,
106        joiner: J,
107    ) -> TriConstraintStream<S, A, K, E, KE, impl TriFilter<S, A, A, A>, Sc>
108    where
109        J: Joiner<A, A> + 'static,
110        F: 'static,
111    {
112        let filter = self.filter;
113        let combined_filter =
114            move |s: &S, a: &A, b: &A, c: &A| filter.test(s, a, b, 0, 0) && joiner.matches(a, c);
115
116        TriConstraintStream::new_self_join_with_filter(
117            self.extractor,
118            self.key_extractor,
119            FnTriFilter::new(combined_filter),
120        )
121    }
122}
123
124// Additional doctests for individual methods
125
126#[cfg(doctest)]
127mod doctests {
128    // # Filter method
129    //
130    // ```
131    // use solverforge_scoring::stream::ConstraintFactory;
132    // use solverforge_scoring::stream::joiner::equal;
133    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
134    // use solverforge_core::score::SoftScore;
135    //
136    // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
137    // struct Item { group: u32, value: i32 }
138    //
139    // #[derive(Clone)]
140    // struct Solution { items: Vec<Item> }
141    //
142    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
143    //     .for_each(|s: &Solution| s.items.as_slice())
144    //     .join_self(equal(|i: &Item| i.group))
145    //     .filter(|a: &Item, b: &Item| a.value + b.value > 10)
146    //     .penalize(SoftScore::of(1))
147    //     .as_constraint("High sum pairs");
148    //
149    // let solution = Solution {
150    //     items: vec![
151    //         Item { group: 1, value: 6 },
152    //         Item { group: 1, value: 7 },
153    //     ],
154    // };
155    //
156    // // 6+7=13 > 10, matches
157    // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1));
158    // ```
159    //
160    // # Penalize method
161    //
162    // ```
163    // use solverforge_scoring::stream::ConstraintFactory;
164    // use solverforge_scoring::stream::joiner::equal;
165    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
166    // use solverforge_core::score::SoftScore;
167    //
168    // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
169    // struct Task { priority: u32 }
170    //
171    // #[derive(Clone)]
172    // struct Solution { tasks: Vec<Task> }
173    //
174    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
175    //     .for_each(|s: &Solution| s.tasks.as_slice())
176    //     .join_self(equal(|t: &Task| t.priority))
177    //     .penalize(SoftScore::of(5))
178    //     .as_constraint("Pair priority conflict");
179    //
180    // let solution = Solution {
181    //     tasks: vec![
182    //         Task { priority: 1 },
183    //         Task { priority: 1 },
184    //     ],
185    // };
186    //
187    // // One pair = -5
188    // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-5));
189    // ```
190    //
191    // # Penalize with dynamic weight
192    //
193    // ```
194    // use solverforge_scoring::stream::ConstraintFactory;
195    // use solverforge_scoring::stream::joiner::equal;
196    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
197    // use solverforge_core::score::SoftScore;
198    //
199    // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
200    // struct Task { team: u32, cost: i64 }
201    //
202    // #[derive(Clone)]
203    // struct Solution { tasks: Vec<Task> }
204    //
205    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
206    //     .for_each(|s: &Solution| s.tasks.as_slice())
207    //     .join_self(equal(|t: &Task| t.team))
208    //     .penalize_with(|a: &Task, b: &Task| {
209    //         SoftScore::of(a.cost + b.cost)
210    //     })
211    //     .as_constraint("Team cost");
212    //
213    // let solution = Solution {
214    //     tasks: vec![
215    //         Task { team: 1, cost: 2 },
216    //         Task { team: 1, cost: 3 },
217    //     ],
218    // };
219    //
220    // // Penalty: 2+3 = -5
221    // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-5));
222    // ```
223    //
224    // # Reward method
225    //
226    // ```
227    // use solverforge_scoring::stream::ConstraintFactory;
228    // use solverforge_scoring::stream::joiner::equal;
229    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
230    // use solverforge_core::score::SoftScore;
231    //
232    // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
233    // struct Person { team: u32 }
234    //
235    // #[derive(Clone)]
236    // struct Solution { people: Vec<Person> }
237    //
238    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
239    //     .for_each(|s: &Solution| s.people.as_slice())
240    //     .join_self(equal(|p: &Person| p.team))
241    //     .reward(SoftScore::of(10))
242    //     .as_constraint("Team synergy");
243    //
244    // let solution = Solution {
245    //     people: vec![
246    //         Person { team: 1 },
247    //         Person { team: 1 },
248    //     ],
249    // };
250    //
251    // // One pair = +10
252    // assert_eq!(constraint.evaluate(&solution), SoftScore::of(10));
253    // ```
254    //
255    // # as_constraint method
256    //
257    // ```
258    // use solverforge_scoring::stream::ConstraintFactory;
259    // use solverforge_scoring::stream::joiner::equal;
260    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
261    // use solverforge_core::score::SoftScore;
262    //
263    // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
264    // struct Item { id: usize }
265    //
266    // #[derive(Clone)]
267    // struct Solution { items: Vec<Item> }
268    //
269    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
270    //     .for_each(|s: &Solution| s.items.as_slice())
271    //     .join_self(equal(|i: &Item| i.id))
272    //     .penalize(SoftScore::of(1))
273    //     .as_constraint("Pair items");
274    //
275    // assert_eq!(constraint.name(), "Pair items");
276    // ```
277}