Skip to main content

solverforge_scoring/constraint/
incremental.rs

1// Zero-erasure incremental uni-constraint.
2//
3// All closure types are concrete generics - no Arc, no dyn, fully monomorphized.
4
5use std::fmt::Debug;
6use std::marker::PhantomData;
7
8use solverforge_core::score::Score;
9use solverforge_core::{ConstraintRef, ImpactType};
10
11use crate::api::analysis::{ConstraintJustification, DetailedConstraintMatch, EntityRef};
12use crate::api::constraint_set::IncrementalConstraint;
13
14// Zero-erasure incremental uni-constraint.
15//
16// All closure types are concrete generics - no Arc, no dyn, fully monomorphized.
17pub struct IncrementalUniConstraint<S, A, E, F, W, Sc>
18where
19    Sc: Score,
20{
21    constraint_ref: ConstraintRef,
22    impact_type: ImpactType,
23    extractor: E,
24    filter: F,
25    weight: W,
26    is_hard: bool,
27    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
28}
29
30impl<S, A, E, F, W, Sc> IncrementalUniConstraint<S, A, E, F, W, Sc>
31where
32    S: Send + Sync + 'static,
33    A: Clone + Send + Sync + 'static,
34    E: Fn(&S) -> &[A] + Send + Sync,
35    F: Fn(&S, &A) -> bool + Send + Sync,
36    W: Fn(&A) -> Sc + Send + Sync,
37    Sc: Score,
38{
39    // Creates a new zero-erasure incremental uni-constraint.
40    pub fn new(
41        constraint_ref: ConstraintRef,
42        impact_type: ImpactType,
43        extractor: E,
44        filter: F,
45        weight: W,
46        is_hard: bool,
47    ) -> Self {
48        Self {
49            constraint_ref,
50            impact_type,
51            extractor,
52            filter,
53            weight,
54            is_hard,
55            _phantom: PhantomData,
56        }
57    }
58
59    #[inline]
60    fn matches(&self, solution: &S, entity: &A) -> bool {
61        (self.filter)(solution, entity)
62    }
63
64    #[inline]
65    fn compute_delta(&self, entity: &A) -> Sc {
66        let base = (self.weight)(entity);
67        match self.impact_type {
68            ImpactType::Penalty => -base,
69            ImpactType::Reward => base,
70        }
71    }
72
73    #[inline]
74    fn reverse_delta(&self, entity: &A) -> Sc {
75        let base = (self.weight)(entity);
76        match self.impact_type {
77            ImpactType::Penalty => base,
78            ImpactType::Reward => -base,
79        }
80    }
81}
82
83impl<S, A, E, F, W, Sc> IncrementalConstraint<S, Sc> for IncrementalUniConstraint<S, A, E, F, W, Sc>
84where
85    S: Send + Sync + 'static,
86    A: Clone + Debug + Send + Sync + 'static,
87    E: Fn(&S) -> &[A] + Send + Sync,
88    F: Fn(&S, &A) -> bool + Send + Sync,
89    W: Fn(&A) -> Sc + Send + Sync,
90    Sc: Score,
91{
92    fn evaluate(&self, solution: &S) -> Sc {
93        let entities = (self.extractor)(solution);
94        let mut total = Sc::zero();
95        for entity in entities {
96            if self.matches(solution, entity) {
97                total = total + self.compute_delta(entity);
98            }
99        }
100        total
101    }
102
103    fn match_count(&self, solution: &S) -> usize {
104        let entities = (self.extractor)(solution);
105        entities
106            .iter()
107            .filter(|e| self.matches(solution, e))
108            .count()
109    }
110
111    fn initialize(&mut self, solution: &S) -> Sc {
112        self.evaluate(solution)
113    }
114
115    fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
116        let entities = (self.extractor)(solution);
117        if entity_index >= entities.len() {
118            return Sc::zero();
119        }
120        let entity = &entities[entity_index];
121        if self.matches(solution, entity) {
122            self.compute_delta(entity)
123        } else {
124            Sc::zero()
125        }
126    }
127
128    fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
129        let entities = (self.extractor)(solution);
130        if entity_index >= entities.len() {
131            return Sc::zero();
132        }
133        let entity = &entities[entity_index];
134        if self.matches(solution, entity) {
135            self.reverse_delta(entity)
136        } else {
137            Sc::zero()
138        }
139    }
140
141    fn reset(&mut self) {
142        // Stateless
143    }
144
145    fn name(&self) -> &str {
146        &self.constraint_ref.name
147    }
148
149    fn is_hard(&self) -> bool {
150        self.is_hard
151    }
152
153    fn constraint_ref(&self) -> ConstraintRef {
154        self.constraint_ref.clone()
155    }
156
157    fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
158        let entities = (self.extractor)(solution);
159        let cref = self.constraint_ref.clone();
160        entities
161            .iter()
162            .filter(|e| self.matches(solution, e))
163            .map(|entity| {
164                let entity_ref = EntityRef::new(entity);
165                let justification = ConstraintJustification::new(vec![entity_ref]);
166                DetailedConstraintMatch::new(
167                    cref.clone(),
168                    self.compute_delta(entity),
169                    justification,
170                )
171            })
172            .collect()
173    }
174
175    fn weight(&self) -> Sc {
176        // For uni-constraints, we use a unit entity to compute the base weight.
177        // This works for constant weights; for dynamic weights, returns zero.
178        Sc::zero()
179    }
180}
181
182impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for IncrementalUniConstraint<S, A, E, F, W, Sc> {
183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        f.debug_struct("IncrementalUniConstraint")
185            .field("name", &self.constraint_ref.name)
186            .field("impact_type", &self.impact_type)
187            .finish()
188    }
189}