solverforge_scoring/stream/
bi_stream.rs

1//! Zero-erasure bi-constraint stream for self-join patterns.
2//!
3//! A `BiConstraintStream` operates on pairs of entities from the same
4//! collection (self-join), such as comparing Shifts to each other.
5//! All type information is preserved at compile time - no Arc, no dyn.
6//!
7//! # Example
8//!
9//! ```
10//! use solverforge_scoring::stream::ConstraintFactory;
11//! use solverforge_scoring::stream::joiner::equal;
12//! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
13//! use solverforge_core::score::SimpleScore;
14//!
15//! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
16//! struct Task { team: u32 }
17//!
18//! #[derive(Clone)]
19//! struct Solution { tasks: Vec<Task> }
20//!
21//! // Penalize when two tasks are on the same team
22//! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
23//!     .for_each(|s: &Solution| s.tasks.as_slice())
24//!     .join_self(equal(|t: &Task| t.team))
25//!     .penalize(SimpleScore::of(1))
26//!     .as_constraint("Team conflict");
27//!
28//! let solution = Solution {
29//!     tasks: vec![
30//!         Task { team: 1 },
31//!         Task { team: 1 },
32//!         Task { team: 2 },
33//!     ],
34//! };
35//!
36//! // One pair on team 1: (0, 1) = -1 penalty
37//! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
38//! ```
39
40use std::hash::Hash;
41
42use solverforge_core::score::Score;
43
44use crate::constraint::IncrementalBiConstraint;
45
46use super::filter::{BiFilter, FnTriFilter, TriFilter};
47use super::joiner::Joiner;
48use super::tri_stream::TriConstraintStream;
49
50super::arity_stream_macros::impl_arity_stream!(
51    bi,
52    BiConstraintStream,
53    BiConstraintBuilder,
54    IncrementalBiConstraint
55);
56
57// join_self method - transitions to TriConstraintStream
58impl<S, A, K, E, KE, F, Sc> BiConstraintStream<S, A, K, E, KE, F, Sc>
59where
60    S: Send + Sync + 'static,
61    A: Clone + Hash + PartialEq + Send + Sync + 'static,
62    K: Eq + Hash + Clone + Send + Sync,
63    E: Fn(&S) -> &[A] + Send + Sync,
64    KE: Fn(&A) -> K + Send + Sync,
65    F: BiFilter<S, A, A>,
66    Sc: Score + 'static,
67{
68    /// Joins this stream with a third element to create triples.
69    ///
70    /// # Example
71    ///
72    /// ```
73    /// use solverforge_scoring::stream::ConstraintFactory;
74    /// use solverforge_scoring::stream::joiner::equal;
75    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
76    /// use solverforge_core::score::SimpleScore;
77    ///
78    /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
79    /// struct Task { team: u32 }
80    ///
81    /// #[derive(Clone)]
82    /// struct Solution { tasks: Vec<Task> }
83    ///
84    /// // Penalize when three tasks are on the same team
85    /// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
86    ///     .for_each(|s: &Solution| s.tasks.as_slice())
87    ///     .join_self(equal(|t: &Task| t.team))
88    ///     .join_self(equal(|t: &Task| t.team))
89    ///     .penalize(SimpleScore::of(1))
90    ///     .as_constraint("Team clustering");
91    ///
92    /// let solution = Solution {
93    ///     tasks: vec![
94    ///         Task { team: 1 },
95    ///         Task { team: 1 },
96    ///         Task { team: 1 },
97    ///         Task { team: 2 },
98    ///     ],
99    /// };
100    ///
101    /// // One triple on team 1: (0, 1, 2) = -1 penalty
102    /// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
103    /// ```
104    pub fn join_self<J>(
105        self,
106        joiner: J,
107    ) -> TriConstraintStream<S, A, K, E, KE, impl TriFilter<S, A, A, A>, Sc>
108    where
109        J: Joiner<A, A> + 'static,
110        F: 'static,
111    {
112        let filter = self.filter;
113        let combined_filter =
114            move |s: &S, a: &A, b: &A, c: &A| filter.test(s, a, b) && joiner.matches(a, c);
115
116        TriConstraintStream::new_self_join_with_filter(
117            self.extractor,
118            self.key_extractor,
119            FnTriFilter::new(combined_filter),
120        )
121    }
122}
123
124// Additional doctests for individual methods
125
126#[cfg(doctest)]
127mod doctests {
128    //! # Filter method
129    //!
130    //! ```
131    //! use solverforge_scoring::stream::ConstraintFactory;
132    //! use solverforge_scoring::stream::joiner::equal;
133    //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
134    //! use solverforge_core::score::SimpleScore;
135    //!
136    //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
137    //! struct Item { group: u32, value: i32 }
138    //!
139    //! #[derive(Clone)]
140    //! struct Solution { items: Vec<Item> }
141    //!
142    //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
143    //!     .for_each(|s: &Solution| s.items.as_slice())
144    //!     .join_self(equal(|i: &Item| i.group))
145    //!     .filter(|a: &Item, b: &Item| a.value + b.value > 10)
146    //!     .penalize(SimpleScore::of(1))
147    //!     .as_constraint("High sum pairs");
148    //!
149    //! let solution = Solution {
150    //!     items: vec![
151    //!         Item { group: 1, value: 6 },
152    //!         Item { group: 1, value: 7 },
153    //!     ],
154    //! };
155    //!
156    //! // 6+7=13 > 10, matches
157    //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
158    //! ```
159    //!
160    //! # Penalize method
161    //!
162    //! ```
163    //! use solverforge_scoring::stream::ConstraintFactory;
164    //! use solverforge_scoring::stream::joiner::equal;
165    //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
166    //! use solverforge_core::score::SimpleScore;
167    //!
168    //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
169    //! struct Task { priority: u32 }
170    //!
171    //! #[derive(Clone)]
172    //! struct Solution { tasks: Vec<Task> }
173    //!
174    //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
175    //!     .for_each(|s: &Solution| s.tasks.as_slice())
176    //!     .join_self(equal(|t: &Task| t.priority))
177    //!     .penalize(SimpleScore::of(5))
178    //!     .as_constraint("Pair priority conflict");
179    //!
180    //! let solution = Solution {
181    //!     tasks: vec![
182    //!         Task { priority: 1 },
183    //!         Task { priority: 1 },
184    //!     ],
185    //! };
186    //!
187    //! // One pair = -5
188    //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-5));
189    //! ```
190    //!
191    //! # Penalize with dynamic weight
192    //!
193    //! ```
194    //! use solverforge_scoring::stream::ConstraintFactory;
195    //! use solverforge_scoring::stream::joiner::equal;
196    //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
197    //! use solverforge_core::score::SimpleScore;
198    //!
199    //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
200    //! struct Task { team: u32, cost: i64 }
201    //!
202    //! #[derive(Clone)]
203    //! struct Solution { tasks: Vec<Task> }
204    //!
205    //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
206    //!     .for_each(|s: &Solution| s.tasks.as_slice())
207    //!     .join_self(equal(|t: &Task| t.team))
208    //!     .penalize_with(|a: &Task, b: &Task| {
209    //!         SimpleScore::of(a.cost + b.cost)
210    //!     })
211    //!     .as_constraint("Team cost");
212    //!
213    //! let solution = Solution {
214    //!     tasks: vec![
215    //!         Task { team: 1, cost: 2 },
216    //!         Task { team: 1, cost: 3 },
217    //!     ],
218    //! };
219    //!
220    //! // Penalty: 2+3 = -5
221    //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-5));
222    //! ```
223    //!
224    //! # Reward method
225    //!
226    //! ```
227    //! use solverforge_scoring::stream::ConstraintFactory;
228    //! use solverforge_scoring::stream::joiner::equal;
229    //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
230    //! use solverforge_core::score::SimpleScore;
231    //!
232    //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
233    //! struct Person { team: u32 }
234    //!
235    //! #[derive(Clone)]
236    //! struct Solution { people: Vec<Person> }
237    //!
238    //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
239    //!     .for_each(|s: &Solution| s.people.as_slice())
240    //!     .join_self(equal(|p: &Person| p.team))
241    //!     .reward(SimpleScore::of(10))
242    //!     .as_constraint("Team synergy");
243    //!
244    //! let solution = Solution {
245    //!     people: vec![
246    //!         Person { team: 1 },
247    //!         Person { team: 1 },
248    //!     ],
249    //! };
250    //!
251    //! // One pair = +10
252    //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(10));
253    //! ```
254    //!
255    //! # as_constraint method
256    //!
257    //! ```
258    //! use solverforge_scoring::stream::ConstraintFactory;
259    //! use solverforge_scoring::stream::joiner::equal;
260    //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
261    //! use solverforge_core::score::SimpleScore;
262    //!
263    //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
264    //! struct Item { id: usize }
265    //!
266    //! #[derive(Clone)]
267    //! struct Solution { items: Vec<Item> }
268    //!
269    //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
270    //!     .for_each(|s: &Solution| s.items.as_slice())
271    //!     .join_self(equal(|i: &Item| i.id))
272    //!     .penalize(SimpleScore::of(1))
273    //!     .as_constraint("Pair items");
274    //!
275    //! assert_eq!(constraint.name(), "Pair items");
276    //! ```
277}