solverforge_scoring/stream/bi_stream.rs
1//! Zero-erasure bi-constraint stream for self-join patterns.
2//!
3//! A `BiConstraintStream` operates on pairs of entities from the same
4//! collection (self-join), such as comparing Shifts to each other.
5//! All type information is preserved at compile time - no Arc, no dyn.
6//!
7//! # Example
8//!
9//! ```
10//! use solverforge_scoring::stream::ConstraintFactory;
11//! use solverforge_scoring::stream::joiner::equal;
12//! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
13//! use solverforge_core::score::SimpleScore;
14//!
15//! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
16//! struct Task { team: u32 }
17//!
18//! #[derive(Clone)]
19//! struct Solution { tasks: Vec<Task> }
20//!
21//! // Penalize when two tasks are on the same team
22//! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
23//! .for_each(|s: &Solution| s.tasks.as_slice())
24//! .join_self(equal(|t: &Task| t.team))
25//! .penalize(SimpleScore::of(1))
26//! .as_constraint("Team conflict");
27//!
28//! let solution = Solution {
29//! tasks: vec![
30//! Task { team: 1 },
31//! Task { team: 1 },
32//! Task { team: 2 },
33//! ],
34//! };
35//!
36//! // One pair on team 1: (0, 1) = -1 penalty
37//! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
38//! ```
39
40use std::hash::Hash;
41
42use solverforge_core::score::Score;
43
44use crate::constraint::IncrementalBiConstraint;
45
46use super::filter::{BiFilter, FnTriFilter, TriFilter};
47use super::joiner::Joiner;
48use super::tri_stream::TriConstraintStream;
49
50super::arity_stream_macros::impl_arity_stream!(
51 bi,
52 BiConstraintStream,
53 BiConstraintBuilder,
54 IncrementalBiConstraint
55);
56
57// join_self method - transitions to TriConstraintStream
58impl<S, A, K, E, KE, F, Sc> BiConstraintStream<S, A, K, E, KE, F, Sc>
59where
60 S: Send + Sync + 'static,
61 A: Clone + Hash + PartialEq + Send + Sync + 'static,
62 K: Eq + Hash + Clone + Send + Sync,
63 E: Fn(&S) -> &[A] + Send + Sync,
64 KE: Fn(&A) -> K + Send + Sync,
65 F: BiFilter<S, A, A>,
66 Sc: Score + 'static,
67{
68 /// Joins this stream with a third element to create triples.
69 ///
70 /// # Example
71 ///
72 /// ```
73 /// use solverforge_scoring::stream::ConstraintFactory;
74 /// use solverforge_scoring::stream::joiner::equal;
75 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
76 /// use solverforge_core::score::SimpleScore;
77 ///
78 /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
79 /// struct Task { team: u32 }
80 ///
81 /// #[derive(Clone)]
82 /// struct Solution { tasks: Vec<Task> }
83 ///
84 /// // Penalize when three tasks are on the same team
85 /// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
86 /// .for_each(|s: &Solution| s.tasks.as_slice())
87 /// .join_self(equal(|t: &Task| t.team))
88 /// .join_self(equal(|t: &Task| t.team))
89 /// .penalize(SimpleScore::of(1))
90 /// .as_constraint("Team clustering");
91 ///
92 /// let solution = Solution {
93 /// tasks: vec![
94 /// Task { team: 1 },
95 /// Task { team: 1 },
96 /// Task { team: 1 },
97 /// Task { team: 2 },
98 /// ],
99 /// };
100 ///
101 /// // One triple on team 1: (0, 1, 2) = -1 penalty
102 /// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
103 /// ```
104 pub fn join_self<J>(
105 self,
106 joiner: J,
107 ) -> TriConstraintStream<S, A, K, E, KE, impl TriFilter<S, A, A, A>, Sc>
108 where
109 J: Joiner<A, A> + 'static,
110 F: 'static,
111 {
112 let filter = self.filter;
113 let combined_filter =
114 move |s: &S, a: &A, b: &A, c: &A| filter.test(s, a, b) && joiner.matches(a, c);
115
116 TriConstraintStream::new_self_join_with_filter(
117 self.extractor,
118 self.key_extractor,
119 FnTriFilter::new(combined_filter),
120 )
121 }
122}
123
124// Additional doctests for individual methods
125
126#[cfg(doctest)]
127mod doctests {
128 //! # Filter method
129 //!
130 //! ```
131 //! use solverforge_scoring::stream::ConstraintFactory;
132 //! use solverforge_scoring::stream::joiner::equal;
133 //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
134 //! use solverforge_core::score::SimpleScore;
135 //!
136 //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
137 //! struct Item { group: u32, value: i32 }
138 //!
139 //! #[derive(Clone)]
140 //! struct Solution { items: Vec<Item> }
141 //!
142 //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
143 //! .for_each(|s: &Solution| s.items.as_slice())
144 //! .join_self(equal(|i: &Item| i.group))
145 //! .filter(|a: &Item, b: &Item| a.value + b.value > 10)
146 //! .penalize(SimpleScore::of(1))
147 //! .as_constraint("High sum pairs");
148 //!
149 //! let solution = Solution {
150 //! items: vec![
151 //! Item { group: 1, value: 6 },
152 //! Item { group: 1, value: 7 },
153 //! ],
154 //! };
155 //!
156 //! // 6+7=13 > 10, matches
157 //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
158 //! ```
159 //!
160 //! # Penalize method
161 //!
162 //! ```
163 //! use solverforge_scoring::stream::ConstraintFactory;
164 //! use solverforge_scoring::stream::joiner::equal;
165 //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
166 //! use solverforge_core::score::SimpleScore;
167 //!
168 //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
169 //! struct Task { priority: u32 }
170 //!
171 //! #[derive(Clone)]
172 //! struct Solution { tasks: Vec<Task> }
173 //!
174 //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
175 //! .for_each(|s: &Solution| s.tasks.as_slice())
176 //! .join_self(equal(|t: &Task| t.priority))
177 //! .penalize(SimpleScore::of(5))
178 //! .as_constraint("Pair priority conflict");
179 //!
180 //! let solution = Solution {
181 //! tasks: vec![
182 //! Task { priority: 1 },
183 //! Task { priority: 1 },
184 //! ],
185 //! };
186 //!
187 //! // One pair = -5
188 //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-5));
189 //! ```
190 //!
191 //! # Penalize with dynamic weight
192 //!
193 //! ```
194 //! use solverforge_scoring::stream::ConstraintFactory;
195 //! use solverforge_scoring::stream::joiner::equal;
196 //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
197 //! use solverforge_core::score::SimpleScore;
198 //!
199 //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
200 //! struct Task { team: u32, cost: i64 }
201 //!
202 //! #[derive(Clone)]
203 //! struct Solution { tasks: Vec<Task> }
204 //!
205 //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
206 //! .for_each(|s: &Solution| s.tasks.as_slice())
207 //! .join_self(equal(|t: &Task| t.team))
208 //! .penalize_with(|a: &Task, b: &Task| {
209 //! SimpleScore::of(a.cost + b.cost)
210 //! })
211 //! .as_constraint("Team cost");
212 //!
213 //! let solution = Solution {
214 //! tasks: vec![
215 //! Task { team: 1, cost: 2 },
216 //! Task { team: 1, cost: 3 },
217 //! ],
218 //! };
219 //!
220 //! // Penalty: 2+3 = -5
221 //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-5));
222 //! ```
223 //!
224 //! # Reward method
225 //!
226 //! ```
227 //! use solverforge_scoring::stream::ConstraintFactory;
228 //! use solverforge_scoring::stream::joiner::equal;
229 //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
230 //! use solverforge_core::score::SimpleScore;
231 //!
232 //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
233 //! struct Person { team: u32 }
234 //!
235 //! #[derive(Clone)]
236 //! struct Solution { people: Vec<Person> }
237 //!
238 //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
239 //! .for_each(|s: &Solution| s.people.as_slice())
240 //! .join_self(equal(|p: &Person| p.team))
241 //! .reward(SimpleScore::of(10))
242 //! .as_constraint("Team synergy");
243 //!
244 //! let solution = Solution {
245 //! people: vec![
246 //! Person { team: 1 },
247 //! Person { team: 1 },
248 //! ],
249 //! };
250 //!
251 //! // One pair = +10
252 //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(10));
253 //! ```
254 //!
255 //! # as_constraint method
256 //!
257 //! ```
258 //! use solverforge_scoring::stream::ConstraintFactory;
259 //! use solverforge_scoring::stream::joiner::equal;
260 //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
261 //! use solverforge_core::score::SimpleScore;
262 //!
263 //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
264 //! struct Item { id: usize }
265 //!
266 //! #[derive(Clone)]
267 //! struct Solution { items: Vec<Item> }
268 //!
269 //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
270 //! .for_each(|s: &Solution| s.items.as_slice())
271 //! .join_self(equal(|i: &Item| i.id))
272 //! .penalize(SimpleScore::of(1))
273 //! .as_constraint("Pair items");
274 //!
275 //! assert_eq!(constraint.name(), "Pair items");
276 //! ```
277}