solverforge_scoring/stream/uni_stream.rs
1//! Zero-erasure uni-constraint stream for single-entity constraint patterns.
2//!
3//! A `UniConstraintStream` operates on a single entity type and supports
4//! filtering, weighting, and constraint finalization. All type information
5//! is preserved at compile time - no Arc, no dyn, fully monomorphized.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::constraint::incremental::IncrementalUniConstraint;
14
15use crate::constraint::if_exists::ExistenceMode;
16
17use super::balance_stream::BalanceConstraintStream;
18use super::bi_stream::BiConstraintStream;
19use super::collector::UniCollector;
20use super::cross_bi_stream::CrossBiConstraintStream;
21use super::filter::{AndUniFilter, FnUniFilter, TrueFilter, UniFilter, UniLeftBiFilter};
22use super::grouped_stream::GroupedConstraintStream;
23use super::if_exists_stream::IfExistsStream;
24use super::joiner::EqualJoiner;
25
26/// Zero-erasure constraint stream over a single entity type.
27///
28/// `UniConstraintStream` accumulates filters and can be finalized into
29/// an `IncrementalUniConstraint` via `penalize()` or `reward()`.
30///
31/// All type parameters are concrete - no trait objects, no Arc allocations
32/// in the hot path.
33///
34/// # Type Parameters
35///
36/// - `S` - Solution type
37/// - `A` - Entity type
38/// - `E` - Extractor function type
39/// - `F` - Combined filter type
40/// - `Sc` - Score type
41pub struct UniConstraintStream<S, A, E, F, Sc>
42where
43 Sc: Score,
44{
45 extractor: E,
46 filter: F,
47 _phantom: PhantomData<(S, A, Sc)>,
48}
49
50impl<S, A, E, Sc> UniConstraintStream<S, A, E, TrueFilter, Sc>
51where
52 S: Send + Sync + 'static,
53 A: Clone + Send + Sync + 'static,
54 E: Fn(&S) -> &[A] + Send + Sync,
55 Sc: Score + 'static,
56{
57 /// Creates a new uni-constraint stream with the given extractor.
58 pub fn new(extractor: E) -> Self {
59 Self {
60 extractor,
61 filter: TrueFilter,
62 _phantom: PhantomData,
63 }
64 }
65}
66
67impl<S, A, E, F, Sc> UniConstraintStream<S, A, E, F, Sc>
68where
69 S: Send + Sync + 'static,
70 A: Clone + Send + Sync + 'static,
71 E: Fn(&S) -> &[A] + Send + Sync,
72 F: UniFilter<S, A>,
73 Sc: Score + 'static,
74{
75 /// Adds a filter predicate to the stream.
76 ///
77 /// Multiple filters are combined with AND semantics at compile time.
78 /// Each filter adds a new type layer, preserving zero-erasure.
79 ///
80 /// To access related entities, use shadow variables on your entity type
81 /// (e.g., `#[inverse_relation_shadow_variable]`) rather than solution traversal.
82 pub fn filter<P>(
83 self,
84 predicate: P,
85 ) -> UniConstraintStream<
86 S,
87 A,
88 E,
89 AndUniFilter<F, FnUniFilter<impl Fn(&S, &A) -> bool + Send + Sync>>,
90 Sc,
91 >
92 where
93 P: Fn(&A) -> bool + Send + Sync + 'static,
94 {
95 UniConstraintStream {
96 extractor: self.extractor,
97 filter: AndUniFilter::new(
98 self.filter,
99 FnUniFilter::new(move |_s: &S, a: &A| predicate(a)),
100 ),
101 _phantom: PhantomData,
102 }
103 }
104
105 /// Joins this stream with itself to create pairs (zero-erasure).
106 ///
107 /// Requires an `EqualJoiner` to enable key-based indexing for O(k) lookups.
108 /// For self-joins, pairs are ordered (i < j) to avoid duplicates.
109 ///
110 /// Any filters accumulated on this stream are applied to both entities
111 /// individually before the join.
112 pub fn join_self<K, KA, KB>(
113 self,
114 joiner: EqualJoiner<KA, KB, K>,
115 ) -> BiConstraintStream<S, A, K, E, KA, UniLeftBiFilter<F, A>, Sc>
116 where
117 A: Hash + PartialEq,
118 K: Eq + Hash + Clone + Send + Sync,
119 KA: Fn(&A) -> K + Send + Sync,
120 KB: Fn(&A) -> K + Send + Sync,
121 {
122 let (key_extractor, _) = joiner.into_keys();
123
124 // Convert uni-filter to bi-filter that applies to left entity
125 let bi_filter = UniLeftBiFilter::new(self.filter);
126
127 BiConstraintStream::new_self_join_with_filter(self.extractor, key_extractor, bi_filter)
128 }
129
130 /// Joins this stream with another collection to create cross-entity pairs (zero-erasure).
131 ///
132 /// Requires an `EqualJoiner` to enable key-based indexing for O(1) lookups.
133 /// Unlike `join_self` which pairs entities within the same collection,
134 /// `join` creates pairs from two different collections (e.g., Shift joined
135 /// with Employee).
136 ///
137 /// Any filters accumulated on this stream are applied to the A entity
138 /// before the join.
139 pub fn join<B, EB, K, KA, KB>(
140 self,
141 extractor_b: EB,
142 joiner: EqualJoiner<KA, KB, K>,
143 ) -> CrossBiConstraintStream<S, A, B, K, E, EB, KA, KB, UniLeftBiFilter<F, B>, Sc>
144 where
145 B: Clone + Send + Sync + 'static,
146 EB: Fn(&S) -> &[B] + Send + Sync,
147 K: Eq + Hash + Clone + Send + Sync,
148 KA: Fn(&A) -> K + Send + Sync,
149 KB: Fn(&B) -> K + Send + Sync,
150 {
151 let (key_a, key_b) = joiner.into_keys();
152
153 // Convert uni-filter to bi-filter that applies to left entity only
154 let bi_filter = UniLeftBiFilter::new(self.filter);
155
156 CrossBiConstraintStream::new_with_filter(
157 self.extractor,
158 extractor_b,
159 key_a,
160 key_b,
161 bi_filter,
162 )
163 }
164
165 /// Groups entities by key and aggregates with a collector.
166 ///
167 /// Returns a zero-erasure `GroupedConstraintStream` that can be penalized
168 /// or rewarded based on the aggregated result for each group.
169 pub fn group_by<K, KF, C>(
170 self,
171 key_fn: KF,
172 collector: C,
173 ) -> GroupedConstraintStream<S, A, K, E, KF, C, Sc>
174 where
175 K: Clone + Eq + Hash + Send + Sync + 'static,
176 KF: Fn(&A) -> K + Send + Sync,
177 C: UniCollector<A> + Send + Sync + 'static,
178 C::Accumulator: Send + Sync,
179 C::Result: Clone + Send + Sync,
180 {
181 GroupedConstraintStream::new(self.extractor, key_fn, collector)
182 }
183
184 /// Creates a balance constraint that penalizes uneven distribution across groups.
185 ///
186 /// Unlike `group_by` which scores each group independently, `balance` computes
187 /// a GLOBAL standard deviation across all group counts and produces a single score.
188 ///
189 /// The `key_fn` returns `Option<K>` to allow skipping entities (e.g., unassigned shifts).
190 /// Any filters accumulated on this stream are also applied.
191 ///
192 /// # Example
193 ///
194 /// ```
195 /// use solverforge_scoring::stream::ConstraintFactory;
196 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
197 /// use solverforge_core::score::SimpleScore;
198 ///
199 /// #[derive(Clone)]
200 /// struct Shift { employee_id: Option<usize> }
201 ///
202 /// #[derive(Clone)]
203 /// struct Solution { shifts: Vec<Shift> }
204 ///
205 /// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
206 /// .for_each(|s: &Solution| &s.shifts)
207 /// .balance(|shift: &Shift| shift.employee_id)
208 /// .penalize(SimpleScore::of(1000))
209 /// .as_constraint("Balance workload");
210 ///
211 /// let solution = Solution {
212 /// shifts: vec![
213 /// Shift { employee_id: Some(0) },
214 /// Shift { employee_id: Some(0) },
215 /// Shift { employee_id: Some(0) },
216 /// Shift { employee_id: Some(1) },
217 /// ],
218 /// };
219 ///
220 /// // Employee 0: 3 shifts, Employee 1: 1 shift
221 /// // std_dev = 1.0, penalty = -1000
222 /// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1000));
223 /// ```
224 pub fn balance<K, KF>(self, key_fn: KF) -> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
225 where
226 K: Clone + Eq + Hash + Send + Sync + 'static,
227 KF: Fn(&A) -> Option<K> + Send + Sync,
228 {
229 BalanceConstraintStream::new(self.extractor, self.filter, key_fn)
230 }
231
232 /// Filters A entities based on whether a matching B entity exists.
233 ///
234 /// Use this when the B collection needs filtering (e.g., only vacationing employees).
235 /// The `extractor_b` returns a `Vec<B>` to allow for filtering.
236 ///
237 /// Any filters accumulated on this stream are applied to A entities.
238 ///
239 /// # Example
240 ///
241 /// ```
242 /// use solverforge_scoring::stream::ConstraintFactory;
243 /// use solverforge_scoring::stream::joiner::equal_bi;
244 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
245 /// use solverforge_core::score::SimpleScore;
246 ///
247 /// #[derive(Clone)]
248 /// struct Shift { id: usize, employee_idx: Option<usize> }
249 ///
250 /// #[derive(Clone)]
251 /// struct Employee { id: usize, on_vacation: bool }
252 ///
253 /// #[derive(Clone)]
254 /// struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
255 ///
256 /// // Penalize shifts assigned to employees who are on vacation
257 /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
258 /// .for_each(|s: &Schedule| s.shifts.as_slice())
259 /// .filter(|shift: &Shift| shift.employee_idx.is_some())
260 /// .if_exists_filtered(
261 /// |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect(),
262 /// equal_bi(
263 /// |shift: &Shift| shift.employee_idx,
264 /// |emp: &Employee| Some(emp.id),
265 /// ),
266 /// )
267 /// .penalize(SimpleScore::of(1))
268 /// .as_constraint("Vacation conflict");
269 ///
270 /// let schedule = Schedule {
271 /// shifts: vec![
272 /// Shift { id: 0, employee_idx: Some(0) }, // assigned to vacationing emp
273 /// Shift { id: 1, employee_idx: Some(1) }, // assigned to working emp
274 /// Shift { id: 2, employee_idx: None }, // unassigned (filtered out)
275 /// ],
276 /// employees: vec![
277 /// Employee { id: 0, on_vacation: true },
278 /// Employee { id: 1, on_vacation: false },
279 /// ],
280 /// };
281 ///
282 /// // Only shift 0 matches (assigned to employee 0 who is on vacation)
283 /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
284 /// ```
285 pub fn if_exists_filtered<B, EB, K, KA, KB>(
286 self,
287 extractor_b: EB,
288 joiner: EqualJoiner<KA, KB, K>,
289 ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
290 where
291 B: Clone + Send + Sync + 'static,
292 EB: Fn(&S) -> Vec<B> + Send + Sync,
293 K: Eq + Hash + Clone + Send + Sync,
294 KA: Fn(&A) -> K + Send + Sync,
295 KB: Fn(&B) -> K + Send + Sync,
296 {
297 let (key_a, key_b) = joiner.into_keys();
298 IfExistsStream::new(
299 ExistenceMode::Exists,
300 self.extractor,
301 extractor_b,
302 key_a,
303 key_b,
304 self.filter,
305 )
306 }
307
308 /// Filters A entities based on whether NO matching B entity exists.
309 ///
310 /// Use this when the B collection needs filtering.
311 /// The `extractor_b` returns a `Vec<B>` to allow for filtering.
312 ///
313 /// Any filters accumulated on this stream are applied to A entities.
314 ///
315 /// # Example
316 ///
317 /// ```
318 /// use solverforge_scoring::stream::ConstraintFactory;
319 /// use solverforge_scoring::stream::joiner::equal_bi;
320 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
321 /// use solverforge_core::score::SimpleScore;
322 ///
323 /// #[derive(Clone)]
324 /// struct Task { id: usize, assignee: Option<usize> }
325 ///
326 /// #[derive(Clone)]
327 /// struct Worker { id: usize, available: bool }
328 ///
329 /// #[derive(Clone)]
330 /// struct Schedule { tasks: Vec<Task>, workers: Vec<Worker> }
331 ///
332 /// // Penalize tasks assigned to workers who are not available
333 /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
334 /// .for_each(|s: &Schedule| s.tasks.as_slice())
335 /// .filter(|task: &Task| task.assignee.is_some())
336 /// .if_not_exists_filtered(
337 /// |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
338 /// equal_bi(
339 /// |task: &Task| task.assignee,
340 /// |worker: &Worker| Some(worker.id),
341 /// ),
342 /// )
343 /// .penalize(SimpleScore::of(1))
344 /// .as_constraint("Unavailable worker");
345 ///
346 /// let schedule = Schedule {
347 /// tasks: vec![
348 /// Task { id: 0, assignee: Some(0) }, // worker 0 is unavailable
349 /// Task { id: 1, assignee: Some(1) }, // worker 1 is available
350 /// Task { id: 2, assignee: None }, // unassigned (filtered out)
351 /// ],
352 /// workers: vec![
353 /// Worker { id: 0, available: false },
354 /// Worker { id: 1, available: true },
355 /// ],
356 /// };
357 ///
358 /// // Task 0's worker (id=0) is NOT in the available workers list
359 /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
360 /// ```
361 pub fn if_not_exists_filtered<B, EB, K, KA, KB>(
362 self,
363 extractor_b: EB,
364 joiner: EqualJoiner<KA, KB, K>,
365 ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
366 where
367 B: Clone + Send + Sync + 'static,
368 EB: Fn(&S) -> Vec<B> + Send + Sync,
369 K: Eq + Hash + Clone + Send + Sync,
370 KA: Fn(&A) -> K + Send + Sync,
371 KB: Fn(&B) -> K + Send + Sync,
372 {
373 let (key_a, key_b) = joiner.into_keys();
374 IfExistsStream::new(
375 ExistenceMode::NotExists,
376 self.extractor,
377 extractor_b,
378 key_a,
379 key_b,
380 self.filter,
381 )
382 }
383
384 /// Penalizes each matching entity with a fixed weight.
385 pub fn penalize(
386 self,
387 weight: Sc,
388 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
389 where
390 Sc: Copy,
391 {
392 // Detect if this is a hard constraint by checking if hard level is non-zero
393 let is_hard = weight
394 .to_level_numbers()
395 .first()
396 .map(|&h| h != 0)
397 .unwrap_or(false);
398 UniConstraintBuilder {
399 extractor: self.extractor,
400 filter: self.filter,
401 impact_type: ImpactType::Penalty,
402 weight: move |_: &A| weight,
403 is_hard,
404 _phantom: PhantomData,
405 }
406 }
407
408 /// Penalizes each matching entity with a dynamic weight.
409 ///
410 /// Note: For dynamic weights, use `penalize_hard_with` to explicitly mark as a hard constraint,
411 /// since the weight function cannot be evaluated at build time.
412 pub fn penalize_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
413 where
414 W: Fn(&A) -> Sc + Send + Sync,
415 {
416 UniConstraintBuilder {
417 extractor: self.extractor,
418 filter: self.filter,
419 impact_type: ImpactType::Penalty,
420 weight: weight_fn,
421 is_hard: false, // Can't detect at build time; use penalize_hard_with for hard constraints
422 _phantom: PhantomData,
423 }
424 }
425
426 /// Penalizes each matching entity with a dynamic weight, explicitly marked as a hard constraint.
427 pub fn penalize_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
428 where
429 W: Fn(&A) -> Sc + Send + Sync,
430 {
431 UniConstraintBuilder {
432 extractor: self.extractor,
433 filter: self.filter,
434 impact_type: ImpactType::Penalty,
435 weight: weight_fn,
436 is_hard: true,
437 _phantom: PhantomData,
438 }
439 }
440
441 /// Rewards each matching entity with a fixed weight.
442 pub fn reward(
443 self,
444 weight: Sc,
445 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
446 where
447 Sc: Copy,
448 {
449 // Detect if this is a hard constraint by checking if hard level is non-zero
450 let is_hard = weight
451 .to_level_numbers()
452 .first()
453 .map(|&h| h != 0)
454 .unwrap_or(false);
455 UniConstraintBuilder {
456 extractor: self.extractor,
457 filter: self.filter,
458 impact_type: ImpactType::Reward,
459 weight: move |_: &A| weight,
460 is_hard,
461 _phantom: PhantomData,
462 }
463 }
464
465 /// Rewards each matching entity with a dynamic weight.
466 ///
467 /// Note: For dynamic weights, use `reward_hard_with` to explicitly mark as a hard constraint,
468 /// since the weight function cannot be evaluated at build time.
469 pub fn reward_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
470 where
471 W: Fn(&A) -> Sc + Send + Sync,
472 {
473 UniConstraintBuilder {
474 extractor: self.extractor,
475 filter: self.filter,
476 impact_type: ImpactType::Reward,
477 weight: weight_fn,
478 is_hard: false, // Can't detect at build time; use reward_hard_with for hard constraints
479 _phantom: PhantomData,
480 }
481 }
482
483 /// Rewards each matching entity with a dynamic weight, explicitly marked as a hard constraint.
484 pub fn reward_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
485 where
486 W: Fn(&A) -> Sc + Send + Sync,
487 {
488 UniConstraintBuilder {
489 extractor: self.extractor,
490 filter: self.filter,
491 impact_type: ImpactType::Reward,
492 weight: weight_fn,
493 is_hard: true,
494 _phantom: PhantomData,
495 }
496 }
497}
498
499impl<S, A, E, F, Sc: Score> std::fmt::Debug for UniConstraintStream<S, A, E, F, Sc> {
500 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
501 f.debug_struct("UniConstraintStream").finish()
502 }
503}
504
505/// Zero-erasure builder for finalizing a uni-constraint.
506pub struct UniConstraintBuilder<S, A, E, F, W, Sc>
507where
508 Sc: Score,
509{
510 extractor: E,
511 filter: F,
512 impact_type: ImpactType,
513 weight: W,
514 is_hard: bool,
515 _phantom: PhantomData<(S, A, Sc)>,
516}
517
518impl<S, A, E, F, W, Sc> UniConstraintBuilder<S, A, E, F, W, Sc>
519where
520 S: Send + Sync + 'static,
521 A: Clone + Send + Sync + 'static,
522 E: Fn(&S) -> &[A] + Send + Sync,
523 F: UniFilter<S, A>,
524 W: Fn(&A) -> Sc + Send + Sync,
525 Sc: Score + 'static,
526{
527 /// Finalizes the builder into a zero-erasure `IncrementalUniConstraint`.
528 pub fn as_constraint(
529 self,
530 name: &str,
531 ) -> IncrementalUniConstraint<S, A, E, impl Fn(&S, &A) -> bool + Send + Sync, W, Sc> {
532 let filter = self.filter;
533 let combined_filter = move |s: &S, a: &A| filter.test(s, a);
534
535 IncrementalUniConstraint::new(
536 ConstraintRef::new("", name),
537 self.impact_type,
538 self.extractor,
539 combined_filter,
540 self.weight,
541 self.is_hard,
542 )
543 }
544}
545
546impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for UniConstraintBuilder<S, A, E, F, W, Sc> {
547 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548 f.debug_struct("UniConstraintBuilder")
549 .field("impact_type", &self.impact_type)
550 .finish()
551 }
552}