solverforge_scoring/stream/factory.rs
1// Constraint factory for creating typed constraint streams.
2//
3// The factory is the entry point for the fluent constraint API.
4
5use std::hash::Hash;
6use std::marker::PhantomData;
7
8use solverforge_core::score::Score;
9
10use super::bi_stream::BiConstraintStream;
11use super::filter::TrueFilter;
12use super::joiner::EqualJoiner;
13use super::UniConstraintStream;
14
15// Factory for creating constraint streams.
16//
17// `ConstraintFactory` is parameterized by the solution type `S` and score type `Sc`.
18// It serves as the entry point for defining constraints using the fluent API.
19//
20// # Example
21//
22// ```
23// use solverforge_scoring::stream::ConstraintFactory;
24// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
25// use solverforge_core::score::SoftScore;
26//
27// #[derive(Clone)]
28// struct Solution {
29// values: Vec<Option<i32>>,
30// }
31//
32// let factory = ConstraintFactory::<Solution, SoftScore>::new();
33//
34// let constraint = factory
35// .for_each(|s: &Solution| &s.values)
36// .filter(|v: &Option<i32>| v.is_none())
37// .penalize(SoftScore::of(1))
38// .as_constraint("Unassigned");
39//
40// let solution = Solution { values: vec![Some(1), None, None] };
41// assert_eq!(constraint.evaluate(&solution), SoftScore::of(-2));
42// ```
43pub struct ConstraintFactory<S, Sc: Score> {
44 _phantom: PhantomData<(fn() -> S, fn() -> Sc)>,
45}
46
47impl<S, Sc> ConstraintFactory<S, Sc>
48where
49 S: Send + Sync + 'static,
50 Sc: Score + 'static,
51{
52 // Creates a new constraint factory.
53 pub fn new() -> Self {
54 Self {
55 _phantom: PhantomData,
56 }
57 }
58
59 // Creates a zero-erasure uni-constraint stream over entities extracted from the solution.
60 //
61 // The extractor function receives a reference to the solution and returns
62 // a slice of entities to iterate over. The extractor type is preserved
63 // as a concrete generic for full zero-erasure.
64 pub fn for_each<A, E>(self, extractor: E) -> UniConstraintStream<S, A, E, TrueFilter, Sc>
65 where
66 A: Clone + Send + Sync + 'static,
67 E: Fn(&S) -> &[A] + Send + Sync,
68 {
69 UniConstraintStream::new(extractor)
70 }
71
72 // Creates a zero-erasure bi-constraint stream over unique pairs of entities.
73 //
74 // This is equivalent to `for_each(extractor).join_self(joiner)` but provides
75 // a more concise API for the common case of self-joins with key-based grouping.
76 //
77 // Pairs are ordered (i, j) where i < j to avoid duplicates and self-pairs.
78 //
79 // # Example
80 //
81 // ```
82 // use solverforge_scoring::stream::{ConstraintFactory, joiner::equal};
83 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
84 // use solverforge_core::score::SoftScore;
85 //
86 // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
87 // struct Task { team: u32, priority: u32 }
88 //
89 // #[derive(Clone)]
90 // struct Solution { tasks: Vec<Task> }
91 //
92 // let factory = ConstraintFactory::<Solution, SoftScore>::new();
93 //
94 // // Penalize when two tasks on the same team conflict
95 // let constraint = factory
96 // .for_each_unique_pair(
97 // |s: &Solution| s.tasks.as_slice(),
98 // equal(|t: &Task| t.team)
99 // )
100 // .penalize(SoftScore::of(1))
101 // .as_constraint("Team conflict");
102 //
103 // let solution = Solution {
104 // tasks: vec![
105 // Task { team: 1, priority: 1 },
106 // Task { team: 1, priority: 2 }, // Same team as first
107 // Task { team: 2, priority: 1 },
108 // ],
109 // };
110 //
111 // // One pair on same team: (0, 1) = -1 penalty
112 // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1));
113 // ```
114 pub fn for_each_unique_pair<A, E, K, KA>(
115 self,
116 extractor: E,
117 joiner: EqualJoiner<KA, KA, K>,
118 ) -> BiConstraintStream<S, A, K, E, impl Fn(&S, &A, usize) -> K + Send + Sync, TrueFilter, Sc>
119 where
120 A: Clone + Hash + PartialEq + Send + Sync + 'static,
121 E: Fn(&S) -> &[A] + Send + Sync,
122 K: Eq + Hash + Clone + Send + Sync,
123 KA: Fn(&A) -> K + Send + Sync,
124 {
125 let (key_extractor, _) = joiner.into_keys();
126 // Wrap to match the new KE: Fn(&S, &A, usize) -> K signature
127 let wrapped_ke = move |_s: &S, a: &A, _idx: usize| key_extractor(a);
128 BiConstraintStream::new_self_join(extractor, wrapped_ke)
129 }
130}
131
132impl<S, Sc> Default for ConstraintFactory<S, Sc>
133where
134 S: Send + Sync + 'static,
135 Sc: Score + 'static,
136{
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142impl<S, Sc: Score> Clone for ConstraintFactory<S, Sc> {
143 fn clone(&self) -> Self {
144 Self {
145 _phantom: PhantomData,
146 }
147 }
148}
149
150impl<S, Sc: Score> std::fmt::Debug for ConstraintFactory<S, Sc> {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 f.debug_struct("ConstraintFactory").finish()
153 }
154}