solverforge_scoring/stream/factory.rs
1//! Constraint factory for creating typed constraint streams.
2//!
3//! The factory is the entry point for the fluent constraint API.
4
5use std::hash::Hash;
6use std::marker::PhantomData;
7
8use solverforge_core::score::Score;
9
10use super::bi_stream::BiConstraintStream;
11use super::filter::TrueFilter;
12use super::joiner::EqualJoiner;
13use super::UniConstraintStream;
14
15/// Factory for creating constraint streams.
16///
17/// `ConstraintFactory` is parameterized by the solution type `S` and score type `Sc`.
18/// It serves as the entry point for defining constraints using the fluent API.
19///
20/// # Example
21///
22/// ```
23/// use solverforge_scoring::stream::ConstraintFactory;
24/// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
25/// use solverforge_core::score::SimpleScore;
26///
27/// #[derive(Clone)]
28/// struct Solution {
29/// values: Vec<Option<i32>>,
30/// }
31///
32/// let factory = ConstraintFactory::<Solution, SimpleScore>::new();
33///
34/// let constraint = factory
35/// .for_each(|s: &Solution| &s.values)
36/// .filter(|v: &Option<i32>| v.is_none())
37/// .penalize(SimpleScore::of(1))
38/// .as_constraint("Unassigned");
39///
40/// let solution = Solution { values: vec![Some(1), None, None] };
41/// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-2));
42/// ```
43pub struct ConstraintFactory<S, Sc: Score> {
44 _phantom: PhantomData<(S, Sc)>,
45}
46
47impl<S, Sc> ConstraintFactory<S, Sc>
48where
49 S: Send + Sync + 'static,
50 Sc: Score + 'static,
51{
52 /// Creates a new constraint factory.
53 pub fn new() -> Self {
54 Self {
55 _phantom: PhantomData,
56 }
57 }
58
59 /// Creates a zero-erasure uni-constraint stream over entities extracted from the solution.
60 ///
61 /// The extractor function receives a reference to the solution and returns
62 /// a slice of entities to iterate over. The extractor type is preserved
63 /// as a concrete generic for full zero-erasure.
64 pub fn for_each<A, E>(self, extractor: E) -> UniConstraintStream<S, A, E, TrueFilter, Sc>
65 where
66 A: Clone + Send + Sync + 'static,
67 E: Fn(&S) -> &[A] + Send + Sync,
68 {
69 UniConstraintStream::new(extractor)
70 }
71
72 /// Creates a zero-erasure bi-constraint stream over unique pairs of entities.
73 ///
74 /// This is equivalent to `for_each(extractor).join_self(joiner)` but provides
75 /// a more concise API for the common case of self-joins with key-based grouping.
76 ///
77 /// Pairs are ordered (i, j) where i < j to avoid duplicates and self-pairs.
78 ///
79 /// # Example
80 ///
81 /// ```
82 /// use solverforge_scoring::stream::{ConstraintFactory, joiner::equal};
83 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
84 /// use solverforge_core::score::SimpleScore;
85 ///
86 /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
87 /// struct Task { team: u32, priority: u32 }
88 ///
89 /// #[derive(Clone)]
90 /// struct Solution { tasks: Vec<Task> }
91 ///
92 /// let factory = ConstraintFactory::<Solution, SimpleScore>::new();
93 ///
94 /// // Penalize when two tasks on the same team conflict
95 /// let constraint = factory
96 /// .for_each_unique_pair(
97 /// |s: &Solution| s.tasks.as_slice(),
98 /// equal(|t: &Task| t.team)
99 /// )
100 /// .penalize(SimpleScore::of(1))
101 /// .as_constraint("Team conflict");
102 ///
103 /// let solution = Solution {
104 /// tasks: vec![
105 /// Task { team: 1, priority: 1 },
106 /// Task { team: 1, priority: 2 }, // Same team as first
107 /// Task { team: 2, priority: 1 },
108 /// ],
109 /// };
110 ///
111 /// // One pair on same team: (0, 1) = -1 penalty
112 /// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
113 /// ```
114 pub fn for_each_unique_pair<A, E, K, KA>(
115 self,
116 extractor: E,
117 joiner: EqualJoiner<KA, KA, K>,
118 ) -> BiConstraintStream<S, A, K, E, KA, TrueFilter, Sc>
119 where
120 A: Clone + Hash + PartialEq + Send + Sync + 'static,
121 E: Fn(&S) -> &[A] + Send + Sync,
122 K: Eq + Hash + Clone + Send + Sync,
123 KA: Fn(&A) -> K + Send + Sync,
124 {
125 let (key_extractor, _) = joiner.into_keys();
126 BiConstraintStream::new_self_join(extractor, key_extractor)
127 }
128}
129
130impl<S, Sc> Default for ConstraintFactory<S, Sc>
131where
132 S: Send + Sync + 'static,
133 Sc: Score + 'static,
134{
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140impl<S, Sc: Score> Clone for ConstraintFactory<S, Sc> {
141 fn clone(&self) -> Self {
142 Self {
143 _phantom: PhantomData,
144 }
145 }
146}
147
148impl<S, Sc: Score> std::fmt::Debug for ConstraintFactory<S, Sc> {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("ConstraintFactory").finish()
151 }
152}