solverforge_scoring/stream/uni_stream.rs
1// Zero-erasure uni-constraint stream for single-entity constraint patterns.
2//
3// A `UniConstraintStream` operates on a single entity type and supports
4// filtering, weighting, and constraint finalization. All type information
5// is preserved at compile time - no Arc, no dyn, fully monomorphized.
6
7use std::hash::Hash;
8use std::marker::PhantomData;
9
10use solverforge_core::score::Score;
11use solverforge_core::{ConstraintRef, ImpactType};
12
13use crate::constraint::incremental::IncrementalUniConstraint;
14
15use crate::constraint::if_exists::ExistenceMode;
16
17use super::balance_stream::BalanceConstraintStream;
18use super::collector::UniCollector;
19use super::filter::{AndUniFilter, FnUniFilter, TrueFilter, UniFilter};
20use super::grouped_stream::GroupedConstraintStream;
21use super::if_exists_stream::IfExistsStream;
22use super::joiner::EqualJoiner;
23
24// Zero-erasure constraint stream over a single entity type.
25//
26// `UniConstraintStream` accumulates filters and can be finalized into
27// an `IncrementalUniConstraint` via `penalize()` or `reward()`.
28//
29// All type parameters are concrete - no trait objects, no Arc allocations
30// in the hot path.
31//
32// # Type Parameters
33//
34// - `S` - Solution type
35// - `A` - Entity type
36// - `E` - Extractor function type
37// - `F` - Combined filter type
38// - `Sc` - Score type
39pub struct UniConstraintStream<S, A, E, F, Sc>
40where
41 Sc: Score,
42{
43 extractor: E,
44 filter: F,
45 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
46}
47
48impl<S, A, E, Sc> UniConstraintStream<S, A, E, TrueFilter, Sc>
49where
50 S: Send + Sync + 'static,
51 A: Clone + Send + Sync + 'static,
52 E: Fn(&S) -> &[A] + Send + Sync,
53 Sc: Score + 'static,
54{
55 // Creates a new uni-constraint stream with the given extractor.
56 pub fn new(extractor: E) -> Self {
57 Self {
58 extractor,
59 filter: TrueFilter,
60 _phantom: PhantomData,
61 }
62 }
63}
64
65impl<S, A, E, F, Sc> UniConstraintStream<S, A, E, F, Sc>
66where
67 S: Send + Sync + 'static,
68 A: Clone + Send + Sync + 'static,
69 E: Fn(&S) -> &[A] + Send + Sync,
70 F: UniFilter<S, A>,
71 Sc: Score + 'static,
72{
73 // Adds a filter predicate to the stream.
74 //
75 // Multiple filters are combined with AND semantics at compile time.
76 // Each filter adds a new type layer, preserving zero-erasure.
77 //
78 // To access related entities, use shadow variables on your entity type
79 // (e.g., `#[inverse_relation_shadow_variable]`) rather than solution traversal.
80 pub fn filter<P>(
81 self,
82 predicate: P,
83 ) -> UniConstraintStream<
84 S,
85 A,
86 E,
87 AndUniFilter<F, FnUniFilter<impl Fn(&S, &A) -> bool + Send + Sync>>,
88 Sc,
89 >
90 where
91 P: Fn(&A) -> bool + Send + Sync + 'static,
92 {
93 UniConstraintStream {
94 extractor: self.extractor,
95 filter: AndUniFilter::new(
96 self.filter,
97 FnUniFilter::new(move |_s: &S, a: &A| predicate(a)),
98 ),
99 _phantom: PhantomData,
100 }
101 }
102
103 // Joins this stream using the provided join target.
104 //
105 // Dispatch depends on argument type:
106 // - `equal(|a: &A| key_fn(a))` — self-join, returns `BiConstraintStream`
107 // - `(extractor_b, equal_bi(ka, kb))` — keyed cross-join, returns `CrossBiConstraintStream`
108 // - `(other_stream, |a, b| predicate)` — predicate cross-join O(n*m)
109 pub fn join<J>(self, target: J) -> J::Output
110 where
111 J: super::join_target::JoinTarget<S, A, E, F, Sc>,
112 {
113 target.apply(self.extractor, self.filter)
114 }
115
116 // Groups entities by key and aggregates with a collector.
117 //
118 // Returns a zero-erasure `GroupedConstraintStream` that can be penalized
119 // or rewarded based on the aggregated result for each group.
120 pub fn group_by<K, KF, C>(
121 self,
122 key_fn: KF,
123 collector: C,
124 ) -> GroupedConstraintStream<S, A, K, E, F, KF, C, Sc>
125 where
126 K: Clone + Eq + Hash + Send + Sync + 'static,
127 KF: Fn(&A) -> K + Send + Sync,
128 C: UniCollector<A> + Send + Sync + 'static,
129 C::Accumulator: Send + Sync,
130 C::Result: Clone + Send + Sync,
131 {
132 GroupedConstraintStream::new(self.extractor, self.filter, key_fn, collector)
133 }
134
135 // Creates a balance constraint that penalizes uneven distribution across groups.
136 //
137 // Unlike `group_by` which scores each group independently, `balance` computes
138 // a GLOBAL standard deviation across all group counts and produces a single score.
139 //
140 // The `key_fn` returns `Option<K>` to allow skipping entities (e.g., unassigned shifts).
141 // Any filters accumulated on this stream are also applied.
142 //
143 // # Example
144 //
145 // ```
146 // use solverforge_scoring::stream::ConstraintFactory;
147 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
148 // use solverforge_core::score::SoftScore;
149 //
150 // #[derive(Clone)]
151 // struct Shift { employee_id: Option<usize> }
152 //
153 // #[derive(Clone)]
154 // struct Solution { shifts: Vec<Shift> }
155 //
156 // let constraint = ConstraintFactory::<Solution, SoftScore>::new()
157 // .for_each(|s: &Solution| &s.shifts)
158 // .balance(|shift: &Shift| shift.employee_id)
159 // .penalize(SoftScore::of(1000))
160 // .named("Balance workload");
161 //
162 // let solution = Solution {
163 // shifts: vec![
164 // Shift { employee_id: Some(0) },
165 // Shift { employee_id: Some(0) },
166 // Shift { employee_id: Some(0) },
167 // Shift { employee_id: Some(1) },
168 // ],
169 // };
170 //
171 // // Employee 0: 3 shifts, Employee 1: 1 shift
172 // // std_dev = 1.0, penalty = -1000
173 // assert_eq!(constraint.evaluate(&solution), SoftScore::of(-1000));
174 // ```
175 pub fn balance<K, KF>(self, key_fn: KF) -> BalanceConstraintStream<S, A, K, E, F, KF, Sc>
176 where
177 K: Clone + Eq + Hash + Send + Sync + 'static,
178 KF: Fn(&A) -> Option<K> + Send + Sync,
179 {
180 BalanceConstraintStream::new(self.extractor, self.filter, key_fn)
181 }
182
183 // Filters A entities based on whether a matching B entity exists.
184 //
185 // Use this when the B collection needs filtering (e.g., only vacationing employees).
186 // The `extractor_b` returns a `Vec<B>` to allow for filtering.
187 //
188 // Any filters accumulated on this stream are applied to A entities.
189 //
190 // # Example
191 //
192 // ```
193 // use solverforge_scoring::stream::ConstraintFactory;
194 // use solverforge_scoring::stream::joiner::equal_bi;
195 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
196 // use solverforge_core::score::SoftScore;
197 //
198 // #[derive(Clone)]
199 // struct Shift { id: usize, employee_idx: Option<usize> }
200 //
201 // #[derive(Clone)]
202 // struct Employee { id: usize, on_vacation: bool }
203 //
204 // #[derive(Clone)]
205 // struct Schedule { shifts: Vec<Shift>, employees: Vec<Employee> }
206 //
207 // // Penalize shifts assigned to employees who are on vacation
208 // let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
209 // .for_each(|s: &Schedule| s.shifts.as_slice())
210 // .filter(|shift: &Shift| shift.employee_idx.is_some())
211 // .if_exists_filtered(
212 // |s: &Schedule| s.employees.iter().filter(|e| e.on_vacation).cloned().collect(),
213 // equal_bi(
214 // |shift: &Shift| shift.employee_idx,
215 // |emp: &Employee| Some(emp.id),
216 // ),
217 // )
218 // .penalize(SoftScore::of(1))
219 // .named("Vacation conflict");
220 //
221 // let schedule = Schedule {
222 // shifts: vec![
223 // Shift { id: 0, employee_idx: Some(0) }, // assigned to vacationing emp
224 // Shift { id: 1, employee_idx: Some(1) }, // assigned to working emp
225 // Shift { id: 2, employee_idx: None }, // unassigned (filtered out)
226 // ],
227 // employees: vec![
228 // Employee { id: 0, on_vacation: true },
229 // Employee { id: 1, on_vacation: false },
230 // ],
231 // };
232 //
233 // // Only shift 0 matches (assigned to employee 0 who is on vacation)
234 // assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-1));
235 // ```
236 pub fn if_exists_filtered<B, EB, K, KA, KB>(
237 self,
238 extractor_b: EB,
239 joiner: EqualJoiner<KA, KB, K>,
240 ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
241 where
242 B: Clone + Send + Sync + 'static,
243 EB: Fn(&S) -> Vec<B> + Send + Sync,
244 K: Eq + Hash + Clone + Send + Sync,
245 KA: Fn(&A) -> K + Send + Sync,
246 KB: Fn(&B) -> K + Send + Sync,
247 {
248 let (key_a, key_b) = joiner.into_keys();
249 IfExistsStream::new(
250 ExistenceMode::Exists,
251 self.extractor,
252 extractor_b,
253 key_a,
254 key_b,
255 self.filter,
256 )
257 }
258
259 // Filters A entities based on whether NO matching B entity exists.
260 //
261 // Use this when the B collection needs filtering.
262 // The `extractor_b` returns a `Vec<B>` to allow for filtering.
263 //
264 // Any filters accumulated on this stream are applied to A entities.
265 //
266 // # Example
267 //
268 // ```
269 // use solverforge_scoring::stream::ConstraintFactory;
270 // use solverforge_scoring::stream::joiner::equal_bi;
271 // use solverforge_scoring::api::constraint_set::IncrementalConstraint;
272 // use solverforge_core::score::SoftScore;
273 //
274 // #[derive(Clone)]
275 // struct Task { id: usize, assignee: Option<usize> }
276 //
277 // #[derive(Clone)]
278 // struct Worker { id: usize, available: bool }
279 //
280 // #[derive(Clone)]
281 // struct Schedule { tasks: Vec<Task>, workers: Vec<Worker> }
282 //
283 // // Penalize tasks assigned to workers who are not available
284 // let constraint = ConstraintFactory::<Schedule, SoftScore>::new()
285 // .for_each(|s: &Schedule| s.tasks.as_slice())
286 // .filter(|task: &Task| task.assignee.is_some())
287 // .if_not_exists_filtered(
288 // |s: &Schedule| s.workers.iter().filter(|w| w.available).cloned().collect(),
289 // equal_bi(
290 // |task: &Task| task.assignee,
291 // |worker: &Worker| Some(worker.id),
292 // ),
293 // )
294 // .penalize(SoftScore::of(1))
295 // .named("Unavailable worker");
296 //
297 // let schedule = Schedule {
298 // tasks: vec![
299 // Task { id: 0, assignee: Some(0) }, // worker 0 is unavailable
300 // Task { id: 1, assignee: Some(1) }, // worker 1 is available
301 // Task { id: 2, assignee: None }, // unassigned (filtered out)
302 // ],
303 // workers: vec![
304 // Worker { id: 0, available: false },
305 // Worker { id: 1, available: true },
306 // ],
307 // };
308 //
309 // // Task 0's worker (id=0) is NOT in the available workers list
310 // assert_eq!(constraint.evaluate(&schedule), SoftScore::of(-1));
311 // ```
312 pub fn if_not_exists_filtered<B, EB, K, KA, KB>(
313 self,
314 extractor_b: EB,
315 joiner: EqualJoiner<KA, KB, K>,
316 ) -> IfExistsStream<S, A, B, K, E, EB, KA, KB, F, Sc>
317 where
318 B: Clone + Send + Sync + 'static,
319 EB: Fn(&S) -> Vec<B> + Send + Sync,
320 K: Eq + Hash + Clone + Send + Sync,
321 KA: Fn(&A) -> K + Send + Sync,
322 KB: Fn(&B) -> K + Send + Sync,
323 {
324 let (key_a, key_b) = joiner.into_keys();
325 IfExistsStream::new(
326 ExistenceMode::NotExists,
327 self.extractor,
328 extractor_b,
329 key_a,
330 key_b,
331 self.filter,
332 )
333 }
334
335 // Penalizes each matching entity with a fixed weight.
336 pub fn penalize(
337 self,
338 weight: Sc,
339 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
340 where
341 Sc: Copy,
342 {
343 // Detect if this is a hard constraint by checking if hard level is non-zero
344 let is_hard = weight
345 .to_level_numbers()
346 .first()
347 .map(|&h| h != 0)
348 .unwrap_or(false);
349 UniConstraintBuilder {
350 extractor: self.extractor,
351 filter: self.filter,
352 impact_type: ImpactType::Penalty,
353 weight: move |_: &A| weight,
354 is_hard,
355 expected_descriptor: None,
356 _phantom: PhantomData,
357 }
358 }
359
360 // Penalizes each matching entity with a dynamic weight.
361 //
362 // Note: For dynamic weights, use `penalize_hard_with` to explicitly mark as a hard constraint,
363 // since the weight function cannot be evaluated at build time.
364 pub fn penalize_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
365 where
366 W: Fn(&A) -> Sc + Send + Sync,
367 {
368 UniConstraintBuilder {
369 extractor: self.extractor,
370 filter: self.filter,
371 impact_type: ImpactType::Penalty,
372 weight: weight_fn,
373 is_hard: false, // Can't detect at build time; use penalize_hard_with for hard constraints
374 expected_descriptor: None,
375 _phantom: PhantomData,
376 }
377 }
378
379 // Penalizes each matching entity with a dynamic weight, explicitly marked as a hard constraint.
380 pub fn penalize_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
381 where
382 W: Fn(&A) -> Sc + Send + Sync,
383 {
384 UniConstraintBuilder {
385 extractor: self.extractor,
386 filter: self.filter,
387 impact_type: ImpactType::Penalty,
388 weight: weight_fn,
389 is_hard: true,
390 expected_descriptor: None,
391 _phantom: PhantomData,
392 }
393 }
394
395 // Rewards each matching entity with a fixed weight.
396 pub fn reward(
397 self,
398 weight: Sc,
399 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
400 where
401 Sc: Copy,
402 {
403 // Detect if this is a hard constraint by checking if hard level is non-zero
404 let is_hard = weight
405 .to_level_numbers()
406 .first()
407 .map(|&h| h != 0)
408 .unwrap_or(false);
409 UniConstraintBuilder {
410 extractor: self.extractor,
411 filter: self.filter,
412 impact_type: ImpactType::Reward,
413 weight: move |_: &A| weight,
414 is_hard,
415 expected_descriptor: None,
416 _phantom: PhantomData,
417 }
418 }
419
420 // Rewards each matching entity with a dynamic weight.
421 //
422 // Note: For dynamic weights, use `reward_hard_with` to explicitly mark as a hard constraint,
423 // since the weight function cannot be evaluated at build time.
424 pub fn reward_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
425 where
426 W: Fn(&A) -> Sc + Send + Sync,
427 {
428 UniConstraintBuilder {
429 extractor: self.extractor,
430 filter: self.filter,
431 impact_type: ImpactType::Reward,
432 weight: weight_fn,
433 is_hard: false, // Can't detect at build time; use reward_hard_with for hard constraints
434 expected_descriptor: None,
435 _phantom: PhantomData,
436 }
437 }
438
439 // Rewards each matching entity with a dynamic weight, explicitly marked as a hard constraint.
440 pub fn reward_hard_with<W>(self, weight_fn: W) -> UniConstraintBuilder<S, A, E, F, W, Sc>
441 where
442 W: Fn(&A) -> Sc + Send + Sync,
443 {
444 UniConstraintBuilder {
445 extractor: self.extractor,
446 filter: self.filter,
447 impact_type: ImpactType::Reward,
448 weight: weight_fn,
449 is_hard: true,
450 expected_descriptor: None,
451 _phantom: PhantomData,
452 }
453 }
454
455 // Penalizes each matching entity with one hard score unit.
456 pub fn penalize_hard(
457 self,
458 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
459 where
460 Sc: Copy,
461 {
462 self.penalize(Sc::one_hard())
463 }
464
465 // Penalizes each matching entity with one soft score unit.
466 pub fn penalize_soft(
467 self,
468 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
469 where
470 Sc: Copy,
471 {
472 self.penalize(Sc::one_soft())
473 }
474
475 // Rewards each matching entity with one hard score unit.
476 pub fn reward_hard(
477 self,
478 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
479 where
480 Sc: Copy,
481 {
482 self.reward(Sc::one_hard())
483 }
484
485 // Rewards each matching entity with one soft score unit.
486 pub fn reward_soft(
487 self,
488 ) -> UniConstraintBuilder<S, A, E, F, impl Fn(&A) -> Sc + Send + Sync, Sc>
489 where
490 Sc: Copy,
491 {
492 self.reward(Sc::one_soft())
493 }
494}
495
496impl<S, A, E, F, Sc: Score> UniConstraintStream<S, A, E, F, Sc> {
497 #[doc(hidden)]
498 pub fn extractor(&self) -> &E {
499 &self.extractor
500 }
501
502 #[doc(hidden)]
503 pub fn into_parts(self) -> (E, F) {
504 (self.extractor, self.filter)
505 }
506
507 #[doc(hidden)]
508 pub fn from_parts(extractor: E, filter: F) -> Self {
509 Self {
510 extractor,
511 filter,
512 _phantom: PhantomData,
513 }
514 }
515}
516
517impl<S, A, E, F, Sc: Score> std::fmt::Debug for UniConstraintStream<S, A, E, F, Sc> {
518 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519 f.debug_struct("UniConstraintStream").finish()
520 }
521}
522
523// Zero-erasure builder for finalizing a uni-constraint.
524pub struct UniConstraintBuilder<S, A, E, F, W, Sc>
525where
526 Sc: Score,
527{
528 extractor: E,
529 filter: F,
530 impact_type: ImpactType,
531 weight: W,
532 is_hard: bool,
533 expected_descriptor: Option<usize>,
534 _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
535}
536
537impl<S, A, E, F, W, Sc> UniConstraintBuilder<S, A, E, F, W, Sc>
538where
539 S: Send + Sync + 'static,
540 A: Clone + Send + Sync + 'static,
541 E: Fn(&S) -> &[A] + Send + Sync,
542 F: UniFilter<S, A>,
543 W: Fn(&A) -> Sc + Send + Sync,
544 Sc: Score + 'static,
545{
546 // Restricts this constraint to only fire for the given descriptor index.
547 //
548 // Required when multiple entity classes exist (e.g., FurnaceAssignment at 0,
549 // ShiftAssignment at 1). Without this, on_insert/on_retract fire for all entity
550 // classes using the constraint's entity_index, which indexes into the wrong slice.
551 pub fn for_descriptor(mut self, descriptor_index: usize) -> Self {
552 self.expected_descriptor = Some(descriptor_index);
553 self
554 }
555
556 // Finalizes the builder into a zero-erasure `IncrementalUniConstraint`.
557 pub fn named(
558 self,
559 name: &str,
560 ) -> IncrementalUniConstraint<S, A, E, impl Fn(&S, &A) -> bool + Send + Sync, W, Sc> {
561 let filter = self.filter;
562 let combined_filter = move |s: &S, a: &A| filter.test(s, a);
563
564 let mut constraint = IncrementalUniConstraint::new(
565 ConstraintRef::new("", name),
566 self.impact_type,
567 self.extractor,
568 combined_filter,
569 self.weight,
570 self.is_hard,
571 );
572 if let Some(d) = self.expected_descriptor {
573 constraint = constraint.with_descriptor(d);
574 }
575 constraint
576 }
577}
578
579impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for UniConstraintBuilder<S, A, E, F, W, Sc> {
580 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581 f.debug_struct("UniConstraintBuilder")
582 .field("impact_type", &self.impact_type)
583 .finish()
584 }
585}