Skip to main content

solverforge_scoring/constraint/
incremental.rs

1/* Zero-erasure incremental uni-constraint.
2
3All closure types are concrete generics - no Arc, no dyn, fully monomorphized.
4*/
5
6use std::fmt::Debug;
7use std::marker::PhantomData;
8
9use solverforge_core::score::Score;
10use solverforge_core::{ConstraintRef, ImpactType};
11
12use crate::api::analysis::{ConstraintJustification, DetailedConstraintMatch, EntityRef};
13use crate::api::constraint_set::IncrementalConstraint;
14
15/* Zero-erasure incremental uni-constraint.
16
17All closure types are concrete generics - no Arc, no dyn, fully monomorphized.
18*/
19pub struct IncrementalUniConstraint<S, A, E, F, W, Sc>
20where
21    Sc: Score,
22{
23    constraint_ref: ConstraintRef,
24    impact_type: ImpactType,
25    extractor: E,
26    filter: F,
27    weight: W,
28    is_hard: bool,
29    expected_descriptor: Option<usize>,
30    _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
31}
32
33impl<S, A, E, F, W, Sc> IncrementalUniConstraint<S, A, E, F, W, Sc>
34where
35    S: Send + Sync + 'static,
36    A: Clone + Send + Sync + 'static,
37    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
38    F: Fn(&S, &A) -> bool + Send + Sync,
39    W: Fn(&A) -> Sc + Send + Sync,
40    Sc: Score,
41{
42    // Creates a new zero-erasure incremental uni-constraint.
43    pub fn new(
44        constraint_ref: ConstraintRef,
45        impact_type: ImpactType,
46        extractor: E,
47        filter: F,
48        weight: W,
49        is_hard: bool,
50    ) -> Self {
51        Self {
52            constraint_ref,
53            impact_type,
54            extractor,
55            filter,
56            weight,
57            is_hard,
58            expected_descriptor: None,
59            _phantom: PhantomData,
60        }
61    }
62
63    pub fn with_descriptor(mut self, descriptor_index: usize) -> Self {
64        self.expected_descriptor = Some(descriptor_index);
65        self
66    }
67
68    #[inline]
69    fn matches(&self, solution: &S, entity: &A) -> bool {
70        (self.filter)(solution, entity)
71    }
72
73    #[inline]
74    fn compute_delta(&self, entity: &A) -> Sc {
75        let base = (self.weight)(entity);
76        match self.impact_type {
77            ImpactType::Penalty => -base,
78            ImpactType::Reward => base,
79        }
80    }
81
82    #[inline]
83    fn reverse_delta(&self, entity: &A) -> Sc {
84        let base = (self.weight)(entity);
85        match self.impact_type {
86            ImpactType::Penalty => base,
87            ImpactType::Reward => -base,
88        }
89    }
90}
91
92impl<S, A, E, F, W, Sc> IncrementalConstraint<S, Sc> for IncrementalUniConstraint<S, A, E, F, W, Sc>
93where
94    S: Send + Sync + 'static,
95    A: Clone + Debug + Send + Sync + 'static,
96    E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
97    F: Fn(&S, &A) -> bool + Send + Sync,
98    W: Fn(&A) -> Sc + Send + Sync,
99    Sc: Score,
100{
101    fn evaluate(&self, solution: &S) -> Sc {
102        let entities = self.extractor.extract(solution);
103        let mut total = Sc::zero();
104        for entity in entities {
105            if self.matches(solution, entity) {
106                total = total + self.compute_delta(entity);
107            }
108        }
109        total
110    }
111
112    fn match_count(&self, solution: &S) -> usize {
113        let entities = self.extractor.extract(solution);
114        entities
115            .iter()
116            .filter(|e| self.matches(solution, e))
117            .count()
118    }
119
120    fn initialize(&mut self, solution: &S) -> Sc {
121        self.evaluate(solution)
122    }
123
124    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
125        if let Some(expected) = self.expected_descriptor {
126            if descriptor_index != expected {
127                return Sc::zero();
128            }
129        }
130        let entities = self.extractor.extract(solution);
131        if entity_index >= entities.len() {
132            return Sc::zero();
133        }
134        let entity = &entities[entity_index];
135        if self.matches(solution, entity) {
136            self.compute_delta(entity)
137        } else {
138            Sc::zero()
139        }
140    }
141
142    fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
143        if let Some(expected) = self.expected_descriptor {
144            if descriptor_index != expected {
145                return Sc::zero();
146            }
147        }
148        let entities = self.extractor.extract(solution);
149        if entity_index >= entities.len() {
150            return Sc::zero();
151        }
152        let entity = &entities[entity_index];
153        if self.matches(solution, entity) {
154            self.reverse_delta(entity)
155        } else {
156            Sc::zero()
157        }
158    }
159
160    fn reset(&mut self) {
161        // Stateless
162    }
163
164    fn name(&self) -> &str {
165        &self.constraint_ref.name
166    }
167
168    fn is_hard(&self) -> bool {
169        self.is_hard
170    }
171
172    fn constraint_ref(&self) -> ConstraintRef {
173        self.constraint_ref.clone()
174    }
175
176    fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
177        let entities = self.extractor.extract(solution);
178        let cref = self.constraint_ref.clone();
179        entities
180            .iter()
181            .filter(|e| self.matches(solution, e))
182            .map(|entity| {
183                let entity_ref = EntityRef::new(entity);
184                let justification = ConstraintJustification::new(vec![entity_ref]);
185                DetailedConstraintMatch::new(
186                    cref.clone(),
187                    self.compute_delta(entity),
188                    justification,
189                )
190            })
191            .collect()
192    }
193
194    fn weight(&self) -> Sc {
195        // For uni-constraints, we use a unit entity to compute the base weight.
196        // This works for constant weights; for dynamic weights, returns zero.
197        Sc::zero()
198    }
199}
200
201impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for IncrementalUniConstraint<S, A, E, F, W, Sc> {
202    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203        f.debug_struct("IncrementalUniConstraint")
204            .field("name", &self.constraint_ref.name)
205            .field("impact_type", &self.impact_type)
206            .finish()
207    }
208}