solverforge_scoring/stream/grouped_stream.rs
1// Zero-erasure grouped constraint stream for group-by constraint patterns.
2//
3// A `GroupedConstraintStream` operates on groups of entities and supports
4// filtering, weighting, and constraint finalization.
5// All type information is preserved at compile time - no Arc, no dyn.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collector::UniCollector;
14use super::complemented_stream::ComplementedConstraintStream;
15use super::filter::UniFilter;
16use crate::constraint::grouped::GroupedUniConstraint;
17
18// Zero-erasure constraint stream over grouped entities.
19//
20// `GroupedConstraintStream` is created by `UniConstraintStream::group_by()`
21// and operates on (key, collector_result) tuples.
22//
23// All type parameters are concrete - no trait objects, no Arc allocations.
24//
25// # Type Parameters
26//
27// - `S` - Solution type
28// - `A` - Entity type
29// - `K` - Group key type
30// - `E` - Extractor function for entities
31// - `Fi` - Filter type (preserved from upstream stream)
32// - `KF` - Key function
33// - `C` - Collector type
34// - `Sc` - Score type
35//
36// # Example
37//
38// ```
39// use solverforge_scoring::stream::ConstraintFactory;
40// use solverforge_scoring::stream::collector::count;
41// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
42// use solverforge_core::score::SoftScore;
43//
44// #[derive(Clone, Hash, PartialEq, Eq)]
45// struct Shift { employee_id: usize }
46//
47// #[derive(Clone)]
48// struct Solution { shifts: Vec<Shift> }
49//
50// let constraint = ConstraintFactory::<Solution, SoftScore>::new()
51// .for_each(|s: &Solution| &s.shifts)
52// .group_by(|shift: &Shift| shift.employee_id, count())
53// .penalize_with(|count: &usize| SoftScore::of((*count * *count) as i64))
54// .as_constraint("Balanced workload");
55//
56// let solution = Solution {
57// shifts: vec![
58// Shift { employee_id: 1 },
59// Shift { employee_id: 1 },
60// Shift { employee_id: 1 },
61// Shift { employee_id: 2 },
62// ],
63// };
64//
65// // Employee 1: 3² = 9, Employee 2: 1² = 1, Total: -10
66// assert_eq!(constraint.evaluate(&solution), SoftScore::of(-10));
67// ```
68pub struct GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
69where
70 Sc: Score,
71{
72 extractor: E,
73 filter: Fi,
74 key_fn: KF,
75 collector: C,
76 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
77}
78
79impl<S, A, K, E, Fi, KF, C, Sc> GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
80where
81 S: Send + Sync + 'static,
82 A: Clone + Send + Sync + 'static,
83 K: Clone + Eq + Hash + Send + Sync + 'static,
84 E: Fn(&S) -> &[A] + Send + Sync,
85 Fi: UniFilter<S, A>,
86 KF: Fn(&A) -> K + Send + Sync,
87 C: UniCollector<A> + Send + Sync + 'static,
88 C::Accumulator: Send + Sync,
89 C::Result: Clone + Send + Sync,
90 Sc: Score + 'static,
91{
92 // Creates a new zero-erasure grouped constraint stream.
93 pub(crate) fn new(extractor: E, filter: Fi, key_fn: KF, collector: C) -> Self {
94 Self {
95 extractor,
96 filter,
97 key_fn,
98 collector,
99 _phantom: PhantomData,
100 }
101 }
102
103 // Penalizes each group with a weight based on the collector result.
104 //
105 // # Example
106 //
107 // ```
108 // use solverforge_scoring::stream::ConstraintFactory;
109 // use solverforge_scoring::stream::collector::count;
110 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
111 // use solverforge_core::score::SoftScore;
112 //
113 // #[derive(Clone, Hash, PartialEq, Eq)]
114 // struct Task { priority: u32 }
115 //
116 // #[derive(Clone)]
117 // struct Solution { tasks: Vec<Task> }
118 //
119 // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
120 // .for_each(|s: &Solution| &s.tasks)
121 // .group_by(|t: &Task| t.priority, count())
122 // .penalize_with(|count: &usize| SoftScore::of(*count as i64))
123 // .as_constraint("Priority distribution");
124 //
125 // let solution = Solution {
126 // tasks: vec![
127 // Task { priority: 1 },
128 // Task { priority: 1 },
129 // Task { priority: 2 },
130 // ],
131 // };
132 //
133 // // Priority 1: 2 tasks, Priority 2: 1 task, Total: -3
134 // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-3));
135 // ```
136 pub fn penalize_with<W>(
137 self,
138 weight_fn: W,
139 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
140 where
141 W: Fn(&C::Result) -> Sc + Send + Sync,
142 {
143 GroupedConstraintBuilder {
144 extractor: self.extractor,
145 filter: self.filter,
146 key_fn: self.key_fn,
147 collector: self.collector,
148 impact_type: ImpactType::Penalty,
149 weight_fn,
150 is_hard: false,
151 expected_descriptor: None,
152 _phantom: PhantomData,
153 }
154 }
155
156 // Penalizes each group with a weight, explicitly marked as hard constraint.
157 pub fn penalize_hard_with<W>(
158 self,
159 weight_fn: W,
160 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
161 where
162 W: Fn(&C::Result) -> Sc + Send + Sync,
163 {
164 GroupedConstraintBuilder {
165 extractor: self.extractor,
166 filter: self.filter,
167 key_fn: self.key_fn,
168 collector: self.collector,
169 impact_type: ImpactType::Penalty,
170 weight_fn,
171 is_hard: true,
172 expected_descriptor: None,
173 _phantom: PhantomData,
174 }
175 }
176
177 // Rewards each group with a weight based on the collector result.
178 pub fn reward_with<W>(
179 self,
180 weight_fn: W,
181 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
182 where
183 W: Fn(&C::Result) -> Sc + Send + Sync,
184 {
185 GroupedConstraintBuilder {
186 extractor: self.extractor,
187 filter: self.filter,
188 key_fn: self.key_fn,
189 collector: self.collector,
190 impact_type: ImpactType::Reward,
191 weight_fn,
192 is_hard: false,
193 expected_descriptor: None,
194 _phantom: PhantomData,
195 }
196 }
197
198 // Rewards each group with a weight, explicitly marked as hard constraint.
199 pub fn reward_hard_with<W>(
200 self,
201 weight_fn: W,
202 ) -> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
203 where
204 W: Fn(&C::Result) -> Sc + Send + Sync,
205 {
206 GroupedConstraintBuilder {
207 extractor: self.extractor,
208 filter: self.filter,
209 key_fn: self.key_fn,
210 collector: self.collector,
211 impact_type: ImpactType::Reward,
212 weight_fn,
213 is_hard: true,
214 expected_descriptor: None,
215 _phantom: PhantomData,
216 }
217 }
218
219 // Adds complement entities with default values for missing keys.
220 //
221 // This ensures all keys from the complement source are represented,
222 // using the grouped value if present, or the default value otherwise.
223 //
224 // **Note:** The key function for A entities wraps the original key to
225 // return `Some(K)`. For filtering (skipping entities without valid keys),
226 // use `complement_filtered` instead.
227 //
228 // # Example
229 //
230 // ```
231 // use solverforge_scoring::stream::ConstraintFactory;
232 // use solverforge_scoring::stream::collector::count;
233 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
234 // use solverforge_core::score::SoftScore;
235 //
236 // #[derive(Clone, Hash, PartialEq, Eq)]
237 // struct Employee { id: usize }
238 //
239 // #[derive(Clone, Hash, PartialEq, Eq)]
240 // struct Shift { employee_id: usize }
241 //
242 // #[derive(Clone)]
243 // struct Schedule {
244 // employees: Vec<Employee>,
245 // shifts: Vec<Shift>,
246 // }
247 //
248 // // Count shifts per employee, including employees with 0 shifts
249 // let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
250 // .for_each(|s: &Schedule| &s.shifts)
251 // .group_by(|shift: &Shift| shift.employee_id, count())
252 // .complement(
253 // |s: &Schedule| s.employees.as_slice(),
254 // |emp: &Employee| emp.id,
255 // |_emp: &Employee| 0usize,
256 // )
257 // .penalize_with(|count: &usize| SoftScore::of(*count as i64))
258 // .as_constraint("Shift count");
259 //
260 // let schedule = Schedule {
261 // employees: vec![Employee { id: 0 }, Employee { id: 1 }],
262 // shifts: vec![
263 // Shift { employee_id: 0 },
264 // Shift { employee_id: 0 },
265 // ],
266 // };
267 //
268 // // Employee 0: 2, Employee 1: 0 → Total: -2
269 // assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
270 // ```
271 pub fn complement<B, EB, KB, D>(
272 self,
273 extractor_b: EB,
274 key_b: KB,
275 default_fn: D,
276 ) -> ComplementedConstraintStream<
277 S,
278 A,
279 B,
280 K,
281 E,
282 EB,
283 impl Fn(&A) -> Option<K> + Send + Sync,
284 KB,
285 C,
286 D,
287 Sc,
288 >
289 where
290 B: Clone + Send + Sync + 'static,
291 EB: Fn(&S) -> &[B] + Send + Sync,
292 KB: Fn(&B) -> K + Send + Sync,
293 D: Fn(&B) -> C::Result + Send + Sync,
294 {
295 let key_fn = self.key_fn;
296 let wrapped_key_fn = move |a: &A| Some((key_fn)(a));
297 ComplementedConstraintStream::new(
298 self.extractor,
299 extractor_b,
300 wrapped_key_fn,
301 key_b,
302 self.collector,
303 default_fn,
304 )
305 }
306
307 // Adds complement entities with a custom key function for filtering.
308 //
309 // Like `complement`, but allows providing a custom key function for A entities
310 // that returns `Option<K>`. Entities returning `None` are skipped.
311 //
312 // # Example
313 //
314 // ```
315 // use solverforge_scoring::stream::ConstraintFactory;
316 // use solverforge_scoring::stream::collector::count;
317 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
318 // use solverforge_core::score::SoftScore;
319 //
320 // #[derive(Clone, Hash, PartialEq, Eq)]
321 // struct Employee { id: usize }
322 //
323 // #[derive(Clone, Hash, PartialEq, Eq)]
324 // struct Shift { employee_id: Option<usize> }
325 //
326 // #[derive(Clone)]
327 // struct Schedule {
328 // employees: Vec<Employee>,
329 // shifts: Vec<Shift>,
330 // }
331 //
332 // // Count shifts per employee, skipping unassigned shifts
333 // // The group_by key is ignored; complement_with_key provides its own
334 // let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
335 // .for_each(|s: &Schedule| &s.shifts)
336 // .group_by(|_shift: &Shift| 0usize, count()) // Placeholder key, will be overridden
337 // .complement_with_key(
338 // |s: &Schedule| s.employees.as_slice(),
339 // |shift: &Shift| shift.employee_id, // Option<usize>
340 // |emp: &Employee| emp.id, // usize
341 // |_emp: &Employee| 0usize,
342 // )
343 // .penalize_with(|count: &usize| SoftScore::of(*count as i64))
344 // .as_constraint("Shift count");
345 //
346 // let schedule = Schedule {
347 // employees: vec![Employee { id: 0 }, Employee { id: 1 }],
348 // shifts: vec![
349 // Shift { employee_id: Some(0) },
350 // Shift { employee_id: Some(0) },
351 // Shift { employee_id: None }, // Skipped
352 // ],
353 // };
354 //
355 // // Employee 0: 2, Employee 1: 0 → Total: -2
356 // assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-2));
357 // ```
358 pub fn complement_with_key<B, EB, KA2, KB, D>(
359 self,
360 extractor_b: EB,
361 key_a: KA2,
362 key_b: KB,
363 default_fn: D,
364 ) -> ComplementedConstraintStream<S, A, B, K, E, EB, KA2, KB, C, D, Sc>
365 where
366 B: Clone + Send + Sync + 'static,
367 EB: Fn(&S) -> &[B] + Send + Sync,
368 KA2: Fn(&A) -> Option<K> + Send + Sync,
369 KB: Fn(&B) -> K + Send + Sync,
370 D: Fn(&B) -> C::Result + Send + Sync,
371 {
372 ComplementedConstraintStream::new(
373 self.extractor,
374 extractor_b,
375 key_a,
376 key_b,
377 self.collector,
378 default_fn,
379 )
380 }
381}
382
383impl<S, A, K, E, Fi, KF, C, Sc: Score> std::fmt::Debug
384 for GroupedConstraintStream<S, A, K, E, Fi, KF, C, Sc>
385{
386 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387 f.debug_struct("GroupedConstraintStream").finish()
388 }
389}
390
391// Zero-erasure builder for finalizing a grouped constraint.
392pub struct GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
393where
394 Sc: Score,
395{
396 extractor: E,
397 filter: Fi,
398 key_fn: KF,
399 collector: C,
400 impact_type: ImpactType,
401 weight_fn: W,
402 is_hard: bool,
403 expected_descriptor: Option<usize>,
404 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
405}
406
407impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
408where
409 S: Send + Sync + 'static,
410 A: Clone + Send + Sync + 'static,
411 K: Clone + Eq + Hash + Send + Sync + 'static,
412 E: Fn(&S) -> &[A] + Send + Sync,
413 Fi: UniFilter<S, A>,
414 KF: Fn(&A) -> K + Send + Sync,
415 C: UniCollector<A> + Send + Sync + 'static,
416 C::Accumulator: Send + Sync,
417 C::Result: Clone + Send + Sync,
418 W: Fn(&C::Result) -> Sc + Send + Sync,
419 Sc: Score + 'static,
420{
421 // Finalizes the builder into a zero-erasure `GroupedUniConstraint`.
422 //
423 // # Example
424 //
425 // ```
426 // use solverforge_scoring::stream::ConstraintFactory;
427 // use solverforge_scoring::stream::collector::count;
428 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
429 // use solverforge_core::score::SoftScore;
430 //
431 // #[derive(Clone, Hash, PartialEq, Eq)]
432 // struct Item { category: u32 }
433 //
434 // #[derive(Clone)]
435 // struct Solution { items: Vec<Item> }
436 //
437 // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
438 // .for_each(|s: &Solution| &s.items)
439 // .group_by(|i: &Item| i.category, count())
440 // .penalize_with(|n: &usize| SoftScore::of(*n as i64))
441 // .as_constraint("Category penalty");
442 //
443 // assert_eq!(constraint.name(), "Category penalty");
444 // ```
445 pub fn for_descriptor(mut self, descriptor_index: usize) -> Self {
446 self.expected_descriptor = Some(descriptor_index);
447 self
448 }
449
450 pub fn as_constraint(self, name: &str) -> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc> {
451 let mut constraint = GroupedUniConstraint::new(
452 ConstraintRef::new("", name),
453 self.impact_type,
454 self.extractor,
455 self.filter,
456 self.key_fn,
457 self.collector,
458 self.weight_fn,
459 self.is_hard,
460 );
461 if let Some(d) = self.expected_descriptor {
462 constraint = constraint.with_descriptor(d);
463 }
464 constraint
465 }
466}
467
468impl<S, A, K, E, Fi, KF, C, W, Sc: Score> std::fmt::Debug
469 for GroupedConstraintBuilder<S, A, K, E, Fi, KF, C, W, Sc>
470{
471 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472 f.debug_struct("GroupedConstraintBuilder")
473 .field("impact_type", &self.impact_type)
474 .finish()
475 }
476}