solverforge_scoring/director/
mod.rs1use std::any::Any;
13
14use solverforge_core::domain::{PlanningSolution, SolutionDescriptor};
15
16pub mod recording;
17pub mod shadow_aware;
18pub mod typed;
19
20#[cfg(test)]
21mod recording_tests;
22#[cfg(test)]
23mod typed_bench;
24
25pub use recording::RecordingScoreDirector;
26pub use shadow_aware::{ShadowAwareScoreDirector, ShadowVariableSupport, SolvableSolution};
27
28pub trait ScoreDirector<S: PlanningSolution>: Send {
37 fn working_solution(&self) -> &S;
39
40 fn working_solution_mut(&mut self) -> &mut S;
42
43 fn calculate_score(&mut self) -> S::Score;
45
46 fn solution_descriptor(&self) -> &SolutionDescriptor;
48
49 fn clone_working_solution(&self) -> S;
51
52 fn before_variable_changed(
56 &mut self,
57 descriptor_index: usize,
58 entity_index: usize,
59 variable_name: &str,
60 );
61
62 fn after_variable_changed(
66 &mut self,
67 descriptor_index: usize,
68 entity_index: usize,
69 variable_name: &str,
70 );
71
72 fn before_entity_changed(&mut self, entity_index: usize) {
76 self.before_variable_changed(0, entity_index, "");
77 }
78
79 fn after_entity_changed(&mut self, entity_index: usize) {
83 self.after_variable_changed(0, entity_index, "");
84 }
85
86 fn trigger_variable_listeners(&mut self);
88
89 fn entity_count(&self, descriptor_index: usize) -> Option<usize>;
91
92 fn total_entity_count(&self) -> Option<usize>;
94
95 fn get_entity(&self, descriptor_index: usize, entity_index: usize) -> Option<&dyn Any>;
97
98 fn is_incremental(&self) -> bool {
100 false
101 }
102
103 fn reset(&mut self) {}
105
106 fn register_undo(&mut self, _undo: Box<dyn FnOnce(&mut S) + Send>) {
113 }
115}
116
117pub struct ScoreDirectorFactory<S: PlanningSolution, C> {
122 solution_descriptor: SolutionDescriptor,
123 score_calculator: C,
124 _phantom: std::marker::PhantomData<S>,
125}
126
127impl<S, C> ScoreDirectorFactory<S, C>
128where
129 S: PlanningSolution,
130 C: Fn(&S) -> S::Score + Send + Sync,
131{
132 pub fn new(solution_descriptor: SolutionDescriptor, score_calculator: C) -> Self {
134 Self {
135 solution_descriptor,
136 score_calculator,
137 _phantom: std::marker::PhantomData,
138 }
139 }
140
141 pub fn build_score_director(&self, solution: S) -> SimpleScoreDirector<S, &C> {
143 SimpleScoreDirector::new(
144 solution,
145 self.solution_descriptor.clone(),
146 &self.score_calculator,
147 )
148 }
149
150 pub fn solution_descriptor(&self) -> &SolutionDescriptor {
152 &self.solution_descriptor
153 }
154}
155
156impl<S: PlanningSolution, C: Clone> Clone for ScoreDirectorFactory<S, C> {
157 fn clone(&self) -> Self {
158 Self {
159 solution_descriptor: self.solution_descriptor.clone(),
160 score_calculator: self.score_calculator.clone(),
161 _phantom: std::marker::PhantomData,
162 }
163 }
164}
165
166pub struct SimpleScoreDirector<S: PlanningSolution, C> {
171 working_solution: S,
172 solution_descriptor: SolutionDescriptor,
173 score_calculator: C,
174 score_dirty: bool,
175 cached_score: Option<S::Score>,
176}
177
178impl<S, C> SimpleScoreDirector<S, C>
179where
180 S: PlanningSolution,
181 C: Fn(&S) -> S::Score + Send + Sync,
182{
183 pub fn new(solution: S, solution_descriptor: SolutionDescriptor, score_calculator: C) -> Self {
185 SimpleScoreDirector {
186 working_solution: solution,
187 solution_descriptor,
188 score_calculator,
189 score_dirty: true,
190 cached_score: None,
191 }
192 }
193
194 pub fn with_calculator(
198 solution: S,
199 solution_descriptor: SolutionDescriptor,
200 calculator: C,
201 ) -> Self {
202 Self::new(solution, solution_descriptor, calculator)
203 }
204
205 fn mark_dirty(&mut self) {
206 self.score_dirty = true;
207 }
208}
209
210impl<S, C> ScoreDirector<S> for SimpleScoreDirector<S, C>
211where
212 S: PlanningSolution,
213 C: Fn(&S) -> S::Score + Send + Sync,
214{
215 fn working_solution(&self) -> &S {
216 &self.working_solution
217 }
218
219 fn working_solution_mut(&mut self) -> &mut S {
220 self.mark_dirty();
221 &mut self.working_solution
222 }
223
224 fn calculate_score(&mut self) -> S::Score {
225 if !self.score_dirty {
226 if let Some(ref score) = self.cached_score {
227 return *score;
228 }
229 }
230
231 let score = (self.score_calculator)(&self.working_solution);
232 self.working_solution.set_score(Some(score));
233 self.cached_score = Some(score);
234 self.score_dirty = false;
235 score
236 }
237
238 fn solution_descriptor(&self) -> &SolutionDescriptor {
239 &self.solution_descriptor
240 }
241
242 fn clone_working_solution(&self) -> S {
243 self.working_solution.clone()
244 }
245
246 fn before_variable_changed(
247 &mut self,
248 _descriptor_index: usize,
249 _entity_index: usize,
250 _variable_name: &str,
251 ) {
252 self.mark_dirty();
253 }
254
255 fn after_variable_changed(
256 &mut self,
257 _descriptor_index: usize,
258 _entity_index: usize,
259 _variable_name: &str,
260 ) {
261 }
263
264 fn trigger_variable_listeners(&mut self) {
265 }
267
268 fn entity_count(&self, descriptor_index: usize) -> Option<usize> {
269 self.solution_descriptor
270 .entity_descriptors
271 .get(descriptor_index)?
272 .entity_count(&self.working_solution as &dyn Any)
273 }
274
275 fn total_entity_count(&self) -> Option<usize> {
276 self.solution_descriptor
277 .total_entity_count(&self.working_solution as &dyn Any)
278 }
279
280 fn get_entity(&self, descriptor_index: usize, entity_index: usize) -> Option<&dyn Any> {
281 self.solution_descriptor.get_entity(
282 &self.working_solution as &dyn Any,
283 descriptor_index,
284 entity_index,
285 )
286 }
287
288 fn is_incremental(&self) -> bool {
289 false
290 }
291
292 fn reset(&mut self) {
293 self.mark_dirty();
294 self.cached_score = None;
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use solverforge_core::domain::{EntityDescriptor, TypedEntityExtractor};
302 use solverforge_core::score::SimpleScore;
303 use std::any::TypeId;
304
305 #[derive(Clone, Debug, PartialEq)]
306 struct Queen {
307 id: i64,
308 row: Option<i32>,
309 }
310
311 #[derive(Clone, Debug)]
312 struct NQueensSolution {
313 queens: Vec<Queen>,
314 score: Option<SimpleScore>,
315 }
316
317 impl PlanningSolution for NQueensSolution {
318 type Score = SimpleScore;
319
320 fn score(&self) -> Option<Self::Score> {
321 self.score
322 }
323
324 fn set_score(&mut self, score: Option<Self::Score>) {
325 self.score = score;
326 }
327 }
328
329 fn get_queens(s: &NQueensSolution) -> &Vec<Queen> {
330 &s.queens
331 }
332
333 fn get_queens_mut(s: &mut NQueensSolution) -> &mut Vec<Queen> {
334 &mut s.queens
335 }
336
337 fn calculate_conflicts(solution: &NQueensSolution) -> SimpleScore {
338 let mut conflicts = 0i64;
339 let queens = &solution.queens;
340
341 for i in 0..queens.len() {
342 for j in (i + 1)..queens.len() {
343 if let (Some(row_i), Some(row_j)) = (queens[i].row, queens[j].row) {
344 if row_i == row_j {
345 conflicts += 1;
346 }
347 let col_diff = (j - i) as i32;
348 if (row_i - row_j).abs() == col_diff {
349 conflicts += 1;
350 }
351 }
352 }
353 }
354
355 SimpleScore::of(-conflicts)
356 }
357
358 fn create_test_descriptor() -> SolutionDescriptor {
359 let extractor = Box::new(TypedEntityExtractor::new(
360 "Queen",
361 "queens",
362 get_queens,
363 get_queens_mut,
364 ));
365 let entity_desc = EntityDescriptor::new("Queen", TypeId::of::<Queen>(), "queens")
366 .with_extractor(extractor);
367
368 SolutionDescriptor::new("NQueensSolution", TypeId::of::<NQueensSolution>())
369 .with_entity(entity_desc)
370 }
371
372 #[test]
373 fn test_simple_score_director_calculate_score() {
374 let solution = NQueensSolution {
375 queens: vec![
376 Queen {
377 id: 0,
378 row: Some(0),
379 },
380 Queen {
381 id: 1,
382 row: Some(1),
383 },
384 Queen {
385 id: 2,
386 row: Some(2),
387 },
388 Queen {
389 id: 3,
390 row: Some(3),
391 },
392 ],
393 score: None,
394 };
395
396 let descriptor = create_test_descriptor();
397 let mut director =
398 SimpleScoreDirector::with_calculator(solution, descriptor, calculate_conflicts);
399
400 let score = director.calculate_score();
402 assert_eq!(score, SimpleScore::of(-6));
403 }
404
405 #[test]
406 fn test_score_director_factory() {
407 let solution = NQueensSolution {
408 queens: vec![Queen {
409 id: 0,
410 row: Some(0),
411 }],
412 score: None,
413 };
414
415 let descriptor = create_test_descriptor();
416 let factory = ScoreDirectorFactory::new(descriptor, calculate_conflicts);
417
418 let mut director = factory.build_score_director(solution);
419 let score = director.calculate_score();
420 assert_eq!(score, SimpleScore::of(0));
421 }
422}