solverforge_scoring/stream/grouped_stream.rs
1// Zero-erasure grouped constraint stream for group-by constraint patterns.
2//
3// A `GroupedConstraintStream` operates on groups of entities and supports
4// filtering, weighting, and constraint finalization.
5// All type information is preserved at compile time - no Arc, no dyn.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collector::UniCollector;
14use super::complemented_stream::ComplementedConstraintStream;
15use super::filter::UniFilter;
16use crate::constraint::grouped::GroupedUniConstraint;
17
18// Zero-erasure constraint stream over grouped entities.
19//
20// `GroupedConstraintStream` is created by `UniConstraintStream::group_by()`
21// and operates on (key, collector_result) tuples.
22//
23// All type parameters are concrete - no trait objects, no Arc allocations.
24//
25// # Type Parameters
26//
27// - `S` - Solution type
28// - `A` - Entity type
29// - `K` - Group key type
30// - `E` - Extractor function for entities
31// - `Fi` - Filter type (preserved from upstream stream)
32// - `KF` - Key function
33// - `C` - Collector type
34// - `Sc` - Score type
35//
36// # Example
37//
38// ```
39// use solverforge_scoring::stream::ConstraintFactory;
40// use solverforge_scoring::stream::collector::count;
41// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
42// use solverforge_core::score::SoftScore;
43//
44// #[derive(Clone, Hash, PartialEq, Eq)]
45// struct Shift { employee_id: usize }
46//
47// #[derive(Clone)]
48// struct Solution { shifts: Vec<Shift> }
49//
50// let constraint = ConstraintFactory::<Solution, SoftScore>::new()
51// .for_each(|s: &Solution| &s.shifts)
52// .group_by(|shift: &Shift| shift.employee_id, count())
53// .penalize_with(|count: &usize| SoftScore::of((*count * *count) as i64))
54// .as_constraint("Balanced workload");
55//
56// let solution = Solution {
57// shifts: vec![
58// Shift { employee_id: 1 },
59// Shift { employee_id: 1 },
60// Shift { employee_id: 1 },
61// Shift { employee_id: 2 },
62// ],
63// };
64//
65// // Employee 1: 3² = 9, Employee 2: 1² = 1, Total: -10
66// assert_eq!(constraint.evaluate(&solution), SoftScore::of(-10));
67// ```
68pub struct GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
69where
70 Sc: Score,
71{
72 extractor: E,
73 filter: Fi,
74 key_fn: KF,
75 collector: C,
76 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
77}
78
79impl<S, A, K, E, Fi, KF, C, Sc> GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
80where
81 S: Send + Sync + 'static,
82 A: Clone + Send + Sync + 'static,
83 K: Clone + Eq + Hash + Send + Sync + 'static,
84 E: Fn(&S) -> &[A] + Send + Sync,
85 Fi: UniFilter<S, A>,
86 KF: Fn(&A) -> K + Send + Sync,
87 C: UniCollector<A> + Send + Sync + 'static,
88 C::Accumulator: Send + Sync,
89 C::Result: Clone + Send + Sync,
90 Sc: Score + 'static,
91{
92 // Creates a new zero-erasure grouped constraint stream.
93 pub(crate) fn new(extractor: E, filter: Fi, key_fn: KF, collector: C) -> Self {
94 Self {
95 extractor,
96 filter,
97 key_fn,
98 collector,
99 _phantom: PhantomData,
100 }
101 }
102
103 // Penalizes each group with a weight based on the collector result.
104 //
105 // # Example
106 //
107 // ```
108 // use solverforge_scoring::stream::ConstraintFactory;
109 // use solverforge_scoring::stream::collector::count;
110 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
111 // use solverforge_core::score::SoftScore;
112 //
113 // #[derive(Clone, Hash, PartialEq, Eq)]
114 // struct Task { priority: u32 }
115 //
116 // #[derive(Clone)]
117 // struct Solution { tasks: Vec<Task> }
118 //
119 // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
120 // .for_each(|s: &Solution| &s.tasks)
121 // .group_by(|t: &Task| t.priority, count())
122 // .penalize_with(|count: &usize| SoftScore::of(*count as i64))
123 // .as_constraint("Priority distribution");
124 //
125 // let solution = Solution {
126 // tasks: vec![
127 // Task { priority: 1 },
128 // Task { priority: 1 },
129 // Task { priority: 2 },
130 // ],
131 // };
132 //
133 // // Priority 1: 2 tasks, Priority 2: 1 task, Total: -3
134 // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-3));
135 // ```
136 pub fn penalize_with<W>(
137 self,
138 weight_fn: W,
139 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
140 where
141 W: Fn(&C::Result) -> Sc + Send + Sync,
142 {
143 GroupedConstraintBuilder {
144 extractor: self.extractor,
145 filter: self.filter,
146 key_fn: self.key_fn,
147 collector: self.collector,
148 impact_type: ImpactType::Penalty,
149 weight_fn,
150 is_hard: false,
151 expected_descriptor: None,
152 _phantom: PhantomData,
153 }
154 }
155
156 // Penalizes each group with a weight, explicitly marked as hard constraint.
157 pub fn penalize_hard_with<W>(
158 self,
159 weight_fn: W,
160 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
161 where
162 W: Fn(&C::Result) -> Sc + Send + Sync,
163 {
164 GroupedConstraintBuilder {
165 extractor: self.extractor,
166 filter: self.filter,
167 key_fn: self.key_fn,
168 collector: self.collector,
169 impact_type: ImpactType::Penalty,
170 weight_fn,
171 is_hard: true,
172 expected_descriptor: None,
173 _phantom: PhantomData,
174 }
175 }
176
177 // Rewards each group with a weight based on the collector result.
178 pub fn reward_with<W>(
179 self,
180 weight_fn: W,
181 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
182 where
183 W: Fn(&C::Result) -> Sc + Send + Sync,
184 {
185 GroupedConstraintBuilder {
186 extractor: self.extractor,
187 filter: self.filter,
188 key_fn: self.key_fn,
189 collector: self.collector,
190 impact_type: ImpactType::Reward,
191 weight_fn,
192 is_hard: false,
193 expected_descriptor: None,
194 _phantom: PhantomData,
195 }
196 }
197
198 // Rewards each group with a weight, explicitly marked as hard constraint.
199 pub fn reward_hard_with<W>(
200 self,
201 weight_fn: W,
202 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
203 where
204 W: Fn(&C::Result) -> Sc + Send + Sync,
205 {
206 GroupedConstraintBuilder {
207 extractor: self.extractor,
208 filter: self.filter,
209 key_fn: self.key_fn,
210 collector: self.collector,
211 impact_type: ImpactType::Reward,
212 weight_fn,
213 is_hard: true,
214 expected_descriptor: None,
215 _phantom: PhantomData,
216 }
217 }
218
219 // Penalizes each group with one hard score unit.
220 pub fn penalize_hard(
221 self,
222 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
223 where
224 Sc: Copy,
225 {
226 let w = Sc::one_hard();
227 self.penalize_hard_with(move |_: &C::Result| w)
228 }
229
230 // Penalizes each group with one soft score unit.
231 pub fn penalize_soft(
232 self,
233 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
234 where
235 Sc: Copy,
236 {
237 let w = Sc::one_soft();
238 self.penalize_with(move |_: &C::Result| w)
239 }
240
241 // Rewards each group with one hard score unit.
242 pub fn reward_hard(
243 self,
244 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
245 where
246 Sc: Copy,
247 {
248 let w = Sc::one_hard();
249 self.reward_hard_with(move |_: &C::Result| w)
250 }
251
252 // Rewards each group with one soft score unit.
253 pub fn reward_soft(
254 self,
255 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, impl Fn(&C::Result) -> Sc + Send + Sync, Sc>
256 where
257 Sc: Copy,
258 {
259 let w = Sc::one_soft();
260 self.reward_with(move |_: &C::Result| w)
261 }
262
263 // Adds complement entities with default values for missing keys.
264 //
265 // This ensures all keys from the complement source are represented,
266 // using the grouped value if present, or the default value otherwise.
267 //
268 // **Note:** The key function for A entities wraps the original key to
269 // return `Some(K)`. For filtering (skipping entities without valid keys),
270 // use `complement_filtered` instead.
271 //
272 // # Example
273 //
274 // ```
275 // use solverforge_scoring::stream::ConstraintFactory;
276 // use solverforge_scoring::stream::collector::count;
277 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
278 // use solverforge_core::score::SoftScore;
279 //
280 // #[derive(Clone, Hash, PartialEq, Eq)]
281 // struct Employee { id: usize }
282 //
283 // #[derive(Clone, Hash, PartialEq, Eq)]
284 // struct Shift { employee_id: usize }
285 //
286 // #[derive(Clone)]
287 // struct Schedule {
288 // employees: Vec<Employee>,
289 // shifts: Vec<Shift>,
290 // }
291 //
292 // // Count shifts per employee, including employees with 0 shifts
293 // let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
294 // .for_each(|s: &Schedule| &s.shifts)
295 // .group_by(|shift: &Shift| shift.employee_id, count())
296 // .complement(
297 // |s: &Schedule| s.employees.as_slice(),
298 // |emp: &Employee| emp.id,
299 // |_emp: &Employee| 0usize,
300 // )
301 // .penalize_with(|count: &usize| SoftScore::of(*count as i64))
302 // .as_constraint("Shift count");
303 //
304 // let schedule = Schedule {
305 // employees: vec![Employee { id: 0 }, Employee { id: 1 }],
306 // shifts: vec![
307 // Shift { employee_id: 0 },
308 // Shift { employee_id: 0 },
309 // ],
310 // };
311 //
312 // // Employee 0: 2, Employee 1: 0 → Total: -2
313 // assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
314 // ```
315 pub fn complement<B, EB, KB, D>(
316 self,
317 extractor_b: EB,
318 key_b: KB,
319 default_fn: D,
320 ) -> ComplementedConstraintStream<
321 S,
322 A,
323 B,
324 K,
325 E,
326 EB,
327 impl Fn(&A) -> Option<K> + Send + Sync,
328 KB,
329 C,
330 D,
331 Sc,
332 >
333 where
334 B: Clone + Send + Sync + 'static,
335 EB: Fn(&S) -> &[B] + Send + Sync,
336 KB: Fn(&B) -> K + Send + Sync,
337 D: Fn(&B) -> C::Result + Send + Sync,
338 {
339 let key_fn = self.key_fn;
340 let wrapped_key_fn = move |a: &A| Some((key_fn)(a));
341 ComplementedConstraintStream::new(
342 self.extractor,
343 extractor_b,
344 wrapped_key_fn,
345 key_b,
346 self.collector,
347 default_fn,
348 )
349 }
350
351 // Adds complement entities with a custom key function for filtering.
352 //
353 // Like `complement`, but allows providing a custom key function for A entities
354 // that returns `Option<K>`. Entities returning `None` are skipped.
355 //
356 // # Example
357 //
358 // ```
359 // use solverforge_scoring::stream::ConstraintFactory;
360 // use solverforge_scoring::stream::collector::count;
361 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
362 // use solverforge_core::score::SoftScore;
363 //
364 // #[derive(Clone, Hash, PartialEq, Eq)]
365 // struct Employee { id: usize }
366 //
367 // #[derive(Clone, Hash, PartialEq, Eq)]
368 // struct Shift { employee_id: Option<usize> }
369 //
370 // #[derive(Clone)]
371 // struct Schedule {
372 // employees: Vec<Employee>,
373 // shifts: Vec<Shift>,
374 // }
375 //
376 // // Count shifts per employee, skipping unassigned shifts
377 // // The group_by key is ignored; complement_with_key provides its own
378 // let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
379 // .for_each(|s: &Schedule| &s.shifts)
380 // .group_by(|_shift: &Shift| 0usize, count()) // Placeholder key, will be overridden
381 // .complement_with_key(
382 // |s: &Schedule| s.employees.as_slice(),
383 // |shift: &Shift| shift.employee_id, // Option<usize>
384 // |emp: &Employee| emp.id, // usize
385 // |_emp: &Employee| 0usize,
386 // )
387 // .penalize_with(|count: &usize| SoftScore::of(*count as i64))
388 // .as_constraint("Shift count");
389 //
390 // let schedule = Schedule {
391 // employees: vec![Employee { id: 0 }, Employee { id: 1 }],
392 // shifts: vec![
393 // Shift { employee_id: Some(0) },
394 // Shift { employee_id: Some(0) },
395 // Shift { employee_id: None }, // Skipped
396 // ],
397 // };
398 //
399 // // Employee 0: 2, Employee 1: 0 → Total: -2
400 // assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
401 // ```
402 pub fn complement_with_key<B, EB, KA2, KB, D>(
403 self,
404 extractor_b: EB,
405 key_a: KA2,
406 key_b: KB,
407 default_fn: D,
408 ) -> ComplementedConstraintStream<S, A, B, K, E, EB, KA2, KB, C, D, Sc>
409 where
410 B: Clone + Send + Sync + 'static,
411 EB: Fn(&S) -> &[B] + Send + Sync,
412 KA2: Fn(&A) -> Option<K> + Send + Sync,
413 KB: Fn(&B) -> K + Send + Sync,
414 D: Fn(&B) -> C::Result + Send + Sync,
415 {
416 ComplementedConstraintStream::new(
417 self.extractor,
418 extractor_b,
419 key_a,
420 key_b,
421 self.collector,
422 default_fn,
423 )
424 }
425}
426
427impl<S, A, K, E, Fi, KF, C, Sc: Score> std::fmt::Debug
428 for GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
429{
430 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431 f.debug_struct("GroupedConstraintStream").finish()
432 }
433}
434
435// Zero-erasure builder for finalizing a grouped constraint.
436pub struct GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
437where
438 Sc: Score,
439{
440 extractor: E,
441 filter: Fi,
442 key_fn: KF,
443 collector: C,
444 impact_type: ImpactType,
445 weight_fn: W,
446 is_hard: bool,
447 expected_descriptor: Option<usize>,
448 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
449}
450
451impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
452where
453 S: Send + Sync + 'static,
454 A: Clone + Send + Sync + 'static,
455 K: Clone + Eq + Hash + Send + Sync + 'static,
456 E: Fn(&S) -> &[A] + Send + Sync,
457 Fi: UniFilter<S, A>,
458 KF: Fn(&A) -> K + Send + Sync,
459 C: UniCollector<A> + Send + Sync + 'static,
460 C::Accumulator: Send + Sync,
461 C::Result: Clone + Send + Sync,
462 W: Fn(&C::Result) -> Sc + Send + Sync,
463 Sc: Score + 'static,
464{
465 // Finalizes the builder into a zero-erasure `GroupedUniConstraint`.
466 //
467 // # Example
468 //
469 // ```
470 // use solverforge_scoring::stream::ConstraintFactory;
471 // use solverforge_scoring::stream::collector::count;
472 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
473 // use solverforge_core::score::SoftScore;
474 //
475 // #[derive(Clone, Hash, PartialEq, Eq)]
476 // struct Item { category: u32 }
477 //
478 // #[derive(Clone)]
479 // struct Solution { items: Vec<Item> }
480 //
481 // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
482 // .for_each(|s: &Solution| &s.items)
483 // .group_by(|i: &Item| i.category, count())
484 // .penalize_with(|n: &usize| SoftScore::of(*n as i64))
485 // .as_constraint("Category penalty");
486 //
487 // assert_eq!(constraint.name(), "Category penalty");
488 // ```
489 // Alias for `as_constraint`.
490 pub fn named(self, name: &str) -> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc> {
491 self.as_constraint(name)
492 }
493
494 pub fn for_descriptor(mut self, descriptor_index: usize) -> Self {
495 self.expected_descriptor = Some(descriptor_index);
496 self
497 }
498
499 pub fn as_constraint(self, name: &str) -> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc> {
500 let mut constraint = GroupedUniConstraint::new(
501 ConstraintRef::new("", name),
502 self.impact_type,
503 self.extractor,
504 self.filter,
505 self.key_fn,
506 self.collector,
507 self.weight_fn,
508 self.is_hard,
509 );
510 if let Some(d) = self.expected_descriptor {
511 constraint = constraint.with_descriptor(d);
512 }
513 constraint
514 }
515}
516
517impl<S, A, K, E, Fi, KF, C, W, Sc: Score> std::fmt::Debug
518 for GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
519{
520 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521 f.debug_struct("GroupedConstraintBuilder")
522 .field("impact_type", &self.impact_type)
523 .finish()
524 }
525}