Skip to main content

solverforge_scoring/stream/
factory.rs

1// Constraint factory for creating typed constraint streams.
2//
3// The factory is the entry point for the fluent constraint API.
4
5use std::hash::Hash;
6use std::marker::PhantomData;
7
8use solverforge_core::score::Score;
9
10use super::bi_stream::BiConstraintStream;
11use super::filter::TrueFilter;
12use super::joiner::EqualJoiner;
13use super::UniConstraintStream;
14
15// Factory for creating constraint streams.
16//
17// `ConstraintFactory` is parameterized by the solution type `S` and score type `Sc`.
18// It serves as the entry point for defining constraints using the fluent API.
19//
20// # Example
21//
22// ```
23// use solverforge_scoring::stream::ConstraintFactory;
24// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
25// use solverforge_core::score::SoftScore;
26//
27// #[derive(Clone)]
28// struct Solution {
29//     values: Vec<Option<i32>>,
30// }
31//
32// let factory = ConstraintFactory::<Solution, SoftScore>::new();
33//
34// let constraint = factory
35//     .for_each(|s: &Solution| &s.values)
36//     .filter(|v: &Option<i32>| v.is_none())
37//     .penalize(SoftScore::of(1))
38//     .as_constraint("Unassigned");
39//
40// let solution = Solution { values: vec![Some(1), None, None] };
41// assert_eq!(constraint.evaluate(&solution), SoftScore::of(-2));
42// ```
43pub struct ConstraintFactory<S, Sc: Score> {
44    _phantom: PhantomData<(fn() -> S, fn() -> Sc)>,
45}
46
47impl<S, Sc> ConstraintFactory<S, Sc>
48where
49    S: Send + Sync + 'static,
50    Sc: Score + 'static,
51{
52    // Creates a new constraint factory.
53    pub fn new() -> Self {
54        Self {
55            _phantom: PhantomData,
56        }
57    }
58
59    // Creates a zero-erasure uni-constraint stream over entities extracted from the solution.
60    //
61    // The extractor function receives a reference to the solution and returns
62    // a slice of entities to iterate over. The extractor type is preserved
63    // as a concrete generic for full zero-erasure.
64    pub fn for_each<A, E>(self, extractor: E) -> UniConstraintStream<S, A, E, TrueFilter, Sc>
65    where
66        A: Clone + Send + Sync + 'static,
67        E: Fn(&S) -> &[A] + Send + Sync,
68    {
69        UniConstraintStream::new(extractor)
70    }
71
72    // Creates a zero-erasure bi-constraint stream over unique pairs of entities.
73    //
74    // This is equivalent to `for_each(extractor).join_self(joiner)` but provides
75    // a more concise API for the common case of self-joins with key-based grouping.
76    //
77    // Pairs are ordered (i, j) where i < j to avoid duplicates and self-pairs.
78    //
79    // # Example
80    //
81    // ```
82    // use solverforge_scoring::stream::{ConstraintFactory, joiner::equal};
83    // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
84    // use solverforge_core::score::SoftScore;
85    //
86    // #[derive(Clone, Debug, Hash, PartialEq, Eq)]
87    // struct Task { team: u32, priority: u32 }
88    //
89    // #[derive(Clone)]
90    // struct Solution { tasks: Vec<Task> }
91    //
92    // let factory = ConstraintFactory::<Solution, SoftScore>::new();
93    //
94    // // Penalize when two tasks on the same team conflict
95    // let constraint = factory
96    //     .for_each_unique_pair(
97    //         |s: &Solution| s.tasks.as_slice(),
98    //         equal(|t: &Task| t.team)
99    //     )
100    //     .penalize(SoftScore::of(1))
101    //     .as_constraint("Team conflict");
102    //
103    // let solution = Solution {
104    //     tasks: vec![
105    //         Task { team: 1, priority: 1 },
106    //         Task { team: 1, priority: 2 },  // Same team as first
107    //         Task { team: 2, priority: 1 },
108    //     ],
109    // };
110    //
111    // // One pair on same team: (0, 1) = -1 penalty
112    // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1));
113    // ```
114    pub fn for_each_unique_pair<A, E, K, KA>(
115        self,
116        extractor: E,
117        joiner: EqualJoiner<KA, KA, K>,
118    ) -> BiConstraintStream<S, A, K, E, impl Fn(&S, &A, usize) -> K + Send + Sync, TrueFilter, Sc>
119    where
120        A: Clone + Hash + PartialEq + Send + Sync + 'static,
121        E: Fn(&S) -> &[A] + Send + Sync,
122        K: Eq + Hash + Clone + Send + Sync,
123        KA: Fn(&A) -> K + Send + Sync,
124    {
125        let (key_extractor, _) = joiner.into_keys();
126        // Wrap to match the new KE: Fn(&S, &A, usize) -> K signature
127        let wrapped_ke = move |_s: &S, a: &A, _idx: usize| key_extractor(a);
128        BiConstraintStream::new_self_join(extractor, wrapped_ke)
129    }
130}
131
132impl<S, Sc> Default for ConstraintFactory<S, Sc>
133where
134    S: Send + Sync + 'static,
135    Sc: Score + 'static,
136{
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142impl<S, Sc: Score> Clone for ConstraintFactory<S, Sc> {
143    fn clone(&self) -> Self {
144        Self {
145            _phantom: PhantomData,
146        }
147    }
148}
149
150impl<S, Sc: Score> std::fmt::Debug for ConstraintFactory<S, Sc> {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        f.debug_struct("ConstraintFactory").finish()
153    }
154}