solverforge_scoring/stream/grouped_stream.rs
1// Zero-erasure grouped constraint stream for group-by constraint patterns.
2//
3// A `GroupedConstraintStream` operates on groups of entities and supports
4// filtering, weighting, and constraint finalization.
5// All type information is preserved at compile time - no Arc, no dyn.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collector::UniCollector;
14use super::complemented_stream::ComplementedConstraintStream;
15use crate::constraint::grouped::GroupedUniConstraint;
16
17// Zero-erasure constraint stream over grouped entities.
18//
19// `GroupedConstraintStream` is created by `UniConstraintStream::group_by()`
20// and operates on (key, collector_result) tuples.
21//
22// All type parameters are concrete - no trait objects, no Arc allocations.
23//
24// # Type Parameters
25//
26// - `S` - Solution type
27// - `A` - Entity type
28// - `K` - Group key type
29// - `E` - Extractor function for entities
30// - `KF` - Key function
31// - `C` - Collector type
32// - `Sc` - Score type
33//
34// # Example
35//
36// ```
37// use solverforge_scoring::stream::ConstraintFactory;
38// use solverforge_scoring::stream::collector::count;
39// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
40// use solverforge_core::score::SimpleScore;
41//
42// #[derive(Clone, Hash, PartialEq, Eq)]
43// struct Shift { employee_id: usize }
44//
45// #[derive(Clone)]
46// struct Solution { shifts: Vec<Shift> }
47//
48// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
49// .for_each(|s: &Solution| &s.shifts)
50// .group_by(|shift: &Shift| shift.employee_id, count())
51// .penalize_with(|count: &usize| SimpleScore::of((*count * *count) as i64))
52// .as_constraint("Balanced workload");
53//
54// let solution = Solution {
55// shifts: vec![
56// Shift { employee_id: 1 },
57// Shift { employee_id: 1 },
58// Shift { employee_id: 1 },
59// Shift { employee_id: 2 },
60// ],
61// };
62//
63// // Employee 1: 3² = 9, Employee 2: 1² = 1, Total: -10
64// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-10));
65// ```
66pub struct GroupedConstraintStream<S, A, K, E, KF, C, Sc>
67where
68 Sc: Score,
69{
70 extractor: E,
71 key_fn: KF,
72 collector: C,
73 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
74}
75
76impl<S, A, K, E, KF, C, Sc> GroupedConstraintStream<S, A, K, E, KF, C, Sc>
77where
78 S: Send + Sync + 'static,
79 A: Clone + Send + Sync + 'static,
80 K: Clone + Eq + Hash + Send + Sync + 'static,
81 E: Fn(&S) -> &[A] + Send + Sync,
82 KF: Fn(&A) -> K + Send + Sync,
83 C: UniCollector<A> + Send + Sync + 'static,
84 C::Accumulator: Send + Sync,
85 C::Result: Clone + Send + Sync,
86 Sc: Score + 'static,
87{
88 // Creates a new zero-erasure grouped constraint stream.
89 pub(crate) fn new(extractor: E, key_fn: KF, collector: C) -> Self {
90 Self {
91 extractor,
92 key_fn,
93 collector,
94 _phantom: PhantomData,
95 }
96 }
97
98 // Penalizes each group with a weight based on the collector result.
99 //
100 // # Example
101 //
102 // ```
103 // use solverforge_scoring::stream::ConstraintFactory;
104 // use solverforge_scoring::stream::collector::count;
105 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
106 // use solverforge_core::score::SimpleScore;
107 //
108 // #[derive(Clone, Hash, PartialEq, Eq)]
109 // struct Task { priority: u32 }
110 //
111 // #[derive(Clone)]
112 // struct Solution { tasks: Vec<Task> }
113 //
114 // let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
115 // .for_each(|s: &Solution| &s.tasks)
116 // .group_by(|t: &Task| t.priority, count())
117 // .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
118 // .as_constraint("Priority distribution");
119 //
120 // let solution = Solution {
121 // tasks: vec![
122 // Task { priority: 1 },
123 // Task { priority: 1 },
124 // Task { priority: 2 },
125 // ],
126 // };
127 //
128 // // Priority 1: 2 tasks, Priority 2: 1 task, Total: -3
129 // assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-3));
130 // ```
131 pub fn penalize_with<W>(
132 self,
133 weight_fn: W,
134 ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
135 where
136 W: Fn(&C::Result) -> Sc + Send + Sync,
137 {
138 GroupedConstraintBuilder {
139 extractor: self.extractor,
140 key_fn: self.key_fn,
141 collector: self.collector,
142 impact_type: ImpactType::Penalty,
143 weight_fn,
144 is_hard: false,
145 expected_descriptor: None,
146 _phantom: PhantomData,
147 }
148 }
149
150 // Penalizes each group with a weight, explicitly marked as hard constraint.
151 pub fn penalize_hard_with<W>(
152 self,
153 weight_fn: W,
154 ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
155 where
156 W: Fn(&C::Result) -> Sc + Send + Sync,
157 {
158 GroupedConstraintBuilder {
159 extractor: self.extractor,
160 key_fn: self.key_fn,
161 collector: self.collector,
162 impact_type: ImpactType::Penalty,
163 weight_fn,
164 is_hard: true,
165 expected_descriptor: None,
166 _phantom: PhantomData,
167 }
168 }
169
170 // Rewards each group with a weight based on the collector result.
171 pub fn reward_with<W>(self, weight_fn: W) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
172 where
173 W: Fn(&C::Result) -> Sc + Send + Sync,
174 {
175 GroupedConstraintBuilder {
176 extractor: self.extractor,
177 key_fn: self.key_fn,
178 collector: self.collector,
179 impact_type: ImpactType::Reward,
180 weight_fn,
181 is_hard: false,
182 expected_descriptor: None,
183 _phantom: PhantomData,
184 }
185 }
186
187 // Rewards each group with a weight, explicitly marked as hard constraint.
188 pub fn reward_hard_with<W>(
189 self,
190 weight_fn: W,
191 ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
192 where
193 W: Fn(&C::Result) -> Sc + Send + Sync,
194 {
195 GroupedConstraintBuilder {
196 extractor: self.extractor,
197 key_fn: self.key_fn,
198 collector: self.collector,
199 impact_type: ImpactType::Reward,
200 weight_fn,
201 is_hard: true,
202 expected_descriptor: None,
203 _phantom: PhantomData,
204 }
205 }
206
207 // Adds complement entities with default values for missing keys.
208 //
209 // This ensures all keys from the complement source are represented,
210 // using the grouped value if present, or the default value otherwise.
211 //
212 // **Note:** The key function for A entities wraps the original key to
213 // return `Some(K)`. For filtering (skipping entities without valid keys),
214 // use `complement_filtered` instead.
215 //
216 // # Example
217 //
218 // ```
219 // use solverforge_scoring::stream::ConstraintFactory;
220 // use solverforge_scoring::stream::collector::count;
221 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
222 // use solverforge_core::score::SimpleScore;
223 //
224 // #[derive(Clone, Hash, PartialEq, Eq)]
225 // struct Employee { id: usize }
226 //
227 // #[derive(Clone, Hash, PartialEq, Eq)]
228 // struct Shift { employee_id: usize }
229 //
230 // #[derive(Clone)]
231 // struct Schedule {
232 // employees: Vec<Employee>,
233 // shifts: Vec<Shift>,
234 // }
235 //
236 // // Count shifts per employee, including employees with 0 shifts
237 // let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
238 // .for_each(|s: &Schedule| &s.shifts)
239 // .group_by(|shift: &Shift| shift.employee_id, count())
240 // .complement(
241 // |s: &Schedule| s.employees.as_slice(),
242 // |emp: &Employee| emp.id,
243 // |_emp: &Employee| 0usize,
244 // )
245 // .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
246 // .as_constraint("Shift count");
247 //
248 // let schedule = Schedule {
249 // employees: vec![Employee { id: 0 }, Employee { id: 1 }],
250 // shifts: vec![
251 // Shift { employee_id: 0 },
252 // Shift { employee_id: 0 },
253 // ],
254 // };
255 //
256 // // Employee 0: 2, Employee 1: 0 → Total: -2
257 // assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
258 // ```
259 pub fn complement<B, EB, KB, D>(
260 self,
261 extractor_b: EB,
262 key_b: KB,
263 default_fn: D,
264 ) -> ComplementedConstraintStream<
265 S,
266 A,
267 B,
268 K,
269 E,
270 EB,
271 impl Fn(&A) -> Option<K> + Send + Sync,
272 KB,
273 C,
274 D,
275 Sc,
276 >
277 where
278 B: Clone + Send + Sync + 'static,
279 EB: Fn(&S) -> &[B] + Send + Sync,
280 KB: Fn(&B) -> K + Send + Sync,
281 D: Fn(&B) -> C::Result + Send + Sync,
282 {
283 let key_fn = self.key_fn;
284 let wrapped_key_fn = move |a: &A| Some((key_fn)(a));
285 ComplementedConstraintStream::new(
286 self.extractor,
287 extractor_b,
288 wrapped_key_fn,
289 key_b,
290 self.collector,
291 default_fn,
292 )
293 }
294
295 // Adds complement entities with a custom key function for filtering.
296 //
297 // Like `complement`, but allows providing a custom key function for A entities
298 // that returns `Option<K>`. Entities returning `None` are skipped.
299 //
300 // # Example
301 //
302 // ```
303 // use solverforge_scoring::stream::ConstraintFactory;
304 // use solverforge_scoring::stream::collector::count;
305 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
306 // use solverforge_core::score::SimpleScore;
307 //
308 // #[derive(Clone, Hash, PartialEq, Eq)]
309 // struct Employee { id: usize }
310 //
311 // #[derive(Clone, Hash, PartialEq, Eq)]
312 // struct Shift { employee_id: Option<usize> }
313 //
314 // #[derive(Clone)]
315 // struct Schedule {
316 // employees: Vec<Employee>,
317 // shifts: Vec<Shift>,
318 // }
319 //
320 // // Count shifts per employee, skipping unassigned shifts
321 // // The group_by key is ignored; complement_with_key provides its own
322 // let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
323 // .for_each(|s: &Schedule| &s.shifts)
324 // .group_by(|_shift: &Shift| 0usize, count()) // Placeholder key, will be overridden
325 // .complement_with_key(
326 // |s: &Schedule| s.employees.as_slice(),
327 // |shift: &Shift| shift.employee_id, // Option<usize>
328 // |emp: &Employee| emp.id, // usize
329 // |_emp: &Employee| 0usize,
330 // )
331 // .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
332 // .as_constraint("Shift count");
333 //
334 // let schedule = Schedule {
335 // employees: vec![Employee { id: 0 }, Employee { id: 1 }],
336 // shifts: vec![
337 // Shift { employee_id: Some(0) },
338 // Shift { employee_id: Some(0) },
339 // Shift { employee_id: None }, // Skipped
340 // ],
341 // };
342 //
343 // // Employee 0: 2, Employee 1: 0 → Total: -2
344 // assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
345 // ```
346 pub fn complement_with_key<B, EB, KA2, KB, D>(
347 self,
348 extractor_b: EB,
349 key_a: KA2,
350 key_b: KB,
351 default_fn: D,
352 ) -> ComplementedConstraintStream<S, A, B, K, E, EB, KA2, KB, C, D, Sc>
353 where
354 B: Clone + Send + Sync + 'static,
355 EB: Fn(&S) -> &[B] + Send + Sync,
356 KA2: Fn(&A) -> Option<K> + Send + Sync,
357 KB: Fn(&B) -> K + Send + Sync,
358 D: Fn(&B) -> C::Result + Send + Sync,
359 {
360 ComplementedConstraintStream::new(
361 self.extractor,
362 extractor_b,
363 key_a,
364 key_b,
365 self.collector,
366 default_fn,
367 )
368 }
369}
370
371impl<S, A, K, E, KF, C, Sc: Score> std::fmt::Debug
372 for GroupedConstraintStream<S, A, K, E, KF, C, Sc>
373{
374 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375 f.debug_struct("GroupedConstraintStream").finish()
376 }
377}
378
379// Zero-erasure builder for finalizing a grouped constraint.
380pub struct GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
381where
382 Sc: Score,
383{
384 extractor: E,
385 key_fn: KF,
386 collector: C,
387 impact_type: ImpactType,
388 weight_fn: W,
389 is_hard: bool,
390 expected_descriptor: Option<usize>,
391 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
392}
393
394impl<S, A, K, E, KF, C, W, Sc> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
395where
396 S: Send + Sync + 'static,
397 A: Clone + Send + Sync + 'static,
398 K: Clone + Eq + Hash + Send + Sync + 'static,
399 E: Fn(&S) -> &[A] + Send + Sync,
400 KF: Fn(&A) -> K + Send + Sync,
401 C: UniCollector<A> + Send + Sync + 'static,
402 C::Accumulator: Send + Sync,
403 C::Result: Clone + Send + Sync,
404 W: Fn(&C::Result) -> Sc + Send + Sync,
405 Sc: Score + 'static,
406{
407 // Finalizes the builder into a zero-erasure `GroupedUniConstraint`.
408 //
409 // # Example
410 //
411 // ```
412 // use solverforge_scoring::stream::ConstraintFactory;
413 // use solverforge_scoring::stream::collector::count;
414 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
415 // use solverforge_core::score::SimpleScore;
416 //
417 // #[derive(Clone, Hash, PartialEq, Eq)]
418 // struct Item { category: u32 }
419 //
420 // #[derive(Clone)]
421 // struct Solution { items: Vec<Item> }
422 //
423 // let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
424 // .for_each(|s: &Solution| &s.items)
425 // .group_by(|i: &Item| i.category, count())
426 // .penalize_with(|n: &usize| SimpleScore::of(*n as i64))
427 // .as_constraint("Category penalty");
428 //
429 // assert_eq!(constraint.name(), "Category penalty");
430 // ```
431 pub fn for_descriptor(mut self, descriptor_index: usize) -> Self {
432 self.expected_descriptor = Some(descriptor_index);
433 self
434 }
435
436 pub fn as_constraint(self, name: &str) -> GroupedUniConstraint<S, A, K, E, KF, C, W, Sc> {
437 let mut constraint = GroupedUniConstraint::new(
438 ConstraintRef::new("", name),
439 self.impact_type,
440 self.extractor,
441 self.key_fn,
442 self.collector,
443 self.weight_fn,
444 self.is_hard,
445 );
446 if let Some(d) = self.expected_descriptor {
447 constraint = constraint.with_descriptor(d);
448 }
449 constraint
450 }
451}
452
453impl<S, A, K, E, KF, C, W, Sc: Score> std::fmt::Debug
454 for GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
455{
456 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457 f.debug_struct("GroupedConstraintBuilder")
458 .field("impact_type", &self.impact_type)
459 .finish()
460 }
461}