1use std::fmt;
2
3use crate::planning::{
4 ScalarAssignmentDeclaration, ScalarAssignmentRule, ScalarCandidateProvider, ScalarEdit,
5 ScalarGroup, ScalarGroupKind, ScalarGroupLimits,
6};
7
8use super::value_source::ValueSource;
9use super::variable::{ScalarGetter, ScalarSetter, ScalarVariableSlot};
10
11pub struct ScalarGroupMemberBinding<S> {
12 pub descriptor_index: usize,
13 pub variable_index: usize,
14 pub entity_type_name: &'static str,
15 pub variable_name: &'static str,
16 pub getter: ScalarGetter<S>,
17 pub setter: ScalarSetter<S>,
18 pub value_source: ValueSource<S>,
19 pub entity_count: fn(&S) -> usize,
20 pub candidate_values: Option<super::variable::ScalarCandidateValues<S>>,
21 pub allows_unassigned: bool,
22}
23
24impl<S> Clone for ScalarGroupMemberBinding<S> {
25 fn clone(&self) -> Self {
26 *self
27 }
28}
29
30impl<S> Copy for ScalarGroupMemberBinding<S> {}
31
32impl<S> ScalarGroupMemberBinding<S> {
33 pub fn from_scalar_slot(slot: ScalarVariableSlot<S>) -> Self {
34 Self {
35 descriptor_index: slot.descriptor_index,
36 variable_index: slot.variable_index,
37 entity_type_name: slot.entity_type_name,
38 variable_name: slot.variable_name,
39 getter: slot.getter,
40 setter: slot.setter,
41 value_source: slot.value_source,
42 entity_count: slot.entity_count,
43 candidate_values: slot.candidate_values,
44 allows_unassigned: slot.allows_unassigned,
45 }
46 }
47
48 pub fn current_value(&self, solution: &S, entity_index: usize) -> Option<usize> {
49 (self.getter)(solution, entity_index, self.variable_index)
50 }
51
52 pub fn value_is_legal(
53 &self,
54 solution: &S,
55 entity_index: usize,
56 candidate: Option<usize>,
57 ) -> bool {
58 let Some(value) = candidate else {
59 return self.allows_unassigned;
60 };
61 match self.value_source {
62 ValueSource::Empty => false,
63 ValueSource::CountableRange { from, to } => from <= value && value < to,
64 ValueSource::SolutionCount {
65 count_fn,
66 provider_index,
67 } => value < count_fn(solution, provider_index),
68 ValueSource::EntitySlice { values_for_entity } => {
69 values_for_entity(solution, entity_index, self.variable_index).contains(&value)
70 }
71 }
72 }
73
74 pub fn entity_count(&self, solution: &S) -> usize {
75 (self.entity_count)(solution)
76 }
77
78 pub fn candidate_values(
79 &self,
80 solution: &S,
81 entity_index: usize,
82 value_candidate_limit: Option<usize>,
83 ) -> Vec<usize> {
84 if let Some(candidate_values) = self.candidate_values {
85 let values = candidate_values(solution, entity_index, self.variable_index);
86 return match value_candidate_limit {
87 Some(limit) => values.iter().copied().take(limit).collect(),
88 None => values.to_vec(),
89 };
90 }
91 match self.value_source {
92 ValueSource::Empty => Vec::new(),
93 ValueSource::CountableRange { from, to } => {
94 let end = value_candidate_limit
95 .map(|limit| from.saturating_add(limit).min(to))
96 .unwrap_or(to);
97 (from..end).collect()
98 }
99 ValueSource::SolutionCount {
100 count_fn,
101 provider_index,
102 } => {
103 let count = count_fn(solution, provider_index);
104 let end = value_candidate_limit
105 .map(|limit| limit.min(count))
106 .unwrap_or(count);
107 (0..end).collect()
108 }
109 ValueSource::EntitySlice { values_for_entity } => {
110 let values = values_for_entity(solution, entity_index, self.variable_index);
111 match value_candidate_limit {
112 Some(limit) => values.iter().copied().take(limit).collect(),
113 None => values.to_vec(),
114 }
115 }
116 }
117 }
118}
119
120impl<S> fmt::Debug for ScalarGroupMemberBinding<S> {
121 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122 f.debug_struct("ScalarGroupMemberBinding")
123 .field("descriptor_index", &self.descriptor_index)
124 .field("variable_index", &self.variable_index)
125 .field("entity_type_name", &self.entity_type_name)
126 .field("variable_name", &self.variable_name)
127 .field("value_source", &self.value_source)
128 .field("allows_unassigned", &self.allows_unassigned)
129 .finish()
130 }
131}
132
133pub struct ScalarGroupBinding<S> {
134 pub group_name: &'static str,
135 pub members: Vec<ScalarGroupMemberBinding<S>>,
136 pub kind: ScalarGroupBindingKind<S>,
137 pub limits: ScalarGroupLimits,
138}
139
140pub enum ScalarGroupBindingKind<S> {
141 Candidates {
142 candidate_provider: ScalarCandidateProvider<S>,
143 },
144 Assignment(ScalarAssignmentBinding<S>),
145}
146
147impl<S> Clone for ScalarGroupBindingKind<S> {
148 fn clone(&self) -> Self {
149 *self
150 }
151}
152
153impl<S> Copy for ScalarGroupBindingKind<S> {}
154
155pub struct ScalarAssignmentBinding<S> {
156 pub target: ScalarGroupMemberBinding<S>,
157 pub required_entity: Option<fn(&S, usize) -> bool>,
158 pub capacity_key: Option<fn(&S, usize, usize) -> Option<usize>>,
159 pub position_key: Option<fn(&S, usize) -> i64>,
160 pub sequence_key: Option<fn(&S, usize, usize) -> Option<usize>>,
161 pub entity_order: Option<fn(&S, usize) -> i64>,
162 pub value_order: Option<fn(&S, usize, usize) -> i64>,
163 pub assignment_rule: Option<ScalarAssignmentRule<S>>,
164}
165
166impl<S> Clone for ScalarAssignmentBinding<S> {
167 fn clone(&self) -> Self {
168 *self
169 }
170}
171
172impl<S> Copy for ScalarAssignmentBinding<S> {}
173
174impl<S> ScalarAssignmentBinding<S> {
175 fn bind(
176 group_name: &'static str,
177 members: &[ScalarGroupMemberBinding<S>],
178 declaration: ScalarAssignmentDeclaration<S>,
179 ) -> Self {
180 assert_eq!(
181 members.len(),
182 1,
183 "assignment scalar group `{group_name}` must target exactly one scalar planning variable",
184 );
185 let target = members[0];
186 assert!(
187 target.allows_unassigned,
188 "assignment scalar group `{group_name}` target {}.{} must allow unassigned values",
189 target.entity_type_name, target.variable_name,
190 );
191 assert!(
192 declaration.assignment_rule.is_none() || declaration.sequence_key.is_some(),
193 "assignment scalar group `{group_name}` with an assignment rule must declare a sequence key",
194 );
195 Self {
196 target,
197 required_entity: declaration.required_entity,
198 capacity_key: declaration.capacity_key,
199 position_key: declaration.position_key,
200 sequence_key: declaration.sequence_key,
201 entity_order: declaration.entity_order,
202 value_order: declaration.value_order,
203 assignment_rule: declaration.assignment_rule,
204 }
205 }
206
207 pub fn entity_count(&self, solution: &S) -> usize {
208 self.target.entity_count(solution)
209 }
210
211 pub fn current_value(&self, solution: &S, entity_index: usize) -> Option<usize> {
212 self.target.current_value(solution, entity_index)
213 }
214
215 pub fn is_required(&self, solution: &S, entity_index: usize) -> bool {
216 self.required_entity
217 .map(|required_entity| required_entity(solution, entity_index))
218 .unwrap_or(false)
219 }
220
221 pub fn capacity_key(&self, solution: &S, entity_index: usize, value: usize) -> Option<usize> {
222 self.capacity_key
223 .and_then(|capacity_key| capacity_key(solution, entity_index, value))
224 }
225
226 pub fn position_key(&self, solution: &S, entity_index: usize) -> Option<i64> {
227 self.position_key
228 .map(|position_key| position_key(solution, entity_index))
229 }
230
231 pub fn sequence_key(&self, solution: &S, entity_index: usize, value: usize) -> Option<usize> {
232 self.sequence_key
233 .and_then(|sequence_key| sequence_key(solution, entity_index, value))
234 }
235
236 pub fn entity_order_key(&self, solution: &S, entity_index: usize) -> Option<i64> {
237 self.entity_order
238 .map(|entity_order| entity_order(solution, entity_index))
239 }
240
241 pub fn value_order_key(&self, solution: &S, entity_index: usize, value: usize) -> Option<i64> {
242 self.value_order
243 .map(|value_order| value_order(solution, entity_index, value))
244 }
245
246 pub fn assignment_edge_allowed(
247 &self,
248 solution: &S,
249 left_entity: usize,
250 left_value: usize,
251 right_entity: usize,
252 right_value: usize,
253 ) -> bool {
254 self.assignment_rule
255 .map(|assignment_rule| {
256 assignment_rule(solution, left_entity, left_value, right_entity, right_value)
257 })
258 .unwrap_or(true)
259 }
260
261 pub fn candidate_values(
262 &self,
263 solution: &S,
264 entity_index: usize,
265 value_candidate_limit: Option<usize>,
266 ) -> Vec<usize> {
267 let mut values =
268 self.target
269 .candidate_values(solution, entity_index, value_candidate_limit);
270 values.sort_by_key(|value| (self.value_order_key(solution, entity_index, *value), *value));
271 values
272 }
273
274 pub fn value_is_legal(&self, solution: &S, entity_index: usize, value: Option<usize>) -> bool {
275 self.target.value_is_legal(solution, entity_index, value)
276 }
277
278 pub fn edit(&self, entity_index: usize, value: Option<usize>) -> ScalarEdit<S> {
279 ScalarEdit::from_descriptor_index(
280 self.target.descriptor_index,
281 entity_index,
282 self.target.variable_name,
283 value,
284 )
285 }
286
287 pub fn remaining_required_count(&self, solution: &S) -> u64 {
288 (0..self.entity_count(solution))
289 .filter(|entity_index| {
290 self.is_required(solution, *entity_index)
291 && self.current_value(solution, *entity_index).is_none()
292 })
293 .fold(0_u64, |count, _| count.saturating_add(1))
294 }
295
296 pub fn unassigned_count(&self, solution: &S) -> u64 {
297 (0..self.entity_count(solution))
298 .filter(|entity_index| self.current_value(solution, *entity_index).is_none())
299 .fold(0_u64, |count, _| count.saturating_add(1))
300 }
301}
302
303impl<S> ScalarGroupBinding<S> {
304 pub fn bind(group: ScalarGroup<S>, scalar_slots: &[ScalarVariableSlot<S>]) -> Self {
305 let members = group
306 .targets()
307 .iter()
308 .map(|target| {
309 let descriptor_index = target.descriptor_index();
310 let variable_name = target.variable_name();
311 let slot = scalar_slots
312 .iter()
313 .copied()
314 .find(|slot| {
315 slot.descriptor_index == descriptor_index
316 && slot.variable_name == variable_name
317 })
318 .unwrap_or_else(|| {
319 panic!(
320 "scalar group `{}` targets unknown scalar variable `{}` on descriptor {}",
321 group.group_name(),
322 variable_name,
323 descriptor_index
324 )
325 });
326 ScalarGroupMemberBinding::from_scalar_slot(slot)
327 })
328 .collect::<Vec<_>>();
329
330 let kind = match group.kind() {
331 ScalarGroupKind::Assignment(declaration) => ScalarGroupBindingKind::Assignment(
332 ScalarAssignmentBinding::bind(group.group_name(), &members, declaration),
333 ),
334 ScalarGroupKind::Candidates { candidate_provider } => {
335 ScalarGroupBindingKind::Candidates { candidate_provider }
336 }
337 };
338
339 Self {
340 group_name: group.group_name(),
341 members,
342 kind,
343 limits: group.limits(),
344 }
345 }
346
347 pub fn member_for_edit(&self, edit: &ScalarEdit<S>) -> Option<ScalarGroupMemberBinding<S>> {
348 self.members.iter().copied().find(|member| {
349 member.descriptor_index == edit.descriptor_index()
350 && member.variable_name == edit.variable_name()
351 })
352 }
353
354 pub fn assignment(&self) -> Option<&ScalarAssignmentBinding<S>> {
355 match &self.kind {
356 ScalarGroupBindingKind::Assignment(assignment) => Some(assignment),
357 ScalarGroupBindingKind::Candidates { .. } => None,
358 }
359 }
360
361 pub fn is_assignment(&self) -> bool {
362 matches!(self.kind, ScalarGroupBindingKind::Assignment(_))
363 }
364
365 pub fn is_candidate_group(&self) -> bool {
366 matches!(self.kind, ScalarGroupBindingKind::Candidates { .. })
367 }
368
369 pub fn has_sequence_metadata(&self) -> bool {
370 self.assignment()
371 .is_some_and(|assignment| assignment.sequence_key.is_some())
372 }
373
374 pub fn has_position_metadata(&self) -> bool {
375 self.assignment()
376 .is_some_and(|assignment| assignment.position_key.is_some())
377 }
378
379 pub fn default_max_moves_per_step(&self) -> Option<usize> {
380 self.limits.max_moves_per_step
381 }
382}
383
384impl<S> Clone for ScalarGroupBinding<S> {
385 fn clone(&self) -> Self {
386 Self {
387 group_name: self.group_name,
388 members: self.members.clone(),
389 kind: self.kind,
390 limits: self.limits,
391 }
392 }
393}
394
395impl<S> fmt::Debug for ScalarGroupBinding<S> {
396 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
397 f.debug_struct("ScalarGroupBinding")
398 .field("group_name", &self.group_name)
399 .field("members", &self.members)
400 .finish_non_exhaustive()
401 }
402}
403
404pub fn bind_scalar_groups<S>(
405 groups: Vec<ScalarGroup<S>>,
406 scalar_slots: &[ScalarVariableSlot<S>],
407) -> Vec<ScalarGroupBinding<S>> {
408 groups
409 .into_iter()
410 .map(|group| ScalarGroupBinding::bind(group, scalar_slots))
411 .collect()
412}