solverforge_scoring/stream/bi_stream.rs
1//! Zero-erasure bi-constraint stream for self-join patterns.
2//!
3//! A `BiConstraintStream` operates on pairs of entities from the same
4//! collection (self-join), such as comparing Shifts to each other.
5//! All type information is preserved at compile time - no Arc, no dyn.
6//!
7//! # Example
8//!
9//! ```
10//! use solverforge_scoring::stream::ConstraintFactory;
11//! use solverforge_scoring::stream::joiner::equal;
12//! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
13//! use solverforge_core::score::SimpleScore;
14//!
15//! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
16//! struct Task { team: u32 }
17//!
18//! #[derive(Clone)]
19//! struct Solution { tasks: Vec<Task> }
20//!
21//! // Penalize when two tasks are on the same team
22//! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
23//! .for_each(|s: &Solution| s.tasks.as_slice())
24//! .join_self(equal(|t: &Task| t.team))
25//! .penalize(SimpleScore::of(1))
26//! .as_constraint("Team conflict");
27//!
28//! let solution = Solution {
29//! tasks: vec![
30//! Task { team: 1 },
31//! Task { team: 1 },
32//! Task { team: 2 },
33//! ],
34//! };
35//!
36//! // One pair on team 1: (0, 1) = -1 penalty
37//! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
38//! ```
39
40use std::hash::Hash;
41
42use solverforge_core::score::Score;
43
44use crate::constraint::bi_incremental::IncrementalBiConstraint;
45
46use super::filter::{BiFilter, FnTriFilter};
47use super::joiner::Joiner;
48use super::tri_stream::TriConstraintStream;
49
50super::arity_stream_macros::impl_arity_stream!(
51 bi,
52 BiConstraintStream,
53 BiConstraintBuilder,
54 IncrementalBiConstraint
55);
56
57// join_self method - transitions to TriConstraintStream
58impl<S, A, K, E, KE, F, Sc> BiConstraintStream<S, A, K, E, KE, F, Sc>
59where
60 S: Send + Sync + 'static,
61 A: Clone + Hash + PartialEq + Send + Sync + 'static,
62 K: Eq + Hash + Clone + Send + Sync,
63 E: Fn(&S) -> &[A] + Send + Sync,
64 KE: Fn(&A) -> K + Send + Sync,
65 F: BiFilter<A, A>,
66 Sc: Score + 'static,
67{
68 /// Joins this stream with a third element to create triples.
69 ///
70 /// # Example
71 ///
72 /// ```
73 /// use solverforge_scoring::stream::ConstraintFactory;
74 /// use solverforge_scoring::stream::joiner::equal;
75 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
76 /// use solverforge_core::score::SimpleScore;
77 ///
78 /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
79 /// struct Task { team: u32 }
80 ///
81 /// #[derive(Clone)]
82 /// struct Solution { tasks: Vec<Task> }
83 ///
84 /// // Penalize when three tasks are on the same team
85 /// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
86 /// .for_each(|s: &Solution| s.tasks.as_slice())
87 /// .join_self(equal(|t: &Task| t.team))
88 /// .join_self(equal(|t: &Task| t.team))
89 /// .penalize(SimpleScore::of(1))
90 /// .as_constraint("Team clustering");
91 ///
92 /// let solution = Solution {
93 /// tasks: vec![
94 /// Task { team: 1 },
95 /// Task { team: 1 },
96 /// Task { team: 1 },
97 /// Task { team: 2 },
98 /// ],
99 /// };
100 ///
101 /// // One triple on team 1: (0, 1, 2) = -1 penalty
102 /// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
103 /// ```
104 pub fn join_self<J>(
105 self,
106 joiner: J,
107 ) -> TriConstraintStream<S, A, K, E, KE, impl super::filter::TriFilter<A, A, A>, Sc>
108 where
109 J: Joiner<A, A> + 'static,
110 F: 'static,
111 {
112 let filter = self.filter;
113 let combined_filter = move |a: &A, b: &A, c: &A| filter.test(a, b) && joiner.matches(a, c);
114
115 TriConstraintStream::new_self_join_with_filter(
116 self.extractor,
117 self.key_extractor,
118 FnTriFilter::new(combined_filter),
119 )
120 }
121}
122
123// Additional doctests for individual methods
124
125#[cfg(doctest)]
126mod doctests {
127 //! # Filter method
128 //!
129 //! ```
130 //! use solverforge_scoring::stream::ConstraintFactory;
131 //! use solverforge_scoring::stream::joiner::equal;
132 //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
133 //! use solverforge_core::score::SimpleScore;
134 //!
135 //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
136 //! struct Item { group: u32, value: i32 }
137 //!
138 //! #[derive(Clone)]
139 //! struct Solution { items: Vec<Item> }
140 //!
141 //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
142 //! .for_each(|s: &Solution| s.items.as_slice())
143 //! .join_self(equal(|i: &Item| i.group))
144 //! .filter(|a: &Item, b: &Item| a.value + b.value > 10)
145 //! .penalize(SimpleScore::of(1))
146 //! .as_constraint("High sum pairs");
147 //!
148 //! let solution = Solution {
149 //! items: vec![
150 //! Item { group: 1, value: 6 },
151 //! Item { group: 1, value: 7 },
152 //! ],
153 //! };
154 //!
155 //! // 6+7=13 > 10, matches
156 //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
157 //! ```
158 //!
159 //! # Penalize method
160 //!
161 //! ```
162 //! use solverforge_scoring::stream::ConstraintFactory;
163 //! use solverforge_scoring::stream::joiner::equal;
164 //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
165 //! use solverforge_core::score::SimpleScore;
166 //!
167 //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
168 //! struct Task { priority: u32 }
169 //!
170 //! #[derive(Clone)]
171 //! struct Solution { tasks: Vec<Task> }
172 //!
173 //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
174 //! .for_each(|s: &Solution| s.tasks.as_slice())
175 //! .join_self(equal(|t: &Task| t.priority))
176 //! .penalize(SimpleScore::of(5))
177 //! .as_constraint("Pair priority conflict");
178 //!
179 //! let solution = Solution {
180 //! tasks: vec![
181 //! Task { priority: 1 },
182 //! Task { priority: 1 },
183 //! ],
184 //! };
185 //!
186 //! // One pair = -5
187 //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-5));
188 //! ```
189 //!
190 //! # Penalize with dynamic weight
191 //!
192 //! ```
193 //! use solverforge_scoring::stream::ConstraintFactory;
194 //! use solverforge_scoring::stream::joiner::equal;
195 //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
196 //! use solverforge_core::score::SimpleScore;
197 //!
198 //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
199 //! struct Task { team: u32, cost: i64 }
200 //!
201 //! #[derive(Clone)]
202 //! struct Solution { tasks: Vec<Task> }
203 //!
204 //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
205 //! .for_each(|s: &Solution| s.tasks.as_slice())
206 //! .join_self(equal(|t: &Task| t.team))
207 //! .penalize_with(|a: &Task, b: &Task| {
208 //! SimpleScore::of(a.cost + b.cost)
209 //! })
210 //! .as_constraint("Team cost");
211 //!
212 //! let solution = Solution {
213 //! tasks: vec![
214 //! Task { team: 1, cost: 2 },
215 //! Task { team: 1, cost: 3 },
216 //! ],
217 //! };
218 //!
219 //! // Penalty: 2+3 = -5
220 //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-5));
221 //! ```
222 //!
223 //! # Reward method
224 //!
225 //! ```
226 //! use solverforge_scoring::stream::ConstraintFactory;
227 //! use solverforge_scoring::stream::joiner::equal;
228 //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
229 //! use solverforge_core::score::SimpleScore;
230 //!
231 //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
232 //! struct Person { team: u32 }
233 //!
234 //! #[derive(Clone)]
235 //! struct Solution { people: Vec<Person> }
236 //!
237 //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
238 //! .for_each(|s: &Solution| s.people.as_slice())
239 //! .join_self(equal(|p: &Person| p.team))
240 //! .reward(SimpleScore::of(10))
241 //! .as_constraint("Team synergy");
242 //!
243 //! let solution = Solution {
244 //! people: vec![
245 //! Person { team: 1 },
246 //! Person { team: 1 },
247 //! ],
248 //! };
249 //!
250 //! // One pair = +10
251 //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(10));
252 //! ```
253 //!
254 //! # as_constraint method
255 //!
256 //! ```
257 //! use solverforge_scoring::stream::ConstraintFactory;
258 //! use solverforge_scoring::stream::joiner::equal;
259 //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
260 //! use solverforge_core::score::SimpleScore;
261 //!
262 //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
263 //! struct Item { id: usize }
264 //!
265 //! #[derive(Clone)]
266 //! struct Solution { items: Vec<Item> }
267 //!
268 //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
269 //! .for_each(|s: &Solution| s.items.as_slice())
270 //! .join_self(equal(|i: &Item| i.id))
271 //! .penalize(SimpleScore::of(1))
272 //! .as_constraint("Pair items");
273 //!
274 //! assert_eq!(constraint.name(), "Pair items");
275 //! ```
276}