solverforge_scoring/stream/
factory.rs

1//! Constraint factory for creating typed constraint streams.
2//!
3//! The factory is the entry point for the fluent constraint API.
4
5use std::hash::Hash;
6use std::marker::PhantomData;
7
8use solverforge_core::score::Score;
9
10use super::bi_stream::BiConstraintStream;
11use super::filter::TrueFilter;
12use super::joiner::EqualJoiner;
13use super::UniConstraintStream;
14
15/// Factory for creating constraint streams.
16///
17/// `ConstraintFactory` is parameterized by the solution type `S` and score type `Sc`.
18/// It serves as the entry point for defining constraints using the fluent API.
19///
20/// # Example
21///
22/// ```
23/// use solverforge_scoring::stream::ConstraintFactory;
24/// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
25/// use solverforge_core::score::SimpleScore;
26///
27/// #[derive(Clone)]
28/// struct Solution {
29///     values: Vec<Option<i32>>,
30/// }
31///
32/// let factory = ConstraintFactory::<Solution, SimpleScore>::new();
33///
34/// let constraint = factory
35///     .for_each(|s: &Solution| &s.values)
36///     .filter(|v: &Option<i32>| v.is_none())
37///     .penalize(SimpleScore::of(1))
38///     .as_constraint("Unassigned");
39///
40/// let solution = Solution { values: vec![Some(1), None, None] };
41/// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-2));
42/// ```
43pub struct ConstraintFactory<S, Sc: Score> {
44    _phantom: PhantomData<(S, Sc)>,
45}
46
47impl<S, Sc> ConstraintFactory<S, Sc>
48where
49    S: Send + Sync + 'static,
50    Sc: Score + 'static,
51{
52    /// Creates a new constraint factory.
53    pub fn new() -> Self {
54        Self {
55            _phantom: PhantomData,
56        }
57    }
58
59    /// Creates a zero-erasure uni-constraint stream over entities extracted from the solution.
60    ///
61    /// The extractor function receives a reference to the solution and returns
62    /// a slice of entities to iterate over. The extractor type is preserved
63    /// as a concrete generic for full zero-erasure.
64    pub fn for_each<A, E>(self, extractor: E) -> UniConstraintStream<S, A, E, TrueFilter, Sc>
65    where
66        A: Clone + Send + Sync + 'static,
67        E: Fn(&S) -> &[A] + Send + Sync,
68    {
69        UniConstraintStream::new(extractor)
70    }
71
72    /// Creates a zero-erasure bi-constraint stream over unique pairs of entities.
73    ///
74    /// This is equivalent to `for_each(extractor).join_self(joiner)` but provides
75    /// a more concise API for the common case of self-joins with key-based grouping.
76    ///
77    /// Pairs are ordered (i, j) where i < j to avoid duplicates and self-pairs.
78    ///
79    /// # Example
80    ///
81    /// ```
82    /// use solverforge_scoring::stream::{ConstraintFactory, joiner::equal};
83    /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
84    /// use solverforge_core::score::SimpleScore;
85    ///
86    /// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
87    /// struct Task { team: u32, priority: u32 }
88    ///
89    /// #[derive(Clone)]
90    /// struct Solution { tasks: Vec<Task> }
91    ///
92    /// let factory = ConstraintFactory::<Solution, SimpleScore>::new();
93    ///
94    /// // Penalize when two tasks on the same team conflict
95    /// let constraint = factory
96    ///     .for_each_unique_pair(
97    ///         |s: &Solution| s.tasks.as_slice(),
98    ///         equal(|t: &Task| t.team)
99    ///     )
100    ///     .penalize(SimpleScore::of(1))
101    ///     .as_constraint("Team conflict");
102    ///
103    /// let solution = Solution {
104    ///     tasks: vec![
105    ///         Task { team: 1, priority: 1 },
106    ///         Task { team: 1, priority: 2 },  // Same team as first
107    ///         Task { team: 2, priority: 1 },
108    ///     ],
109    /// };
110    ///
111    /// // One pair on same team: (0, 1) = -1 penalty
112    /// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1));
113    /// ```
114    pub fn for_each_unique_pair<A, E, K, KA>(
115        self,
116        extractor: E,
117        joiner: EqualJoiner<KA, KA, K>,
118    ) -> BiConstraintStream<S, A, K, E, KA, TrueFilter, Sc>
119    where
120        A: Clone + Hash + PartialEq + Send + Sync + 'static,
121        E: Fn(&S) -> &[A] + Send + Sync,
122        K: Eq + Hash + Clone + Send + Sync,
123        KA: Fn(&A) -> K + Send + Sync,
124    {
125        let (key_extractor, _) = joiner.into_keys();
126        BiConstraintStream::new_self_join(extractor, key_extractor)
127    }
128}
129
130impl<S, Sc> Default for ConstraintFactory<S, Sc>
131where
132    S: Send + Sync + 'static,
133    Sc: Score + 'static,
134{
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140impl<S, Sc: Score> Clone for ConstraintFactory<S, Sc> {
141    fn clone(&self) -> Self {
142        Self {
143            _phantom: PhantomData,
144        }
145    }
146}
147
148impl<S, Sc: Score> std::fmt::Debug for ConstraintFactory<S, Sc> {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        f.debug_struct("ConstraintFactory").finish()
151    }
152}