solverforge_solver/heuristic/selector/
entity.rs

1//! Entity selectors for iterating over planning entities
2
3use std::fmt::Debug;
4
5use solverforge_core::domain::PlanningSolution;
6use solverforge_scoring::ScoreDirector;
7
8/// A reference to an entity within a solution.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub struct EntityReference {
11    /// Index of the entity descriptor.
12    pub descriptor_index: usize,
13    /// Index of the entity within its collection.
14    pub entity_index: usize,
15}
16
17impl EntityReference {
18    /// Creates a new entity reference.
19    pub fn new(descriptor_index: usize, entity_index: usize) -> Self {
20        Self {
21            descriptor_index,
22            entity_index,
23        }
24    }
25}
26
27/// Trait for selecting entities from a planning solution.
28///
29/// Entity selectors provide an iteration order over the entities that
30/// the solver will consider for moves.
31pub trait EntitySelector<S: PlanningSolution>: Send + Debug {
32    /// Returns an iterator over entity references.
33    ///
34    /// The iterator yields `EntityReference` values that identify entities
35    /// within the solution.
36    fn iter<'a>(
37        &'a self,
38        score_director: &'a dyn ScoreDirector<S>,
39    ) -> Box<dyn Iterator<Item = EntityReference> + 'a>;
40
41    /// Returns the approximate number of entities.
42    fn size(&self, score_director: &dyn ScoreDirector<S>) -> usize;
43
44    /// Returns true if this selector may return the same entity multiple times.
45    fn is_never_ending(&self) -> bool {
46        false
47    }
48}
49
50/// An entity selector that iterates over all entities from the solution.
51#[derive(Debug, Clone)]
52pub struct FromSolutionEntitySelector {
53    /// The descriptor index to select from.
54    descriptor_index: usize,
55    /// Whether to skip pinned entities.
56    skip_pinned: bool,
57}
58
59impl FromSolutionEntitySelector {
60    /// Creates a new entity selector for the given descriptor index.
61    pub fn new(descriptor_index: usize) -> Self {
62        Self {
63            descriptor_index,
64            skip_pinned: false,
65        }
66    }
67
68    /// Creates an entity selector that skips pinned entities.
69    pub fn with_skip_pinned(mut self, skip: bool) -> Self {
70        self.skip_pinned = skip;
71        self
72    }
73}
74
75impl<S: PlanningSolution> EntitySelector<S> for FromSolutionEntitySelector {
76    fn iter<'a>(
77        &'a self,
78        score_director: &'a dyn ScoreDirector<S>,
79    ) -> Box<dyn Iterator<Item = EntityReference> + 'a> {
80        let count = score_director
81            .entity_count(self.descriptor_index)
82            .unwrap_or(0);
83
84        let desc_idx = self.descriptor_index;
85
86        Box::new((0..count).map(move |i| EntityReference::new(desc_idx, i)))
87    }
88
89    fn size(&self, score_director: &dyn ScoreDirector<S>) -> usize {
90        score_director
91            .entity_count(self.descriptor_index)
92            .unwrap_or(0)
93    }
94}
95
96/// An entity selector that iterates over all entities from all descriptors.
97#[derive(Debug, Clone, Default)]
98pub struct AllEntitiesSelector;
99
100impl AllEntitiesSelector {
101    /// Creates a new selector for all entities.
102    pub fn new() -> Self {
103        Self
104    }
105}
106
107impl<S: PlanningSolution> EntitySelector<S> for AllEntitiesSelector {
108    fn iter<'a>(
109        &'a self,
110        score_director: &'a dyn ScoreDirector<S>,
111    ) -> Box<dyn Iterator<Item = EntityReference> + 'a> {
112        let desc = score_director.solution_descriptor();
113        let descriptor_count = desc.entity_descriptors.len();
114
115        let mut refs = Vec::new();
116        for desc_idx in 0..descriptor_count {
117            let count = score_director.entity_count(desc_idx).unwrap_or(0);
118            for entity_idx in 0..count {
119                refs.push(EntityReference::new(desc_idx, entity_idx));
120            }
121        }
122
123        Box::new(refs.into_iter())
124    }
125
126    fn size(&self, score_director: &dyn ScoreDirector<S>) -> usize {
127        score_director.total_entity_count().unwrap_or(0)
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use solverforge_core::domain::{EntityDescriptor, SolutionDescriptor, TypedEntityExtractor};
135    use solverforge_core::score::SimpleScore;
136    use solverforge_scoring::SimpleScoreDirector;
137    use std::any::TypeId;
138
139    #[allow(dead_code)]
140    #[derive(Clone, Debug)]
141    struct Queen {
142        id: i64,
143        row: Option<i32>,
144    }
145
146    #[derive(Clone, Debug)]
147    struct NQueensSolution {
148        queens: Vec<Queen>,
149        score: Option<SimpleScore>,
150    }
151
152    impl PlanningSolution for NQueensSolution {
153        type Score = SimpleScore;
154
155        fn score(&self) -> Option<Self::Score> {
156            self.score
157        }
158
159        fn set_score(&mut self, score: Option<Self::Score>) {
160            self.score = score;
161        }
162    }
163
164    fn get_queens(s: &NQueensSolution) -> &Vec<Queen> {
165        &s.queens
166    }
167
168    fn get_queens_mut(s: &mut NQueensSolution) -> &mut Vec<Queen> {
169        &mut s.queens
170    }
171
172    fn create_test_director(
173        n: usize,
174    ) -> SimpleScoreDirector<NQueensSolution, impl Fn(&NQueensSolution) -> SimpleScore> {
175        let queens: Vec<_> = (0..n)
176            .map(|i| Queen {
177                id: i as i64,
178                row: Some(i as i32),
179            })
180            .collect();
181
182        let solution = NQueensSolution {
183            queens,
184            score: None,
185        };
186
187        let extractor = Box::new(TypedEntityExtractor::new(
188            "Queen",
189            "queens",
190            get_queens,
191            get_queens_mut,
192        ));
193        let entity_desc = EntityDescriptor::new("Queen", TypeId::of::<Queen>(), "queens")
194            .with_extractor(extractor);
195
196        let descriptor =
197            SolutionDescriptor::new("NQueensSolution", TypeId::of::<NQueensSolution>())
198                .with_entity(entity_desc);
199
200        SimpleScoreDirector::with_calculator(solution, descriptor, |_| SimpleScore::of(0))
201    }
202
203    #[test]
204    fn test_from_solution_entity_selector() {
205        let director = create_test_director(4);
206
207        // Verify entity IDs match indices
208        let solution = director.working_solution();
209        for (i, queen) in solution.queens.iter().enumerate() {
210            assert_eq!(queen.id, i as i64);
211        }
212
213        let selector = FromSolutionEntitySelector::new(0);
214
215        let refs: Vec<_> = selector.iter(&director).collect();
216        assert_eq!(refs.len(), 4);
217        assert_eq!(refs[0], EntityReference::new(0, 0));
218        assert_eq!(refs[1], EntityReference::new(0, 1));
219        assert_eq!(refs[2], EntityReference::new(0, 2));
220        assert_eq!(refs[3], EntityReference::new(0, 3));
221
222        assert_eq!(selector.size(&director), 4);
223    }
224
225    #[test]
226    fn test_all_entities_selector() {
227        let director = create_test_director(3);
228
229        let selector = AllEntitiesSelector::new();
230
231        let refs: Vec<_> = selector.iter(&director).collect();
232        assert_eq!(refs.len(), 3);
233
234        assert_eq!(selector.size(&director), 3);
235    }
236}