solverforge_solver/heuristic/selector/
entity.rs1use std::any::Any;
4use std::fmt::Debug;
5
6use solverforge_core::domain::PlanningSolution;
7use solverforge_scoring::ScoreDirector;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub struct EntityReference {
12 pub descriptor_index: usize,
14 pub entity_index: usize,
16}
17
18impl EntityReference {
19 pub fn new(descriptor_index: usize, entity_index: usize) -> Self {
21 Self {
22 descriptor_index,
23 entity_index,
24 }
25 }
26}
27
28pub trait EntitySelector<S: PlanningSolution>: Send + Debug {
36 fn iter<'a, D: ScoreDirector<S>>(
41 &'a self,
42 score_director: &'a D,
43 ) -> impl Iterator<Item = EntityReference> + 'a;
44
45 fn size<D: ScoreDirector<S>>(&self, score_director: &D) -> usize;
47
48 fn is_never_ending(&self) -> bool {
50 false
51 }
52}
53
54#[derive(Clone)]
56pub struct FromSolutionEntitySelector {
57 descriptor_index: usize,
59 skip_pinned: bool,
61 is_pinned_fn: Option<fn(&dyn Any) -> bool>,
66}
67
68impl Debug for FromSolutionEntitySelector {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 f.debug_struct("FromSolutionEntitySelector")
71 .field("descriptor_index", &self.descriptor_index)
72 .field("skip_pinned", &self.skip_pinned)
73 .field("has_is_pinned_fn", &self.is_pinned_fn.is_some())
74 .finish()
75 }
76}
77
78impl FromSolutionEntitySelector {
79 pub fn new(descriptor_index: usize) -> Self {
81 Self {
82 descriptor_index,
83 skip_pinned: false,
84 is_pinned_fn: None,
85 }
86 }
87
88 pub fn with_skip_pinned(mut self, skip: bool) -> Self {
94 self.skip_pinned = skip;
95 self
96 }
97
98 pub fn with_is_pinned_fn(mut self, f: fn(&dyn Any) -> bool) -> Self {
105 self.is_pinned_fn = Some(f);
106 self
107 }
108}
109
110impl<S: PlanningSolution> EntitySelector<S> for FromSolutionEntitySelector {
111 fn iter<'a, D: ScoreDirector<S>>(
112 &'a self,
113 score_director: &'a D,
114 ) -> impl Iterator<Item = EntityReference> + 'a {
115 let count = score_director
116 .entity_count(self.descriptor_index)
117 .unwrap_or(0);
118
119 let desc_idx = self.descriptor_index;
120 let skip_pinned = self.skip_pinned;
121 let is_pinned_fn = self.is_pinned_fn;
122
123 (0..count)
124 .filter(move |&i| {
125 if skip_pinned {
126 if let Some(pinned_fn) = is_pinned_fn {
127 if let Some(entity) = score_director.get_entity(desc_idx, i) {
128 if pinned_fn(entity) {
129 return false;
130 }
131 }
132 }
133 }
134 true
135 })
136 .map(move |i| EntityReference::new(desc_idx, i))
137 }
138
139 fn size<D: ScoreDirector<S>>(&self, score_director: &D) -> usize {
140 score_director
141 .entity_count(self.descriptor_index)
142 .unwrap_or(0)
143 }
144}
145
146#[derive(Debug, Clone, Default)]
148pub struct AllEntitiesSelector;
149
150impl AllEntitiesSelector {
151 pub fn new() -> Self {
153 Self
154 }
155}
156
157impl<S: PlanningSolution> EntitySelector<S> for AllEntitiesSelector {
158 fn iter<'a, D: ScoreDirector<S>>(
159 &'a self,
160 score_director: &'a D,
161 ) -> impl Iterator<Item = EntityReference> + 'a {
162 let desc = score_director.solution_descriptor();
163 let descriptor_count = desc.entity_descriptors.len();
164
165 let mut refs = Vec::new();
166 for desc_idx in 0..descriptor_count {
167 let count = score_director.entity_count(desc_idx).unwrap_or(0);
168 for entity_idx in 0..count {
169 refs.push(EntityReference::new(desc_idx, entity_idx));
170 }
171 }
172
173 refs.into_iter()
174 }
175
176 fn size<D: ScoreDirector<S>>(&self, score_director: &D) -> usize {
177 score_director.total_entity_count().unwrap_or(0)
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::test_utils::create_simple_nqueens_director;
185
186 #[test]
187 fn test_from_solution_entity_selector() {
188 let director = create_simple_nqueens_director(4);
189
190 let solution = director.working_solution();
192 for (i, queen) in solution.queens.iter().enumerate() {
193 assert_eq!(queen.column, i as i64);
194 }
195
196 let selector = FromSolutionEntitySelector::new(0);
197
198 let refs: Vec<_> = selector.iter(&director).collect();
199 assert_eq!(refs.len(), 4);
200 assert_eq!(refs[0], EntityReference::new(0, 0));
201 assert_eq!(refs[1], EntityReference::new(0, 1));
202 assert_eq!(refs[2], EntityReference::new(0, 2));
203 assert_eq!(refs[3], EntityReference::new(0, 3));
204
205 assert_eq!(selector.size(&director), 4);
206 }
207
208 #[test]
209 fn test_all_entities_selector() {
210 let director = create_simple_nqueens_director(3);
211
212 let selector = AllEntitiesSelector::new();
213
214 let refs: Vec<_> = selector.iter(&director).collect();
215 assert_eq!(refs.len(), 3);
216
217 assert_eq!(selector.size(&director), 3);
218 }
219}