solverforge_scoring/stream/grouped_stream.rs
1/* Zero-erasure grouped constraint stream for group-by constraint patterns.
2
3A `GroupedConstraintStream` operates on groups of entities and supports
4filtering, weighting, and constraint finalization.
5All type information is preserved at compile time - no Arc, no dyn.
6*/
7
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use super::collection_extract::CollectionExtract;
15use super::collector::UniCollector;
16use super::complemented_stream::ComplementedConstraintStream;
17use super::filter::UniFilter;
18use crate::constraint::grouped::GroupedUniConstraint;
19
20/* Zero-erasure constraint stream over grouped entities.
21
22`GroupedConstraintStream` is created by `UniConstraintStream::group_by()`
23and operates on (key, collector_result) tuples.
24
25All type parameters are concrete - no trait objects, no Arc allocations.
26
27# Type Parameters
28
29- `S` - Solution type
30- `A` - Entity type
31- `K` - Group key type
32- `E` - Extractor function for entities
33- `Fi` - Filter type (preserved from upstream stream)
34- `KF` - Key function
35- `C` - Collector type
36- `Sc` - Score type
37
38# Example
39
40```
41use solverforge_scoring::stream::ConstraintFactory;
42use solverforge_scoring::stream::collector::count;
43use solverforge_scoring::api::constraint_set::IncrementalConstraint;
44use solverforge_core::score::SoftScore;
45
46#[derive(Clone, Hash, PartialEq, Eq)]
47struct Shift { employee_id: usize }
48
49#[derive(Clone)]
50struct Solution { shifts: Vec<Shift> }
51
52let constraint = ConstraintFactory::<Solution, SoftScore>::new()
53.for_each(|s: &Solution| &s.shifts)
54.group_by(|shift: &Shift| shift.employee_id, count())
55.penalize_with(|count: &usize| SoftScore::of((*count * *count) as i64))
56.named("Balanced workload");
57
58let solution = Solution {
59shifts: vec![
60Shift { employee_id: 1 },
61Shift { employee_id: 1 },
62Shift { employee_id: 1 },
63Shift { employee_id: 2 },
64],
65};
66
67// Employee 1: 3² = 9, Employee 2: 1² = 1, Total: -10
68assert_eq!(constraint.evaluate(&solution), SoftScore::of(-10));
69```
70*/
71pub struct GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
72where
73 Sc: Score,
74{
75 extractor: E,
76 filter: Fi,
77 key_fn: KF,
78 collector: C,
79 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
80}
81
82impl<S, A, K, E, Fi, KF, C, Sc> GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
83where
84 S: Send + Sync + 'static,
85 A: Clone + Send + Sync + 'static,
86 K: Clone + Eq + Hash + Send + Sync + 'static,
87 E: CollectionExtract<S, Item = A>,
88 Fi: UniFilter<S, A>,
89 KF: Fn(&A) -> K + Send + Sync,
90 C: UniCollector<A> + Send + Sync + 'static,
91 C::Accumulator: Send + Sync,
92 C::Result: Clone + Send + Sync,
93 Sc: Score + 'static,
94{
95 // Creates a new zero-erasure grouped constraint stream.
96 pub(crate) fn new(extractor: E, filter: Fi, key_fn: KF, collector: C) -> Self {
97 Self {
98 extractor,
99 filter,
100 key_fn,
101 collector,
102 _phantom: PhantomData,
103 }
104 }
105
106 /* Penalizes each group with a weight based on the collector result.
107
108 # Example
109
110 ```
111 use solverforge_scoring::stream::ConstraintFactory;
112 use solverforge_scoring::stream::collector::count;
113 use solverforge_scoring::api::constraint_set::IncrementalConstraint;
114 use solverforge_core::score::SoftScore;
115
116 #[derive(Clone, Hash, PartialEq, Eq)]
117 struct Task { priority: u32 }
118
119 #[derive(Clone)]
120 struct Solution { tasks: Vec<Task> }
121
122 let constraint = ConstraintFactory::<Solution, SoftScore>::new()
123 .for_each(|s: &Solution| &s.tasks)
124 .group_by(|t: &Task| t.priority, count())
125 .penalize_with(|count: &usize| SoftScore::of(*count as i64))
126 .named("Priority distribution");
127
128 let solution = Solution {
129 tasks: vec![
130 Task { priority: 1 },
131 Task { priority: 1 },
132 Task { priority: 2 },
133 ],
134 };
135
136 // Priority 1: 2 tasks, Priority 2: 1 task, Total: -3
137 assert_eq!(constraint.evaluate(&solution), SoftScore::of(-3));
138 ```
139 */
140 pub fn penalize_with<W>(
141 self,
142 weight_fn: W,
143 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
144 where
145 W: Fn(&C::Result) -> Sc + Send + Sync,
146 {
147 GroupedConstraintBuilder {
148 extractor: self.extractor,
149 filter: self.filter,
150 key_fn: self.key_fn,
151 collector: self.collector,
152 impact_type: ImpactType::Penalty,
153 weight_fn,
154 is_hard: false,
155 expected_descriptor: None,
156 _phantom: PhantomData,
157 }
158 }
159
160 // Penalizes each group with a weight, explicitly marked as hard constraint.
161 pub fn penalize_hard_with<W>(
162 self,
163 weight_fn: W,
164 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
165 where
166 W: Fn(&C::Result) -> Sc + Send + Sync,
167 {
168 GroupedConstraintBuilder {
169 extractor: self.extractor,
170 filter: self.filter,
171 key_fn: self.key_fn,
172 collector: self.collector,
173 impact_type: ImpactType::Penalty,
174 weight_fn,
175 is_hard: true,
176 expected_descriptor: None,
177 _phantom: PhantomData,
178 }
179 }
180
181 // Rewards each group with a weight based on the collector result.
182 pub fn reward_with<W>(
183 self,
184 weight_fn: W,
185 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
186 where
187 W: Fn(&C::Result) -> Sc + Send + Sync,
188 {
189 GroupedConstraintBuilder {
190 extractor: self.extractor,
191 filter: self.filter,
192 key_fn: self.key_fn,
193 collector: self.collector,
194 impact_type: ImpactType::Reward,
195 weight_fn,
196 is_hard: false,
197 expected_descriptor: None,
198 _phantom: PhantomData,
199 }
200 }
201
202 // Rewards each group with a weight, explicitly marked as hard constraint.
203 pub fn reward_hard_with<W>(
204 self,
205 weight_fn: W,
206 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
207 where
208 W: Fn(&C::Result) -> Sc + Send + Sync,
209 {
210 GroupedConstraintBuilder {
211 extractor: self.extractor,
212 filter: self.filter,
213 key_fn: self.key_fn,
214 collector: self.collector,
215 impact_type: ImpactType::Reward,
216 weight_fn,
217 is_hard: true,
218 expected_descriptor: None,
219 _phantom: PhantomData,
220 }
221 }
222
223 // Penalizes each group with one hard score unit.
224 pub fn penalize_hard(
225 self,
226 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
227 where
228 Sc: Copy,
229 {
230 let w = Sc::one_hard();
231 self.penalize_hard_with(move |_: &C::Result| w)
232 }
233
234 // Penalizes each group with one soft score unit.
235 pub fn penalize_soft(
236 self,
237 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
238 where
239 Sc: Copy,
240 {
241 let w = Sc::one_soft();
242 self.penalize_with(move |_: &C::Result| w)
243 }
244
245 // Rewards each group with one hard score unit.
246 pub fn reward_hard(
247 self,
248 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
249 where
250 Sc: Copy,
251 {
252 let w = Sc::one_hard();
253 self.reward_hard_with(move |_: &C::Result| w)
254 }
255
256 // Rewards each group with one soft score unit.
257 pub fn reward_soft(
258 self,
259 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
260 where
261 Sc: Copy,
262 {
263 let w = Sc::one_soft();
264 self.reward_with(move |_: &C::Result| w)
265 }
266
267 /* Adds complement entities with default values for missing keys.
268
269 This ensures all keys from the complement source are represented,
270 using the grouped value if present, or the default value otherwise.
271
272 **Note:** The key function for A entities wraps the original key to
273 return `Some(K)`. For filtering (skipping entities without valid keys),
274 use `complement_filtered` instead.
275
276 # Example
277
278 ```
279 use solverforge_scoring::stream::ConstraintFactory;
280 use solverforge_scoring::stream::collector::count;
281 use solverforge_scoring::api::constraint_set::IncrementalConstraint;
282 use solverforge_core::score::SoftScore;
283
284 #[derive(Clone, Hash, PartialEq, Eq)]
285 struct Employee { id: usize }
286
287 #[derive(Clone, Hash, PartialEq, Eq)]
288 struct Shift { employee_id: usize }
289
290 #[derive(Clone)]
291 struct Schedule {
292 employees: Vec<Employee>,
293 shifts: Vec<Shift>,
294 }
295
296 // Count shifts per employee, including employees with 0 shifts
297 let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
298 .for_each(|s: &Schedule| &s.shifts)
299 .group_by(|shift: &Shift| shift.employee_id, count())
300 .complement(
301 |s: &Schedule| s.employees.as_slice(),
302 |emp: &Employee| emp.id,
303 |_emp: &Employee| 0usize,
304 )
305 .penalize_with(|count: &usize| SoftScore::of(*count as i64))
306 .named("Shift count");
307
308 let schedule = Schedule {
309 employees: vec![Employee { id: 0 }, Employee { id: 1 }],
310 shifts: vec![
311 Shift { employee_id: 0 },
312 Shift { employee_id: 0 },
313 ],
314 };
315
316 // Employee 0: 2, Employee 1: 0 → Total: -2
317 assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
318 ```
319 */
320 pub fn complement<B, EB, KB, D>(
321 self,
322 extractor_b: EB,
323 key_b: KB,
324 default_fn: D,
325 ) -> ComplementedConstraintStream<
326 S,
327 A,
328 B,
329 K,
330 E,
331 EB,
332 impl Fn(&A) -> Option<K> + Send + Sync,
333 KB,
334 C,
335 D,
336 Sc,
337 >
338 where
339 B: Clone + Send + Sync + 'static,
340 EB: CollectionExtract<S, Item = B>,
341 KB: Fn(&B) -> K + Send + Sync,
342 D: Fn(&B) -> C::Result + Send + Sync,
343 {
344 let key_fn = self.key_fn;
345 let wrapped_key_fn = move |a: &A| Some((key_fn)(a));
346 ComplementedConstraintStream::new(
347 self.extractor,
348 extractor_b,
349 wrapped_key_fn,
350 key_b,
351 self.collector,
352 default_fn,
353 )
354 }
355
356 /* Adds complement entities with a custom key function for filtering.
357
358 Like `complement`, but allows providing a custom key function for A entities
359 that returns `Option<K>`. Entities returning `None` are skipped.
360
361 # Example
362
363 ```
364 use solverforge_scoring::stream::ConstraintFactory;
365 use solverforge_scoring::stream::collector::count;
366 use solverforge_scoring::api::constraint_set::IncrementalConstraint;
367 use solverforge_core::score::SoftScore;
368
369 #[derive(Clone, Hash, PartialEq, Eq)]
370 struct Employee { id: usize }
371
372 #[derive(Clone, Hash, PartialEq, Eq)]
373 struct Shift { employee_id: Option<usize> }
374
375 #[derive(Clone)]
376 struct Schedule {
377 employees: Vec<Employee>,
378 shifts: Vec<Shift>,
379 }
380
381 // Count shifts per employee, skipping unassigned shifts
382 // The group_by key is ignored; complement_with_key provides its own
383 let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
384 .for_each(|s: &Schedule| &s.shifts)
385 .group_by(|_shift: &Shift| 0usize, count()) // Placeholder key, will be overridden
386 .complement_with_key(
387 |s: &Schedule| s.employees.as_slice(),
388 |shift: &Shift| shift.employee_id, // Option<usize>
389 |emp: &Employee| emp.id, // usize
390 |_emp: &Employee| 0usize,
391 )
392 .penalize_with(|count: &usize| SoftScore::of(*count as i64))
393 .named("Shift count");
394
395 let schedule = Schedule {
396 employees: vec![Employee { id: 0 }, Employee { id: 1 }],
397 shifts: vec![
398 Shift { employee_id: Some(0) },
399 Shift { employee_id: Some(0) },
400 Shift { employee_id: None }, // Skipped
401 ],
402 };
403
404 // Employee 0: 2, Employee 1: 0 → Total: -2
405 assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
406 ```
407 */
408 pub fn complement_with_key<B, EB, KA2, KB, D>(
409 self,
410 extractor_b: EB,
411 key_a: KA2,
412 key_b: KB,
413 default_fn: D,
414 ) -> ComplementedConstraintStream<S, A, B, K, E, EB, KA2, KB, C, D, Sc>
415 where
416 B: Clone + Send + Sync + 'static,
417 EB: CollectionExtract<S, Item = B>,
418 KA2: Fn(&A) -> Option<K> + Send + Sync,
419 KB: Fn(&B) -> K + Send + Sync,
420 D: Fn(&B) -> C::Result + Send + Sync,
421 {
422 ComplementedConstraintStream::new(
423 self.extractor,
424 extractor_b,
425 key_a,
426 key_b,
427 self.collector,
428 default_fn,
429 )
430 }
431}
432
433impl<S, A, K, E, Fi, KF, C, Sc: Score> std::fmt::Debug
434 for GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
435{
436 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
437 f.debug_struct("GroupedConstraintStream").finish()
438 }
439}
440
441// Zero-erasure builder for finalizing a grouped constraint.
442pub struct GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
443where
444 Sc: Score,
445{
446 extractor: E,
447 filter: Fi,
448 key_fn: KF,
449 collector: C,
450 impact_type: ImpactType,
451 weight_fn: W,
452 is_hard: bool,
453 expected_descriptor: Option<usize>,
454 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
455}
456
457impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
458where
459 S: Send + Sync + 'static,
460 A: Clone + Send + Sync + 'static,
461 K: Clone + Eq + Hash + Send + Sync + 'static,
462 E: CollectionExtract<S, Item = A>,
463 Fi: UniFilter<S, A>,
464 KF: Fn(&A) -> K + Send + Sync,
465 C: UniCollector<A> + Send + Sync + 'static,
466 C::Accumulator: Send + Sync,
467 C::Result: Clone + Send + Sync,
468 W: Fn(&C::Result) -> Sc + Send + Sync,
469 Sc: Score + 'static,
470{
471 /* Finalizes the builder into a zero-erasure `GroupedUniConstraint`.
472
473 # Example
474
475 ```
476 use solverforge_scoring::stream::ConstraintFactory;
477 use solverforge_scoring::stream::collector::count;
478 use solverforge_scoring::api::constraint_set::IncrementalConstraint;
479 use solverforge_core::score::SoftScore;
480
481 #[derive(Clone, Hash, PartialEq, Eq)]
482 struct Item { category: u32 }
483
484 #[derive(Clone)]
485 struct Solution { items: Vec<Item> }
486
487 let constraint = ConstraintFactory::<Solution, SoftScore>::new()
488 .for_each(|s: &Solution| &s.items)
489 .group_by(|i: &Item| i.category, count())
490 .penalize_with(|n: &usize| SoftScore::of(*n as i64))
491 .named("Category penalty");
492
493 assert_eq!(constraint.name(), "Category penalty");
494 ```
495 */
496 pub fn named(self, name: &str) -> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc> {
497 let mut constraint = GroupedUniConstraint::new(
498 ConstraintRef::new("", name),
499 self.impact_type,
500 self.extractor,
501 self.filter,
502 self.key_fn,
503 self.collector,
504 self.weight_fn,
505 self.is_hard,
506 );
507 if let Some(d) = self.expected_descriptor {
508 constraint = constraint.with_descriptor(d);
509 }
510 constraint
511 }
512
513 pub fn for_descriptor(mut self, descriptor_index: usize) -> Self {
514 self.expected_descriptor = Some(descriptor_index);
515 self
516 }
517}
518
519impl<S, A, K, E, Fi, KF, C, W, Sc: Score> std::fmt::Debug
520 for GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
521{
522 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523 f.debug_struct("GroupedConstraintBuilder")
524 .field("impact_type", &self.impact_type)
525 .finish()
526 }
527}