Skip to main content

solverforge_solver/heuristic/selector/
entity.rs

1//! Entity selectors for iterating over planning entities
2
3use std::any::Any;
4use std::fmt::Debug;
5
6use solverforge_core::domain::PlanningSolution;
7use solverforge_scoring::ScoreDirector;
8
9/// A reference to an entity within a solution.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub struct EntityReference {
12    /// Index of the entity descriptor.
13    pub descriptor_index: usize,
14    /// Index of the entity within its collection.
15    pub entity_index: usize,
16}
17
18impl EntityReference {
19    /// Creates a new entity reference.
20    pub fn new(descriptor_index: usize, entity_index: usize) -> Self {
21        Self {
22            descriptor_index,
23            entity_index,
24        }
25    }
26}
27
28/// Trait for selecting entities from a planning solution.
29///
30/// Entity selectors provide an iteration order over the entities that
31/// the solver will consider for moves.
32///
33/// # Type Parameters
34/// * `S` - The planning solution type
35pub trait EntitySelector<S: PlanningSolution>: Send + Debug {
36    /// Returns an iterator over entity references.
37    ///
38    /// The iterator yields `EntityReference` values that identify entities
39    /// within the solution.
40    fn iter<'a, D: ScoreDirector<S>>(
41        &'a self,
42        score_director: &'a D,
43    ) -> impl Iterator<Item = EntityReference> + 'a;
44
45    /// Returns the approximate number of entities.
46    fn size<D: ScoreDirector<S>>(&self, score_director: &D) -> usize;
47
48    /// Returns true if this selector may return the same entity multiple times.
49    fn is_never_ending(&self) -> bool {
50        false
51    }
52}
53
54/// An entity selector that iterates over all entities from the solution.
55#[derive(Clone)]
56pub struct FromSolutionEntitySelector {
57    /// The descriptor index to select from.
58    descriptor_index: usize,
59    /// Whether to skip pinned entities.
60    skip_pinned: bool,
61    /// Optional function to test whether an entity (as `&dyn Any`) is pinned.
62    ///
63    /// Required when `skip_pinned` is true. The function receives the entity
64    /// returned by `ScoreDirector::get_entity` and returns `true` if pinned.
65    is_pinned_fn: Option<fn(&dyn Any) -> bool>,
66}
67
68impl Debug for FromSolutionEntitySelector {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        f.debug_struct("FromSolutionEntitySelector")
71            .field("descriptor_index", &self.descriptor_index)
72            .field("skip_pinned", &self.skip_pinned)
73            .field("has_is_pinned_fn", &self.is_pinned_fn.is_some())
74            .finish()
75    }
76}
77
78impl FromSolutionEntitySelector {
79    /// Creates a new entity selector for the given descriptor index.
80    pub fn new(descriptor_index: usize) -> Self {
81        Self {
82            descriptor_index,
83            skip_pinned: false,
84            is_pinned_fn: None,
85        }
86    }
87
88    /// Creates an entity selector that skips pinned entities.
89    ///
90    /// When `skip` is `true`, you must also call
91    /// [`with_is_pinned_fn`](Self::with_is_pinned_fn) to supply a predicate,
92    /// otherwise no filtering occurs.
93    pub fn with_skip_pinned(mut self, skip: bool) -> Self {
94        self.skip_pinned = skip;
95        self
96    }
97
98    /// Sets the function used to determine whether an entity is pinned.
99    ///
100    /// The function receives the entity as `&dyn Any` (the value returned by
101    /// `ScoreDirector::get_entity`) and returns `true` if the entity is pinned.
102    ///
103    /// Only consulted when `skip_pinned` is `true`.
104    pub fn with_is_pinned_fn(mut self, f: fn(&dyn Any) -> bool) -> Self {
105        self.is_pinned_fn = Some(f);
106        self
107    }
108}
109
110impl<S: PlanningSolution> EntitySelector<S> for FromSolutionEntitySelector {
111    fn iter<'a, D: ScoreDirector<S>>(
112        &'a self,
113        score_director: &'a D,
114    ) -> impl Iterator<Item = EntityReference> + 'a {
115        let count = score_director
116            .entity_count(self.descriptor_index)
117            .unwrap_or(0);
118
119        let desc_idx = self.descriptor_index;
120        let skip_pinned = self.skip_pinned;
121        let is_pinned_fn = self.is_pinned_fn;
122
123        (0..count)
124            .filter(move |&i| {
125                if skip_pinned {
126                    if let Some(pinned_fn) = is_pinned_fn {
127                        if let Some(entity) = score_director.get_entity(desc_idx, i) {
128                            if pinned_fn(entity) {
129                                return false;
130                            }
131                        }
132                    }
133                }
134                true
135            })
136            .map(move |i| EntityReference::new(desc_idx, i))
137    }
138
139    fn size<D: ScoreDirector<S>>(&self, score_director: &D) -> usize {
140        score_director
141            .entity_count(self.descriptor_index)
142            .unwrap_or(0)
143    }
144}
145
146/// An entity selector that iterates over all entities from all descriptors.
147#[derive(Debug, Clone, Default)]
148pub struct AllEntitiesSelector;
149
150impl AllEntitiesSelector {
151    /// Creates a new selector for all entities.
152    pub fn new() -> Self {
153        Self
154    }
155}
156
157impl<S: PlanningSolution> EntitySelector<S> for AllEntitiesSelector {
158    fn iter<'a, D: ScoreDirector<S>>(
159        &'a self,
160        score_director: &'a D,
161    ) -> impl Iterator<Item = EntityReference> + 'a {
162        let desc = score_director.solution_descriptor();
163        let descriptor_count = desc.entity_descriptors.len();
164
165        let mut refs = Vec::new();
166        for desc_idx in 0..descriptor_count {
167            let count = score_director.entity_count(desc_idx).unwrap_or(0);
168            for entity_idx in 0..count {
169                refs.push(EntityReference::new(desc_idx, entity_idx));
170            }
171        }
172
173        refs.into_iter()
174    }
175
176    fn size<D: ScoreDirector<S>>(&self, score_director: &D) -> usize {
177        score_director.total_entity_count().unwrap_or(0)
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use crate::test_utils::create_simple_nqueens_director;
185
186    #[test]
187    fn test_from_solution_entity_selector() {
188        let director = create_simple_nqueens_director(4);
189
190        // Verify column values match indices (column is set to index in create_uninitialized_nqueens)
191        let solution = director.working_solution();
192        for (i, queen) in solution.queens.iter().enumerate() {
193            assert_eq!(queen.column, i as i64);
194        }
195
196        let selector = FromSolutionEntitySelector::new(0);
197
198        let refs: Vec<_> = selector.iter(&director).collect();
199        assert_eq!(refs.len(), 4);
200        assert_eq!(refs[0], EntityReference::new(0, 0));
201        assert_eq!(refs[1], EntityReference::new(0, 1));
202        assert_eq!(refs[2], EntityReference::new(0, 2));
203        assert_eq!(refs[3], EntityReference::new(0, 3));
204
205        assert_eq!(selector.size(&director), 4);
206    }
207
208    #[test]
209    fn test_all_entities_selector() {
210        let director = create_simple_nqueens_director(3);
211
212        let selector = AllEntitiesSelector::new();
213
214        let refs: Vec<_> = selector.iter(&director).collect();
215        assert_eq!(refs.len(), 3);
216
217        assert_eq!(selector.size(&director), 3);
218    }
219}