solverforge_scoring/stream/uni_stream.rs
1//! Zero-erasure uni-constraint stream for single-entity constraint patterns.
2//!
3//! A `UniConstraintStream` operates on a single entity type and supports
4//! filtering, weighting, and constraint finalization. All type information
5//! is preserved at compile time - no Arc, no dyn, fully monomorphized.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::constraint::incremental::IncrementalUniConstraint;
14
15use crate::constraint::if_exists::ExistenceMode;
16
17use super::balance_stream::BalanceConstraintStream;
18use super::bi_stream::BiConstraintStream;
19use super::collector::UniCollector;
20use super::cross_bi_stream::CrossBiConstraintStream;
21use super::filter::{AndUniFilter, FnUniFilter, TrueFilter, UniFilter, UniLeftBiFilter};
22use super::grouped_stream::GroupedConstraintStream;
23use super::if_exists_stream::IfExistsStream;
24use super::joiner::EqualJoiner;
25
26/// Zero-erasure constraint stream over a single entity type.
27///
28/// `UniConstraintStream` accumulates filters and can be finalized into
29/// an `IncrementalUniConstraint` via `penalize()` or `reward()`.
30///
31/// All type parameters are concrete - no trait objects, no Arc allocations
32/// in the hot path.
33///
34/// # Type Parameters
35///
36/// - `S` - Solution type
37/// - `A` - Entity type
38/// - `E` - Extractor function type
39/// - `F` - Combined filter type
40/// - `Sc` - Score type
41pub struct UniConstraintStream<S, A, E, F, Sc>
42where
43 Sc: Score,
44{
45 extractor: E,
46 filter: F,
47 _phantom: PhantomData<(S, A, Sc)>,
48}
49
50impl<S, A, E, Sc> UniConstraintStream<S, A, E, TrueFilter, Sc>
51where
52 S: Send + Sync + 'static,
53 A: Clone + Send + Sync + 'static,
54 E: Fn(&S) -> &[A] + Send + Sync,
55 Sc: Score + 'static,
56{
57 /// Creates a new uni-constraint stream with the given extractor.
58 pub fn new(extractor: E) -> Self {
59 Self {
60 extractor,
61 filter: TrueFilter,
62 _phantom: PhantomData,
63 }
64 }
65}
66
67impl<S, A, E, F, Sc> UniConstraintStream<S, A, E, F, Sc>
68where
69 S: Send + Sync + 'static,
70 A: Clone + Send + Sync + 'static,
71 E: Fn(&S) -> &[A] + Send + Sync,
72 F: UniFilter<A>,
73 Sc: Score + 'static,
74{
75 /// Adds a filter predicate to the stream.
76 ///
77 /// Multiple filters are combined with AND semantics at compile time.
78 /// Each filter adds a new type layer, preserving zero-erasure.
79 pub fn filter<P>(
80 self,
81 predicate: P,
82 ) -> UniConstraintStream<S, A, E, AndUniFilter<F, FnUniFilter<P>>, Sc>
83 where
84 P: Fn(&A) -> bool + Send + Sync,
85 {
86 UniConstraintStream {
87 extractor: self.extractor,
88 filter: AndUniFilter::new(self.filter, FnUniFilter::new(predicate)),
89 _phantom: PhantomData,
90 }
91 }
92
93 /// Joins this stream with itself to create pairs (zero-erasure).
94 ///
95 /// Requires an `EqualJoiner` to enable key-based indexing for O(k) lookups.
96 /// For self-joins, pairs are ordered (i < j) to avoid duplicates.
97 ///
98 /// Any filters accumulated on this stream are applied to both entities
99 /// individually before the join.
100 pub fn join_self<K, KA, KB>(
101 self,
102 joiner: EqualJoiner<KA, KB, K>,
103 ) -> BiConstraintStream<S, A, K, E, KA, UniLeftBiFilter<F, A>, Sc>
104 where
105 A: Hash + PartialEq,
106 K: Eq + Hash + Clone + Send + Sync,
107 KA: Fn(&A) -> K + Send + Sync,
108 KB: Fn(&A) -> K + Send + Sync,
109 {
110 let (key_extractor, _) = joiner.into_keys();
111
112 // Convert uni-filter to bi-filter that applies to left entity
113 let bi_filter = UniLeftBiFilter::new(self.filter);
114
115 BiConstraintStream::new_self_join_with_filter(self.extractor, key_extractor, bi_filter)
116 }
117
118 /// Joins this stream with another collection to create cross-entity pairs (zero-erasure).
119 ///
120 /// Requires an `EqualJoiner` to enable key-based indexing for O(1) lookups.
121 /// Unlike `join_self` which pairs entities within the same collection,
122 /// `join` creates pairs from two different collections (e.g., Shift joined
123 /// with Employee).
124 ///
125 /// Any filters accumulated on this stream are applied to the A entity
126 /// before the join.
127 pub fn join<B, EB, K, KA, KB>(
128 self,
129 extractor_b: EB,
130 joiner: EqualJoiner<KA, KB, K>,
131 ) -> CrossBiConstraintStream<S, A, B, K, E, EB, KA, KB, UniLeftBiFilter<F, B>, Sc>
132 where
133 B: Clone + Send + Sync + 'static,
134 EB: Fn(&S) -> &[B] + Send + Sync,
135 K: Eq + Hash + Clone + Send + Sync,
136 KA: Fn(&A) -> K + Send + Sync,
137 KB: Fn(&B) -> K + Send + Sync,
138 {
139 let (key_a, key_b) = joiner.into_keys();
140
141 // Convert uni-filter to bi-filter that applies to left entity only
142 let bi_filter = UniLeftBiFilter::new(self.filter);
143
144 CrossBiConstraintStream::new_with_filter(
145 self.extractor,
146 extractor_b,
147 key_a,
148 key_b,
149 bi_filter,
150 )
151 }
152
153 /// Groups entities by key and aggregates with a collector.
154 ///
155 /// Returns a zero-erasure `GroupedConstraintStream` that can be penalized
156 /// or rewarded based on the aggregated result for each group.
157 pub fn group_by<K, KF, C>(
158 self,
159 key_fn: KF,
160 collector: C,
161 ) -> GroupedConstraintStream<S, A, K, E, KF, C, Sc>
162 where
163 K: Clone + Eq + Hash + Send + Sync + 'static,
164 KF: Fn(&A) -> K + Send + Sync,
165 C: UniCollector<A> + Send + Sync + 'static,
166 C::Accumulator: Send + Sync,
167 C::Result: Clone + Send + Sync,
168 {
169 GroupedConstraintStream::new(self.extractor, key_fn, collector)
170 }
171
172 /// Creates a balance constraint that penalizes uneven distribution across groups.
173 ///
174 /// Unlike `group_by` which scores each group independently, `balance` computes
175 /// a GLOBAL standard deviation across all group counts and produces a single score.
176 ///
177 /// The `key_fn` returns `Option<K>` to allow skipping entities (e.g., unassigned shifts).
178 /// Any filters accumulated on this stream are also applied.
179 ///
180 /// # Example
181 ///
182 /// ```
183 /// use solverforge_scoring::stream::ConstraintFactory;
184 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
185 /// use solverforge_core::score::SimpleScore;
186 ///
187 /// #[derive(Clone)]
188 /// struct Shift { employee_id: Option<usize> }
189 ///
190 /// #[derive(Clone)]
191 /// struct Solution { shifts: Vec<Shift> }
192 ///
193 /// let constraint = ConstraintFactory::<Solution, SimpleScore>::new()
194 /// .for_each(|s: &Solution| &s.shifts)
195 /// .balance(|shift: &Shift| shift.employee_id)
196 /// .penalize(SimpleScore::of(1000))
197 /// .as_constraint("Balance workload");
198 ///
199 /// let solution = Solution {
200 /// shifts: vec![
201 /// Shift { employee_id: Some(0) },
202 /// Shift { employee_id: Some(0) },
203 /// Shift { employee_id: Some(0) },
204 /// Shift { employee_id: Some(1) },
205 /// ],
206 /// };
207 ///
208 /// // Employee 0: 3 shifts, Employee 1: 1 shift
209 /// // std_dev = 1.0, penalty = -1000
210 /// assert_eq!(constraint.evaluate(&solution), SimpleScore::of(-1000));
211 /// ```
212 pub fn balance<K, KF>(self, key_fn: KF) -> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
213 where
214 K: Clone + Eq + Hash + Send + Sync + 'static,
215 KF: Fn(&A) -> Option<K> + Send + Sync,
216 {
217 BalanceConstraintStream::new(self.extractor, self.filter, key_fn)
218 }
219
220 /// Filters A entities based on whether a matching B entity exists.
221 ///
222 /// Use this when the B collection needs filtering (e.g., only vacationing employees).
223 /// The `extractor_b` returns a `Vec<B>` to allow for filtering.
224 ///
225 /// Any filters accumulated on this stream are applied to A entities.
226 ///
227 /// # Example
228 ///
229 /// ```
230 /// use solverforge_scoring::stream::ConstraintFactory;
231 /// use solverforge_scoring::stream::joiner::equal_bi;
232 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
233 /// use solverforge_core::score::SimpleScore;
234 ///
235 /// #[derive(Clone)]
236 /// struct Shift { id: usize, employee_idx: Option<usize> }
237 ///
238 /// #[derive(Clone)]
239 /// struct Employee { id: usize, on_vacation: bool }
240 ///
241 /// #[derive(Clone)]
242 /// struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
243 ///
244 /// // Penalize shifts assigned to employees who are on vacation
245 /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
246 /// .for_each(|s: &Schedule| s.shifts.as_slice())
247 /// .filter(|shift: &Shift| shift.employee_idx.is_some())
248 /// .if_exists_filtered(
249 /// |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect(),
250 /// equal_bi(
251 /// |shift: &Shift| shift.employee_idx,
252 /// |emp: &Employee| Some(emp.id),
253 /// ),
254 /// )
255 /// .penalize(SimpleScore::of(1))
256 /// .as_constraint("Vacation conflict");
257 ///
258 /// let schedule = Schedule {
259 /// shifts: vec![
260 /// Shift { id: 0, employee_idx: Some(0) }, // assigned to vacationing emp
261 /// Shift { id: 1, employee_idx: Some(1) }, // assigned to working emp
262 /// Shift { id: 2, employee_idx: None }, // unassigned (filtered out)
263 /// ],
264 /// employees: vec![
265 /// Employee { id: 0, on_vacation: true },
266 /// Employee { id: 1, on_vacation: false },
267 /// ],
268 /// };
269 ///
270 /// // Only shift 0 matches (assigned to employee 0 who is on vacation)
271 /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
272 /// ```
273 pub fn if_exists_filtered<B, EB, K, KA, KB>(
274 self,
275 extractor_b: EB,
276 joiner: EqualJoiner<KA, KB, K>,
277 ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
278 where
279 B: Clone + Send + Sync + 'static,
280 EB: Fn(&S) -> Vec<B> + Send + Sync,
281 K: Eq + Hash + Clone + Send + Sync,
282 KA: Fn(&A) -> K + Send + Sync,
283 KB: Fn(&B) -> K + Send + Sync,
284 {
285 let (key_a, key_b) = joiner.into_keys();
286 IfExistsStream::new(
287 ExistenceMode::Exists,
288 self.extractor,
289 extractor_b,
290 key_a,
291 key_b,
292 self.filter,
293 )
294 }
295
296 /// Filters A entities based on whether NO matching B entity exists.
297 ///
298 /// Use this when the B collection needs filtering.
299 /// The `extractor_b` returns a `Vec<B>` to allow for filtering.
300 ///
301 /// Any filters accumulated on this stream are applied to A entities.
302 ///
303 /// # Example
304 ///
305 /// ```
306 /// use solverforge_scoring::stream::ConstraintFactory;
307 /// use solverforge_scoring::stream::joiner::equal_bi;
308 /// use solverforge_scoring::api::constraint_set::IncrementalConstraint;
309 /// use solverforge_core::score::SimpleScore;
310 ///
311 /// #[derive(Clone)]
312 /// struct Task { id: usize, assignee: Option<usize> }
313 ///
314 /// #[derive(Clone)]
315 /// struct Worker { id: usize, available: bool }
316 ///
317 /// #[derive(Clone)]
318 /// struct Schedule { tasks: Vec<Task>, workers: Vec<Worker> }
319 ///
320 /// // Penalize tasks assigned to workers who are not available
321 /// let constraint = ConstraintFactory::<Schedule, SimpleScore>::new()
322 /// .for_each(|s: &Schedule| s.tasks.as_slice())
323 /// .filter(|task: &Task| task.assignee.is_some())
324 /// .if_not_exists_filtered(
325 /// |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
326 /// equal_bi(
327 /// |task: &Task| task.assignee,
328 /// |worker: &Worker| Some(worker.id),
329 /// ),
330 /// )
331 /// .penalize(SimpleScore::of(1))
332 /// .as_constraint("Unavailable worker");
333 ///
334 /// let schedule = Schedule {
335 /// tasks: vec![
336 /// Task { id: 0, assignee: Some(0) }, // worker 0 is unavailable
337 /// Task { id: 1, assignee: Some(1) }, // worker 1 is available
338 /// Task { id: 2, assignee: None }, // unassigned (filtered out)
339 /// ],
340 /// workers: vec![
341 /// Worker { id: 0, available: false },
342 /// Worker { id: 1, available: true },
343 /// ],
344 /// };
345 ///
346 /// // Task 0's worker (id=0) is NOT in the available workers list
347 /// assert_eq!(constraint.evaluate(&schedule), SimpleScore::of(-1));
348 /// ```
349 pub fn if_not_exists_filtered<B, EB, K, KA, KB>(
350 self,
351 extractor_b: EB,
352 joiner: EqualJoiner<KA, KB, K>,
353 ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
354 where
355 B: Clone + Send + Sync + 'static,
356 EB: Fn(&S) -> Vec<B> + Send + Sync,
357 K: Eq + Hash + Clone + Send + Sync,
358 KA: Fn(&A) -> K + Send + Sync,
359 KB: Fn(&B) -> K + Send + Sync,
360 {
361 let (key_a, key_b) = joiner.into_keys();
362 IfExistsStream::new(
363 ExistenceMode::NotExists,
364 self.extractor,
365 extractor_b,
366 key_a,
367 key_b,
368 self.filter,
369 )
370 }
371
372 /// Penalizes each matching entity with a fixed weight.
373 pub fn penalize(
374 self,
375 weight: Sc,
376 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
377 where
378 Sc: Clone,
379 {
380 // Detect if this is a hard constraint by checking if hard level is non-zero
381 let is_hard = weight
382 .to_level_numbers()
383 .first()
384 .map(|&h| h != 0)
385 .unwrap_or(false);
386 UniConstraintBuilder {
387 extractor: self.extractor,
388 filter: self.filter,
389 impact_type: ImpactType::Penalty,
390 weight: move |_: &A| weight.clone(),
391 is_hard,
392 _phantom: PhantomData,
393 }
394 }
395
396 /// Penalizes each matching entity with a dynamic weight.
397 ///
398 /// Note: For dynamic weights, use `penalize_hard_with` to explicitly mark as a hard constraint,
399 /// since the weight function cannot be evaluated at build time.
400 pub fn penalize_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
401 where
402 W: Fn(&A) -> Sc + Send + Sync,
403 {
404 UniConstraintBuilder {
405 extractor: self.extractor,
406 filter: self.filter,
407 impact_type: ImpactType::Penalty,
408 weight: weight_fn,
409 is_hard: false, // Can't detect at build time; use penalize_hard_with for hard constraints
410 _phantom: PhantomData,
411 }
412 }
413
414 /// Penalizes each matching entity with a dynamic weight, explicitly marked as a hard constraint.
415 pub fn penalize_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
416 where
417 W: Fn(&A) -> Sc + Send + Sync,
418 {
419 UniConstraintBuilder {
420 extractor: self.extractor,
421 filter: self.filter,
422 impact_type: ImpactType::Penalty,
423 weight: weight_fn,
424 is_hard: true,
425 _phantom: PhantomData,
426 }
427 }
428
429 /// Rewards each matching entity with a fixed weight.
430 pub fn reward(
431 self,
432 weight: Sc,
433 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
434 where
435 Sc: Clone,
436 {
437 // Detect if this is a hard constraint by checking if hard level is non-zero
438 let is_hard = weight
439 .to_level_numbers()
440 .first()
441 .map(|&h| h != 0)
442 .unwrap_or(false);
443 UniConstraintBuilder {
444 extractor: self.extractor,
445 filter: self.filter,
446 impact_type: ImpactType::Reward,
447 weight: move |_: &A| weight.clone(),
448 is_hard,
449 _phantom: PhantomData,
450 }
451 }
452
453 /// Rewards each matching entity with a dynamic weight.
454 ///
455 /// Note: For dynamic weights, use `reward_hard_with` to explicitly mark as a hard constraint,
456 /// since the weight function cannot be evaluated at build time.
457 pub fn reward_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
458 where
459 W: Fn(&A) -> Sc + Send + Sync,
460 {
461 UniConstraintBuilder {
462 extractor: self.extractor,
463 filter: self.filter,
464 impact_type: ImpactType::Reward,
465 weight: weight_fn,
466 is_hard: false, // Can't detect at build time; use reward_hard_with for hard constraints
467 _phantom: PhantomData,
468 }
469 }
470
471 /// Rewards each matching entity with a dynamic weight, explicitly marked as a hard constraint.
472 pub fn reward_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
473 where
474 W: Fn(&A) -> Sc + Send + Sync,
475 {
476 UniConstraintBuilder {
477 extractor: self.extractor,
478 filter: self.filter,
479 impact_type: ImpactType::Reward,
480 weight: weight_fn,
481 is_hard: true,
482 _phantom: PhantomData,
483 }
484 }
485}
486
487impl<S, A, E, F, Sc: Score> std::fmt::Debug for UniConstraintStream<S, A, E, F, Sc> {
488 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
489 f.debug_struct("UniConstraintStream").finish()
490 }
491}
492
493/// Zero-erasure builder for finalizing a uni-constraint.
494pub struct UniConstraintBuilder<S, A, E, F, W, Sc>
495where
496 Sc: Score,
497{
498 extractor: E,
499 filter: F,
500 impact_type: ImpactType,
501 weight: W,
502 is_hard: bool,
503 _phantom: PhantomData<(S, A, Sc)>,
504}
505
506impl<S, A, E, F, W, Sc> UniConstraintBuilder<S, A, E, F, W, Sc>
507where
508 S: Send + Sync + 'static,
509 A: Clone + Send + Sync + 'static,
510 E: Fn(&S) -> &[A] + Send + Sync,
511 F: UniFilter<A>,
512 W: Fn(&A) -> Sc + Send + Sync,
513 Sc: Score + 'static,
514{
515 /// Finalizes the builder into a zero-erasure `IncrementalUniConstraint`.
516 pub fn as_constraint(
517 self,
518 name: &str,
519 ) -> IncrementalUniConstraint<S, A, E, impl Fn(&A) -> bool + Send + Sync, W, Sc> {
520 let filter = self.filter;
521 let combined_filter = move |a: &A| filter.test(a);
522
523 IncrementalUniConstraint::new(
524 ConstraintRef::new("", name),
525 self.impact_type,
526 self.extractor,
527 combined_filter,
528 self.weight,
529 self.is_hard,
530 )
531 }
532}
533
534impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for UniConstraintBuilder<S, A, E, F, W, Sc> {
535 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536 f.debug_struct("UniConstraintBuilder")
537 .field("impact_type", &self.impact_type)
538 .finish()
539 }
540}