solverforge_scoring/stream/
bi_stream.rs

1//! Zero-erasure bi-constraint stream for self-join patterns.
2//!
3//! A `BiConstraintStream` operates on pairs of entities from the same
4//! collection (self-join), such as comparing Shifts to each other.
5//! All type information is preserved at compile time - no Arc, no dyn.
6//!
7//! # Example
8//!
9//! ```
10//! use solverforge_scoring::stream::ConstraintFactory;
11//! use solverforge_scoring::stream::joiner::equal;
12//! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
13//! use solverforge_core::score::SimpleScore;
14//!
15//! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
16//! struct Task { team: u32 }
17//!
18//! #[derive(Clone)]
19//! struct Solution { tasks: Vec<Task> }
20//!
21//! // Penalize when two tasks are on the same team
22//! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
23//!     .for_each(|s: &Solution| s.tasks.as_slice())
24//!     .join_self(equal(|t: &Task| t.team))
25//!     .penalize(SimpleScore::of(1))
26//!     .as_constraint("Team conflict");
27//!
28//! let solution = Solution {
29//!     tasks: vec![
30//!         Task { team: 1 },
31//!         Task { team: 1 },
32//!         Task { team: 2 },
33//!     ],
34//! };
35//!
36//! // One pair on team 1: (0, 1) = -1 penalty
37//! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
38//! ```
39
40use std::hash::Hash;
41
42use solverforge_core::score::Score;
43
44use crate::constraint::bi_incremental::IncrementalBiConstraint;
45
46use super::filter::{BiFilter, FnTriFilter};
47use super::joiner::Joiner;
48use super::tri_stream::TriConstraintStream;
49
50super::arity_stream_macros::impl_arity_stream!(
51    bi,
52    BiConstraintStream,
53    BiConstraintBuilder,
54    IncrementalBiConstraint
55);
56
57// join_self method - transitions to TriConstraintStream
58impl<S, A, K, E, KE, F, Sc> BiConstraintStream<S, A, K, E, KE, F, Sc>
59where
60    S: Send + Sync + 'static,
61    A: Clone + Hash + PartialEq + Send + Sync + 'static,
62    K: Eq + Hash + Clone + Send + Sync,
63    E: Fn(&S) -> &[A] + Send + Sync,
64    KE: Fn(&A) -> K + Send + Sync,
65    F: BiFilter<A, A>,
66    Sc: Score + 'static,
67{
68    /// Joins this stream with a third element to create triples.
69    ///
70    /// # Example
71    ///
72    /// ```
73    /// use solverforge_scoring::stream::ConstraintFactory;
74    /// use solverforge_scoring::stream::joiner::equal;
75    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
76    /// use solverforge_core::score::SimpleScore;
77    ///
78    /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
79    /// struct Task { team: u32 }
80    ///
81    /// #[derive(Clone)]
82    /// struct Solution { tasks: Vec<Task> }
83    ///
84    /// // Penalize when three tasks are on the same team
85    /// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
86    ///     .for_each(|s: &Solution| s.tasks.as_slice())
87    ///     .join_self(equal(|t: &Task| t.team))
88    ///     .join_self(equal(|t: &Task| t.team))
89    ///     .penalize(SimpleScore::of(1))
90    ///     .as_constraint("Team clustering");
91    ///
92    /// let solution = Solution {
93    ///     tasks: vec![
94    ///         Task { team: 1 },
95    ///         Task { team: 1 },
96    ///         Task { team: 1 },
97    ///         Task { team: 2 },
98    ///     ],
99    /// };
100    ///
101    /// // One triple on team 1: (0, 1, 2) = -1 penalty
102    /// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
103    /// ```
104    pub fn join_self<J>(
105        self,
106        joiner: J,
107    ) -> TriConstraintStream<S, A, K, E, KE, impl super::filter::TriFilter<A, A, A>, Sc>
108    where
109        J: Joiner<A, A> + 'static,
110        F: 'static,
111    {
112        let filter = self.filter;
113        let combined_filter = move |a: &A, b: &A, c: &A| filter.test(a, b) && joiner.matches(a, c);
114
115        TriConstraintStream::new_self_join_with_filter(
116            self.extractor,
117            self.key_extractor,
118            FnTriFilter::new(combined_filter),
119        )
120    }
121}
122
123// Additional doctests for individual methods
124
125#[cfg(doctest)]
126mod doctests {
127    //! # Filter method
128    //!
129    //! ```
130    //! use solverforge_scoring::stream::ConstraintFactory;
131    //! use solverforge_scoring::stream::joiner::equal;
132    //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
133    //! use solverforge_core::score::SimpleScore;
134    //!
135    //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
136    //! struct Item { group: u32, value: i32 }
137    //!
138    //! #[derive(Clone)]
139    //! struct Solution { items: Vec<Item> }
140    //!
141    //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
142    //!     .for_each(|s: &Solution| s.items.as_slice())
143    //!     .join_self(equal(|i: &Item| i.group))
144    //!     .filter(|a: &Item, b: &Item| a.value + b.value > 10)
145    //!     .penalize(SimpleScore::of(1))
146    //!     .as_constraint("High sum pairs");
147    //!
148    //! let solution = Solution {
149    //!     items: vec![
150    //!         Item { group: 1, value: 6 },
151    //!         Item { group: 1, value: 7 },
152    //!     ],
153    //! };
154    //!
155    //! // 6+7=13 > 10, matches
156    //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
157    //! ```
158    //!
159    //! # Penalize method
160    //!
161    //! ```
162    //! use solverforge_scoring::stream::ConstraintFactory;
163    //! use solverforge_scoring::stream::joiner::equal;
164    //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
165    //! use solverforge_core::score::SimpleScore;
166    //!
167    //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
168    //! struct Task { priority: u32 }
169    //!
170    //! #[derive(Clone)]
171    //! struct Solution { tasks: Vec<Task> }
172    //!
173    //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
174    //!     .for_each(|s: &Solution| s.tasks.as_slice())
175    //!     .join_self(equal(|t: &Task| t.priority))
176    //!     .penalize(SimpleScore::of(5))
177    //!     .as_constraint("Pair priority conflict");
178    //!
179    //! let solution = Solution {
180    //!     tasks: vec![
181    //!         Task { priority: 1 },
182    //!         Task { priority: 1 },
183    //!     ],
184    //! };
185    //!
186    //! // One pair = -5
187    //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-5));
188    //! ```
189    //!
190    //! # Penalize with dynamic weight
191    //!
192    //! ```
193    //! use solverforge_scoring::stream::ConstraintFactory;
194    //! use solverforge_scoring::stream::joiner::equal;
195    //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
196    //! use solverforge_core::score::SimpleScore;
197    //!
198    //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
199    //! struct Task { team: u32, cost: i64 }
200    //!
201    //! #[derive(Clone)]
202    //! struct Solution { tasks: Vec<Task> }
203    //!
204    //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
205    //!     .for_each(|s: &Solution| s.tasks.as_slice())
206    //!     .join_self(equal(|t: &Task| t.team))
207    //!     .penalize_with(|a: &Task, b: &Task| {
208    //!         SimpleScore::of(a.cost + b.cost)
209    //!     })
210    //!     .as_constraint("Team cost");
211    //!
212    //! let solution = Solution {
213    //!     tasks: vec![
214    //!         Task { team: 1, cost: 2 },
215    //!         Task { team: 1, cost: 3 },
216    //!     ],
217    //! };
218    //!
219    //! // Penalty: 2+3 = -5
220    //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-5));
221    //! ```
222    //!
223    //! # Reward method
224    //!
225    //! ```
226    //! use solverforge_scoring::stream::ConstraintFactory;
227    //! use solverforge_scoring::stream::joiner::equal;
228    //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
229    //! use solverforge_core::score::SimpleScore;
230    //!
231    //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
232    //! struct Person { team: u32 }
233    //!
234    //! #[derive(Clone)]
235    //! struct Solution { people: Vec<Person> }
236    //!
237    //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
238    //!     .for_each(|s: &Solution| s.people.as_slice())
239    //!     .join_self(equal(|p: &Person| p.team))
240    //!     .reward(SimpleScore::of(10))
241    //!     .as_constraint("Team synergy");
242    //!
243    //! let solution = Solution {
244    //!     people: vec![
245    //!         Person { team: 1 },
246    //!         Person { team: 1 },
247    //!     ],
248    //! };
249    //!
250    //! // One pair = +10
251    //! assert_eq!(constraint.evaluate(&solution), SimpleScore::of(10));
252    //! ```
253    //!
254    //! # as_constraint method
255    //!
256    //! ```
257    //! use solverforge_scoring::stream::ConstraintFactory;
258    //! use solverforge_scoring::stream::joiner::equal;
259    //! use solverforge_scoring::api::constraint_set::IncrementalConstraint;
260    //! use solverforge_core::score::SimpleScore;
261    //!
262    //! #[derive(Clone, Debug, Hash, PartialEq, Eq)]
263    //! struct Item { id: usize }
264    //!
265    //! #[derive(Clone)]
266    //! struct Solution { items: Vec<Item> }
267    //!
268    //! let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
269    //!     .for_each(|s: &Solution| s.items.as_slice())
270    //!     .join_self(equal(|i: &Item| i.id))
271    //!     .penalize(SimpleScore::of(1))
272    //!     .as_constraint("Pair items");
273    //!
274    //! assert_eq!(constraint.name(), "Pair items");
275    //! ```
276}