solverforge_scoring/stream/uni_stream.rs
1// Zero-erasure uni-constraint stream for single-entity constraint patterns.
2//
3// A `UniConstraintStream` operates on a single entity type and supports
4// filtering, weighting, and constraint finalization. All type information
5// is preserved at compile time - no Arc, no dyn, fully monomorphized.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::constraint::incremental::IncrementalUniConstraint;
14
15use crate::constraint::if_exists::ExistenceMode;
16
17use super::balance_stream::BalanceConstraintStream;
18use super::bi_stream::BiConstraintStream;
19use super::collector::UniCollector;
20use super::cross_bi_stream::CrossBiConstraintStream;
21use super::filter::{AndUniFilter, FnUniFilter, TrueFilter, UniFilter, UniLeftBiFilter};
22use super::grouped_stream::GroupedConstraintStream;
23use super::if_exists_stream::IfExistsStream;
24use super::joiner::EqualJoiner;
25
26// Zero-erasure constraint stream over a single entity type.
27//
28// `UniConstraintStream` accumulates filters and can be finalized into
29// an `IncrementalUniConstraint` via `penalize()` or `reward()`.
30//
31// All type parameters are concrete - no trait objects, no Arc allocations
32// in the hot path.
33//
34// # Type Parameters
35//
36// - `S` - Solution type
37// - `A` - Entity type
38// - `E` - Extractor function type
39// - `F` - Combined filter type
40// - `Sc` - Score type
41pub struct UniConstraintStream<S, A, E, F, Sc>
42where
43 Sc: Score,
44{
45 extractor: E,
46 filter: F,
47 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
48}
49
50impl<S, A, E, Sc> UniConstraintStream<S, A, E, TrueFilter, Sc>
51where
52 S: Send + Sync + 'static,
53 A: Clone + Send + Sync + 'static,
54 E: Fn(&S) -> &[A] + Send + Sync,
55 Sc: Score + 'static,
56{
57 // Creates a new uni-constraint stream with the given extractor.
58 pub fn new(extractor: E) -> Self {
59 Self {
60 extractor,
61 filter: TrueFilter,
62 _phantom: PhantomData,
63 }
64 }
65}
66
67impl<S, A, E, F, Sc> UniConstraintStream<S, A, E, F, Sc>
68where
69 S: Send + Sync + 'static,
70 A: Clone + Send + Sync + 'static,
71 E: Fn(&S) -> &[A] + Send + Sync,
72 F: UniFilter<S, A>,
73 Sc: Score + 'static,
74{
75 // Adds a filter predicate to the stream.
76 //
77 // Multiple filters are combined with AND semantics at compile time.
78 // Each filter adds a new type layer, preserving zero-erasure.
79 //
80 // To access related entities, use shadow variables on your entity type
81 // (e.g., `#[inverse_relation_shadow_variable]`) rather than solution traversal.
82 pub fn filter<P>(
83 self,
84 predicate: P,
85 ) -> UniConstraintStream<
86 S,
87 A,
88 E,
89 AndUniFilter<F, FnUniFilter<impl Fn(&S, &A) -> bool + Send + Sync>>,
90 Sc,
91 >
92 where
93 P: Fn(&A) -> bool + Send + Sync + 'static,
94 {
95 UniConstraintStream {
96 extractor: self.extractor,
97 filter: AndUniFilter::new(
98 self.filter,
99 FnUniFilter::new(move |_s: &S, a: &A| predicate(a)),
100 ),
101 _phantom: PhantomData,
102 }
103 }
104
105 // Joins this stream with itself to create pairs (zero-erasure).
106 //
107 // Requires an `EqualJoiner` to enable key-based indexing for O(k) lookups.
108 // For self-joins, pairs are ordered (i < j) to avoid duplicates.
109 //
110 // Any filters accumulated on this stream are applied to both entities
111 // individually before the join.
112 pub fn join_self<K, KA, KB>(
113 self,
114 joiner: EqualJoiner<KA, KB, K>,
115 ) -> BiConstraintStream<
116 S,
117 A,
118 K,
119 E,
120 impl Fn(&S, &A, usize) -> K + Send + Sync,
121 UniLeftBiFilter<F, A>,
122 Sc,
123 >
124 where
125 A: Hash + PartialEq,
126 K: Eq + Hash + Clone + Send + Sync,
127 KA: Fn(&A) -> K + Send + Sync,
128 KB: Fn(&A) -> K + Send + Sync,
129 {
130 let (key_extractor, _) = joiner.into_keys();
131
132 // Wrap key_extractor to match the new KE: Fn(&S, &A, usize) -> K signature.
133 // The static stream API doesn't need solution/index, so ignore them.
134 let wrapped_ke = move |_s: &S, a: &A, _idx: usize| key_extractor(a);
135
136 // Convert uni-filter to bi-filter that applies to left entity
137 let bi_filter = UniLeftBiFilter::new(self.filter);
138
139 BiConstraintStream::new_self_join_with_filter(self.extractor, wrapped_ke, bi_filter)
140 }
141
142 // Joins this stream with another collection to create cross-entity pairs (zero-erasure).
143 //
144 // Requires an `EqualJoiner` to enable key-based indexing for O(1) lookups.
145 // Unlike `join_self` which pairs entities within the same collection,
146 // `join` creates pairs from two different collections (e.g., Shift joined
147 // with Employee).
148 //
149 // Any filters accumulated on this stream are applied to the A entity
150 // before the join.
151 pub fn join<B, EB, K, KA, KB>(
152 self,
153 extractor_b: EB,
154 joiner: EqualJoiner<KA, KB, K>,
155 ) -> CrossBiConstraintStream<S, A, B, K, E, EB, KA, KB, UniLeftBiFilter<F, B>, Sc>
156 where
157 B: Clone + Send + Sync + 'static,
158 EB: Fn(&S) -> &[B] + Send + Sync,
159 K: Eq + Hash + Clone + Send + Sync,
160 KA: Fn(&A) -> K + Send + Sync,
161 KB: Fn(&B) -> K + Send + Sync,
162 {
163 let (key_a, key_b) = joiner.into_keys();
164
165 // Convert uni-filter to bi-filter that applies to left entity only
166 let bi_filter = UniLeftBiFilter::new(self.filter);
167
168 CrossBiConstraintStream::new_with_filter(
169 self.extractor,
170 extractor_b,
171 key_a,
172 key_b,
173 bi_filter,
174 )
175 }
176
177 // Groups entities by key and aggregates with a collector.
178 //
179 // Returns a zero-erasure `GroupedConstraintStream` that can be penalized
180 // or rewarded based on the aggregated result for each group.
181 pub fn group_by<K, KF, C>(
182 self,
183 key_fn: KF,
184 collector: C,
185 ) -> GroupedConstraintStream<S, A, K, E, F, KF, C, Sc>
186 where
187 K: Clone + Eq + Hash + Send + Sync + 'static,
188 KF: Fn(&A) -> K + Send + Sync,
189 C: UniCollector<A> + Send + Sync + 'static,
190 C::Accumulator: Send + Sync,
191 C::Result: Clone + Send + Sync,
192 {
193 GroupedConstraintStream::new(self.extractor, self.filter, key_fn, collector)
194 }
195
196 // Creates a balance constraint that penalizes uneven distribution across groups.
197 //
198 // Unlike `group_by` which scores each group independently, `balance` computes
199 // a GLOBAL standard deviation across all group counts and produces a single score.
200 //
201 // The `key_fn` returns `Option<K>` to allow skipping entities (e.g., unassigned shifts).
202 // Any filters accumulated on this stream are also applied.
203 //
204 // # Example
205 //
206 // ```
207 // use solverforge_scoring::stream::ConstraintFactory;
208 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
209 // use solverforge_core::score::SimpleScore;
210 //
211 // #[derive(Clone)]
212 // struct Shift { employee_id: Option<usize> }
213 //
214 // #[derive(Clone)]
215 // struct Solution { shifts: Vec<Shift> }
216 //
217 // let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
218 // .for_each(|s: &Solution| &s.shifts)
219 // .balance(|shift: &Shift| shift.employee_id)
220 // .penalize(SimpleScore::of(1000))
221 // .as_constraint("Balance workload");
222 //
223 // let solution = Solution {
224 // shifts: vec![
225 // Shift { employee_id: Some(0) },
226 // Shift { employee_id: Some(0) },
227 // Shift { employee_id: Some(0) },
228 // Shift { employee_id: Some(1) },
229 // ],
230 // };
231 //
232 // // Employee 0: 3 shifts, Employee 1: 1 shift
233 // // std_dev = 1.0, penalty = -1000
234 // assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1000));
235 // ```
236 pub fn balance<K, KF>(self, key_fn: KF) -> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
237 where
238 K: Clone + Eq + Hash + Send + Sync + 'static,
239 KF: Fn(&A) -> Option<K> + Send + Sync,
240 {
241 BalanceConstraintStream::new(self.extractor, self.filter, key_fn)
242 }
243
244 // Filters A entities based on whether a matching B entity exists.
245 //
246 // Use this when the B collection needs filtering (e.g., only vacationing employees).
247 // The `extractor_b` returns a `Vec<B>` to allow for filtering.
248 //
249 // Any filters accumulated on this stream are applied to A entities.
250 //
251 // # Example
252 //
253 // ```
254 // use solverforge_scoring::stream::ConstraintFactory;
255 // use solverforge_scoring::stream::joiner::equal_bi;
256 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
257 // use solverforge_core::score::SimpleScore;
258 //
259 // #[derive(Clone)]
260 // struct Shift { id: usize, employee_idx: Option<usize> }
261 //
262 // #[derive(Clone)]
263 // struct Employee { id: usize, on_vacation: bool }
264 //
265 // #[derive(Clone)]
266 // struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
267 //
268 // // Penalize shifts assigned to employees who are on vacation
269 // let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
270 // .for_each(|s: &Schedule| s.shifts.as_slice())
271 // .filter(|shift: &Shift| shift.employee_idx.is_some())
272 // .if_exists_filtered(
273 // |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect(),
274 // equal_bi(
275 // |shift: &Shift| shift.employee_idx,
276 // |emp: &Employee| Some(emp.id),
277 // ),
278 // )
279 // .penalize(SimpleScore::of(1))
280 // .as_constraint("Vacation conflict");
281 //
282 // let schedule = Schedule {
283 // shifts: vec![
284 // Shift { id: 0, employee_idx: Some(0) }, // assigned to vacationing emp
285 // Shift { id: 1, employee_idx: Some(1) }, // assigned to working emp
286 // Shift { id: 2, employee_idx: None }, // unassigned (filtered out)
287 // ],
288 // employees: vec![
289 // Employee { id: 0, on_vacation: true },
290 // Employee { id: 1, on_vacation: false },
291 // ],
292 // };
293 //
294 // // Only shift 0 matches (assigned to employee 0 who is on vacation)
295 // assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
296 // ```
297 pub fn if_exists_filtered<B, EB, K, KA, KB>(
298 self,
299 extractor_b: EB,
300 joiner: EqualJoiner<KA, KB, K>,
301 ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
302 where
303 B: Clone + Send + Sync + 'static,
304 EB: Fn(&S) -> Vec<B> + Send + Sync,
305 K: Eq + Hash + Clone + Send + Sync,
306 KA: Fn(&A) -> K + Send + Sync,
307 KB: Fn(&B) -> K + Send + Sync,
308 {
309 let (key_a, key_b) = joiner.into_keys();
310 IfExistsStream::new(
311 ExistenceMode::Exists,
312 self.extractor,
313 extractor_b,
314 key_a,
315 key_b,
316 self.filter,
317 )
318 }
319
320 // Filters A entities based on whether NO matching B entity exists.
321 //
322 // Use this when the B collection needs filtering.
323 // The `extractor_b` returns a `Vec<B>` to allow for filtering.
324 //
325 // Any filters accumulated on this stream are applied to A entities.
326 //
327 // # Example
328 //
329 // ```
330 // use solverforge_scoring::stream::ConstraintFactory;
331 // use solverforge_scoring::stream::joiner::equal_bi;
332 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
333 // use solverforge_core::score::SimpleScore;
334 //
335 // #[derive(Clone)]
336 // struct Task { id: usize, assignee: Option<usize> }
337 //
338 // #[derive(Clone)]
339 // struct Worker { id: usize, available: bool }
340 //
341 // #[derive(Clone)]
342 // struct Schedule { tasks: Vec<Task>, workers: Vec<Worker> }
343 //
344 // // Penalize tasks assigned to workers who are not available
345 // let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
346 // .for_each(|s: &Schedule| s.tasks.as_slice())
347 // .filter(|task: &Task| task.assignee.is_some())
348 // .if_not_exists_filtered(
349 // |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
350 // equal_bi(
351 // |task: &Task| task.assignee,
352 // |worker: &Worker| Some(worker.id),
353 // ),
354 // )
355 // .penalize(SimpleScore::of(1))
356 // .as_constraint("Unavailable worker");
357 //
358 // let schedule = Schedule {
359 // tasks: vec![
360 // Task { id: 0, assignee: Some(0) }, // worker 0 is unavailable
361 // Task { id: 1, assignee: Some(1) }, // worker 1 is available
362 // Task { id: 2, assignee: None }, // unassigned (filtered out)
363 // ],
364 // workers: vec![
365 // Worker { id: 0, available: false },
366 // Worker { id: 1, available: true },
367 // ],
368 // };
369 //
370 // // Task 0's worker (id=0) is NOT in the available workers list
371 // assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
372 // ```
373 pub fn if_not_exists_filtered<B, EB, K, KA, KB>(
374 self,
375 extractor_b: EB,
376 joiner: EqualJoiner<KA, KB, K>,
377 ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
378 where
379 B: Clone + Send + Sync + 'static,
380 EB: Fn(&S) -> Vec<B> + Send + Sync,
381 K: Eq + Hash + Clone + Send + Sync,
382 KA: Fn(&A) -> K + Send + Sync,
383 KB: Fn(&B) -> K + Send + Sync,
384 {
385 let (key_a, key_b) = joiner.into_keys();
386 IfExistsStream::new(
387 ExistenceMode::NotExists,
388 self.extractor,
389 extractor_b,
390 key_a,
391 key_b,
392 self.filter,
393 )
394 }
395
396 // Penalizes each matching entity with a fixed weight.
397 pub fn penalize(
398 self,
399 weight: Sc,
400 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
401 where
402 Sc: Copy,
403 {
404 // Detect if this is a hard constraint by checking if hard level is non-zero
405 let is_hard = weight
406 .to_level_numbers()
407 .first()
408 .map(|&h| h != 0)
409 .unwrap_or(false);
410 UniConstraintBuilder {
411 extractor: self.extractor,
412 filter: self.filter,
413 impact_type: ImpactType::Penalty,
414 weight: move |_: &A| weight,
415 is_hard,
416 expected_descriptor: None,
417 _phantom: PhantomData,
418 }
419 }
420
421 // Penalizes each matching entity with a dynamic weight.
422 //
423 // Note: For dynamic weights, use `penalize_hard_with` to explicitly mark as a hard constraint,
424 // since the weight function cannot be evaluated at build time.
425 pub fn penalize_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
426 where
427 W: Fn(&A) -> Sc + Send + Sync,
428 {
429 UniConstraintBuilder {
430 extractor: self.extractor,
431 filter: self.filter,
432 impact_type: ImpactType::Penalty,
433 weight: weight_fn,
434 is_hard: false, // Can't detect at build time; use penalize_hard_with for hard constraints
435 expected_descriptor: None,
436 _phantom: PhantomData,
437 }
438 }
439
440 // Penalizes each matching entity with a dynamic weight, explicitly marked as a hard constraint.
441 pub fn penalize_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
442 where
443 W: Fn(&A) -> Sc + Send + Sync,
444 {
445 UniConstraintBuilder {
446 extractor: self.extractor,
447 filter: self.filter,
448 impact_type: ImpactType::Penalty,
449 weight: weight_fn,
450 is_hard: true,
451 expected_descriptor: None,
452 _phantom: PhantomData,
453 }
454 }
455
456 // Rewards each matching entity with a fixed weight.
457 pub fn reward(
458 self,
459 weight: Sc,
460 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
461 where
462 Sc: Copy,
463 {
464 // Detect if this is a hard constraint by checking if hard level is non-zero
465 let is_hard = weight
466 .to_level_numbers()
467 .first()
468 .map(|&h| h != 0)
469 .unwrap_or(false);
470 UniConstraintBuilder {
471 extractor: self.extractor,
472 filter: self.filter,
473 impact_type: ImpactType::Reward,
474 weight: move |_: &A| weight,
475 is_hard,
476 expected_descriptor: None,
477 _phantom: PhantomData,
478 }
479 }
480
481 // Rewards each matching entity with a dynamic weight.
482 //
483 // Note: For dynamic weights, use `reward_hard_with` to explicitly mark as a hard constraint,
484 // since the weight function cannot be evaluated at build time.
485 pub fn reward_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
486 where
487 W: Fn(&A) -> Sc + Send + Sync,
488 {
489 UniConstraintBuilder {
490 extractor: self.extractor,
491 filter: self.filter,
492 impact_type: ImpactType::Reward,
493 weight: weight_fn,
494 is_hard: false, // Can't detect at build time; use reward_hard_with for hard constraints
495 expected_descriptor: None,
496 _phantom: PhantomData,
497 }
498 }
499
500 // Rewards each matching entity with a dynamic weight, explicitly marked as a hard constraint.
501 pub fn reward_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
502 where
503 W: Fn(&A) -> Sc + Send + Sync,
504 {
505 UniConstraintBuilder {
506 extractor: self.extractor,
507 filter: self.filter,
508 impact_type: ImpactType::Reward,
509 weight: weight_fn,
510 is_hard: true,
511 expected_descriptor: None,
512 _phantom: PhantomData,
513 }
514 }
515}
516
517impl<S, A, E, F, Sc: Score> std::fmt::Debug for UniConstraintStream<S, A, E, F, Sc> {
518 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519 f.debug_struct("UniConstraintStream").finish()
520 }
521}
522
523// Zero-erasure builder for finalizing a uni-constraint.
524pub struct UniConstraintBuilder<S, A, E, F, W, Sc>
525where
526 Sc: Score,
527{
528 extractor: E,
529 filter: F,
530 impact_type: ImpactType,
531 weight: W,
532 is_hard: bool,
533 expected_descriptor: Option<usize>,
534 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
535}
536
537impl<S, A, E, F, W, Sc> UniConstraintBuilder<S, A, E, F, W, Sc>
538where
539 S: Send + Sync + 'static,
540 A: Clone + Send + Sync + 'static,
541 E: Fn(&S) -> &[A] + Send + Sync,
542 F: UniFilter<S, A>,
543 W: Fn(&A) -> Sc + Send + Sync,
544 Sc: Score + 'static,
545{
546 // Restricts this constraint to only fire for the given descriptor index.
547 //
548 // Required when multiple entity classes exist (e.g., FurnaceAssignment at 0,
549 // ShiftAssignment at 1). Without this, on_insert/on_retract fire for all entity
550 // classes using the constraint's entity_index, which indexes into the wrong slice.
551 pub fn for_descriptor(mut self, descriptor_index: usize) -> Self {
552 self.expected_descriptor = Some(descriptor_index);
553 self
554 }
555
556 // Finalizes the builder into a zero-erasure `IncrementalUniConstraint`.
557 pub fn as_constraint(
558 self,
559 name: &str,
560 ) -> IncrementalUniConstraint<S, A, E, impl Fn(&S, &A) -> bool + Send + Sync, W, Sc> {
561 let filter = self.filter;
562 let combined_filter = move |s: &S, a: &A| filter.test(s, a);
563
564 let mut constraint = IncrementalUniConstraint::new(
565 ConstraintRef::new("", name),
566 self.impact_type,
567 self.extractor,
568 combined_filter,
569 self.weight,
570 self.is_hard,
571 );
572 if let Some(d) = self.expected_descriptor {
573 constraint = constraint.with_descriptor(d);
574 }
575 constraint
576 }
577}
578
579impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for UniConstraintBuilder<S, A, E, F, W, Sc> {
580 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581 f.debug_struct("UniConstraintBuilder")
582 .field("impact_type", &self.impact_type)
583 .finish()
584 }
585}