solverforge_scoring/stream/grouped_stream.rs
1// Zero-erasure grouped constraint stream for group-by constraint patterns.
2//
3// A `GroupedConstraintStream` operates on groups of entities and supports
4// filtering, weighting, and constraint finalization.
5// All type information is preserved at compile time - no Arc, no dyn.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use super::collector::UniCollector;
14use super::complemented_stream::ComplementedConstraintStream;
15use crate::constraint::grouped::GroupedUniConstraint;
16
17// Zero-erasure constraint stream over grouped entities.
18//
19// `GroupedConstraintStream` is created by `UniConstraintStream::group_by()`
20// and operates on (key, collector_result) tuples.
21//
22// All type parameters are concrete - no trait objects, no Arc allocations.
23//
24// # Type Parameters
25//
26// - `S` - Solution type
27// - `A` - Entity type
28// - `K` - Group key type
29// - `E` - Extractor function for entities
30// - `KF` - Key function
31// - `C` - Collector type
32// - `Sc` - Score type
33//
34// # Example
35//
36// ```
37// use solverforge_scoring::stream::ConstraintFactory;
38// use solverforge_scoring::stream::collector::count;
39// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
40// use solverforge_core::score::SimpleScore;
41//
42// #[derive(Clone, Hash, PartialEq, Eq)]
43// struct Shift { employee_id: usize }
44//
45// #[derive(Clone)]
46// struct Solution { shifts: Vec<Shift> }
47//
48// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
49// .for_each(|s: &Solution| &s.shifts)
50// .group_by(|shift: &Shift| shift.employee_id, count())
51// .penalize_with(|count: &usize| SimpleScore::of((*count * *count) as i64))
52// .as_constraint("Balanced workload");
53//
54// let solution = Solution {
55// shifts: vec![
56// Shift { employee_id: 1 },
57// Shift { employee_id: 1 },
58// Shift { employee_id: 1 },
59// Shift { employee_id: 2 },
60// ],
61// };
62//
63// // Employee 1: 3² = 9, Employee 2: 1² = 1, Total: -10
64// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-10));
65// ```
66pub struct GroupedConstraintStream<S, A, K, E, KF, C, Sc>
67where
68 Sc: Score,
69{
70 extractor: E,
71 key_fn: KF,
72 collector: C,
73 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
74}
75
76impl<S, A, K, E, KF, C, Sc> GroupedConstraintStream<S, A, K, E, KF, C, Sc>
77where
78 S: Send + Sync + 'static,
79 A: Clone + Send + Sync + 'static,
80 K: Clone + Eq + Hash + Send + Sync + 'static,
81 E: Fn(&S) -> &[A] + Send + Sync,
82 KF: Fn(&A) -> K + Send + Sync,
83 C: UniCollector<A> + Send + Sync + 'static,
84 C::Accumulator: Send + Sync,
85 C::Result: Clone + Send + Sync,
86 Sc: Score + 'static,
87{
88 // Creates a new zero-erasure grouped constraint stream.
89 pub(crate) fn new(extractor: E, key_fn: KF, collector: C) -> Self {
90 Self {
91 extractor,
92 key_fn,
93 collector,
94 _phantom: PhantomData,
95 }
96 }
97
98 // Penalizes each group with a weight based on the collector result.
99 //
100 // # Example
101 //
102 // ```
103 // use solverforge_scoring::stream::ConstraintFactory;
104 // use solverforge_scoring::stream::collector::count;
105 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
106 // use solverforge_core::score::SimpleScore;
107 //
108 // #[derive(Clone, Hash, PartialEq, Eq)]
109 // struct Task { priority: u32 }
110 //
111 // #[derive(Clone)]
112 // struct Solution { tasks: Vec<Task> }
113 //
114 // let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
115 // .for_each(|s: &Solution| &s.tasks)
116 // .group_by(|t: &Task| t.priority, count())
117 // .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
118 // .as_constraint("Priority distribution");
119 //
120 // let solution = Solution {
121 // tasks: vec![
122 // Task { priority: 1 },
123 // Task { priority: 1 },
124 // Task { priority: 2 },
125 // ],
126 // };
127 //
128 // // Priority 1: 2 tasks, Priority 2: 1 task, Total: -3
129 // assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-3));
130 // ```
131 pub fn penalize_with<W>(
132 self,
133 weight_fn: W,
134 ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
135 where
136 W: Fn(&C::Result) -> Sc + Send + Sync,
137 {
138 GroupedConstraintBuilder {
139 extractor: self.extractor,
140 key_fn: self.key_fn,
141 collector: self.collector,
142 impact_type: ImpactType::Penalty,
143 weight_fn,
144 is_hard: false,
145 _phantom: PhantomData,
146 }
147 }
148
149 // Penalizes each group with a weight, explicitly marked as hard constraint.
150 pub fn penalize_hard_with<W>(
151 self,
152 weight_fn: W,
153 ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
154 where
155 W: Fn(&C::Result) -> Sc + Send + Sync,
156 {
157 GroupedConstraintBuilder {
158 extractor: self.extractor,
159 key_fn: self.key_fn,
160 collector: self.collector,
161 impact_type: ImpactType::Penalty,
162 weight_fn,
163 is_hard: true,
164 _phantom: PhantomData,
165 }
166 }
167
168 // Rewards each group with a weight based on the collector result.
169 pub fn reward_with<W>(self, weight_fn: W) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
170 where
171 W: Fn(&C::Result) -> Sc + Send + Sync,
172 {
173 GroupedConstraintBuilder {
174 extractor: self.extractor,
175 key_fn: self.key_fn,
176 collector: self.collector,
177 impact_type: ImpactType::Reward,
178 weight_fn,
179 is_hard: false,
180 _phantom: PhantomData,
181 }
182 }
183
184 // Rewards each group with a weight, explicitly marked as hard constraint.
185 pub fn reward_hard_with<W>(
186 self,
187 weight_fn: W,
188 ) -> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
189 where
190 W: Fn(&C::Result) -> Sc + Send + Sync,
191 {
192 GroupedConstraintBuilder {
193 extractor: self.extractor,
194 key_fn: self.key_fn,
195 collector: self.collector,
196 impact_type: ImpactType::Reward,
197 weight_fn,
198 is_hard: true,
199 _phantom: PhantomData,
200 }
201 }
202
203 // Adds complement entities with default values for missing keys.
204 //
205 // This ensures all keys from the complement source are represented,
206 // using the grouped value if present, or the default value otherwise.
207 //
208 // **Note:** The key function for A entities wraps the original key to
209 // return `Some(K)`. For filtering (skipping entities without valid keys),
210 // use `complement_filtered` instead.
211 //
212 // # Example
213 //
214 // ```
215 // use solverforge_scoring::stream::ConstraintFactory;
216 // use solverforge_scoring::stream::collector::count;
217 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
218 // use solverforge_core::score::SimpleScore;
219 //
220 // #[derive(Clone, Hash, PartialEq, Eq)]
221 // struct Employee { id: usize }
222 //
223 // #[derive(Clone, Hash, PartialEq, Eq)]
224 // struct Shift { employee_id: usize }
225 //
226 // #[derive(Clone)]
227 // struct Schedule {
228 // employees: Vec<Employee>,
229 // shifts: Vec<Shift>,
230 // }
231 //
232 // // Count shifts per employee, including employees with 0 shifts
233 // let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
234 // .for_each(|s: &Schedule| &s.shifts)
235 // .group_by(|shift: &Shift| shift.employee_id, count())
236 // .complement(
237 // |s: &Schedule| s.employees.as_slice(),
238 // |emp: &Employee| emp.id,
239 // |_emp: &Employee| 0usize,
240 // )
241 // .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
242 // .as_constraint("Shift count");
243 //
244 // let schedule = Schedule {
245 // employees: vec![Employee { id: 0 }, Employee { id: 1 }],
246 // shifts: vec![
247 // Shift { employee_id: 0 },
248 // Shift { employee_id: 0 },
249 // ],
250 // };
251 //
252 // // Employee 0: 2, Employee 1: 0 → Total: -2
253 // assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
254 // ```
255 pub fn complement<B, EB, KB, D>(
256 self,
257 extractor_b: EB,
258 key_b: KB,
259 default_fn: D,
260 ) -> ComplementedConstraintStream<
261 S,
262 A,
263 B,
264 K,
265 E,
266 EB,
267 impl Fn(&A) -> Option<K> + Send + Sync,
268 KB,
269 C,
270 D,
271 Sc,
272 >
273 where
274 B: Clone + Send + Sync + 'static,
275 EB: Fn(&S) -> &[B] + Send + Sync,
276 KB: Fn(&B) -> K + Send + Sync,
277 D: Fn(&B) -> C::Result + Send + Sync,
278 {
279 let key_fn = self.key_fn;
280 let wrapped_key_fn = move |a: &A| Some((key_fn)(a));
281 ComplementedConstraintStream::new(
282 self.extractor,
283 extractor_b,
284 wrapped_key_fn,
285 key_b,
286 self.collector,
287 default_fn,
288 )
289 }
290
291 // Adds complement entities with a custom key function for filtering.
292 //
293 // Like `complement`, but allows providing a custom key function for A entities
294 // that returns `Option<K>`. Entities returning `None` are skipped.
295 //
296 // # Example
297 //
298 // ```
299 // use solverforge_scoring::stream::ConstraintFactory;
300 // use solverforge_scoring::stream::collector::count;
301 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
302 // use solverforge_core::score::SimpleScore;
303 //
304 // #[derive(Clone, Hash, PartialEq, Eq)]
305 // struct Employee { id: usize }
306 //
307 // #[derive(Clone, Hash, PartialEq, Eq)]
308 // struct Shift { employee_id: Option<usize> }
309 //
310 // #[derive(Clone)]
311 // struct Schedule {
312 // employees: Vec<Employee>,
313 // shifts: Vec<Shift>,
314 // }
315 //
316 // // Count shifts per employee, skipping unassigned shifts
317 // // The group_by key is ignored; complement_with_key provides its own
318 // let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
319 // .for_each(|s: &Schedule| &s.shifts)
320 // .group_by(|_shift: &Shift| 0usize, count()) // Placeholder key, will be overridden
321 // .complement_with_key(
322 // |s: &Schedule| s.employees.as_slice(),
323 // |shift: &Shift| shift.employee_id, // Option<usize>
324 // |emp: &Employee| emp.id, // usize
325 // |_emp: &Employee| 0usize,
326 // )
327 // .penalize_with(|count: &usize| SimpleScore::of(*count as i64))
328 // .as_constraint("Shift count");
329 //
330 // let schedule = Schedule {
331 // employees: vec![Employee { id: 0 }, Employee { id: 1 }],
332 // shifts: vec![
333 // Shift { employee_id: Some(0) },
334 // Shift { employee_id: Some(0) },
335 // Shift { employee_id: None }, // Skipped
336 // ],
337 // };
338 //
339 // // Employee 0: 2, Employee 1: 0 → Total: -2
340 // assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-2));
341 // ```
342 pub fn complement_with_key<B, EB, KA2, KB, D>(
343 self,
344 extractor_b: EB,
345 key_a: KA2,
346 key_b: KB,
347 default_fn: D,
348 ) -> ComplementedConstraintStream<S, A, B, K, E, EB, KA2, KB, C, D, Sc>
349 where
350 B: Clone + Send + Sync + 'static,
351 EB: Fn(&S) -> &[B] + Send + Sync,
352 KA2: Fn(&A) -> Option<K> + Send + Sync,
353 KB: Fn(&B) -> K + Send + Sync,
354 D: Fn(&B) -> C::Result + Send + Sync,
355 {
356 ComplementedConstraintStream::new(
357 self.extractor,
358 extractor_b,
359 key_a,
360 key_b,
361 self.collector,
362 default_fn,
363 )
364 }
365}
366
367impl<S, A, K, E, KF, C, Sc: Score> std::fmt::Debug
368 for GroupedConstraintStream<S, A, K, E, KF, C, Sc>
369{
370 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371 f.debug_struct("GroupedConstraintStream").finish()
372 }
373}
374
375// Zero-erasure builder for finalizing a grouped constraint.
376pub struct GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
377where
378 Sc: Score,
379{
380 extractor: E,
381 key_fn: KF,
382 collector: C,
383 impact_type: ImpactType,
384 weight_fn: W,
385 is_hard: bool,
386 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> K, fn() -> Sc)>,
387}
388
389impl<S, A, K, E, KF, C, W, Sc> GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
390where
391 S: Send + Sync + 'static,
392 A: Clone + Send + Sync + 'static,
393 K: Clone + Eq + Hash + Send + Sync + 'static,
394 E: Fn(&S) -> &[A] + Send + Sync,
395 KF: Fn(&A) -> K + Send + Sync,
396 C: UniCollector<A> + Send + Sync + 'static,
397 C::Accumulator: Send + Sync,
398 C::Result: Clone + Send + Sync,
399 W: Fn(&C::Result) -> Sc + Send + Sync,
400 Sc: Score + 'static,
401{
402 // Finalizes the builder into a zero-erasure `GroupedUniConstraint`.
403 //
404 // # Example
405 //
406 // ```
407 // use solverforge_scoring::stream::ConstraintFactory;
408 // use solverforge_scoring::stream::collector::count;
409 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
410 // use solverforge_core::score::SimpleScore;
411 //
412 // #[derive(Clone, Hash, PartialEq, Eq)]
413 // struct Item { category: u32 }
414 //
415 // #[derive(Clone)]
416 // struct Solution { items: Vec<Item> }
417 //
418 // let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
419 // .for_each(|s: &Solution| &s.items)
420 // .group_by(|i: &Item| i.category, count())
421 // .penalize_with(|n: &usize| SimpleScore::of(*n as i64))
422 // .as_constraint("Category penalty");
423 //
424 // assert_eq!(constraint.name(), "Category penalty");
425 // ```
426 pub fn as_constraint(self, name: &str) -> GroupedUniConstraint<S, A, K, E, KF, C, W, Sc> {
427 GroupedUniConstraint::new(
428 ConstraintRef::new("", name),
429 self.impact_type,
430 self.extractor,
431 self.key_fn,
432 self.collector,
433 self.weight_fn,
434 self.is_hard,
435 )
436 }
437}
438
439impl<S, A, K, E, KF, C, W, Sc: Score> std::fmt::Debug
440 for GroupedConstraintBuilder<S, A, K, E, KF, C, W, Sc>
441{
442 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443 f.debug_struct("GroupedConstraintBuilder")
444 .field("impact_type", &self.impact_type)
445 .finish()
446 }
447}