solverforge_scoring/stream/uni_stream.rs
1/* Zero-erasure uni-constraint stream for single-entity constraint patterns.
2
3A `UniConstraintStream` operates on a single entity type and supports
4filtering, weighting, and constraint finalization. All type information
5is preserved at compile time - no Arc, no dyn, fully monomorphized.
6*/
7
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11use solverforge_core::score::Score;
12use solverforge_core::{ConstraintRef, ImpactType};
13
14use crate::constraint::incremental::IncrementalUniConstraint;
15
16use crate::constraint::if_exists::ExistenceMode;
17
18use super::balance_stream::BalanceConstraintStream;
19use super::collection_extract::CollectionExtract;
20use super::collector::UniCollector;
21use super::filter::{AndUniFilter, FnUniFilter, TrueFilter, UniFilter};
22use super::grouped_stream::GroupedConstraintStream;
23use super::if_exists_stream::IfExistsStream;
24use super::joiner::EqualJoiner;
25
26/* Zero-erasure constraint stream over a single entity type.
27
28`UniConstraintStream` accumulates filters and can be finalized into
29an `IncrementalUniConstraint` via `penalize()` or `reward()`.
30
31All type parameters are concrete - no trait objects, no Arc allocations
32in the hot path.
33
34# Type Parameters
35
36- `S` - Solution type
37- `A` - Entity type
38- `E` - Extractor function type
39- `F` - Combined filter type
40- `Sc` - Score type
41*/
42pub struct UniConstraintStream<S, A, E, F, Sc>
43where
44 Sc: Score,
45{
46 extractor: E,
47 filter: F,
48 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
49}
50
51impl<S, A, E, Sc> UniConstraintStream<S, A, E, TrueFilter, Sc>
52where
53 S: Send + Sync + 'static,
54 A: Clone + Send + Sync + 'static,
55 E: CollectionExtract<S, Item = A>,
56 Sc: Score + 'static,
57{
58 // Creates a new uni-constraint stream with the given extractor.
59 pub fn new(extractor: E) -> Self {
60 Self {
61 extractor,
62 filter: TrueFilter,
63 _phantom: PhantomData,
64 }
65 }
66}
67
68impl<S, A, E, F, Sc> UniConstraintStream<S, A, E, F, Sc>
69where
70 S: Send + Sync + 'static,
71 A: Clone + Send + Sync + 'static,
72 E: CollectionExtract<S, Item = A>,
73 F: UniFilter<S, A>,
74 Sc: Score + 'static,
75{
76 /* Adds a filter predicate to the stream.
77
78 Multiple filters are combined with AND semantics at compile time.
79 Each filter adds a new type layer, preserving zero-erasure.
80
81 To access related entities, use shadow variables on your entity type
82 (e.g., `#[inverse_relation_shadow_variable]`) rather than solution traversal.
83 */
84 pub fn filter<P>(
85 self,
86 predicate: P,
87 ) -> UniConstraintStream<
88 S,
89 A,
90 E,
91 AndUniFilter<F, FnUniFilter<impl Fn(&S, &A) -> bool + Send + Sync>>,
92 Sc,
93 >
94 where
95 P: Fn(&A) -> bool + Send + Sync + 'static,
96 {
97 UniConstraintStream {
98 extractor: self.extractor,
99 filter: AndUniFilter::new(
100 self.filter,
101 FnUniFilter::new(move |_s: &S, a: &A| predicate(a)),
102 ),
103 _phantom: PhantomData,
104 }
105 }
106
107 /* Joins this stream using the provided join target.
108
109 Dispatch depends on argument type:
110 - `equal(|a: &A| key_fn(a))` — self-join, returns `BiConstraintStream`
111 - `(extractor_b, equal_bi(ka, kb))` — keyed cross-join, returns `CrossBiConstraintStream`
112 - `(other_stream, |a, b| predicate)` — predicate cross-join O(n*m)
113 */
114 pub fn join<J>(self, target: J) -> J::Output
115 where
116 J: super::join_target::JoinTarget<S, A, E, F, Sc>,
117 {
118 target.apply(self.extractor, self.filter)
119 }
120
121 /* Groups entities by key and aggregates with a collector.
122
123 Returns a zero-erasure `GroupedConstraintStream` that can be penalized
124 or rewarded based on the aggregated result for each group.
125 */
126 pub fn group_by<K, KF, C>(
127 self,
128 key_fn: KF,
129 collector: C,
130 ) -> GroupedConstraintStream<S, A, K, E, F, KF, C, Sc>
131 where
132 K: Clone + Eq + Hash + Send + Sync + 'static,
133 KF: Fn(&A) -> K + Send + Sync,
134 C: UniCollector<A> + Send + Sync + 'static,
135 C::Accumulator: Send + Sync,
136 C::Result: Clone + Send + Sync,
137 {
138 GroupedConstraintStream::new(self.extractor, self.filter, key_fn, collector)
139 }
140
141 /* Creates a balance constraint that penalizes uneven distribution across groups.
142
143 Unlike `group_by` which scores each group independently, `balance` computes
144 a GLOBAL standard deviation across all group counts and produces a single score.
145
146 The `key_fn` returns `Option<K>` to allow skipping entities (e.g., unassigned shifts).
147 Any filters accumulated on this stream are also applied.
148
149 # Example
150
151 ```
152 use solverforge_scoring::stream::ConstraintFactory;
153 use solverforge_scoring::api::constraint_set::IncrementalConstraint;
154 use solverforge_core::score::SoftScore;
155
156 #[derive(Clone)]
157 struct Shift { employee_id: Option<usize> }
158
159 #[derive(Clone)]
160 struct Solution { shifts: Vec<Shift> }
161
162 let constraint = ConstraintFactory::<Solution, SoftScore>::new()
163 .for_each(|s: &Solution| &s.shifts)
164 .balance(|shift: &Shift| shift.employee_id)
165 .penalize(SoftScore::of(1000))
166 .named("Balance workload");
167
168 let solution = Solution {
169 shifts: vec![
170 Shift { employee_id: Some(0) },
171 Shift { employee_id: Some(0) },
172 Shift { employee_id: Some(0) },
173 Shift { employee_id: Some(1) },
174 ],
175 };
176
177 // Employee 0: 3 shifts, Employee 1: 1 shift
178 // std_dev = 1.0, penalty = -1000
179 assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1000));
180 ```
181 */
182 pub fn balance<K, KF>(self, key_fn: KF) -> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
183 where
184 K: Clone + Eq + Hash + Send + Sync + 'static,
185 KF: Fn(&A) -> Option<K> + Send + Sync,
186 {
187 BalanceConstraintStream::new(self.extractor, self.filter, key_fn)
188 }
189
190 /* Filters A entities based on whether a matching B entity exists.
191
192 Use this when the B collection needs filtering (e.g., only vacationing employees).
193 The `extractor_b` returns a `Vec<B>` to allow for filtering.
194
195 Any filters accumulated on this stream are applied to A entities.
196
197 # Example
198
199 ```
200 use solverforge_scoring::stream::ConstraintFactory;
201 use solverforge_scoring::stream::joiner::equal_bi;
202 use solverforge_scoring::api::constraint_set::IncrementalConstraint;
203 use solverforge_core::score::SoftScore;
204
205 #[derive(Clone)]
206 struct Shift { id: usize, employee_idx: Option<usize> }
207
208 #[derive(Clone)]
209 struct Employee { id: usize, on_vacation: bool }
210
211 #[derive(Clone)]
212 struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
213
214 // Penalize shifts assigned to employees who are on vacation
215 let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
216 .for_each(|s: &Schedule| s.shifts.as_slice())
217 .filter(|shift: &Shift| shift.employee_idx.is_some())
218 .if_exists_filtered(
219 |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect(),
220 equal_bi(
221 |shift: &Shift| shift.employee_idx,
222 |emp: &Employee| Some(emp.id),
223 ),
224 )
225 .penalize(SoftScore::of(1))
226 .named("Vacation conflict");
227
228 let schedule = Schedule {
229 shifts: vec![
230 Shift { id: 0, employee_idx: Some(0) }, // assigned to vacationing emp
231 Shift { id: 1, employee_idx: Some(1) }, // assigned to working emp
232 Shift { id: 2, employee_idx: None }, // unassigned (filtered out)
233 ],
234 employees: vec![
235 Employee { id: 0, on_vacation: true },
236 Employee { id: 1, on_vacation: false },
237 ],
238 };
239
240 // Only shift 0 matches (assigned to employee 0 who is on vacation)
241 assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-1));
242 ```
243 */
244 pub fn if_exists_filtered<B, EB, K, KA, KB>(
245 self,
246 extractor_b: EB,
247 joiner: EqualJoiner<KA, KB, K>,
248 ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
249 where
250 B: Clone + Send + Sync + 'static,
251 EB: Fn(&S) -> Vec<B> + Send + Sync,
252 K: Eq + Hash + Clone + Send + Sync,
253 KA: Fn(&A) -> K + Send + Sync,
254 KB: Fn(&B) -> K + Send + Sync,
255 {
256 let (key_a, key_b) = joiner.into_keys();
257 IfExistsStream::new(
258 ExistenceMode::Exists,
259 self.extractor,
260 extractor_b,
261 key_a,
262 key_b,
263 self.filter,
264 )
265 }
266
267 /* Filters A entities based on whether NO matching B entity exists.
268
269 Use this when the B collection needs filtering.
270 The `extractor_b` returns a `Vec<B>` to allow for filtering.
271
272 Any filters accumulated on this stream are applied to A entities.
273
274 # Example
275
276 ```
277 use solverforge_scoring::stream::ConstraintFactory;
278 use solverforge_scoring::stream::joiner::equal_bi;
279 use solverforge_scoring::api::constraint_set::IncrementalConstraint;
280 use solverforge_core::score::SoftScore;
281
282 #[derive(Clone)]
283 struct Task { id: usize, assignee: Option<usize> }
284
285 #[derive(Clone)]
286 struct Worker { id: usize, available: bool }
287
288 #[derive(Clone)]
289 struct Schedule { tasks: Vec<Task>, workers: Vec<Worker> }
290
291 // Penalize tasks assigned to workers who are not available
292 let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
293 .for_each(|s: &Schedule| s.tasks.as_slice())
294 .filter(|task: &Task| task.assignee.is_some())
295 .if_not_exists_filtered(
296 |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
297 equal_bi(
298 |task: &Task| task.assignee,
299 |worker: &Worker| Some(worker.id),
300 ),
301 )
302 .penalize(SoftScore::of(1))
303 .named("Unavailable worker");
304
305 let schedule = Schedule {
306 tasks: vec![
307 Task { id: 0, assignee: Some(0) }, // worker 0 is unavailable
308 Task { id: 1, assignee: Some(1) }, // worker 1 is available
309 Task { id: 2, assignee: None }, // unassigned (filtered out)
310 ],
311 workers: vec![
312 Worker { id: 0, available: false },
313 Worker { id: 1, available: true },
314 ],
315 };
316
317 // Task 0's worker (id=0) is NOT in the available workers list
318 assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-1));
319 ```
320 */
321 pub fn if_not_exists_filtered<B, EB, K, KA, KB>(
322 self,
323 extractor_b: EB,
324 joiner: EqualJoiner<KA, KB, K>,
325 ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
326 where
327 B: Clone + Send + Sync + 'static,
328 EB: Fn(&S) -> Vec<B> + Send + Sync,
329 K: Eq + Hash + Clone + Send + Sync,
330 KA: Fn(&A) -> K + Send + Sync,
331 KB: Fn(&B) -> K + Send + Sync,
332 {
333 let (key_a, key_b) = joiner.into_keys();
334 IfExistsStream::new(
335 ExistenceMode::NotExists,
336 self.extractor,
337 extractor_b,
338 key_a,
339 key_b,
340 self.filter,
341 )
342 }
343
344 // Penalizes each matching entity with a fixed weight.
345 pub fn penalize(
346 self,
347 weight: Sc,
348 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
349 where
350 Sc: Copy,
351 {
352 // Detect if this is a hard constraint by checking if hard level is non-zero
353 let is_hard = weight
354 .to_level_numbers()
355 .first()
356 .map(|&h| h != 0)
357 .unwrap_or(false);
358 UniConstraintBuilder {
359 extractor: self.extractor,
360 filter: self.filter,
361 impact_type: ImpactType::Penalty,
362 weight: move |_: &A| weight,
363 is_hard,
364 expected_descriptor: None,
365 _phantom: PhantomData,
366 }
367 }
368
369 /* Penalizes each matching entity with a dynamic weight.
370
371 Note: For dynamic weights, use `penalize_hard_with` to explicitly mark as a hard constraint,
372 since the weight function cannot be evaluated at build time.
373 */
374 pub fn penalize_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
375 where
376 W: Fn(&A) -> Sc + Send + Sync,
377 {
378 UniConstraintBuilder {
379 extractor: self.extractor,
380 filter: self.filter,
381 impact_type: ImpactType::Penalty,
382 weight: weight_fn,
383 is_hard: false, // Can't detect at build time; use penalize_hard_with for hard constraints
384 expected_descriptor: None,
385 _phantom: PhantomData,
386 }
387 }
388
389 // Penalizes each matching entity with a dynamic weight, explicitly marked as a hard constraint.
390 pub fn penalize_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
391 where
392 W: Fn(&A) -> Sc + Send + Sync,
393 {
394 UniConstraintBuilder {
395 extractor: self.extractor,
396 filter: self.filter,
397 impact_type: ImpactType::Penalty,
398 weight: weight_fn,
399 is_hard: true,
400 expected_descriptor: None,
401 _phantom: PhantomData,
402 }
403 }
404
405 // Rewards each matching entity with a fixed weight.
406 pub fn reward(
407 self,
408 weight: Sc,
409 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
410 where
411 Sc: Copy,
412 {
413 // Detect if this is a hard constraint by checking if hard level is non-zero
414 let is_hard = weight
415 .to_level_numbers()
416 .first()
417 .map(|&h| h != 0)
418 .unwrap_or(false);
419 UniConstraintBuilder {
420 extractor: self.extractor,
421 filter: self.filter,
422 impact_type: ImpactType::Reward,
423 weight: move |_: &A| weight,
424 is_hard,
425 expected_descriptor: None,
426 _phantom: PhantomData,
427 }
428 }
429
430 /* Rewards each matching entity with a dynamic weight.
431
432 Note: For dynamic weights, use `reward_hard_with` to explicitly mark as a hard constraint,
433 since the weight function cannot be evaluated at build time.
434 */
435 pub fn reward_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
436 where
437 W: Fn(&A) -> Sc + Send + Sync,
438 {
439 UniConstraintBuilder {
440 extractor: self.extractor,
441 filter: self.filter,
442 impact_type: ImpactType::Reward,
443 weight: weight_fn,
444 is_hard: false, // Can't detect at build time; use reward_hard_with for hard constraints
445 expected_descriptor: None,
446 _phantom: PhantomData,
447 }
448 }
449
450 // Rewards each matching entity with a dynamic weight, explicitly marked as a hard constraint.
451 pub fn reward_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
452 where
453 W: Fn(&A) -> Sc + Send + Sync,
454 {
455 UniConstraintBuilder {
456 extractor: self.extractor,
457 filter: self.filter,
458 impact_type: ImpactType::Reward,
459 weight: weight_fn,
460 is_hard: true,
461 expected_descriptor: None,
462 _phantom: PhantomData,
463 }
464 }
465
466 // Penalizes each matching entity with one hard score unit.
467 pub fn penalize_hard(
468 self,
469 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
470 where
471 Sc: Copy,
472 {
473 self.penalize(Sc::one_hard())
474 }
475
476 // Penalizes each matching entity with one soft score unit.
477 pub fn penalize_soft(
478 self,
479 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
480 where
481 Sc: Copy,
482 {
483 self.penalize(Sc::one_soft())
484 }
485
486 // Rewards each matching entity with one hard score unit.
487 pub fn reward_hard(
488 self,
489 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
490 where
491 Sc: Copy,
492 {
493 self.reward(Sc::one_hard())
494 }
495
496 // Rewards each matching entity with one soft score unit.
497 pub fn reward_soft(
498 self,
499 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
500 where
501 Sc: Copy,
502 {
503 self.reward(Sc::one_soft())
504 }
505}
506
507impl<S, A, E, F, Sc: Score> UniConstraintStream<S, A, E, F, Sc> {
508 #[doc(hidden)]
509 pub fn extractor(&self) -> &E {
510 &self.extractor
511 }
512
513 #[doc(hidden)]
514 pub fn into_parts(self) -> (E, F) {
515 (self.extractor, self.filter)
516 }
517
518 #[doc(hidden)]
519 pub fn from_parts(extractor: E, filter: F) -> Self {
520 Self {
521 extractor,
522 filter,
523 _phantom: PhantomData,
524 }
525 }
526}
527
528impl<S, A, E, F, Sc: Score> std::fmt::Debug for UniConstraintStream<S, A, E, F, Sc> {
529 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
530 f.debug_struct("UniConstraintStream").finish()
531 }
532}
533
534// Zero-erasure builder for finalizing a uni-constraint.
535pub struct UniConstraintBuilder<S, A, E, F, W, Sc>
536where
537 Sc: Score,
538{
539 extractor: E,
540 filter: F,
541 impact_type: ImpactType,
542 weight: W,
543 is_hard: bool,
544 expected_descriptor: Option<usize>,
545 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
546}
547
548impl<S, A, E, F, W, Sc> UniConstraintBuilder<S, A, E, F, W, Sc>
549where
550 S: Send + Sync + 'static,
551 A: Clone + Send + Sync + 'static,
552 E: CollectionExtract<S, Item = A>,
553 F: UniFilter<S, A>,
554 W: Fn(&A) -> Sc + Send + Sync,
555 Sc: Score + 'static,
556{
557 /* Restricts this constraint to only fire for the given descriptor index.
558
559 Required when multiple entity classes exist (e.g., FurnaceAssignment at 0,
560 ShiftAssignment at 1). Without this, on_insert/on_retract fire for all entity
561 classes using the constraint's entity_index, which indexes into the wrong slice.
562 */
563 pub fn for_descriptor(mut self, descriptor_index: usize) -> Self {
564 self.expected_descriptor = Some(descriptor_index);
565 self
566 }
567
568 // Finalizes the builder into a zero-erasure `IncrementalUniConstraint`.
569 pub fn named(
570 self,
571 name: &str,
572 ) -> IncrementalUniConstraint<S, A, E, impl Fn(&S, &A) -> bool + Send + Sync, W, Sc> {
573 let filter = self.filter;
574 let combined_filter = move |s: &S, a: &A| filter.test(s, a);
575
576 let mut constraint = IncrementalUniConstraint::new(
577 ConstraintRef::new("", name),
578 self.impact_type,
579 self.extractor,
580 combined_filter,
581 self.weight,
582 self.is_hard,
583 );
584 if let Some(d) = self.expected_descriptor {
585 constraint = constraint.with_descriptor(d);
586 }
587 constraint
588 }
589}
590
591impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for UniConstraintBuilder<S, A, E, F, W, Sc> {
592 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593 f.debug_struct("UniConstraintBuilder")
594 .field("impact_type", &self.impact_type)
595 .finish()
596 }
597}