Skip to main content

solverforge_scoring/stream/
tri_stream.rs

1// Zero-erasure tri-constraint stream for three-entity constraint patterns.
2//
3// A `TriConstraintStream` operates on triples of entities and supports
4// filtering, weighting, and constraint finalization. All type information
5// is preserved at compile time - no Arc, no dyn.
6//
7// # Example
8//
9// ```
10// use solverforge_scoring::stream::ConstraintFactory;
11// use solverforge_scoring::stream::joiner::equal;
12// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
13// use solverforge_core::score::SoftScore;
14//
15// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
16// struct Task { team: u32 }
17//
18// #[derive(Clone)]
19// struct Solution { tasks: Vec<Task> }
20//
21// // Penalize when three tasks are on the same team
22// let constraint = ConstraintFactory::<Solution, SoftScore>::new()
23//     .for_each(|s: &Solution| s.tasks.as_slice())
24//     .join(equal(|t: &Task| t.team))
25//     .join(equal(|t: &Task| t.team))
26//     .penalize(SoftScore::of(1))
27//     .named("Team clustering");
28//
29// let solution = Solution {
30//     tasks: vec![
31//         Task { team: 1 },
32//         Task { team: 1 },
33//         Task { team: 1 },
34//         Task { team: 2 },
35//     ],
36// };
37//
38// // One triple on team 1: (0, 1, 2) = -1 penalty
39// assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1));
40// ```
41
42use std::hash::Hash;
43
44use solverforge_core::score::Score;
45
46use crate::constraint::IncrementalTriConstraint;
47
48use super::filter::{FnQuadFilter, QuadFilter, TriFilter};
49use super::joiner::Joiner;
50use super::quad_stream::QuadConstraintStream;
51
52super::arity_stream_macros::impl_arity_stream!(
53    tri,
54    TriConstraintStream,
55    TriConstraintBuilder,
56    IncrementalTriConstraint
57);
58
59// join method - transitions to QuadConstraintStream
60impl<S, A, K, E, KE, F, Sc> TriConstraintStream<S, A, K, E, KE, F, Sc>
61where
62    S: Send + Sync + 'static,
63    A: Clone + Hash + PartialEq + Send + Sync + 'static,
64    K: Eq + Hash + Clone + Send + Sync,
65    E: Fn(&S) -> &[A] + Send + Sync,
66    KE: Fn(&S, &A, usize) -> K + Send + Sync,
67    F: TriFilter<S, A, A, A>,
68    Sc: Score + 'static,
69{
70    // Joins this stream with a fourth element to create quadruples.
71    //
72    // # Example
73    //
74    // ```
75    // use solverforge_scoring::stream::ConstraintFactory;
76    // use solverforge_scoring::stream::joiner::equal;
77    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
78    // use solverforge_core::score::SoftScore;
79    //
80    // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
81    // struct Task { team: u32 }
82    //
83    // #[derive(Clone)]
84    // struct Solution { tasks: Vec<Task> }
85    //
86    // // Penalize when four tasks are on the same team
87    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
88    //     .for_each(|s: &Solution| s.tasks.as_slice())
89    //     .join(equal(|t: &Task| t.team))
90    //     .join(equal(|t: &Task| t.team))
91    //     .join(equal(|t: &Task| t.team))
92    //     .penalize(SoftScore::of(1))
93    //     .named("Team clustering");
94    //
95    // let solution = Solution {
96    //     tasks: vec![
97    //         Task { team: 1 },
98    //         Task { team: 1 },
99    //         Task { team: 1 },
100    //         Task { team: 1 },
101    //         Task { team: 2 },
102    //     ],
103    // };
104    //
105    // // One quadruple on team 1: (0, 1, 2, 3) = -1 penalty
106    // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1));
107    // ```
108    pub fn join<J>(
109        self,
110        joiner: J,
111    ) -> QuadConstraintStream<S, A, K, E, KE, impl QuadFilter<S, A, A, A, A>, Sc>
112    where
113        J: Joiner<A, A> + 'static,
114        F: 'static,
115    {
116        let filter = self.filter;
117        let combined_filter = move |s: &S, a: &A, b: &A, c: &A, d: &A| {
118            filter.test(s, a, b, c) && joiner.matches(a, d)
119        };
120
121        QuadConstraintStream::new_self_join_with_filter(
122            self.extractor,
123            self.key_extractor,
124            FnQuadFilter::new(combined_filter),
125        )
126    }
127}
128
129// Additional doctests for individual methods
130
131#[cfg(doctest)]
132mod doctests {
133    // # Filter method
134    //
135    // ```
136    // use solverforge_scoring::stream::ConstraintFactory;
137    // use solverforge_scoring::stream::joiner::equal;
138    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
139    // use solverforge_core::score::SoftScore;
140    //
141    // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
142    // struct Item { group: u32, value: i32 }
143    //
144    // #[derive(Clone)]
145    // struct Solution { items: Vec<Item> }
146    //
147    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
148    //     .for_each(|s: &Solution| s.items.as_slice())
149    //     .join(equal(|i: &Item| i.group))
150    //     .join(equal(|i: &Item| i.group))
151    //     .filter(|a: &Item, b: &Item, c: &Item| a.value + b.value + c.value > 10)
152    //     .penalize(SoftScore::of(1))
153    //     .named("High sum triples");
154    //
155    // let solution = Solution {
156    //     items: vec![
157    //         Item { group: 1, value: 3 },
158    //         Item { group: 1, value: 4 },
159    //         Item { group: 1, value: 5 },
160    //     ],
161    // };
162    //
163    // // 3+4+5=12 > 10, matches
164    // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1));
165    // ```
166    //
167    // # Penalize method
168    //
169    // ```
170    // use solverforge_scoring::stream::ConstraintFactory;
171    // use solverforge_scoring::stream::joiner::equal;
172    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
173    // use solverforge_core::score::SoftScore;
174    //
175    // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
176    // struct Task { priority: u32 }
177    //
178    // #[derive(Clone)]
179    // struct Solution { tasks: Vec<Task> }
180    //
181    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
182    //     .for_each(|s: &Solution| s.tasks.as_slice())
183    //     .join(equal(|t: &Task| t.priority))
184    //     .join(equal(|t: &Task| t.priority))
185    //     .penalize(SoftScore::of(5))
186    //     .named("Triple priority conflict");
187    //
188    // let solution = Solution {
189    //     tasks: vec![
190    //         Task { priority: 1 },
191    //         Task { priority: 1 },
192    //         Task { priority: 1 },
193    //     ],
194    // };
195    //
196    // // One triple = -5
197    // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-5));
198    // ```
199    //
200    // # Penalize with dynamic weight
201    //
202    // ```
203    // use solverforge_scoring::stream::ConstraintFactory;
204    // use solverforge_scoring::stream::joiner::equal;
205    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
206    // use solverforge_core::score::SoftScore;
207    //
208    // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
209    // struct Task { team: u32, cost: i64 }
210    //
211    // #[derive(Clone)]
212    // struct Solution { tasks: Vec<Task> }
213    //
214    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
215    //     .for_each(|s: &Solution| s.tasks.as_slice())
216    //     .join(equal(|t: &Task| t.team))
217    //     .join(equal(|t: &Task| t.team))
218    //     .penalize_with(|a: &Task, b: &Task, c: &Task| {
219    //         SoftScore::of(a.cost + b.cost + c.cost)
220    //     })
221    //     .named("Team cost");
222    //
223    // let solution = Solution {
224    //     tasks: vec![
225    //         Task { team: 1, cost: 2 },
226    //         Task { team: 1, cost: 3 },
227    //         Task { team: 1, cost: 5 },
228    //     ],
229    // };
230    //
231    // // Penalty: 2+3+5 = -10
232    // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-10));
233    // ```
234    //
235    // # Reward method
236    //
237    // ```
238    // use solverforge_scoring::stream::ConstraintFactory;
239    // use solverforge_scoring::stream::joiner::equal;
240    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
241    // use solverforge_core::score::SoftScore;
242    //
243    // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
244    // struct Person { team: u32 }
245    //
246    // #[derive(Clone)]
247    // struct Solution { people: Vec<Person> }
248    //
249    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
250    //     .for_each(|s: &Solution| s.people.as_slice())
251    //     .join(equal(|p: &Person| p.team))
252    //     .join(equal(|p: &Person| p.team))
253    //     .reward(SoftScore::of(10))
254    //     .named("Team synergy");
255    //
256    // let solution = Solution {
257    //     people: vec![
258    //         Person { team: 1 },
259    //         Person { team: 1 },
260    //         Person { team: 1 },
261    //     ],
262    // };
263    //
264    // // One triple = +10
265    // assert_eq!(constraint.evaluate(&solution), SoftScore::of(10));
266    // ```
267    //
268    // # named method
269    //
270    // ```
271    // use solverforge_scoring::stream::ConstraintFactory;
272    // use solverforge_scoring::stream::joiner::equal;
273    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
274    // use solverforge_core::score::SoftScore;
275    //
276    // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
277    // struct Item { id: usize }
278    //
279    // #[derive(Clone)]
280    // struct Solution { items: Vec<Item> }
281    //
282    // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
283    //     .for_each(|s: &Solution| s.items.as_slice())
284    //     .join(equal(|i: &Item| i.id))
285    //     .join(equal(|i: &Item| i.id))
286    //     .penalize(SoftScore::of(1))
287    //     .named("Triple items");
288    //
289    // assert_eq!(constraint.name(), "Triple items");
290    // ```
291}