solverforge_scoring/stream/grouped_stream.rs
1//! Zero-erasure grouped constraint stream for group-by constraint patterns.
2//!
3//! A `GroupedConstraintStream` operates on groups of entities and supports
4//! filtering, weighting, and constraint finalization.
5//! All type information is preserved at compile time - no Arc, no dyn.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collector::UniCollector;
14use super::complemented_stream::ComplementedConstraintStream;
15use crate::constraint::grouped::GroupedUniConstraint;
16
17/// Zero-erasure constraint stream over grouped entities.
18///
19/// `GroupedConstraintStream` is created by `UniConstraintStream::group_by()`
20/// and operates on (key, collector_result) tuples.
21///
22/// All type parameters are concrete - no trait objects, no Arc allocations.
23///
24/// # Type Parameters
25///
26/// - `S` - Solution type
27/// - `A` - Entity type
28/// - `K` - Group key type
29/// - `E` - Extractor function for entities
30/// - `KF` - Key function
31/// - `C` - Collector type
32/// - `Sc` - Score type
33///
34/// # Example
35///
36/// ```
37/// use solverforge_scoring::stream::ConstraintFactory;
38/// use solverforge_scoring::stream::collector::count;
39/// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
40/// use solverforge_core::score::SimpleScore;
41///
42/// #[derive(Clone, Hash, PartialEq, Eq)]
43/// struct Shift { employee_id: usize }
44///
45/// #[derive(Clone)]
46/// struct Solution { shifts: Vec<Shift> }
47///
48/// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
49/// .for_each(|s: &Solution| &s.shifts)
50/// .group_by(|shift: &Shift| shift.employee_id, count())
51/// .penalize_with(|count: &usize| SimpleScore::of((*count * *count) as i64))
52/// .as_constraint("Balanced workload");
53///
54/// let solution = Solution {
55/// shifts: vec![
56/// Shift { employee_id: 1 },
57/// Shift { employee_id: 1 },
58/// Shift { employee_id: 1 },
59/// Shift { employee_id: 2 },
60/// ],
61/// };
62///
63/// // Employee 1: 3² = 9, Employee 2: 1² = 1, Total: -10
64/// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-10));
65/// ```
66pub struct GroupedConstraintStream<S, A, K, E, KF, C, Sc>
67where
68 Sc: Score,
69{
70 extractor: E,
71 key_fn: KF,
72 collector: C,
73 _phantom: PhantomData<(S, A, K, Sc)>,
74}
75
76impl<S, A, K, E, KF, C, Sc> GroupedConstraintStream<S, A, K, E, KF, C, Sc>
77where
78 S: Send + Sync + 'static,
79 A: Clone + Send + Sync + 'static,
80 K: Clone + Eq + Hash + Send + Sync + 'static,
81 E: Fn(&S) -> &[A] + Send + Sync,
82 KF: Fn(&A) -> K + Send + Sync,
83 C: UniCollector<A> + Send + Sync + 'static,
84 C::Accumulator: Send + Sync,
85 C::Result: Clone + Send + Sync,
86 Sc: Score + 'static,
87{
88 /// Creates a new zero-erasure grouped constraint stream.
89 pub(crate) fn new(extractor: E, key_fn: KF, collector: C) -> Self {
90 Self {
91 extractor,
92 key_fn,
93 collector,
94 _phantom: PhantomData,
95 }
96 }
97
98 /// Penalizes each group with a weight based on the collector result.
99 ///
100 /// # Example
101 ///
102 /// ```
103 /// use solverforge_scoring::stream::ConstraintFactory;
104 /// use solverforge_scoring::stream::collector::count;
105 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
106 /// use solverforge_core::score::SimpleScore;
107 ///
108 /// #[derive(Clone, Hash, PartialEq, Eq)]
109 /// struct Task { priority: u32 }
110 ///
111 /// #[derive(Clone)]
112 /// struct Solution { tasks: Vec<Task> }
113 ///
114 /// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
115 /// .for_each(|s: &Solution| &s.tasks)
116 /// .group_by(|t: &Task| t.priority, count())
117 /// .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
118 /// .as_constraint("Priority distribution");
119 ///
120 /// let solution = Solution {
121 /// tasks: vec![
122 /// Task { priority: 1 },
123 /// Task { priority: 1 },
124 /// Task { priority: 2 },
125 /// ],
126 /// };
127 ///
128 /// // Priority 1: 2 tasks, Priority 2: 1 task, Total: -3
129 /// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-3));
130 /// ```
131 pub fn penalize_with<W>(
132 self,
133 weight_fn: W,
134 ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
135 where
136 W: Fn(&C::Result) -> Sc + Send + Sync,
137 {
138 GroupedConstraintBuilder {
139 extractor: self.extractor,
140 key_fn: self.key_fn,
141 collector: self.collector,
142 impact_type: ImpactType::Penalty,
143 weight_fn,
144 is_hard: false,
145 _phantom: PhantomData,
146 }
147 }
148
149 /// Penalizes each group with a weight, explicitly marked as hard constraint.
150 pub fn penalize_hard_with<W>(
151 self,
152 weight_fn: W,
153 ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
154 where
155 W: Fn(&C::Result) -> Sc + Send + Sync,
156 {
157 GroupedConstraintBuilder {
158 extractor: self.extractor,
159 key_fn: self.key_fn,
160 collector: self.collector,
161 impact_type: ImpactType::Penalty,
162 weight_fn,
163 is_hard: true,
164 _phantom: PhantomData,
165 }
166 }
167
168 /// Rewards each group with a weight based on the collector result.
169 pub fn reward_with<W>(self, weight_fn: W) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
170 where
171 W: Fn(&C::Result) -> Sc + Send + Sync,
172 {
173 GroupedConstraintBuilder {
174 extractor: self.extractor,
175 key_fn: self.key_fn,
176 collector: self.collector,
177 impact_type: ImpactType::Reward,
178 weight_fn,
179 is_hard: false,
180 _phantom: PhantomData,
181 }
182 }
183
184 /// Rewards each group with a weight, explicitly marked as hard constraint.
185 pub fn reward_hard_with<W>(
186 self,
187 weight_fn: W,
188 ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
189 where
190 W: Fn(&C::Result) -> Sc + Send + Sync,
191 {
192 GroupedConstraintBuilder {
193 extractor: self.extractor,
194 key_fn: self.key_fn,
195 collector: self.collector,
196 impact_type: ImpactType::Reward,
197 weight_fn,
198 is_hard: true,
199 _phantom: PhantomData,
200 }
201 }
202
203 /// Adds complement entities with default values for missing keys.
204 ///
205 /// This ensures all keys from the complement source are represented,
206 /// using the grouped value if present, or the default value otherwise.
207 ///
208 /// **Note:** The key function for A entities wraps the original key to
209 /// return `Some(K)`. For filtering (skipping entities without valid keys),
210 /// use `complement_filtered` instead.
211 ///
212 /// # Example
213 ///
214 /// ```
215 /// use solverforge_scoring::stream::ConstraintFactory;
216 /// use solverforge_scoring::stream::collector::count;
217 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
218 /// use solverforge_core::score::SimpleScore;
219 ///
220 /// #[derive(Clone, Hash, PartialEq, Eq)]
221 /// struct Employee { id: usize }
222 ///
223 /// #[derive(Clone, Hash, PartialEq, Eq)]
224 /// struct Shift { employee_id: usize }
225 ///
226 /// #[derive(Clone)]
227 /// struct Schedule {
228 /// employees: Vec<Employee>,
229 /// shifts: Vec<Shift>,
230 /// }
231 ///
232 /// // Count shifts per employee, including employees with 0 shifts
233 /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
234 /// .for_each(|s: &Schedule| &s.shifts)
235 /// .group_by(|shift: &Shift| shift.employee_id, count())
236 /// .complement(
237 /// |s: &Schedule| s.employees.as_slice(),
238 /// |emp: &Employee| emp.id,
239 /// |_emp: &Employee| 0usize,
240 /// )
241 /// .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
242 /// .as_constraint("Shift count");
243 ///
244 /// let schedule = Schedule {
245 /// employees: vec![Employee { id: 0 }, Employee { id: 1 }],
246 /// shifts: vec![
247 /// Shift { employee_id: 0 },
248 /// Shift { employee_id: 0 },
249 /// ],
250 /// };
251 ///
252 /// // Employee 0: 2, Employee 1: 0 → Total: -2
253 /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
254 /// ```
255 pub fn complement<B, EB, KB, D>(
256 self,
257 extractor_b: EB,
258 key_b: KB,
259 default_fn: D,
260 ) -> ComplementedConstraintStream<
261 S,
262 A,
263 B,
264 K,
265 E,
266 EB,
267 impl Fn(&A) -> Option<K> + Send + Sync,
268 KB,
269 C,
270 D,
271 Sc,
272 >
273 where
274 B: Clone + Send + Sync + 'static,
275 EB: Fn(&S) -> &[B] + Send + Sync,
276 KB: Fn(&B) -> K + Send + Sync,
277 D: Fn(&B) -> C::Result + Send + Sync,
278 {
279 let key_fn = self.key_fn;
280 let wrapped_key_fn = move |a: &A| Some((key_fn)(a));
281 ComplementedConstraintStream::new(
282 self.extractor,
283 extractor_b,
284 wrapped_key_fn,
285 key_b,
286 self.collector,
287 default_fn,
288 )
289 }
290
291 /// Adds complement entities with a custom key function for filtering.
292 ///
293 /// Like `complement`, but allows providing a custom key function for A entities
294 /// that returns `Option<K>`. Entities returning `None` are skipped.
295 ///
296 /// # Example
297 ///
298 /// ```
299 /// use solverforge_scoring::stream::ConstraintFactory;
300 /// use solverforge_scoring::stream::collector::count;
301 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
302 /// use solverforge_core::score::SimpleScore;
303 ///
304 /// #[derive(Clone, Hash, PartialEq, Eq)]
305 /// struct Employee { id: usize }
306 ///
307 /// #[derive(Clone, Hash, PartialEq, Eq)]
308 /// struct Shift { employee_id: Option<usize> }
309 ///
310 /// #[derive(Clone)]
311 /// struct Schedule {
312 /// employees: Vec<Employee>,
313 /// shifts: Vec<Shift>,
314 /// }
315 ///
316 /// // Count shifts per employee, skipping unassigned shifts
317 /// // The group_by key is ignored; complement_with_key provides its own
318 /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
319 /// .for_each(|s: &Schedule| &s.shifts)
320 /// .group_by(|_shift: &Shift| 0usize, count()) // Placeholder key, will be overridden
321 /// .complement_with_key(
322 /// |s: &Schedule| s.employees.as_slice(),
323 /// |shift: &Shift| shift.employee_id, // Option<usize>
324 /// |emp: &Employee| emp.id, // usize
325 /// |_emp: &Employee| 0usize,
326 /// )
327 /// .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
328 /// .as_constraint("Shift count");
329 ///
330 /// let schedule = Schedule {
331 /// employees: vec![Employee { id: 0 }, Employee { id: 1 }],
332 /// shifts: vec![
333 /// Shift { employee_id: Some(0) },
334 /// Shift { employee_id: Some(0) },
335 /// Shift { employee_id: None }, // Skipped
336 /// ],
337 /// };
338 ///
339 /// // Employee 0: 2, Employee 1: 0 → Total: -2
340 /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
341 /// ```
342 pub fn complement_with_key<B, EB, KA2, KB, D>(
343 self,
344 extractor_b: EB,
345 key_a: KA2,
346 key_b: KB,
347 default_fn: D,
348 ) -> ComplementedConstraintStream<S, A, B, K, E, EB, KA2, KB, C, D, Sc>
349 where
350 B: Clone + Send + Sync + 'static,
351 EB: Fn(&S) -> &[B] + Send + Sync,
352 KA2: Fn(&A) -> Option<K> + Send + Sync,
353 KB: Fn(&B) -> K + Send + Sync,
354 D: Fn(&B) -> C::Result + Send + Sync,
355 {
356 ComplementedConstraintStream::new(
357 self.extractor,
358 extractor_b,
359 key_a,
360 key_b,
361 self.collector,
362 default_fn,
363 )
364 }
365}
366
367impl<S, A, K, E, KF, C, Sc: Score> std::fmt::Debug
368 for GroupedConstraintStream<S, A, K, E, KF, C, Sc>
369{
370 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371 f.debug_struct("GroupedConstraintStream").finish()
372 }
373}
374
375/// Zero-erasure builder for finalizing a grouped constraint.
376pub struct GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
377where
378 Sc: Score,
379{
380 extractor: E,
381 key_fn: KF,
382 collector: C,
383 impact_type: ImpactType,
384 weight_fn: W,
385 is_hard: bool,
386 _phantom: PhantomData<(S, A, K, Sc)>,
387}
388
389impl<S, A, K, E, KF, C, W, Sc> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
390where
391 S: Send + Sync + 'static,
392 A: Clone + Send + Sync + 'static,
393 K: Clone + Eq + Hash + Send + Sync + 'static,
394 E: Fn(&S) -> &[A] + Send + Sync,
395 KF: Fn(&A) -> K + Send + Sync,
396 C: UniCollector<A> + Send + Sync + 'static,
397 C::Accumulator: Send + Sync,
398 C::Result: Clone + Send + Sync,
399 W: Fn(&C::Result) -> Sc + Send + Sync,
400 Sc: Score + 'static,
401{
402 /// Finalizes the builder into a zero-erasure `GroupedUniConstraint`.
403 ///
404 /// # Example
405 ///
406 /// ```
407 /// use solverforge_scoring::stream::ConstraintFactory;
408 /// use solverforge_scoring::stream::collector::count;
409 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
410 /// use solverforge_core::score::SimpleScore;
411 ///
412 /// #[derive(Clone, Hash, PartialEq, Eq)]
413 /// struct Item { category: u32 }
414 ///
415 /// #[derive(Clone)]
416 /// struct Solution { items: Vec<Item> }
417 ///
418 /// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
419 /// .for_each(|s: &Solution| &s.items)
420 /// .group_by(|i: &Item| i.category, count())
421 /// .penalize_with(|n: &usize| SimpleScore::of(*n as i64))
422 /// .as_constraint("Category penalty");
423 ///
424 /// assert_eq!(constraint.name(), "Category penalty");
425 /// ```
426 pub fn as_constraint(self, name: &str) -> GroupedUniConstraint<S, A, K, E, KF, C, W, Sc> {
427 GroupedUniConstraint::new(
428 ConstraintRef::new("", name),
429 self.impact_type,
430 self.extractor,
431 self.key_fn,
432 self.collector,
433 self.weight_fn,
434 self.is_hard,
435 )
436 }
437}
438
439impl<S, A, K, E, KF, C, W, Sc: Score> std::fmt::Debug
440 for GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
441{
442 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443 f.debug_struct("GroupedConstraintBuilder")
444 .field("impact_type", &self.impact_type)
445 .finish()
446 }
447}