1use crate::constraints::{StreamComponent, WasmFunction};
2use crate::wasm::PredicateDefinition;
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub struct Constraint {
7 pub name: String,
8 #[serde(skip_serializing_if = "Option::is_none")]
9 pub package: Option<String>,
10 #[serde(skip_serializing_if = "Option::is_none")]
11 pub description: Option<String>,
12 #[serde(skip_serializing_if = "Option::is_none")]
13 pub group: Option<String>,
14 pub components: Vec<StreamComponent>,
15 #[serde(skip_serializing_if = "Option::is_none")]
16 pub indictment: Option<WasmFunction>,
17 #[serde(skip_serializing_if = "Option::is_none")]
18 pub justification: Option<WasmFunction>,
19}
20
21impl Constraint {
22 pub fn new(name: impl Into<String>) -> Self {
23 Self {
24 name: name.into(),
25 package: None,
26 description: None,
27 group: None,
28 components: Vec::new(),
29 indictment: None,
30 justification: None,
31 }
32 }
33
34 pub fn with_package(mut self, package: impl Into<String>) -> Self {
35 self.package = Some(package.into());
36 self
37 }
38
39 pub fn with_description(mut self, description: impl Into<String>) -> Self {
40 self.description = Some(description.into());
41 self
42 }
43
44 pub fn with_group(mut self, group: impl Into<String>) -> Self {
45 self.group = Some(group.into());
46 self
47 }
48
49 pub fn with_component(mut self, component: StreamComponent) -> Self {
50 self.components.push(component);
51 self
52 }
53
54 pub fn with_components(mut self, components: Vec<StreamComponent>) -> Self {
55 self.components = components;
56 self
57 }
58
59 pub fn with_indictment(mut self, indictment: WasmFunction) -> Self {
60 self.indictment = Some(indictment);
61 self
62 }
63
64 pub fn with_justification(mut self, justification: WasmFunction) -> Self {
65 self.justification = Some(justification);
66 self
67 }
68
69 pub fn full_name(&self) -> String {
70 match &self.package {
71 Some(pkg) => format!("{}/{}", pkg, self.name),
72 None => self.name.clone(),
73 }
74 }
75}
76
77#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
78pub struct ConstraintSet {
79 pub constraints: Vec<Constraint>,
80}
81
82impl ConstraintSet {
83 pub fn new() -> Self {
84 Self {
85 constraints: Vec::new(),
86 }
87 }
88
89 pub fn with_constraint(mut self, constraint: Constraint) -> Self {
90 self.constraints.push(constraint);
91 self
92 }
93
94 pub fn add_constraint(&mut self, constraint: Constraint) {
95 self.constraints.push(constraint);
96 }
97
98 pub fn len(&self) -> usize {
99 self.constraints.len()
100 }
101
102 pub fn is_empty(&self) -> bool {
103 self.constraints.is_empty()
104 }
105
106 pub fn iter(&self) -> impl Iterator<Item = &Constraint> {
107 self.constraints.iter()
108 }
109
110 pub fn to_dto(&self) -> indexmap::IndexMap<String, Vec<StreamComponent>> {
111 self.constraints
112 .iter()
113 .map(|c| (c.name.clone(), c.components.clone()))
114 .collect()
115 }
116
117 pub fn extract_predicates(&self) -> Vec<PredicateDefinition> {
123 let mut predicates = Vec::new();
124 let mut seen = std::collections::HashSet::new();
125
126 for constraint in &self.constraints {
127 Self::collect_predicates_from_components(
128 &constraint.components,
129 &mut predicates,
130 &mut seen,
131 );
132 }
133
134 predicates
135 }
136
137 fn collect_predicates_from_components(
138 components: &[StreamComponent],
139 predicates: &mut Vec<PredicateDefinition>,
140 seen: &mut std::collections::HashSet<String>,
141 ) {
142 for component in components {
143 Self::collect_from_component(component, predicates, seen);
144 }
145 }
146
147 fn collect_from_component(
148 component: &StreamComponent,
149 predicates: &mut Vec<PredicateDefinition>,
150 seen: &mut std::collections::HashSet<String>,
151 ) {
152 match component {
153 StreamComponent::Filter { predicate } => {
154 Self::add_predicate_if_new(predicate, 1, predicates, seen);
155 }
156 StreamComponent::Penalize {
157 scale_by: Some(scale_by),
158 ..
159 }
160 | StreamComponent::Reward {
161 scale_by: Some(scale_by),
162 ..
163 }
164 | StreamComponent::Impact {
165 scale_by: Some(scale_by),
166 ..
167 } => {
168 Self::add_predicate_if_new(scale_by, 1, predicates, seen);
171 }
172 StreamComponent::Map { mappers } | StreamComponent::Expand { mappers } => {
173 for mapper in mappers {
174 Self::add_predicate_if_new(mapper, 1, predicates, seen);
175 }
176 }
177 StreamComponent::GroupBy { keys, .. } => {
178 for key in keys {
179 Self::add_predicate_if_new(key, 1, predicates, seen);
180 }
181 }
182 StreamComponent::FlattenLast { map: Some(map) } => {
183 Self::add_predicate_if_new(map, 1, predicates, seen);
184 }
185 StreamComponent::IndictWith {
186 indicted_object_provider,
187 } => {
188 Self::add_predicate_if_new(indicted_object_provider, 1, predicates, seen);
189 }
190 StreamComponent::JustifyWith {
191 justification_supplier,
192 } => {
193 Self::add_predicate_if_new(justification_supplier, 1, predicates, seen);
194 }
195 StreamComponent::Concat { other_components } => {
196 Self::collect_predicates_from_components(other_components, predicates, seen);
197 }
198 StreamComponent::ForEachUniquePair { joiners, .. }
199 | StreamComponent::Join { joiners, .. }
200 | StreamComponent::IfExists { joiners, .. }
201 | StreamComponent::IfNotExists { joiners, .. }
202 | StreamComponent::IfExistsOther { joiners, .. }
203 | StreamComponent::IfNotExistsOther { joiners, .. }
204 | StreamComponent::IfExistsIncludingUnassigned { joiners, .. }
205 | StreamComponent::IfNotExistsIncludingUnassigned { joiners, .. } => {
206 Self::collect_from_joiners(joiners, predicates, seen);
207 }
208 StreamComponent::ForEach { .. }
210 | StreamComponent::ForEachIncludingUnassigned { .. }
211 | StreamComponent::Complement { .. }
212 | StreamComponent::Distinct
213 | StreamComponent::Penalize { scale_by: None, .. }
214 | StreamComponent::Reward { scale_by: None, .. }
215 | StreamComponent::Impact { scale_by: None, .. }
216 | StreamComponent::FlattenLast { map: None } => {}
217 }
218 }
219
220 fn collect_from_joiners(
221 joiners: &[crate::constraints::Joiner],
222 predicates: &mut Vec<PredicateDefinition>,
223 seen: &mut std::collections::HashSet<String>,
224 ) {
225 use crate::constraints::Joiner;
226 for joiner in joiners {
227 match joiner {
228 Joiner::Equal {
229 map,
230 left_map,
231 right_map,
232 relation_predicate,
233 hasher,
234 } => {
235 if let Some(f) = map {
236 Self::add_predicate_if_new(f, 1, predicates, seen);
237 }
238 if let Some(f) = left_map {
239 Self::add_predicate_if_new(f, 1, predicates, seen);
240 }
241 if let Some(f) = right_map {
242 Self::add_predicate_if_new(f, 1, predicates, seen);
243 }
244 if let Some(f) = relation_predicate {
245 Self::add_predicate_if_new(f, 2, predicates, seen);
246 }
247 if let Some(f) = hasher {
248 Self::add_predicate_if_new(f, 1, predicates, seen);
249 }
250 }
251 Joiner::LessThan {
252 map,
253 left_map,
254 right_map,
255 comparator,
256 }
257 | Joiner::LessThanOrEqual {
258 map,
259 left_map,
260 right_map,
261 comparator,
262 }
263 | Joiner::GreaterThan {
264 map,
265 left_map,
266 right_map,
267 comparator,
268 }
269 | Joiner::GreaterThanOrEqual {
270 map,
271 left_map,
272 right_map,
273 comparator,
274 } => {
275 if let Some(f) = map {
276 Self::add_predicate_if_new(f, 1, predicates, seen);
277 }
278 if let Some(f) = left_map {
279 Self::add_predicate_if_new(f, 1, predicates, seen);
280 }
281 if let Some(f) = right_map {
282 Self::add_predicate_if_new(f, 1, predicates, seen);
283 }
284 Self::add_predicate_if_new(comparator, 2, predicates, seen);
285 }
286 Joiner::Overlapping {
287 start_map,
288 end_map,
289 left_start_map,
290 left_end_map,
291 right_start_map,
292 right_end_map,
293 comparator,
294 } => {
295 if let Some(f) = start_map {
296 Self::add_predicate_if_new(f, 1, predicates, seen);
297 }
298 if let Some(f) = end_map {
299 Self::add_predicate_if_new(f, 1, predicates, seen);
300 }
301 if let Some(f) = left_start_map {
302 Self::add_predicate_if_new(f, 1, predicates, seen);
303 }
304 if let Some(f) = left_end_map {
305 Self::add_predicate_if_new(f, 1, predicates, seen);
306 }
307 if let Some(f) = right_start_map {
308 Self::add_predicate_if_new(f, 1, predicates, seen);
309 }
310 if let Some(f) = right_end_map {
311 Self::add_predicate_if_new(f, 1, predicates, seen);
312 }
313 if let Some(f) = comparator {
314 Self::add_predicate_if_new(f, 2, predicates, seen);
315 }
316 }
317 Joiner::Filtering { filter } => {
318 Self::add_predicate_if_new(filter, 2, predicates, seen);
319 }
320 }
321 }
322 }
323
324 fn add_predicate_if_new(
325 func: &WasmFunction,
326 arity: u32,
327 predicates: &mut Vec<PredicateDefinition>,
328 seen: &mut std::collections::HashSet<String>,
329 ) {
330 if let Some(expr) = func.expression() {
331 if !seen.contains(func.name()) {
332 seen.insert(func.name().to_string());
333 predicates.push(PredicateDefinition::from_expression(
334 func.name(),
335 arity,
336 expr.clone(),
337 ));
338 }
339 }
340 }
341}
342
343impl FromIterator<Constraint> for ConstraintSet {
344 fn from_iter<I: IntoIterator<Item = Constraint>>(iter: I) -> Self {
345 ConstraintSet {
346 constraints: iter.into_iter().collect(),
347 }
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use crate::constraints::Joiner;
355
356 #[test]
357 fn test_constraint_new() {
358 let constraint = Constraint::new("Room conflict");
359 assert_eq!(constraint.name, "Room conflict");
360 assert!(constraint.package.is_none());
361 assert!(constraint.components.is_empty());
362 }
363
364 #[test]
365 fn test_constraint_with_package() {
366 let constraint = Constraint::new("Room conflict").with_package("timetabling");
367 assert_eq!(constraint.package, Some("timetabling".to_string()));
368 }
369
370 #[test]
371 fn test_constraint_with_description() {
372 let constraint =
373 Constraint::new("Room conflict").with_description("Two lessons in same room");
374 assert_eq!(
375 constraint.description,
376 Some("Two lessons in same room".to_string())
377 );
378 }
379
380 #[test]
381 fn test_constraint_with_group() {
382 let constraint = Constraint::new("Room conflict").with_group("Hard constraints");
383 assert_eq!(constraint.group, Some("Hard constraints".to_string()));
384 }
385
386 #[test]
387 fn test_constraint_with_component() {
388 let constraint = Constraint::new("Room conflict")
389 .with_component(StreamComponent::for_each("Lesson"))
390 .with_component(StreamComponent::penalize("1hard"));
391 assert_eq!(constraint.components.len(), 2);
392 }
393
394 #[test]
395 fn test_constraint_with_components() {
396 let components = vec![
397 StreamComponent::for_each("Lesson"),
398 StreamComponent::penalize("1hard"),
399 ];
400 let constraint = Constraint::new("Room conflict").with_components(components);
401 assert_eq!(constraint.components.len(), 2);
402 }
403
404 #[test]
405 fn test_constraint_with_indictment() {
406 let constraint =
407 Constraint::new("Room conflict").with_indictment(WasmFunction::new("get_room"));
408 assert!(constraint.indictment.is_some());
409 }
410
411 #[test]
412 fn test_constraint_with_justification() {
413 let constraint = Constraint::new("Room conflict")
414 .with_justification(WasmFunction::new("create_justification"));
415 assert!(constraint.justification.is_some());
416 }
417
418 #[test]
419 fn test_constraint_full_name() {
420 let constraint1 = Constraint::new("Room conflict");
421 assert_eq!(constraint1.full_name(), "Room conflict");
422
423 let constraint2 = Constraint::new("Room conflict").with_package("timetabling");
424 assert_eq!(constraint2.full_name(), "timetabling/Room conflict");
425 }
426
427 #[test]
428 fn test_constraint_set_new() {
429 let set = ConstraintSet::new();
430 assert!(set.is_empty());
431 assert_eq!(set.len(), 0);
432 }
433
434 #[test]
435 fn test_constraint_set_with_constraint() {
436 let set = ConstraintSet::new()
437 .with_constraint(Constraint::new("Constraint 1"))
438 .with_constraint(Constraint::new("Constraint 2"));
439 assert_eq!(set.len(), 2);
440 }
441
442 #[test]
443 fn test_constraint_set_add_constraint() {
444 let mut set = ConstraintSet::new();
445 set.add_constraint(Constraint::new("Constraint 1"));
446 set.add_constraint(Constraint::new("Constraint 2"));
447 assert_eq!(set.len(), 2);
448 }
449
450 #[test]
451 fn test_constraint_set_iter() {
452 let set = ConstraintSet::new()
453 .with_constraint(Constraint::new("C1"))
454 .with_constraint(Constraint::new("C2"));
455
456 let names: Vec<_> = set.iter().map(|c| c.name.as_str()).collect();
457 assert_eq!(names, vec!["C1", "C2"]);
458 }
459
460 #[test]
461 fn test_constraint_set_from_iter() {
462 let constraints = vec![Constraint::new("C1"), Constraint::new("C2")];
463 let set: ConstraintSet = constraints.into_iter().collect();
464 assert_eq!(set.len(), 2);
465 }
466
467 #[test]
468 fn test_constraint_json_serialization() {
469 let constraint = Constraint::new("Room conflict")
470 .with_package("timetabling")
471 .with_component(StreamComponent::for_each_unique_pair_with_joiners(
472 "Lesson",
473 vec![Joiner::equal(WasmFunction::new("get_timeslot"))],
474 ))
475 .with_component(StreamComponent::filter(WasmFunction::new("same_room")))
476 .with_component(StreamComponent::penalize("1hard"));
477
478 let json = serde_json::to_string(&constraint).unwrap();
479 assert!(json.contains("\"name\":\"Room conflict\""));
480 assert!(json.contains("\"package\":\"timetabling\""));
481 assert!(json.contains("\"components\""));
482
483 let parsed: Constraint = serde_json::from_str(&json).unwrap();
484 assert_eq!(parsed, constraint);
485 }
486
487 #[test]
488 fn test_constraint_set_json_serialization() {
489 let set = ConstraintSet::new()
490 .with_constraint(
491 Constraint::new("C1")
492 .with_component(StreamComponent::for_each("Lesson"))
493 .with_component(StreamComponent::penalize("1hard")),
494 )
495 .with_constraint(
496 Constraint::new("C2")
497 .with_component(StreamComponent::for_each("Room"))
498 .with_component(StreamComponent::reward("1soft")),
499 );
500
501 let json = serde_json::to_string(&set).unwrap();
502 let parsed: ConstraintSet = serde_json::from_str(&json).unwrap();
503 assert_eq!(parsed.len(), 2);
504 }
505
506 #[test]
507 fn test_realistic_room_conflict_constraint() {
508 let constraint = Constraint::new("Room conflict")
509 .with_package("school.timetabling")
510 .with_description("A room can accommodate at most one lesson at the same time.")
511 .with_group("Hard constraints")
512 .with_component(StreamComponent::for_each_unique_pair_with_joiners(
513 "Lesson",
514 vec![
515 Joiner::equal(WasmFunction::new("get_timeslot")),
516 Joiner::equal(WasmFunction::new("get_room")),
517 ],
518 ))
519 .with_component(StreamComponent::penalize("1hard"));
520
521 assert_eq!(constraint.components.len(), 2);
522 assert_eq!(constraint.full_name(), "school.timetabling/Room conflict");
523 }
524
525 #[test]
526 fn test_constraint_clone() {
527 let constraint = Constraint::new("Test")
528 .with_package("pkg")
529 .with_component(StreamComponent::for_each("Entity"));
530 let cloned = constraint.clone();
531 assert_eq!(constraint, cloned);
532 }
533
534 #[test]
535 fn test_constraint_debug() {
536 let constraint = Constraint::new("Test");
537 let debug = format!("{:?}", constraint);
538 assert!(debug.contains("Constraint"));
539 assert!(debug.contains("Test"));
540 }
541}