solverforge_scoring/stream/bi_stream.rs
1// Zero-erasure bi-constraint stream for self-join patterns.
2//
3// A `BiConstraintStream` operates on pairs of entities from the same
4// collection (self-join), such as comparing Shifts to each other.
5// All type information is preserved at compile time - no Arc, no dyn.
6//
7// # Example
8//
9// ```
10// use solverforge_scoring::stream::ConstraintFactory;
11// use solverforge_scoring::stream::joiner::equal;
12// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
13// use solverforge_core::score::SoftScore;
14//
15// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
16// struct Task { team: u32 }
17//
18// #[derive(Clone)]
19// struct Solution { tasks: Vec<Task> }
20//
21// // Penalize when two tasks are on the same team
22// let constraint = ConstraintFactory::<Solution, SoftScore>::new()
23// .for_each(|s: &Solution| s.tasks.as_slice())
24// .join_self(equal(|t: &Task| t.team))
25// .penalize(SoftScore::of(1))
26// .as_constraint("Team conflict");
27//
28// let solution = Solution {
29// tasks: vec![
30// Task { team: 1 },
31// Task { team: 1 },
32// Task { team: 2 },
33// ],
34// };
35//
36// // One pair on team 1: (0, 1) = -1 penalty
37// assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1));
38// ```
39
40use std::hash::Hash;
41
42use solverforge_core::score::Score;
43
44use crate::constraint::IncrementalBiConstraint;
45
46use super::filter::{BiFilter, FnTriFilter, TriFilter};
47use super::joiner::Joiner;
48use super::tri_stream::TriConstraintStream;
49
50super::arity_stream_macros::impl_arity_stream!(
51 bi,
52 BiConstraintStream,
53 BiConstraintBuilder,
54 IncrementalBiConstraint
55);
56
57// join_self method - transitions to TriConstraintStream
58impl<S, A, K, E, KE, F, Sc> BiConstraintStream<S, A, K, E, KE, F, Sc>
59where
60 S: Send + Sync + 'static,
61 A: Clone + Hash + PartialEq + Send + Sync + 'static,
62 K: Eq + Hash + Clone + Send + Sync,
63 E: Fn(&S) -> &[A] + Send + Sync,
64 KE: Fn(&S, &A, usize) -> K + Send + Sync,
65 F: BiFilter<S, A, A>,
66 Sc: Score + 'static,
67{
68 // Joins this stream with a third element to create triples.
69 //
70 // # Example
71 //
72 // ```
73 // use solverforge_scoring::stream::ConstraintFactory;
74 // use solverforge_scoring::stream::joiner::equal;
75 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
76 // use solverforge_core::score::SoftScore;
77 //
78 // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
79 // struct Task { team: u32 }
80 //
81 // #[derive(Clone)]
82 // struct Solution { tasks: Vec<Task> }
83 //
84 // // Penalize when three tasks are on the same team
85 // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
86 // .for_each(|s: &Solution| s.tasks.as_slice())
87 // .join_self(equal(|t: &Task| t.team))
88 // .join_self(equal(|t: &Task| t.team))
89 // .penalize(SoftScore::of(1))
90 // .as_constraint("Team clustering");
91 //
92 // let solution = Solution {
93 // tasks: vec![
94 // Task { team: 1 },
95 // Task { team: 1 },
96 // Task { team: 1 },
97 // Task { team: 2 },
98 // ],
99 // };
100 //
101 // // One triple on team 1: (0, 1, 2) = -1 penalty
102 // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1));
103 // ```
104 pub fn join_self<J>(
105 self,
106 joiner: J,
107 ) -> TriConstraintStream<S, A, K, E, KE, impl TriFilter<S, A, A, A>, Sc>
108 where
109 J: Joiner<A, A> + 'static,
110 F: 'static,
111 {
112 let filter = self.filter;
113 let combined_filter =
114 move |s: &S, a: &A, b: &A, c: &A| filter.test(s, a, b, 0, 0) && joiner.matches(a, c);
115
116 TriConstraintStream::new_self_join_with_filter(
117 self.extractor,
118 self.key_extractor,
119 FnTriFilter::new(combined_filter),
120 )
121 }
122}
123
124// Additional doctests for individual methods
125
126#[cfg(doctest)]
127mod doctests {
128 // # Filter method
129 //
130 // ```
131 // use solverforge_scoring::stream::ConstraintFactory;
132 // use solverforge_scoring::stream::joiner::equal;
133 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
134 // use solverforge_core::score::SoftScore;
135 //
136 // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
137 // struct Item { group: u32, value: i32 }
138 //
139 // #[derive(Clone)]
140 // struct Solution { items: Vec<Item> }
141 //
142 // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
143 // .for_each(|s: &Solution| s.items.as_slice())
144 // .join_self(equal(|i: &Item| i.group))
145 // .filter(|a: &Item, b: &Item| a.value + b.value > 10)
146 // .penalize(SoftScore::of(1))
147 // .as_constraint("High sum pairs");
148 //
149 // let solution = Solution {
150 // items: vec![
151 // Item { group: 1, value: 6 },
152 // Item { group: 1, value: 7 },
153 // ],
154 // };
155 //
156 // // 6+7=13 > 10, matches
157 // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1));
158 // ```
159 //
160 // # Penalize method
161 //
162 // ```
163 // use solverforge_scoring::stream::ConstraintFactory;
164 // use solverforge_scoring::stream::joiner::equal;
165 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
166 // use solverforge_core::score::SoftScore;
167 //
168 // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
169 // struct Task { priority: u32 }
170 //
171 // #[derive(Clone)]
172 // struct Solution { tasks: Vec<Task> }
173 //
174 // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
175 // .for_each(|s: &Solution| s.tasks.as_slice())
176 // .join_self(equal(|t: &Task| t.priority))
177 // .penalize(SoftScore::of(5))
178 // .as_constraint("Pair priority conflict");
179 //
180 // let solution = Solution {
181 // tasks: vec![
182 // Task { priority: 1 },
183 // Task { priority: 1 },
184 // ],
185 // };
186 //
187 // // One pair = -5
188 // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-5));
189 // ```
190 //
191 // # Penalize with dynamic weight
192 //
193 // ```
194 // use solverforge_scoring::stream::ConstraintFactory;
195 // use solverforge_scoring::stream::joiner::equal;
196 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
197 // use solverforge_core::score::SoftScore;
198 //
199 // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
200 // struct Task { team: u32, cost: i64 }
201 //
202 // #[derive(Clone)]
203 // struct Solution { tasks: Vec<Task> }
204 //
205 // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
206 // .for_each(|s: &Solution| s.tasks.as_slice())
207 // .join_self(equal(|t: &Task| t.team))
208 // .penalize_with(|a: &Task, b: &Task| {
209 // SoftScore::of(a.cost + b.cost)
210 // })
211 // .as_constraint("Team cost");
212 //
213 // let solution = Solution {
214 // tasks: vec![
215 // Task { team: 1, cost: 2 },
216 // Task { team: 1, cost: 3 },
217 // ],
218 // };
219 //
220 // // Penalty: 2+3 = -5
221 // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-5));
222 // ```
223 //
224 // # Reward method
225 //
226 // ```
227 // use solverforge_scoring::stream::ConstraintFactory;
228 // use solverforge_scoring::stream::joiner::equal;
229 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
230 // use solverforge_core::score::SoftScore;
231 //
232 // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
233 // struct Person { team: u32 }
234 //
235 // #[derive(Clone)]
236 // struct Solution { people: Vec<Person> }
237 //
238 // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
239 // .for_each(|s: &Solution| s.people.as_slice())
240 // .join_self(equal(|p: &Person| p.team))
241 // .reward(SoftScore::of(10))
242 // .as_constraint("Team synergy");
243 //
244 // let solution = Solution {
245 // people: vec![
246 // Person { team: 1 },
247 // Person { team: 1 },
248 // ],
249 // };
250 //
251 // // One pair = +10
252 // assert_eq!(constraint.evaluate(&solution), SoftScore::of(10));
253 // ```
254 //
255 // # as_constraint method
256 //
257 // ```
258 // use solverforge_scoring::stream::ConstraintFactory;
259 // use solverforge_scoring::stream::joiner::equal;
260 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
261 // use solverforge_core::score::SoftScore;
262 //
263 // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
264 // struct Item { id: usize }
265 //
266 // #[derive(Clone)]
267 // struct Solution { items: Vec<Item> }
268 //
269 // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
270 // .for_each(|s: &Solution| s.items.as_slice())
271 // .join_self(equal(|i: &Item| i.id))
272 // .penalize(SoftScore::of(1))
273 // .as_constraint("Pair items");
274 //
275 // assert_eq!(constraint.name(), "Pair items");
276 // ```
277}