solverforge_core/constraints/
constraint.rs

1use crate::constraints::{StreamComponent, WasmFunction};
2use crate::wasm::PredicateDefinition;
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub struct Constraint {
7    pub name: String,
8    #[serde(skip_serializing_if = "Option::is_none")]
9    pub package: Option<String>,
10    #[serde(skip_serializing_if = "Option::is_none")]
11    pub description: Option<String>,
12    #[serde(skip_serializing_if = "Option::is_none")]
13    pub group: Option<String>,
14    pub components: Vec<StreamComponent>,
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub indictment: Option<WasmFunction>,
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub justification: Option<WasmFunction>,
19}
20
21impl Constraint {
22    pub fn new(name: impl Into<String>) -> Self {
23        Self {
24            name: name.into(),
25            package: None,
26            description: None,
27            group: None,
28            components: Vec::new(),
29            indictment: None,
30            justification: None,
31        }
32    }
33
34    pub fn with_package(mut self, package: impl Into<String>) -> Self {
35        self.package = Some(package.into());
36        self
37    }
38
39    pub fn with_description(mut self, description: impl Into<String>) -> Self {
40        self.description = Some(description.into());
41        self
42    }
43
44    pub fn with_group(mut self, group: impl Into<String>) -> Self {
45        self.group = Some(group.into());
46        self
47    }
48
49    pub fn with_component(mut self, component: StreamComponent) -> Self {
50        self.components.push(component);
51        self
52    }
53
54    pub fn with_components(mut self, components: Vec<StreamComponent>) -> Self {
55        self.components = components;
56        self
57    }
58
59    pub fn with_indictment(mut self, indictment: WasmFunction) -> Self {
60        self.indictment = Some(indictment);
61        self
62    }
63
64    pub fn with_justification(mut self, justification: WasmFunction) -> Self {
65        self.justification = Some(justification);
66        self
67    }
68
69    pub fn full_name(&self) -> String {
70        match &self.package {
71            Some(pkg) => format!("{}/{}", pkg, self.name),
72            None => self.name.clone(),
73        }
74    }
75}
76
77#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
78pub struct ConstraintSet {
79    pub constraints: Vec<Constraint>,
80}
81
82impl ConstraintSet {
83    pub fn new() -> Self {
84        Self {
85            constraints: Vec::new(),
86        }
87    }
88
89    pub fn with_constraint(mut self, constraint: Constraint) -> Self {
90        self.constraints.push(constraint);
91        self
92    }
93
94    pub fn add_constraint(&mut self, constraint: Constraint) {
95        self.constraints.push(constraint);
96    }
97
98    pub fn len(&self) -> usize {
99        self.constraints.len()
100    }
101
102    pub fn is_empty(&self) -> bool {
103        self.constraints.is_empty()
104    }
105
106    pub fn iter(&self) -> impl Iterator<Item = &Constraint> {
107        self.constraints.iter()
108    }
109
110    pub fn to_dto(&self) -> indexmap::IndexMap<String, Vec<StreamComponent>> {
111        self.constraints
112            .iter()
113            .map(|c| (c.name.clone(), c.components.clone()))
114            .collect()
115    }
116
117    /// Extracts all predicate definitions from the constraints.
118    ///
119    /// This walks through all constraint components and extracts `WasmFunction`s
120    /// that have associated expressions, converting them to `PredicateDefinition`s
121    /// for compilation into the WASM module.
122    pub fn extract_predicates(&self) -> Vec<PredicateDefinition> {
123        let mut predicates = Vec::new();
124        let mut seen = std::collections::HashSet::new();
125
126        for constraint in &self.constraints {
127            Self::collect_predicates_from_components(
128                &constraint.components,
129                &mut predicates,
130                &mut seen,
131            );
132        }
133
134        predicates
135    }
136
137    fn collect_predicates_from_components(
138        components: &[StreamComponent],
139        predicates: &mut Vec<PredicateDefinition>,
140        seen: &mut std::collections::HashSet<String>,
141    ) {
142        for component in components {
143            Self::collect_from_component(component, predicates, seen);
144        }
145    }
146
147    fn collect_from_component(
148        component: &StreamComponent,
149        predicates: &mut Vec<PredicateDefinition>,
150        seen: &mut std::collections::HashSet<String>,
151    ) {
152        match component {
153            StreamComponent::Filter { predicate } => {
154                Self::add_predicate_if_new(predicate, 1, predicates, seen);
155            }
156            StreamComponent::Penalize {
157                scale_by: Some(scale_by),
158                ..
159            }
160            | StreamComponent::Reward {
161                scale_by: Some(scale_by),
162                ..
163            }
164            | StreamComponent::Impact {
165                scale_by: Some(scale_by),
166                ..
167            } => {
168                // Scale functions can have different arities depending on stream type
169                // Most commonly it's 1 (single entity) but could be 2+ for joins
170                Self::add_predicate_if_new(scale_by, 1, predicates, seen);
171            }
172            StreamComponent::Map { mappers } | StreamComponent::Expand { mappers } => {
173                for mapper in mappers {
174                    Self::add_predicate_if_new(mapper, 1, predicates, seen);
175                }
176            }
177            StreamComponent::GroupBy { keys, .. } => {
178                for key in keys {
179                    Self::add_predicate_if_new(key, 1, predicates, seen);
180                }
181            }
182            StreamComponent::FlattenLast { map: Some(map) } => {
183                Self::add_predicate_if_new(map, 1, predicates, seen);
184            }
185            StreamComponent::IndictWith {
186                indicted_object_provider,
187            } => {
188                Self::add_predicate_if_new(indicted_object_provider, 1, predicates, seen);
189            }
190            StreamComponent::JustifyWith {
191                justification_supplier,
192            } => {
193                Self::add_predicate_if_new(justification_supplier, 1, predicates, seen);
194            }
195            StreamComponent::Concat { other_components } => {
196                Self::collect_predicates_from_components(other_components, predicates, seen);
197            }
198            StreamComponent::ForEachUniquePair { joiners, .. }
199            | StreamComponent::Join { joiners, .. }
200            | StreamComponent::IfExists { joiners, .. }
201            | StreamComponent::IfNotExists { joiners, .. }
202            | StreamComponent::IfExistsOther { joiners, .. }
203            | StreamComponent::IfNotExistsOther { joiners, .. }
204            | StreamComponent::IfExistsIncludingUnassigned { joiners, .. }
205            | StreamComponent::IfNotExistsIncludingUnassigned { joiners, .. } => {
206                Self::collect_from_joiners(joiners, predicates, seen);
207            }
208            // Components without functions
209            StreamComponent::ForEach { .. }
210            | StreamComponent::ForEachIncludingUnassigned { .. }
211            | StreamComponent::Complement { .. }
212            | StreamComponent::Distinct
213            | StreamComponent::Penalize { scale_by: None, .. }
214            | StreamComponent::Reward { scale_by: None, .. }
215            | StreamComponent::Impact { scale_by: None, .. }
216            | StreamComponent::FlattenLast { map: None } => {}
217        }
218    }
219
220    fn collect_from_joiners(
221        joiners: &[crate::constraints::Joiner],
222        predicates: &mut Vec<PredicateDefinition>,
223        seen: &mut std::collections::HashSet<String>,
224    ) {
225        use crate::constraints::Joiner;
226        for joiner in joiners {
227            match joiner {
228                Joiner::Equal {
229                    map,
230                    left_map,
231                    right_map,
232                    relation_predicate,
233                    hasher,
234                } => {
235                    if let Some(f) = map {
236                        Self::add_predicate_if_new(f, 1, predicates, seen);
237                    }
238                    if let Some(f) = left_map {
239                        Self::add_predicate_if_new(f, 1, predicates, seen);
240                    }
241                    if let Some(f) = right_map {
242                        Self::add_predicate_if_new(f, 1, predicates, seen);
243                    }
244                    if let Some(f) = relation_predicate {
245                        Self::add_predicate_if_new(f, 2, predicates, seen);
246                    }
247                    if let Some(f) = hasher {
248                        Self::add_predicate_if_new(f, 1, predicates, seen);
249                    }
250                }
251                Joiner::LessThan {
252                    map,
253                    left_map,
254                    right_map,
255                    comparator,
256                }
257                | Joiner::LessThanOrEqual {
258                    map,
259                    left_map,
260                    right_map,
261                    comparator,
262                }
263                | Joiner::GreaterThan {
264                    map,
265                    left_map,
266                    right_map,
267                    comparator,
268                }
269                | Joiner::GreaterThanOrEqual {
270                    map,
271                    left_map,
272                    right_map,
273                    comparator,
274                } => {
275                    if let Some(f) = map {
276                        Self::add_predicate_if_new(f, 1, predicates, seen);
277                    }
278                    if let Some(f) = left_map {
279                        Self::add_predicate_if_new(f, 1, predicates, seen);
280                    }
281                    if let Some(f) = right_map {
282                        Self::add_predicate_if_new(f, 1, predicates, seen);
283                    }
284                    Self::add_predicate_if_new(comparator, 2, predicates, seen);
285                }
286                Joiner::Overlapping {
287                    start_map,
288                    end_map,
289                    left_start_map,
290                    left_end_map,
291                    right_start_map,
292                    right_end_map,
293                    comparator,
294                } => {
295                    if let Some(f) = start_map {
296                        Self::add_predicate_if_new(f, 1, predicates, seen);
297                    }
298                    if let Some(f) = end_map {
299                        Self::add_predicate_if_new(f, 1, predicates, seen);
300                    }
301                    if let Some(f) = left_start_map {
302                        Self::add_predicate_if_new(f, 1, predicates, seen);
303                    }
304                    if let Some(f) = left_end_map {
305                        Self::add_predicate_if_new(f, 1, predicates, seen);
306                    }
307                    if let Some(f) = right_start_map {
308                        Self::add_predicate_if_new(f, 1, predicates, seen);
309                    }
310                    if let Some(f) = right_end_map {
311                        Self::add_predicate_if_new(f, 1, predicates, seen);
312                    }
313                    if let Some(f) = comparator {
314                        Self::add_predicate_if_new(f, 2, predicates, seen);
315                    }
316                }
317                Joiner::Filtering { filter } => {
318                    Self::add_predicate_if_new(filter, 2, predicates, seen);
319                }
320            }
321        }
322    }
323
324    fn add_predicate_if_new(
325        func: &WasmFunction,
326        arity: u32,
327        predicates: &mut Vec<PredicateDefinition>,
328        seen: &mut std::collections::HashSet<String>,
329    ) {
330        if let Some(expr) = func.expression() {
331            if !seen.contains(func.name()) {
332                seen.insert(func.name().to_string());
333                predicates.push(PredicateDefinition::from_expression(
334                    func.name(),
335                    arity,
336                    expr.clone(),
337                ));
338            }
339        }
340    }
341}
342
343impl FromIterator<Constraint> for ConstraintSet {
344    fn from_iter<I: IntoIterator<Item = Constraint>>(iter: I) -> Self {
345        ConstraintSet {
346            constraints: iter.into_iter().collect(),
347        }
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use crate::constraints::Joiner;
355
356    #[test]
357    fn test_constraint_new() {
358        let constraint = Constraint::new("Room conflict");
359        assert_eq!(constraint.name, "Room conflict");
360        assert!(constraint.package.is_none());
361        assert!(constraint.components.is_empty());
362    }
363
364    #[test]
365    fn test_constraint_with_package() {
366        let constraint = Constraint::new("Room conflict").with_package("timetabling");
367        assert_eq!(constraint.package, Some("timetabling".to_string()));
368    }
369
370    #[test]
371    fn test_constraint_with_description() {
372        let constraint =
373            Constraint::new("Room conflict").with_description("Two lessons in same room");
374        assert_eq!(
375            constraint.description,
376            Some("Two lessons in same room".to_string())
377        );
378    }
379
380    #[test]
381    fn test_constraint_with_group() {
382        let constraint = Constraint::new("Room conflict").with_group("Hard constraints");
383        assert_eq!(constraint.group, Some("Hard constraints".to_string()));
384    }
385
386    #[test]
387    fn test_constraint_with_component() {
388        let constraint = Constraint::new("Room conflict")
389            .with_component(StreamComponent::for_each("Lesson"))
390            .with_component(StreamComponent::penalize("1hard"));
391        assert_eq!(constraint.components.len(), 2);
392    }
393
394    #[test]
395    fn test_constraint_with_components() {
396        let components = vec![
397            StreamComponent::for_each("Lesson"),
398            StreamComponent::penalize("1hard"),
399        ];
400        let constraint = Constraint::new("Room conflict").with_components(components);
401        assert_eq!(constraint.components.len(), 2);
402    }
403
404    #[test]
405    fn test_constraint_with_indictment() {
406        let constraint =
407            Constraint::new("Room conflict").with_indictment(WasmFunction::new("get_room"));
408        assert!(constraint.indictment.is_some());
409    }
410
411    #[test]
412    fn test_constraint_with_justification() {
413        let constraint = Constraint::new("Room conflict")
414            .with_justification(WasmFunction::new("create_justification"));
415        assert!(constraint.justification.is_some());
416    }
417
418    #[test]
419    fn test_constraint_full_name() {
420        let constraint1 = Constraint::new("Room conflict");
421        assert_eq!(constraint1.full_name(), "Room conflict");
422
423        let constraint2 = Constraint::new("Room conflict").with_package("timetabling");
424        assert_eq!(constraint2.full_name(), "timetabling/Room conflict");
425    }
426
427    #[test]
428    fn test_constraint_set_new() {
429        let set = ConstraintSet::new();
430        assert!(set.is_empty());
431        assert_eq!(set.len(), 0);
432    }
433
434    #[test]
435    fn test_constraint_set_with_constraint() {
436        let set = ConstraintSet::new()
437            .with_constraint(Constraint::new("Constraint 1"))
438            .with_constraint(Constraint::new("Constraint 2"));
439        assert_eq!(set.len(), 2);
440    }
441
442    #[test]
443    fn test_constraint_set_add_constraint() {
444        let mut set = ConstraintSet::new();
445        set.add_constraint(Constraint::new("Constraint 1"));
446        set.add_constraint(Constraint::new("Constraint 2"));
447        assert_eq!(set.len(), 2);
448    }
449
450    #[test]
451    fn test_constraint_set_iter() {
452        let set = ConstraintSet::new()
453            .with_constraint(Constraint::new("C1"))
454            .with_constraint(Constraint::new("C2"));
455
456        let names: Vec<_> = set.iter().map(|c| c.name.as_str()).collect();
457        assert_eq!(names, vec!["C1", "C2"]);
458    }
459
460    #[test]
461    fn test_constraint_set_from_iter() {
462        let constraints = vec![Constraint::new("C1"), Constraint::new("C2")];
463        let set: ConstraintSet = constraints.into_iter().collect();
464        assert_eq!(set.len(), 2);
465    }
466
467    #[test]
468    fn test_constraint_json_serialization() {
469        let constraint = Constraint::new("Room conflict")
470            .with_package("timetabling")
471            .with_component(StreamComponent::for_each_unique_pair_with_joiners(
472                "Lesson",
473                vec![Joiner::equal(WasmFunction::new("get_timeslot"))],
474            ))
475            .with_component(StreamComponent::filter(WasmFunction::new("same_room")))
476            .with_component(StreamComponent::penalize("1hard"));
477
478        let json = serde_json::to_string(&constraint).unwrap();
479        assert!(json.contains("\"name\":\"Room conflict\""));
480        assert!(json.contains("\"package\":\"timetabling\""));
481        assert!(json.contains("\"components\""));
482
483        let parsed: Constraint = serde_json::from_str(&json).unwrap();
484        assert_eq!(parsed, constraint);
485    }
486
487    #[test]
488    fn test_constraint_set_json_serialization() {
489        let set = ConstraintSet::new()
490            .with_constraint(
491                Constraint::new("C1")
492                    .with_component(StreamComponent::for_each("Lesson"))
493                    .with_component(StreamComponent::penalize("1hard")),
494            )
495            .with_constraint(
496                Constraint::new("C2")
497                    .with_component(StreamComponent::for_each("Room"))
498                    .with_component(StreamComponent::reward("1soft")),
499            );
500
501        let json = serde_json::to_string(&set).unwrap();
502        let parsed: ConstraintSet = serde_json::from_str(&json).unwrap();
503        assert_eq!(parsed.len(), 2);
504    }
505
506    #[test]
507    fn test_realistic_room_conflict_constraint() {
508        let constraint = Constraint::new("Room conflict")
509            .with_package("school.timetabling")
510            .with_description("A room can accommodate at most one lesson at the same time.")
511            .with_group("Hard constraints")
512            .with_component(StreamComponent::for_each_unique_pair_with_joiners(
513                "Lesson",
514                vec![
515                    Joiner::equal(WasmFunction::new("get_timeslot")),
516                    Joiner::equal(WasmFunction::new("get_room")),
517                ],
518            ))
519            .with_component(StreamComponent::penalize("1hard"));
520
521        assert_eq!(constraint.components.len(), 2);
522        assert_eq!(constraint.full_name(), "school.timetabling/Room conflict");
523    }
524
525    #[test]
526    fn test_constraint_clone() {
527        let constraint = Constraint::new("Test")
528            .with_package("pkg")
529            .with_component(StreamComponent::for_each("Entity"));
530        let cloned = constraint.clone();
531        assert_eq!(constraint, cloned);
532    }
533
534    #[test]
535    fn test_constraint_debug() {
536        let constraint = Constraint::new("Test");
537        let debug = format!("{:?}", constraint);
538        assert!(debug.contains("Constraint"));
539        assert!(debug.contains("Test"));
540    }
541}