solverforge_scoring/stream/uni_stream.rs
1//! Zero-erasure uni-constraint stream for single-entity constraint patterns.
2//!
3//! A `UniConstraintStream` operates on a single entity type and supports
4//! filtering, weighting, and constraint finalization. All type information
5//! is preserved at compile time - no Arc, no dyn, fully monomorphized.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::constraint::incremental::IncrementalUniConstraint;
14
15use crate::constraint::if_exists::ExistenceMode;
16
17use super::balance_stream::BalanceConstraintStream;
18use super::bi_stream::BiConstraintStream;
19use super::collector::UniCollector;
20use super::cross_bi_stream::CrossBiConstraintStream;
21use super::filter::{AndUniFilter, FnUniFilter, TrueFilter, UniFilter, UniLeftBiFilter};
22use super::grouped_stream::GroupedConstraintStream;
23use super::if_exists_stream::IfExistsStream;
24use super::joiner::EqualJoiner;
25
26/// Zero-erasure constraint stream over a single entity type.
27///
28/// `UniConstraintStream` accumulates filters and can be finalized into
29/// an `IncrementalUniConstraint` via `penalize()` or `reward()`.
30///
31/// All type parameters are concrete - no trait objects, no Arc allocations
32/// in the hot path.
33///
34/// # Type Parameters
35///
36/// - `S` - Solution type
37/// - `A` - Entity type
38/// - `E` - Extractor function type
39/// - `F` - Combined filter type
40/// - `Sc` - Score type
41pub struct UniConstraintStream<S, A, E, F, Sc>
42where
43 Sc: Score,
44{
45 extractor: E,
46 filter: F,
47 _phantom: PhantomData<(S, A, Sc)>,
48}
49
50impl<S, A, E, Sc> UniConstraintStream<S, A, E, TrueFilter, Sc>
51where
52 S: Send + Sync + 'static,
53 A: Clone + Send + Sync + 'static,
54 E: Fn(&S) -> &[A] + Send + Sync,
55 Sc: Score + 'static,
56{
57 /// Creates a new uni-constraint stream with the given extractor.
58 pub fn new(extractor: E) -> Self {
59 Self {
60 extractor,
61 filter: TrueFilter,
62 _phantom: PhantomData,
63 }
64 }
65}
66
67impl<S, A, E, F, Sc> UniConstraintStream<S, A, E, F, Sc>
68where
69 S: Send + Sync + 'static,
70 A: Clone + Send + Sync + 'static,
71 E: Fn(&S) -> &[A] + Send + Sync,
72 F: UniFilter<S, A>,
73 Sc: Score + 'static,
74{
75 /// Adds a filter predicate to the stream.
76 ///
77 /// Multiple filters are combined with AND semantics at compile time.
78 /// Each filter adds a new type layer, preserving zero-erasure.
79 ///
80 /// For filters that need access to solution state (shadow variables),
81 /// use [`filter_with_solution`](Self::filter_with_solution) instead.
82 pub fn filter<P>(
83 self,
84 predicate: P,
85 ) -> UniConstraintStream<
86 S,
87 A,
88 E,
89 AndUniFilter<F, FnUniFilter<impl Fn(&S, &A) -> bool + Send + Sync>>,
90 Sc,
91 >
92 where
93 P: Fn(&A) -> bool + Send + Sync + 'static,
94 {
95 UniConstraintStream {
96 extractor: self.extractor,
97 filter: AndUniFilter::new(
98 self.filter,
99 FnUniFilter::new(move |_s: &S, a: &A| predicate(a)),
100 ),
101 _phantom: PhantomData,
102 }
103 }
104
105 /// Adds a solution-aware filter predicate to the stream.
106 ///
107 /// Unlike [`filter`](Self::filter), this method receives the solution
108 /// reference alongside the entity, enabling access to shadow variables
109 /// and computed solution state.
110 ///
111 /// # Example
112 ///
113 /// ```
114 /// use solverforge_scoring::stream::ConstraintFactory;
115 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
116 /// use solverforge_core::score::SimpleScore;
117 ///
118 /// #[derive(Clone, Debug)]
119 /// struct Shift { id: usize, employee_idx: Option<usize> }
120 ///
121 /// #[derive(Clone, Debug)]
122 /// struct Schedule {
123 /// shifts: Vec<Shift>,
124 /// hours_by_employee: Vec<i32>, // shadow variable
125 /// }
126 ///
127 /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
128 /// .for_each(|s: &Schedule| &s.shifts)
129 /// .filter(|shift: &Shift| shift.employee_idx.is_some())
130 /// .filter_with_solution(|schedule: &Schedule, shift: &Shift| {
131 /// // Access shadow variable from solution
132 /// let emp_idx = shift.employee_idx.unwrap();
133 /// schedule.hours_by_employee[emp_idx] > 40
134 /// })
135 /// .penalize(SimpleScore::of(1))
136 /// .as_constraint("Overtime");
137 ///
138 /// let schedule = Schedule {
139 /// shifts: vec![
140 /// Shift { id: 0, employee_idx: Some(0) },
141 /// Shift { id: 1, employee_idx: Some(1) },
142 /// ],
143 /// hours_by_employee: vec![45, 35], // emp 0 has overtime
144 /// };
145 ///
146 /// // Only shift 0 matches (employee 0 has > 40 hours)
147 /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
148 /// ```
149 pub fn filter_with_solution<P>(
150 self,
151 predicate: P,
152 ) -> UniConstraintStream<S, A, E, AndUniFilter<F, FnUniFilter<P>>, Sc>
153 where
154 P: Fn(&S, &A) -> bool + Send + Sync,
155 {
156 UniConstraintStream {
157 extractor: self.extractor,
158 filter: AndUniFilter::new(self.filter, FnUniFilter::new(predicate)),
159 _phantom: PhantomData,
160 }
161 }
162
163 /// Joins this stream with itself to create pairs (zero-erasure).
164 ///
165 /// Requires an `EqualJoiner` to enable key-based indexing for O(k) lookups.
166 /// For self-joins, pairs are ordered (i < j) to avoid duplicates.
167 ///
168 /// Any filters accumulated on this stream are applied to both entities
169 /// individually before the join.
170 pub fn join_self<K, KA, KB>(
171 self,
172 joiner: EqualJoiner<KA, KB, K>,
173 ) -> BiConstraintStream<S, A, K, E, KA, UniLeftBiFilter<F, A>, Sc>
174 where
175 A: Hash + PartialEq,
176 K: Eq + Hash + Clone + Send + Sync,
177 KA: Fn(&A) -> K + Send + Sync,
178 KB: Fn(&A) -> K + Send + Sync,
179 {
180 let (key_extractor, _) = joiner.into_keys();
181
182 // Convert uni-filter to bi-filter that applies to left entity
183 let bi_filter = UniLeftBiFilter::new(self.filter);
184
185 BiConstraintStream::new_self_join_with_filter(self.extractor, key_extractor, bi_filter)
186 }
187
188 /// Joins this stream with another collection to create cross-entity pairs (zero-erasure).
189 ///
190 /// Requires an `EqualJoiner` to enable key-based indexing for O(1) lookups.
191 /// Unlike `join_self` which pairs entities within the same collection,
192 /// `join` creates pairs from two different collections (e.g., Shift joined
193 /// with Employee).
194 ///
195 /// Any filters accumulated on this stream are applied to the A entity
196 /// before the join.
197 pub fn join<B, EB, K, KA, KB>(
198 self,
199 extractor_b: EB,
200 joiner: EqualJoiner<KA, KB, K>,
201 ) -> CrossBiConstraintStream<S, A, B, K, E, EB, KA, KB, UniLeftBiFilter<F, B>, Sc>
202 where
203 B: Clone + Send + Sync + 'static,
204 EB: Fn(&S) -> &[B] + Send + Sync,
205 K: Eq + Hash + Clone + Send + Sync,
206 KA: Fn(&A) -> K + Send + Sync,
207 KB: Fn(&B) -> K + Send + Sync,
208 {
209 let (key_a, key_b) = joiner.into_keys();
210
211 // Convert uni-filter to bi-filter that applies to left entity only
212 let bi_filter = UniLeftBiFilter::new(self.filter);
213
214 CrossBiConstraintStream::new_with_filter(
215 self.extractor,
216 extractor_b,
217 key_a,
218 key_b,
219 bi_filter,
220 )
221 }
222
223 /// Groups entities by key and aggregates with a collector.
224 ///
225 /// Returns a zero-erasure `GroupedConstraintStream` that can be penalized
226 /// or rewarded based on the aggregated result for each group.
227 pub fn group_by<K, KF, C>(
228 self,
229 key_fn: KF,
230 collector: C,
231 ) -> GroupedConstraintStream<S, A, K, E, KF, C, Sc>
232 where
233 K: Clone + Eq + Hash + Send + Sync + 'static,
234 KF: Fn(&A) -> K + Send + Sync,
235 C: UniCollector<A> + Send + Sync + 'static,
236 C::Accumulator: Send + Sync,
237 C::Result: Clone + Send + Sync,
238 {
239 GroupedConstraintStream::new(self.extractor, key_fn, collector)
240 }
241
242 /// Creates a balance constraint that penalizes uneven distribution across groups.
243 ///
244 /// Unlike `group_by` which scores each group independently, `balance` computes
245 /// a GLOBAL standard deviation across all group counts and produces a single score.
246 ///
247 /// The `key_fn` returns `Option<K>` to allow skipping entities (e.g., unassigned shifts).
248 /// Any filters accumulated on this stream are also applied.
249 ///
250 /// # Example
251 ///
252 /// ```
253 /// use solverforge_scoring::stream::ConstraintFactory;
254 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
255 /// use solverforge_core::score::SimpleScore;
256 ///
257 /// #[derive(Clone)]
258 /// struct Shift { employee_id: Option<usize> }
259 ///
260 /// #[derive(Clone)]
261 /// struct Solution { shifts: Vec<Shift> }
262 ///
263 /// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
264 /// .for_each(|s: &Solution| &s.shifts)
265 /// .balance(|shift: &Shift| shift.employee_id)
266 /// .penalize(SimpleScore::of(1000))
267 /// .as_constraint("Balance workload");
268 ///
269 /// let solution = Solution {
270 /// shifts: vec![
271 /// Shift { employee_id: Some(0) },
272 /// Shift { employee_id: Some(0) },
273 /// Shift { employee_id: Some(0) },
274 /// Shift { employee_id: Some(1) },
275 /// ],
276 /// };
277 ///
278 /// // Employee 0: 3 shifts, Employee 1: 1 shift
279 /// // std_dev = 1.0, penalty = -1000
280 /// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1000));
281 /// ```
282 pub fn balance<K, KF>(self, key_fn: KF) -> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
283 where
284 K: Clone + Eq + Hash + Send + Sync + 'static,
285 KF: Fn(&A) -> Option<K> + Send + Sync,
286 {
287 BalanceConstraintStream::new(self.extractor, self.filter, key_fn)
288 }
289
290 /// Filters A entities based on whether a matching B entity exists.
291 ///
292 /// Use this when the B collection needs filtering (e.g., only vacationing employees).
293 /// The `extractor_b` returns a `Vec<B>` to allow for filtering.
294 ///
295 /// Any filters accumulated on this stream are applied to A entities.
296 ///
297 /// # Example
298 ///
299 /// ```
300 /// use solverforge_scoring::stream::ConstraintFactory;
301 /// use solverforge_scoring::stream::joiner::equal_bi;
302 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
303 /// use solverforge_core::score::SimpleScore;
304 ///
305 /// #[derive(Clone)]
306 /// struct Shift { id: usize, employee_idx: Option<usize> }
307 ///
308 /// #[derive(Clone)]
309 /// struct Employee { id: usize, on_vacation: bool }
310 ///
311 /// #[derive(Clone)]
312 /// struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
313 ///
314 /// // Penalize shifts assigned to employees who are on vacation
315 /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
316 /// .for_each(|s: &Schedule| s.shifts.as_slice())
317 /// .filter(|shift: &Shift| shift.employee_idx.is_some())
318 /// .if_exists_filtered(
319 /// |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect(),
320 /// equal_bi(
321 /// |shift: &Shift| shift.employee_idx,
322 /// |emp: &Employee| Some(emp.id),
323 /// ),
324 /// )
325 /// .penalize(SimpleScore::of(1))
326 /// .as_constraint("Vacation conflict");
327 ///
328 /// let schedule = Schedule {
329 /// shifts: vec![
330 /// Shift { id: 0, employee_idx: Some(0) }, // assigned to vacationing emp
331 /// Shift { id: 1, employee_idx: Some(1) }, // assigned to working emp
332 /// Shift { id: 2, employee_idx: None }, // unassigned (filtered out)
333 /// ],
334 /// employees: vec![
335 /// Employee { id: 0, on_vacation: true },
336 /// Employee { id: 1, on_vacation: false },
337 /// ],
338 /// };
339 ///
340 /// // Only shift 0 matches (assigned to employee 0 who is on vacation)
341 /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
342 /// ```
343 pub fn if_exists_filtered<B, EB, K, KA, KB>(
344 self,
345 extractor_b: EB,
346 joiner: EqualJoiner<KA, KB, K>,
347 ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
348 where
349 B: Clone + Send + Sync + 'static,
350 EB: Fn(&S) -> Vec<B> + Send + Sync,
351 K: Eq + Hash + Clone + Send + Sync,
352 KA: Fn(&A) -> K + Send + Sync,
353 KB: Fn(&B) -> K + Send + Sync,
354 {
355 let (key_a, key_b) = joiner.into_keys();
356 IfExistsStream::new(
357 ExistenceMode::Exists,
358 self.extractor,
359 extractor_b,
360 key_a,
361 key_b,
362 self.filter,
363 )
364 }
365
366 /// Filters A entities based on whether NO matching B entity exists.
367 ///
368 /// Use this when the B collection needs filtering.
369 /// The `extractor_b` returns a `Vec<B>` to allow for filtering.
370 ///
371 /// Any filters accumulated on this stream are applied to A entities.
372 ///
373 /// # Example
374 ///
375 /// ```
376 /// use solverforge_scoring::stream::ConstraintFactory;
377 /// use solverforge_scoring::stream::joiner::equal_bi;
378 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
379 /// use solverforge_core::score::SimpleScore;
380 ///
381 /// #[derive(Clone)]
382 /// struct Task { id: usize, assignee: Option<usize> }
383 ///
384 /// #[derive(Clone)]
385 /// struct Worker { id: usize, available: bool }
386 ///
387 /// #[derive(Clone)]
388 /// struct Schedule { tasks: Vec<Task>, workers: Vec<Worker> }
389 ///
390 /// // Penalize tasks assigned to workers who are not available
391 /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
392 /// .for_each(|s: &Schedule| s.tasks.as_slice())
393 /// .filter(|task: &Task| task.assignee.is_some())
394 /// .if_not_exists_filtered(
395 /// |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
396 /// equal_bi(
397 /// |task: &Task| task.assignee,
398 /// |worker: &Worker| Some(worker.id),
399 /// ),
400 /// )
401 /// .penalize(SimpleScore::of(1))
402 /// .as_constraint("Unavailable worker");
403 ///
404 /// let schedule = Schedule {
405 /// tasks: vec![
406 /// Task { id: 0, assignee: Some(0) }, // worker 0 is unavailable
407 /// Task { id: 1, assignee: Some(1) }, // worker 1 is available
408 /// Task { id: 2, assignee: None }, // unassigned (filtered out)
409 /// ],
410 /// workers: vec![
411 /// Worker { id: 0, available: false },
412 /// Worker { id: 1, available: true },
413 /// ],
414 /// };
415 ///
416 /// // Task 0's worker (id=0) is NOT in the available workers list
417 /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
418 /// ```
419 pub fn if_not_exists_filtered<B, EB, K, KA, KB>(
420 self,
421 extractor_b: EB,
422 joiner: EqualJoiner<KA, KB, K>,
423 ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
424 where
425 B: Clone + Send + Sync + 'static,
426 EB: Fn(&S) -> Vec<B> + Send + Sync,
427 K: Eq + Hash + Clone + Send + Sync,
428 KA: Fn(&A) -> K + Send + Sync,
429 KB: Fn(&B) -> K + Send + Sync,
430 {
431 let (key_a, key_b) = joiner.into_keys();
432 IfExistsStream::new(
433 ExistenceMode::NotExists,
434 self.extractor,
435 extractor_b,
436 key_a,
437 key_b,
438 self.filter,
439 )
440 }
441
442 /// Penalizes each matching entity with a fixed weight.
443 pub fn penalize(
444 self,
445 weight: Sc,
446 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
447 where
448 Sc: Copy,
449 {
450 // Detect if this is a hard constraint by checking if hard level is non-zero
451 let is_hard = weight
452 .to_level_numbers()
453 .first()
454 .map(|&h| h != 0)
455 .unwrap_or(false);
456 UniConstraintBuilder {
457 extractor: self.extractor,
458 filter: self.filter,
459 impact_type: ImpactType::Penalty,
460 weight: move |_: &A| weight,
461 is_hard,
462 _phantom: PhantomData,
463 }
464 }
465
466 /// Penalizes each matching entity with a dynamic weight.
467 ///
468 /// Note: For dynamic weights, use `penalize_hard_with` to explicitly mark as a hard constraint,
469 /// since the weight function cannot be evaluated at build time.
470 pub fn penalize_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
471 where
472 W: Fn(&A) -> Sc + Send + Sync,
473 {
474 UniConstraintBuilder {
475 extractor: self.extractor,
476 filter: self.filter,
477 impact_type: ImpactType::Penalty,
478 weight: weight_fn,
479 is_hard: false, // Can't detect at build time; use penalize_hard_with for hard constraints
480 _phantom: PhantomData,
481 }
482 }
483
484 /// Penalizes each matching entity with a dynamic weight, explicitly marked as a hard constraint.
485 pub fn penalize_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
486 where
487 W: Fn(&A) -> Sc + Send + Sync,
488 {
489 UniConstraintBuilder {
490 extractor: self.extractor,
491 filter: self.filter,
492 impact_type: ImpactType::Penalty,
493 weight: weight_fn,
494 is_hard: true,
495 _phantom: PhantomData,
496 }
497 }
498
499 /// Rewards each matching entity with a fixed weight.
500 pub fn reward(
501 self,
502 weight: Sc,
503 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
504 where
505 Sc: Copy,
506 {
507 // Detect if this is a hard constraint by checking if hard level is non-zero
508 let is_hard = weight
509 .to_level_numbers()
510 .first()
511 .map(|&h| h != 0)
512 .unwrap_or(false);
513 UniConstraintBuilder {
514 extractor: self.extractor,
515 filter: self.filter,
516 impact_type: ImpactType::Reward,
517 weight: move |_: &A| weight,
518 is_hard,
519 _phantom: PhantomData,
520 }
521 }
522
523 /// Rewards each matching entity with a dynamic weight.
524 ///
525 /// Note: For dynamic weights, use `reward_hard_with` to explicitly mark as a hard constraint,
526 /// since the weight function cannot be evaluated at build time.
527 pub fn reward_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
528 where
529 W: Fn(&A) -> Sc + Send + Sync,
530 {
531 UniConstraintBuilder {
532 extractor: self.extractor,
533 filter: self.filter,
534 impact_type: ImpactType::Reward,
535 weight: weight_fn,
536 is_hard: false, // Can't detect at build time; use reward_hard_with for hard constraints
537 _phantom: PhantomData,
538 }
539 }
540
541 /// Rewards each matching entity with a dynamic weight, explicitly marked as a hard constraint.
542 pub fn reward_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
543 where
544 W: Fn(&A) -> Sc + Send + Sync,
545 {
546 UniConstraintBuilder {
547 extractor: self.extractor,
548 filter: self.filter,
549 impact_type: ImpactType::Reward,
550 weight: weight_fn,
551 is_hard: true,
552 _phantom: PhantomData,
553 }
554 }
555}
556
557impl<S, A, E, F, Sc: Score> std::fmt::Debug for UniConstraintStream<S, A, E, F, Sc> {
558 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
559 f.debug_struct("UniConstraintStream").finish()
560 }
561}
562
563/// Zero-erasure builder for finalizing a uni-constraint.
564pub struct UniConstraintBuilder<S, A, E, F, W, Sc>
565where
566 Sc: Score,
567{
568 extractor: E,
569 filter: F,
570 impact_type: ImpactType,
571 weight: W,
572 is_hard: bool,
573 _phantom: PhantomData<(S, A, Sc)>,
574}
575
576impl<S, A, E, F, W, Sc> UniConstraintBuilder<S, A, E, F, W, Sc>
577where
578 S: Send + Sync + 'static,
579 A: Clone + Send + Sync + 'static,
580 E: Fn(&S) -> &[A] + Send + Sync,
581 F: UniFilter<S, A>,
582 W: Fn(&A) -> Sc + Send + Sync,
583 Sc: Score + 'static,
584{
585 /// Finalizes the builder into a zero-erasure `IncrementalUniConstraint`.
586 pub fn as_constraint(
587 self,
588 name: &str,
589 ) -> IncrementalUniConstraint<S, A, E, impl Fn(&S, &A) -> bool + Send + Sync, W, Sc> {
590 let filter = self.filter;
591 let combined_filter = move |s: &S, a: &A| filter.test(s, a);
592
593 IncrementalUniConstraint::new(
594 ConstraintRef::new("", name),
595 self.impact_type,
596 self.extractor,
597 combined_filter,
598 self.weight,
599 self.is_hard,
600 )
601 }
602}
603
604impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for UniConstraintBuilder<S, A, E, F, W, Sc> {
605 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
606 f.debug_struct("UniConstraintBuilder")
607 .field("impact_type", &self.impact_type)
608 .finish()
609 }
610}