Skip to main content

solverforge_scoring/constraint/
incremental.rs

1/* Zero-erasure incremental uni-constraint.
2
3All closure types are concrete generics - no Arc, no dyn, fully monomorphized.
4*/
5
6use std::fmt::Debug;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use crate::api::analysis::{ConstraintJustification, DetailedConstraintMatch, EntityRef};
13use crate::api::constraint_set::IncrementalConstraint;
14
15/* Zero-erasure incremental uni-constraint.
16
17All closure types are concrete generics - no Arc, no dyn, fully monomorphized.
18*/
19pub struct IncrementalUniConstraint<S, A, E, F, W, Sc>
20where
21    Sc: Score,
22{
23    constraint_ref: ConstraintRef,
24    impact_type: ImpactType,
25    extractor: E,
26    filter: F,
27    weight: W,
28    is_hard: bool,
29    change_source: crate::stream::collection_extract::ChangeSource,
30    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
31}
32
33impl<S, A, E, F, W, Sc> IncrementalUniConstraint<S, A, E, F, W, Sc>
34where
35    S: Send + Sync + 'static,
36    A: Clone + Send + Sync + 'static,
37    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
38    F: Fn(&S, &A) -> bool + Send + Sync,
39    W: Fn(&A) -> Sc + Send + Sync,
40    Sc: Score,
41{
42    // Creates a new zero-erasure incremental uni-constraint.
43    pub fn new(
44        constraint_ref: ConstraintRef,
45        impact_type: ImpactType,
46        extractor: E,
47        filter: F,
48        weight: W,
49        is_hard: bool,
50    ) -> Self {
51        let change_source = extractor.change_source();
52        Self {
53            constraint_ref,
54            impact_type,
55            extractor,
56            filter,
57            weight,
58            is_hard,
59            change_source,
60            _phantom: PhantomData,
61        }
62    }
63
64    #[inline]
65    fn matches(&self, solution: &S, entity: &A) -> bool {
66        (self.filter)(solution, entity)
67    }
68
69    #[inline]
70    fn compute_delta(&self, entity: &A) -> Sc {
71        let base = (self.weight)(entity);
72        match self.impact_type {
73            ImpactType::Penalty => -base,
74            ImpactType::Reward => base,
75        }
76    }
77
78    #[inline]
79    fn reverse_delta(&self, entity: &A) -> Sc {
80        let base = (self.weight)(entity);
81        match self.impact_type {
82            ImpactType::Penalty => base,
83            ImpactType::Reward => -base,
84        }
85    }
86}
87
88impl<S, A, E, F, W, Sc> IncrementalConstraint<S, Sc> for IncrementalUniConstraint<S, A, E, F, W, Sc>
89where
90    S: Send + Sync + 'static,
91    A: Clone + Debug + Send + Sync + 'static,
92    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
93    F: Fn(&S, &A) -> bool + Send + Sync,
94    W: Fn(&A) -> Sc + Send + Sync,
95    Sc: Score,
96{
97    fn evaluate(&self, solution: &S) -> Sc {
98        let entities = self.extractor.extract(solution);
99        let mut total = Sc::zero();
100        for entity in entities {
101            if self.matches(solution, entity) {
102                total = total + self.compute_delta(entity);
103            }
104        }
105        total
106    }
107
108    fn match_count(&self, solution: &S) -> usize {
109        let entities = self.extractor.extract(solution);
110        entities
111            .iter()
112            .filter(|e| self.matches(solution, e))
113            .count()
114    }
115
116    fn initialize(&mut self, solution: &S) -> Sc {
117        self.evaluate(solution)
118    }
119
120    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
121        if !self
122            .change_source
123            .assert_localizes(descriptor_index, &self.constraint_ref.name)
124        {
125            return Sc::zero();
126        }
127        let entities = self.extractor.extract(solution);
128        if entity_index >= entities.len() {
129            return Sc::zero();
130        }
131        let entity = &entities[entity_index];
132        if self.matches(solution, entity) {
133            self.compute_delta(entity)
134        } else {
135            Sc::zero()
136        }
137    }
138
139    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
140        if !self
141            .change_source
142            .assert_localizes(descriptor_index, &self.constraint_ref.name)
143        {
144            return Sc::zero();
145        }
146        let entities = self.extractor.extract(solution);
147        if entity_index >= entities.len() {
148            return Sc::zero();
149        }
150        let entity = &entities[entity_index];
151        if self.matches(solution, entity) {
152            self.reverse_delta(entity)
153        } else {
154            Sc::zero()
155        }
156    }
157
158    fn reset(&mut self) {
159        // Stateless
160    }
161
162    fn name(&self) -> &str {
163        &self.constraint_ref.name
164    }
165
166    fn is_hard(&self) -> bool {
167        self.is_hard
168    }
169
170    fn constraint_ref(&self) -> ConstraintRef {
171        self.constraint_ref.clone()
172    }
173
174    fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
175        let entities = self.extractor.extract(solution);
176        let cref = self.constraint_ref.clone();
177        entities
178            .iter()
179            .filter(|e| self.matches(solution, e))
180            .map(|entity| {
181                let entity_ref = EntityRef::new(entity);
182                let justification = ConstraintJustification::new(vec![entity_ref]);
183                DetailedConstraintMatch::new(
184                    cref.clone(),
185                    self.compute_delta(entity),
186                    justification,
187                )
188            })
189            .collect()
190    }
191
192    fn weight(&self) -> Sc {
193        // For uni-constraints, we use a unit entity to compute the base weight.
194        // This works for constant weights; for dynamic weights, returns zero.
195        Sc::zero()
196    }
197}
198
199impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for IncrementalUniConstraint<S, A, E, F, W, Sc> {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        f.debug_struct("IncrementalUniConstraint")
202            .field("name", &self.constraint_ref.name)
203            .field("impact_type", &self.impact_type)
204            .finish()
205    }
206}