solverforge_core/domain/
entity_ref.rs

1//! Entity reference types for dynamic entity access.
2//!
3//! These types enable the solver to work with entities at runtime without
4//! knowing their concrete types at compile time.
5
6use std::any::Any;
7use std::fmt::Debug;
8
9/// A reference to a planning entity with its index in the solution.
10///
11/// This struct provides a way to identify and access entities during solving
12/// without needing to know the concrete entity type.
13#[derive(Debug, Clone)]
14pub struct EntityRef {
15    /// Index of this entity in its collection.
16    pub index: usize,
17    /// Name of the entity type.
18    pub type_name: &'static str,
19    /// Name of the collection field in the solution.
20    pub collection_field: &'static str,
21}
22
23impl EntityRef {
24    /// Creates a new entity reference.
25    pub fn new(index: usize, type_name: &'static str, collection_field: &'static str) -> Self {
26        Self {
27            index,
28            type_name,
29            collection_field,
30        }
31    }
32}
33
34/// Trait for extracting entities from a planning solution.
35///
36/// This trait is implemented by closures/functions that can extract
37/// entity references from a solution of a specific type.
38pub trait EntityExtractor: Send + Sync {
39    /// Returns the number of entities in the collection.
40    fn count(&self, solution: &dyn Any) -> Option<usize>;
41
42    /// Gets a reference to an entity by index.
43    fn get<'a>(&self, solution: &'a dyn Any, index: usize) -> Option<&'a dyn Any>;
44
45    /// Gets a mutable reference to an entity by index.
46    fn get_mut<'a>(&self, solution: &'a mut dyn Any, index: usize) -> Option<&'a mut dyn Any>;
47
48    /// Returns an iterator over entity references.
49    fn entity_refs(&self, solution: &dyn Any) -> Vec<EntityRef>;
50
51    /// Clone this extractor.
52    fn clone_box(&self) -> Box<dyn EntityExtractor>;
53
54    /// Clones an entity as a boxed value for insertion into the constraint session.
55    ///
56    /// This is used for incremental scoring where entities need to be inserted
57    /// into the BAVET session as owned, type-erased values.
58    fn clone_entity_boxed(
59        &self,
60        solution: &dyn Any,
61        index: usize,
62    ) -> Option<Box<dyn Any + Send + Sync>>;
63
64    /// Returns the TypeId of the entity type.
65    fn entity_type_id(&self) -> std::any::TypeId;
66}
67
68impl Clone for Box<dyn EntityExtractor> {
69    fn clone(&self) -> Self {
70        self.clone_box()
71    }
72}
73
74/// A concrete entity extractor for a specific solution and entity type.
75///
76/// # Type Parameters
77/// * `S` - The solution type
78/// * `E` - The entity type
79pub struct TypedEntityExtractor<S, E> {
80    /// Name of the entity type.
81    type_name: &'static str,
82    /// Name of the collection field in the solution.
83    collection_field: &'static str,
84    /// Function to get the entity collection from a solution.
85    get_collection: fn(&S) -> &Vec<E>,
86    /// Function to get the mutable entity collection from a solution.
87    get_collection_mut: fn(&mut S) -> &mut Vec<E>,
88}
89
90impl<S, E> TypedEntityExtractor<S, E>
91where
92    S: 'static,
93    E: 'static,
94{
95    /// Creates a new typed entity extractor.
96    pub fn new(
97        type_name: &'static str,
98        collection_field: &'static str,
99        get_collection: fn(&S) -> &Vec<E>,
100        get_collection_mut: fn(&mut S) -> &mut Vec<E>,
101    ) -> Self {
102        Self {
103            type_name,
104            collection_field,
105            get_collection,
106            get_collection_mut,
107        }
108    }
109}
110
111impl<S, E> EntityExtractor for TypedEntityExtractor<S, E>
112where
113    S: Send + Sync + 'static,
114    E: Clone + Send + Sync + 'static,
115{
116    fn count(&self, solution: &dyn Any) -> Option<usize> {
117        let solution = solution.downcast_ref::<S>()?;
118        Some((self.get_collection)(solution).len())
119    }
120
121    fn get<'a>(&self, solution: &'a dyn Any, index: usize) -> Option<&'a dyn Any> {
122        let solution = solution.downcast_ref::<S>()?;
123        let collection = (self.get_collection)(solution);
124        collection.get(index).map(|e| e as &dyn Any)
125    }
126
127    fn get_mut<'a>(&self, solution: &'a mut dyn Any, index: usize) -> Option<&'a mut dyn Any> {
128        let solution = solution.downcast_mut::<S>()?;
129        let collection = (self.get_collection_mut)(solution);
130        collection.get_mut(index).map(|e| e as &mut dyn Any)
131    }
132
133    fn entity_refs(&self, solution: &dyn Any) -> Vec<EntityRef> {
134        let Some(solution) = solution.downcast_ref::<S>() else {
135            return Vec::new();
136        };
137        let collection = (self.get_collection)(solution);
138        (0..collection.len())
139            .map(|i| EntityRef::new(i, self.type_name, self.collection_field))
140            .collect()
141    }
142
143    fn clone_box(&self) -> Box<dyn EntityExtractor> {
144        Box::new(Self {
145            type_name: self.type_name,
146            collection_field: self.collection_field,
147            get_collection: self.get_collection,
148            get_collection_mut: self.get_collection_mut,
149        })
150    }
151
152    fn clone_entity_boxed(
153        &self,
154        solution: &dyn Any,
155        index: usize,
156    ) -> Option<Box<dyn Any + Send + Sync>> {
157        let solution = solution.downcast_ref::<S>()?;
158        let collection = (self.get_collection)(solution);
159        let entity = collection.get(index)?;
160        Some(Box::new(entity.clone()) as Box<dyn Any + Send + Sync>)
161    }
162
163    fn entity_type_id(&self) -> std::any::TypeId {
164        std::any::TypeId::of::<E>()
165    }
166}
167
168impl<S, E> Debug for TypedEntityExtractor<S, E> {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        f.debug_struct("TypedEntityExtractor")
171            .field("type_name", &self.type_name)
172            .field("collection_field", &self.collection_field)
173            .finish()
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[derive(Clone, Debug)]
182    struct TestEntity {
183        id: i64,
184        value: Option<i32>,
185    }
186
187    #[derive(Clone, Debug)]
188    struct TestSolution {
189        entities: Vec<TestEntity>,
190    }
191
192    fn get_entities(s: &TestSolution) -> &Vec<TestEntity> {
193        &s.entities
194    }
195
196    fn get_entities_mut(s: &mut TestSolution) -> &mut Vec<TestEntity> {
197        &mut s.entities
198    }
199
200    #[test]
201    fn test_typed_entity_extractor_count() {
202        let extractor =
203            TypedEntityExtractor::new("TestEntity", "entities", get_entities, get_entities_mut);
204
205        let solution = TestSolution {
206            entities: vec![
207                TestEntity {
208                    id: 1,
209                    value: Some(10),
210                },
211                TestEntity {
212                    id: 2,
213                    value: Some(20),
214                },
215                TestEntity { id: 3, value: None },
216            ],
217        };
218
219        let count = extractor.count(&solution as &dyn Any);
220        assert_eq!(count, Some(3));
221    }
222
223    #[test]
224    fn test_typed_entity_extractor_get() {
225        let extractor =
226            TypedEntityExtractor::new("TestEntity", "entities", get_entities, get_entities_mut);
227
228        let solution = TestSolution {
229            entities: vec![
230                TestEntity {
231                    id: 1,
232                    value: Some(10),
233                },
234                TestEntity {
235                    id: 2,
236                    value: Some(20),
237                },
238            ],
239        };
240
241        let entity = extractor.get(&solution as &dyn Any, 0);
242        assert!(entity.is_some());
243        let entity = entity.unwrap().downcast_ref::<TestEntity>().unwrap();
244        assert_eq!(entity.id, 1);
245        assert_eq!(entity.value, Some(10));
246
247        // Out of bounds
248        let entity = extractor.get(&solution as &dyn Any, 5);
249        assert!(entity.is_none());
250    }
251
252    #[test]
253    fn test_typed_entity_extractor_get_mut() {
254        let extractor =
255            TypedEntityExtractor::new("TestEntity", "entities", get_entities, get_entities_mut);
256
257        let mut solution = TestSolution {
258            entities: vec![TestEntity {
259                id: 1,
260                value: Some(10),
261            }],
262        };
263
264        let entity = extractor.get_mut(&mut solution as &mut dyn Any, 0);
265        assert!(entity.is_some());
266        let entity = entity.unwrap().downcast_mut::<TestEntity>().unwrap();
267        entity.value = Some(100);
268
269        assert_eq!(solution.entities[0].value, Some(100));
270    }
271
272    #[test]
273    fn test_typed_entity_extractor_entity_refs() {
274        let extractor =
275            TypedEntityExtractor::new("TestEntity", "entities", get_entities, get_entities_mut);
276
277        let solution = TestSolution {
278            entities: vec![
279                TestEntity {
280                    id: 1,
281                    value: Some(10),
282                },
283                TestEntity {
284                    id: 2,
285                    value: Some(20),
286                },
287            ],
288        };
289
290        let refs = extractor.entity_refs(&solution as &dyn Any);
291        assert_eq!(refs.len(), 2);
292        assert_eq!(refs[0].index, 0);
293        assert_eq!(refs[0].type_name, "TestEntity");
294        assert_eq!(refs[0].collection_field, "entities");
295        assert_eq!(refs[1].index, 1);
296    }
297
298    #[test]
299    fn test_extractor_wrong_solution_type() {
300        let extractor =
301            TypedEntityExtractor::new("TestEntity", "entities", get_entities, get_entities_mut);
302
303        let wrong_solution = "not a solution";
304        let count = extractor.count(&wrong_solution as &dyn Any);
305        assert!(count.is_none());
306    }
307
308    #[test]
309    fn test_extractor_clone() {
310        let extractor: Box<dyn EntityExtractor> = Box::new(TypedEntityExtractor::new(
311            "TestEntity",
312            "entities",
313            get_entities,
314            get_entities_mut,
315        ));
316
317        let cloned = extractor.clone();
318
319        let solution = TestSolution {
320            entities: vec![TestEntity {
321                id: 1,
322                value: Some(10),
323            }],
324        };
325
326        assert_eq!(cloned.count(&solution as &dyn Any), Some(1));
327    }
328
329    #[test]
330    fn test_clone_entity_boxed() {
331        let extractor =
332            TypedEntityExtractor::new("TestEntity", "entities", get_entities, get_entities_mut);
333
334        let solution = TestSolution {
335            entities: vec![
336                TestEntity {
337                    id: 1,
338                    value: Some(10),
339                },
340                TestEntity {
341                    id: 2,
342                    value: Some(20),
343                },
344            ],
345        };
346
347        // Clone first entity
348        let boxed = extractor.clone_entity_boxed(&solution as &dyn Any, 0);
349        assert!(boxed.is_some());
350        let boxed_entity = boxed.unwrap();
351        let entity = boxed_entity.downcast_ref::<TestEntity>().unwrap();
352        assert_eq!(entity.id, 1);
353        assert_eq!(entity.value, Some(10));
354
355        // Clone second entity
356        let boxed = extractor.clone_entity_boxed(&solution as &dyn Any, 1);
357        assert!(boxed.is_some());
358        let boxed_entity = boxed.unwrap();
359        let entity = boxed_entity.downcast_ref::<TestEntity>().unwrap();
360        assert_eq!(entity.id, 2);
361        assert_eq!(entity.value, Some(20));
362
363        // Out of bounds returns None
364        let boxed = extractor.clone_entity_boxed(&solution as &dyn Any, 5);
365        assert!(boxed.is_none());
366    }
367
368    #[test]
369    fn test_entity_type_id() {
370        let extractor =
371            TypedEntityExtractor::new("TestEntity", "entities", get_entities, get_entities_mut);
372
373        assert_eq!(
374            extractor.entity_type_id(),
375            std::any::TypeId::of::<TestEntity>()
376        );
377    }
378}