Skip to main content

solverforge_scoring/constraint/
incremental.rs

1// Zero-erasure incremental uni-constraint.
2//
3// All closure types are concrete generics - no Arc, no dyn, fully monomorphized.
4
5use std::fmt::Debug;
6use std::marker::PhantomData;
7
8use solverforge_core::score::Score;
9use solverforge_core::{ConstraintRef, ImpactType};
10
11use crate::api::analysis::{ConstraintJustification, DetailedConstraintMatch, EntityRef};
12use crate::api::constraint_set::IncrementalConstraint;
13
14// Zero-erasure incremental uni-constraint.
15//
16// All closure types are concrete generics - no Arc, no dyn, fully monomorphized.
17pub struct IncrementalUniConstraint<S, A, E, F, W, Sc>
18where
19    Sc: Score,
20{
21    constraint_ref: ConstraintRef,
22    impact_type: ImpactType,
23    extractor: E,
24    filter: F,
25    weight: W,
26    is_hard: bool,
27    expected_descriptor: Option<usize>,
28    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
29}
30
31impl<S, A, E, F, W, Sc> IncrementalUniConstraint<S, A, E, F, W, Sc>
32where
33    S: Send + Sync + 'static,
34    A: Clone + Send + Sync + 'static,
35    E: Fn(&S) -> &[A] + Send + Sync,
36    F: Fn(&S, &A) -> bool + Send + Sync,
37    W: Fn(&A) -> Sc + Send + Sync,
38    Sc: Score,
39{
40    // Creates a new zero-erasure incremental uni-constraint.
41    pub fn new(
42        constraint_ref: ConstraintRef,
43        impact_type: ImpactType,
44        extractor: E,
45        filter: F,
46        weight: W,
47        is_hard: bool,
48    ) -> Self {
49        Self {
50            constraint_ref,
51            impact_type,
52            extractor,
53            filter,
54            weight,
55            is_hard,
56            expected_descriptor: None,
57            _phantom: PhantomData,
58        }
59    }
60
61    pub fn with_descriptor(mut self, descriptor_index: usize) -> Self {
62        self.expected_descriptor = Some(descriptor_index);
63        self
64    }
65
66    #[inline]
67    fn matches(&self, solution: &S, entity: &A) -> bool {
68        (self.filter)(solution, entity)
69    }
70
71    #[inline]
72    fn compute_delta(&self, entity: &A) -> Sc {
73        let base = (self.weight)(entity);
74        match self.impact_type {
75            ImpactType::Penalty => -base,
76            ImpactType::Reward => base,
77        }
78    }
79
80    #[inline]
81    fn reverse_delta(&self, entity: &A) -> Sc {
82        let base = (self.weight)(entity);
83        match self.impact_type {
84            ImpactType::Penalty => base,
85            ImpactType::Reward => -base,
86        }
87    }
88}
89
90impl<S, A, E, F, W, Sc> IncrementalConstraint<S, Sc> for IncrementalUniConstraint<S, A, E, F, W, Sc>
91where
92    S: Send + Sync + 'static,
93    A: Clone + Debug + Send + Sync + 'static,
94    E: Fn(&S) -> &[A] + Send + Sync,
95    F: Fn(&S, &A) -> bool + Send + Sync,
96    W: Fn(&A) -> Sc + Send + Sync,
97    Sc: Score,
98{
99    fn evaluate(&self, solution: &S) -> Sc {
100        let entities = (self.extractor)(solution);
101        let mut total = Sc::zero();
102        for entity in entities {
103            if self.matches(solution, entity) {
104                total = total + self.compute_delta(entity);
105            }
106        }
107        total
108    }
109
110    fn match_count(&self, solution: &S) -> usize {
111        let entities = (self.extractor)(solution);
112        entities
113            .iter()
114            .filter(|e| self.matches(solution, e))
115            .count()
116    }
117
118    fn initialize(&mut self, solution: &S) -> Sc {
119        self.evaluate(solution)
120    }
121
122    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
123        if let Some(expected) = self.expected_descriptor {
124            if descriptor_index != expected {
125                return Sc::zero();
126            }
127        }
128        let entities = (self.extractor)(solution);
129        if entity_index >= entities.len() {
130            return Sc::zero();
131        }
132        let entity = &entities[entity_index];
133        if self.matches(solution, entity) {
134            self.compute_delta(entity)
135        } else {
136            Sc::zero()
137        }
138    }
139
140    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
141        if let Some(expected) = self.expected_descriptor {
142            if descriptor_index != expected {
143                return Sc::zero();
144            }
145        }
146        let entities = (self.extractor)(solution);
147        if entity_index >= entities.len() {
148            return Sc::zero();
149        }
150        let entity = &entities[entity_index];
151        if self.matches(solution, entity) {
152            self.reverse_delta(entity)
153        } else {
154            Sc::zero()
155        }
156    }
157
158    fn reset(&mut self) {
159        // Stateless
160    }
161
162    fn name(&self) -> &str {
163        &self.constraint_ref.name
164    }
165
166    fn is_hard(&self) -> bool {
167        self.is_hard
168    }
169
170    fn constraint_ref(&self) -> ConstraintRef {
171        self.constraint_ref.clone()
172    }
173
174    fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
175        let entities = (self.extractor)(solution);
176        let cref = self.constraint_ref.clone();
177        entities
178            .iter()
179            .filter(|e| self.matches(solution, e))
180            .map(|entity| {
181                let entity_ref = EntityRef::new(entity);
182                let justification = ConstraintJustification::new(vec![entity_ref]);
183                DetailedConstraintMatch::new(
184                    cref.clone(),
185                    self.compute_delta(entity),
186                    justification,
187                )
188            })
189            .collect()
190    }
191
192    fn weight(&self) -> Sc {
193        // For uni-constraints, we use a unit entity to compute the base weight.
194        // This works for constant weights; for dynamic weights, returns zero.
195        Sc::zero()
196    }
197}
198
199impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for IncrementalUniConstraint<S, A, E, F, W, Sc> {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        f.debug_struct("IncrementalUniConstraint")
202            .field("name", &self.constraint_ref.name)
203            .field("impact_type", &self.impact_type)
204            .finish()
205    }
206}