solverforge_scoring/director/
mod.rs1use std::any::Any;
12
13use solverforge_core::domain::{PlanningSolution, SolutionDescriptor};
14
15pub mod recording;
16pub mod typed;
17
18#[cfg(test)]
19mod recording_tests;
20#[cfg(test)]
21mod typed_bench;
22
23pub use recording::RecordingScoreDirector;
24
25pub trait ScoreDirector<S: PlanningSolution>: Send {
34 fn working_solution(&self) -> &S;
36
37 fn working_solution_mut(&mut self) -> &mut S;
39
40 fn calculate_score(&mut self) -> S::Score;
42
43 fn solution_descriptor(&self) -> &SolutionDescriptor;
45
46 fn clone_working_solution(&self) -> S;
48
49 fn before_variable_changed(
51 &mut self,
52 descriptor_index: usize,
53 entity_index: usize,
54 variable_name: &str,
55 );
56
57 fn after_variable_changed(
59 &mut self,
60 descriptor_index: usize,
61 entity_index: usize,
62 variable_name: &str,
63 );
64
65 fn trigger_variable_listeners(&mut self);
67
68 fn entity_count(&self, descriptor_index: usize) -> Option<usize>;
70
71 fn total_entity_count(&self) -> Option<usize>;
73
74 fn get_entity(&self, descriptor_index: usize, entity_index: usize) -> Option<&dyn Any>;
76
77 fn is_incremental(&self) -> bool {
79 false
80 }
81
82 fn reset(&mut self) {}
84
85 fn register_undo(&mut self, _undo: Box<dyn FnOnce(&mut S) + Send>) {
92 }
94}
95
96pub struct ScoreDirectorFactory<S: PlanningSolution, C> {
101 solution_descriptor: SolutionDescriptor,
102 score_calculator: C,
103 _phantom: std::marker::PhantomData<S>,
104}
105
106impl<S, C> ScoreDirectorFactory<S, C>
107where
108 S: PlanningSolution,
109 C: Fn(&S) -> S::Score + Send + Sync,
110{
111 pub fn new(solution_descriptor: SolutionDescriptor, score_calculator: C) -> Self {
113 Self {
114 solution_descriptor,
115 score_calculator,
116 _phantom: std::marker::PhantomData,
117 }
118 }
119
120 pub fn build_score_director(&self, solution: S) -> SimpleScoreDirector<S, &C> {
122 SimpleScoreDirector::new(
123 solution,
124 self.solution_descriptor.clone(),
125 &self.score_calculator,
126 )
127 }
128
129 pub fn solution_descriptor(&self) -> &SolutionDescriptor {
131 &self.solution_descriptor
132 }
133}
134
135impl<S: PlanningSolution, C: Clone> Clone for ScoreDirectorFactory<S, C> {
136 fn clone(&self) -> Self {
137 Self {
138 solution_descriptor: self.solution_descriptor.clone(),
139 score_calculator: self.score_calculator.clone(),
140 _phantom: std::marker::PhantomData,
141 }
142 }
143}
144
145pub struct SimpleScoreDirector<S: PlanningSolution, C> {
150 working_solution: S,
151 solution_descriptor: SolutionDescriptor,
152 score_calculator: C,
153 score_dirty: bool,
154 cached_score: Option<S::Score>,
155}
156
157impl<S, C> SimpleScoreDirector<S, C>
158where
159 S: PlanningSolution,
160 C: Fn(&S) -> S::Score + Send + Sync,
161{
162 pub fn new(solution: S, solution_descriptor: SolutionDescriptor, score_calculator: C) -> Self {
164 SimpleScoreDirector {
165 working_solution: solution,
166 solution_descriptor,
167 score_calculator,
168 score_dirty: true,
169 cached_score: None,
170 }
171 }
172
173 pub fn with_calculator(
177 solution: S,
178 solution_descriptor: SolutionDescriptor,
179 calculator: C,
180 ) -> Self {
181 Self::new(solution, solution_descriptor, calculator)
182 }
183
184 fn mark_dirty(&mut self) {
185 self.score_dirty = true;
186 }
187}
188
189impl<S, C> ScoreDirector<S> for SimpleScoreDirector<S, C>
190where
191 S: PlanningSolution,
192 C: Fn(&S) -> S::Score + Send + Sync,
193{
194 fn working_solution(&self) -> &S {
195 &self.working_solution
196 }
197
198 fn working_solution_mut(&mut self) -> &mut S {
199 self.mark_dirty();
200 &mut self.working_solution
201 }
202
203 fn calculate_score(&mut self) -> S::Score {
204 if !self.score_dirty {
205 if let Some(ref score) = self.cached_score {
206 return score.clone();
207 }
208 }
209
210 let score = (self.score_calculator)(&self.working_solution);
211 self.working_solution.set_score(Some(score.clone()));
212 self.cached_score = Some(score.clone());
213 self.score_dirty = false;
214 score
215 }
216
217 fn solution_descriptor(&self) -> &SolutionDescriptor {
218 &self.solution_descriptor
219 }
220
221 fn clone_working_solution(&self) -> S {
222 self.working_solution.clone()
223 }
224
225 fn before_variable_changed(
226 &mut self,
227 _descriptor_index: usize,
228 _entity_index: usize,
229 _variable_name: &str,
230 ) {
231 self.mark_dirty();
232 }
233
234 fn after_variable_changed(
235 &mut self,
236 _descriptor_index: usize,
237 _entity_index: usize,
238 _variable_name: &str,
239 ) {
240 }
242
243 fn trigger_variable_listeners(&mut self) {
244 }
246
247 fn entity_count(&self, descriptor_index: usize) -> Option<usize> {
248 self.solution_descriptor
249 .entity_descriptors
250 .get(descriptor_index)?
251 .entity_count(&self.working_solution as &dyn Any)
252 }
253
254 fn total_entity_count(&self) -> Option<usize> {
255 self.solution_descriptor
256 .total_entity_count(&self.working_solution as &dyn Any)
257 }
258
259 fn get_entity(&self, descriptor_index: usize, entity_index: usize) -> Option<&dyn Any> {
260 self.solution_descriptor.get_entity(
261 &self.working_solution as &dyn Any,
262 descriptor_index,
263 entity_index,
264 )
265 }
266
267 fn is_incremental(&self) -> bool {
268 false
269 }
270
271 fn reset(&mut self) {
272 self.mark_dirty();
273 self.cached_score = None;
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use solverforge_core::domain::{EntityDescriptor, TypedEntityExtractor};
281 use solverforge_core::score::SimpleScore;
282 use std::any::TypeId;
283
284 #[derive(Clone, Debug, PartialEq)]
285 struct Queen {
286 id: i64,
287 row: Option<i32>,
288 }
289
290 #[derive(Clone, Debug)]
291 struct NQueensSolution {
292 queens: Vec<Queen>,
293 score: Option<SimpleScore>,
294 }
295
296 impl PlanningSolution for NQueensSolution {
297 type Score = SimpleScore;
298
299 fn score(&self) -> Option<Self::Score> {
300 self.score
301 }
302
303 fn set_score(&mut self, score: Option<Self::Score>) {
304 self.score = score;
305 }
306 }
307
308 fn get_queens(s: &NQueensSolution) -> &Vec<Queen> {
309 &s.queens
310 }
311
312 fn get_queens_mut(s: &mut NQueensSolution) -> &mut Vec<Queen> {
313 &mut s.queens
314 }
315
316 fn calculate_conflicts(solution: &NQueensSolution) -> SimpleScore {
317 let mut conflicts = 0i64;
318 let queens = &solution.queens;
319
320 for i in 0..queens.len() {
321 for j in (i + 1)..queens.len() {
322 if let (Some(row_i), Some(row_j)) = (queens[i].row, queens[j].row) {
323 if row_i == row_j {
324 conflicts += 1;
325 }
326 let col_diff = (j - i) as i32;
327 if (row_i - row_j).abs() == col_diff {
328 conflicts += 1;
329 }
330 }
331 }
332 }
333
334 SimpleScore::of(-conflicts)
335 }
336
337 fn create_test_descriptor() -> SolutionDescriptor {
338 let extractor = Box::new(TypedEntityExtractor::new(
339 "Queen",
340 "queens",
341 get_queens,
342 get_queens_mut,
343 ));
344 let entity_desc = EntityDescriptor::new("Queen", TypeId::of::<Queen>(), "queens")
345 .with_extractor(extractor);
346
347 SolutionDescriptor::new("NQueensSolution", TypeId::of::<NQueensSolution>())
348 .with_entity(entity_desc)
349 }
350
351 #[test]
352 fn test_simple_score_director_calculate_score() {
353 let solution = NQueensSolution {
354 queens: vec![
355 Queen {
356 id: 0,
357 row: Some(0),
358 },
359 Queen {
360 id: 1,
361 row: Some(1),
362 },
363 Queen {
364 id: 2,
365 row: Some(2),
366 },
367 Queen {
368 id: 3,
369 row: Some(3),
370 },
371 ],
372 score: None,
373 };
374
375 let descriptor = create_test_descriptor();
376 let mut director =
377 SimpleScoreDirector::with_calculator(solution, descriptor, calculate_conflicts);
378
379 let score = director.calculate_score();
381 assert_eq!(score, SimpleScore::of(-6));
382 }
383
384 #[test]
385 fn test_score_director_factory() {
386 let solution = NQueensSolution {
387 queens: vec![Queen {
388 id: 0,
389 row: Some(0),
390 }],
391 score: None,
392 };
393
394 let descriptor = create_test_descriptor();
395 let factory = ScoreDirectorFactory::new(descriptor, calculate_conflicts);
396
397 let mut director = factory.build_score_director(solution);
398 let score = director.calculate_score();
399 assert_eq!(score, SimpleScore::of(0));
400 }
401}