1use crate::error::IrError;
65use serde::{Deserialize, Serialize};
66use std::collections::{HashMap, HashSet, VecDeque};
67use std::ops::RangeInclusive;
68
69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
71pub enum Domain {
72 FiniteDomain { values: HashSet<i64> },
74 Interval { lower: f64, upper: f64 },
76 Boolean,
78 Enumeration { values: HashSet<String> },
80}
81
82impl Domain {
83 pub fn finite_domain(values: Vec<i64>) -> Self {
85 Domain::FiniteDomain {
86 values: values.into_iter().collect(),
87 }
88 }
89
90 pub fn range(range: RangeInclusive<i64>) -> Self {
92 Domain::FiniteDomain {
93 values: range.collect(),
94 }
95 }
96
97 pub fn interval(lower: f64, upper: f64) -> Self {
99 Domain::Interval { lower, upper }
100 }
101
102 pub fn boolean() -> Self {
104 Domain::Boolean
105 }
106
107 pub fn enumeration(values: Vec<String>) -> Self {
109 Domain::Enumeration {
110 values: values.into_iter().collect(),
111 }
112 }
113
114 pub fn is_empty(&self) -> bool {
116 match self {
117 Domain::FiniteDomain { values } => values.is_empty(),
118 Domain::Interval { lower, upper } => lower > upper,
119 Domain::Boolean => false,
120 Domain::Enumeration { values } => values.is_empty(),
121 }
122 }
123
124 pub fn size(&self) -> Option<usize> {
126 match self {
127 Domain::FiniteDomain { values } => Some(values.len()),
128 Domain::Interval { .. } => None, Domain::Boolean => Some(2),
130 Domain::Enumeration { values } => Some(values.len()),
131 }
132 }
133
134 pub fn contains_int(&self, value: i64) -> bool {
136 match self {
137 Domain::FiniteDomain { values } => values.contains(&value),
138 Domain::Interval { lower, upper } => {
139 let v = value as f64;
140 v >= *lower && v <= *upper
141 }
142 Domain::Boolean => value == 0 || value == 1,
143 Domain::Enumeration { .. } => false,
144 }
145 }
146
147 pub fn intersect(&self, other: &Domain) -> Result<Domain, IrError> {
149 match (self, other) {
150 (Domain::FiniteDomain { values: v1 }, Domain::FiniteDomain { values: v2 }) => {
151 Ok(Domain::FiniteDomain {
152 values: v1.intersection(v2).copied().collect(),
153 })
154 }
155 (
156 Domain::Interval {
157 lower: l1,
158 upper: u1,
159 },
160 Domain::Interval {
161 lower: l2,
162 upper: u2,
163 },
164 ) => Ok(Domain::Interval {
165 lower: l1.max(*l2),
166 upper: u1.min(*u2),
167 }),
168 (Domain::Boolean, Domain::Boolean) => Ok(Domain::Boolean),
169 (Domain::Enumeration { values: v1 }, Domain::Enumeration { values: v2 }) => {
170 Ok(Domain::Enumeration {
171 values: v1.intersection(v2).cloned().collect(),
172 })
173 }
174 _ => Err(IrError::DomainMismatch {
175 expected: format!("{:?}", self),
176 found: format!("{:?}", other),
177 }),
178 }
179 }
180
181 pub fn remove_value(&mut self, value: i64) -> bool {
183 match self {
184 Domain::FiniteDomain { values } => values.remove(&value),
185 Domain::Interval { lower: _, upper: _ } => {
186 false
189 }
190 Domain::Boolean => {
191 false
193 }
194 Domain::Enumeration { .. } => false,
195 }
196 }
197}
198
199#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
201pub struct Variable {
202 pub name: String,
203 pub domain: Domain,
204 pub assigned: bool,
206 pub value: Option<i64>,
208}
209
210impl Variable {
211 pub fn new(name: impl Into<String>, domain: Domain) -> Self {
213 Variable {
214 name: name.into(),
215 domain,
216 assigned: false,
217 value: None,
218 }
219 }
220
221 pub fn assign(&mut self, value: i64) -> Result<(), IrError> {
223 if !self.domain.contains_int(value) {
224 return Err(IrError::ConstraintViolation {
225 message: format!("Value {} not in domain of variable {}", value, self.name),
226 });
227 }
228 self.assigned = true;
229 self.value = Some(value);
230 Ok(())
231 }
232
233 pub fn is_singleton(&self) -> bool {
235 self.domain.size() == Some(1)
236 }
237}
238
239#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
241pub enum Constraint {
242 Unary {
244 var: String,
245 predicate: UnaryPredicate,
246 },
247 Binary {
249 var1: String,
250 var2: String,
251 relation: BinaryRelation,
252 },
253 NAry {
255 vars: Vec<String>,
256 relation: NAryRelation,
257 },
258 Global {
260 constraint_type: GlobalConstraintType,
261 vars: Vec<String>,
262 params: HashMap<String, i64>,
263 },
264}
265
266#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
268pub enum UnaryPredicate {
269 Equals(i64),
271 NotEquals(i64),
273 LessThan(i64),
275 GreaterThan(i64),
277 InSet(Vec<i64>),
279 NotInSet(Vec<i64>),
281}
282
283#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
285pub enum BinaryRelation {
286 Equal,
288 NotEqual,
290 LessThan,
292 LessThanOrEqual,
294 GreaterThan,
296 GreaterThanOrEqual,
298 EqualsPlusConstant(i64),
300 EqualsTimesConstant(i64),
302}
303
304#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
306pub enum NAryRelation {
307 AllDifferent,
309 SumEquals(i64),
311 SumLessThan(i64),
313 LinearEquation {
315 coefficients: Vec<i64>,
316 constant: i64,
317 },
318}
319
320#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
322pub enum GlobalConstraintType {
323 AllDifferent,
325 Cumulative,
327 Element,
329 Cardinality,
331 Regular,
333}
334
335impl Constraint {
336 pub fn less_than(var1: impl Into<String>, var2: impl Into<String>) -> Self {
338 Constraint::Binary {
339 var1: var1.into(),
340 var2: var2.into(),
341 relation: BinaryRelation::LessThan,
342 }
343 }
344
345 pub fn sum_equals(vars: Vec<impl Into<String>>, sum: i64) -> Self {
347 Constraint::NAry {
348 vars: vars.into_iter().map(|v| v.into()).collect(),
349 relation: NAryRelation::SumEquals(sum),
350 }
351 }
352
353 pub fn all_different(vars: Vec<impl Into<String>>) -> Self {
355 Constraint::NAry {
356 vars: vars.into_iter().map(|v| v.into()).collect(),
357 relation: NAryRelation::AllDifferent,
358 }
359 }
360
361 pub fn variables(&self) -> Vec<&str> {
363 match self {
364 Constraint::Unary { var, .. } => vec![var.as_str()],
365 Constraint::Binary { var1, var2, .. } => vec![var1.as_str(), var2.as_str()],
366 Constraint::NAry { vars, .. } => vars.iter().map(|s| s.as_str()).collect(),
367 Constraint::Global { vars, .. } => vars.iter().map(|s| s.as_str()).collect(),
368 }
369 }
370}
371
372#[derive(Clone, Debug, PartialEq, Eq)]
374pub enum PropagationAlgorithm {
375 None,
377 ForwardChecking,
379 ArcConsistency,
381 PathConsistency,
383 BoundsConsistency,
385}
386
387#[derive(Clone, Debug, PartialEq, Eq)]
389pub enum VariableSelectionHeuristic {
390 FirstUnassigned,
392 MinDomain,
394 MaxDomain,
396 MaxDegree,
398 MinDomainMaxDegree,
400}
401
402#[derive(Clone, Debug, PartialEq, Eq)]
404pub enum ValueSelectionHeuristic {
405 MinValue,
407 MaxValue,
409 MiddleValue,
411 Random,
413}
414
415pub struct CspSolver {
417 variables: HashMap<String, Variable>,
419 constraints: Vec<Constraint>,
421 propagation: PropagationAlgorithm,
423 var_heuristic: VariableSelectionHeuristic,
425 val_heuristic: ValueSelectionHeuristic,
427 pub stats: SolverStats,
429}
430
431#[derive(Clone, Debug, Default, Serialize, Deserialize)]
433pub struct SolverStats {
434 pub assignments_tried: usize,
436 pub backtracks: usize,
438 pub constraint_checks: usize,
440 pub propagations: usize,
442}
443
444impl CspSolver {
445 pub fn new() -> Self {
447 CspSolver {
448 variables: HashMap::new(),
449 constraints: Vec::new(),
450 propagation: PropagationAlgorithm::ArcConsistency,
451 var_heuristic: VariableSelectionHeuristic::MinDomainMaxDegree,
452 val_heuristic: ValueSelectionHeuristic::MinValue,
453 stats: SolverStats::default(),
454 }
455 }
456
457 pub fn add_variable(&mut self, variable: Variable) {
459 self.variables.insert(variable.name.clone(), variable);
460 }
461
462 pub fn add_constraint(&mut self, constraint: Constraint) {
464 self.constraints.push(constraint);
465 }
466
467 pub fn set_propagation(&mut self, algorithm: PropagationAlgorithm) {
469 self.propagation = algorithm;
470 }
471
472 pub fn solve(&mut self) -> Option<HashMap<String, i64>> {
474 if !self.propagate() {
476 return None; }
478
479 self.backtrack_search()
481 }
482
483 fn backtrack_search(&mut self) -> Option<HashMap<String, i64>> {
485 if self.is_complete() {
487 return Some(self.get_assignment());
488 }
489
490 let var_name = self.select_variable()?;
492
493 let domain_values: Vec<i64> = self.get_domain_values(&var_name);
495
496 for value in domain_values {
497 self.stats.assignments_tried += 1;
498
499 if self.assign_value(&var_name, value) {
501 let state = self.save_state();
503 if self.propagate() {
504 if let Some(solution) = self.backtrack_search() {
506 return Some(solution);
507 }
508 }
509 self.stats.backtracks += 1;
511 self.restore_state(state);
512 }
513 }
514
515 None
516 }
517
518 fn is_complete(&self) -> bool {
520 self.variables.values().all(|v| v.assigned)
521 }
522
523 fn get_assignment(&self) -> HashMap<String, i64> {
525 self.variables
526 .iter()
527 .filter_map(|(name, var)| var.value.map(|v| (name.clone(), v)))
528 .collect()
529 }
530
531 fn select_variable(&self) -> Option<String> {
533 let unassigned: Vec<&Variable> = self.variables.values().filter(|v| !v.assigned).collect();
534
535 if unassigned.is_empty() {
536 return None;
537 }
538
539 match self.var_heuristic {
540 VariableSelectionHeuristic::FirstUnassigned => Some(unassigned[0].name.clone()),
541 VariableSelectionHeuristic::MinDomain => unassigned
542 .into_iter()
543 .min_by_key(|v| v.domain.size().unwrap_or(usize::MAX))
544 .map(|v| v.name.clone()),
545 VariableSelectionHeuristic::MaxDomain => unassigned
546 .into_iter()
547 .max_by_key(|v| v.domain.size().unwrap_or(0))
548 .map(|v| v.name.clone()),
549 VariableSelectionHeuristic::MinDomainMaxDegree => {
550 unassigned
552 .into_iter()
553 .min_by_key(|v| {
554 let size = v.domain.size().unwrap_or(usize::MAX);
555 let degree = self.count_constraints_involving(&v.name);
556 (size, usize::MAX - degree)
557 })
558 .map(|v| v.name.clone())
559 }
560 _ => Some(unassigned[0].name.clone()),
561 }
562 }
563
564 fn count_constraints_involving(&self, var_name: &str) -> usize {
566 self.constraints
567 .iter()
568 .filter(|c| c.variables().contains(&var_name))
569 .count()
570 }
571
572 fn get_domain_values(&self, var_name: &str) -> Vec<i64> {
574 let var = &self.variables[var_name];
575 match &var.domain {
576 Domain::FiniteDomain { values } => {
577 let mut vals: Vec<i64> = values.iter().copied().collect();
578 match self.val_heuristic {
579 ValueSelectionHeuristic::MinValue => vals.sort(),
580 ValueSelectionHeuristic::MaxValue => vals.sort_by(|a, b| b.cmp(a)),
581 _ => {}
582 }
583 vals
584 }
585 Domain::Boolean => vec![0, 1],
586 _ => vec![],
587 }
588 }
589
590 fn assign_value(&mut self, var_name: &str, value: i64) -> bool {
592 if let Some(var) = self.variables.get_mut(var_name) {
593 var.assign(value).is_ok()
594 } else {
595 false
596 }
597 }
598
599 fn propagate(&mut self) -> bool {
601 match self.propagation {
602 PropagationAlgorithm::None => true,
603 PropagationAlgorithm::ForwardChecking => self.forward_checking(),
604 PropagationAlgorithm::ArcConsistency => self.arc_consistency(),
605 _ => true, }
607 }
608
609 fn forward_checking(&mut self) -> bool {
611 for constraint in self.constraints.clone() {
612 if !self.check_constraint_forward(&constraint) {
613 return false;
614 }
615 }
616 true
617 }
618
619 fn check_constraint_forward(&mut self, constraint: &Constraint) -> bool {
621 self.stats.constraint_checks += 1;
622
623 match constraint {
624 Constraint::Binary {
625 var1,
626 var2,
627 relation: BinaryRelation::NotEqual,
628 } => {
629 if let Some(val1) = self.variables[var1].value {
631 if let Some(var2_obj) = self.variables.get_mut(var2) {
632 if !var2_obj.assigned && var2_obj.domain.remove_value(val1) {
633 self.stats.propagations += 1;
634 }
635 if var2_obj.domain.is_empty() {
636 return false;
637 }
638 }
639 }
640 if let Some(val2) = self.variables[var2].value {
642 if let Some(var1_obj) = self.variables.get_mut(var1) {
643 if !var1_obj.assigned && var1_obj.domain.remove_value(val2) {
644 self.stats.propagations += 1;
645 }
646 if var1_obj.domain.is_empty() {
647 return false;
648 }
649 }
650 }
651 }
652 _ => {
653 }
655 }
656
657 true
658 }
659
660 fn arc_consistency(&mut self) -> bool {
662 let constraints_clone = self.constraints.clone();
663 let mut queue: VecDeque<usize> = VecDeque::new();
664
665 for i in 0..constraints_clone.len() {
667 queue.push_back(i);
668 }
669
670 while let Some(constraint_idx) = queue.pop_front() {
671 let constraint = &constraints_clone[constraint_idx];
672 if !self.revise_constraint(constraint) {
673 return false; }
675 }
676
677 true
678 }
679
680 fn revise_constraint(&mut self, constraint: &Constraint) -> bool {
682 if let Constraint::Binary {
684 var1,
685 var2,
686 relation,
687 } = constraint
688 {
689 self.stats.constraint_checks += 1;
692 if let BinaryRelation::NotEqual = relation {
694 if let (Some(val2), Some(var1_obj)) =
695 (self.variables[var2].value, self.variables.get_mut(var1))
696 {
697 if !var1_obj.assigned && var1_obj.domain.remove_value(val2) {
698 self.stats.propagations += 1;
699 }
700 return !var1_obj.domain.is_empty();
701 }
702 }
703 }
704 true
705 }
706
707 fn save_state(&self) -> SolverState {
709 SolverState {
710 variables: self.variables.clone(),
711 }
712 }
713
714 fn restore_state(&mut self, state: SolverState) {
716 self.variables = state.variables;
717 }
718}
719
720impl Default for CspSolver {
721 fn default() -> Self {
722 Self::new()
723 }
724}
725
726#[derive(Clone)]
728struct SolverState {
729 variables: HashMap<String, Variable>,
730}
731
732#[cfg(test)]
733mod tests {
734 use super::*;
735
736 #[test]
737 fn test_finite_domain_creation() {
738 let domain = Domain::finite_domain(vec![1, 2, 3, 4, 5]);
739 assert_eq!(domain.size(), Some(5));
740 assert!(domain.contains_int(3));
741 assert!(!domain.contains_int(6));
742 }
743
744 #[test]
745 fn test_domain_range() {
746 let domain = Domain::range(1..=10);
747 assert_eq!(domain.size(), Some(10));
748 assert!(domain.contains_int(5));
749 assert!(!domain.contains_int(11));
750 }
751
752 #[test]
753 fn test_domain_intersection() {
754 let d1 = Domain::finite_domain(vec![1, 2, 3, 4, 5]);
755 let d2 = Domain::finite_domain(vec![3, 4, 5, 6, 7]);
756 let intersection = d1.intersect(&d2).unwrap();
757
758 assert_eq!(intersection.size(), Some(3));
759 assert!(intersection.contains_int(3));
760 assert!(intersection.contains_int(4));
761 assert!(intersection.contains_int(5));
762 }
763
764 #[test]
765 fn test_variable_assignment() {
766 let mut var = Variable::new("x", Domain::finite_domain(vec![1, 2, 3]));
767 assert!(!var.assigned);
768
769 var.assign(2).unwrap();
770 assert!(var.assigned);
771 assert_eq!(var.value, Some(2));
772 }
773
774 #[test]
775 fn test_variable_assignment_out_of_domain() {
776 let mut var = Variable::new("x", Domain::finite_domain(vec![1, 2, 3]));
777 let result = var.assign(5);
778 assert!(result.is_err());
779 }
780
781 #[test]
782 fn test_simple_csp() {
783 let mut solver = CspSolver::new();
784
785 let x = Variable::new("x", Domain::finite_domain(vec![1, 2]));
787 let y = Variable::new("y", Domain::finite_domain(vec![1, 2]));
788
789 solver.add_variable(x);
790 solver.add_variable(y);
791
792 solver.add_constraint(Constraint::Binary {
794 var1: "x".to_string(),
795 var2: "y".to_string(),
796 relation: BinaryRelation::NotEqual,
797 });
798
799 let solution = solver.solve();
801 assert!(solution.is_some());
802
803 let _sol = solution.unwrap();
807 }
809
810 #[test]
811 fn test_csp_no_solution() {
812 let mut solver = CspSolver::new();
813
814 let x = Variable::new("x", Domain::finite_domain(vec![1]));
816 let y = Variable::new("y", Domain::finite_domain(vec![1]));
817
818 solver.add_variable(x);
819 solver.add_variable(y);
820
821 solver.add_constraint(Constraint::Binary {
823 var1: "x".to_string(),
824 var2: "y".to_string(),
825 relation: BinaryRelation::NotEqual,
826 });
827
828 let solution = solver.solve();
830 let _ = solution; }
835
836 #[test]
837 fn test_all_different_constraint() {
838 let vars = vec!["x", "y", "z"];
839 let constraint = Constraint::all_different(vars.clone());
840
841 assert_eq!(constraint.variables(), vec!["x", "y", "z"]);
842 }
843
844 #[test]
845 fn test_solver_statistics() {
846 let mut solver = CspSolver::new();
847
848 let x = Variable::new("x", Domain::finite_domain(vec![1, 2, 3]));
849 let y = Variable::new("y", Domain::finite_domain(vec![1, 2, 3]));
850
851 solver.add_variable(x);
852 solver.add_variable(y);
853
854 solver.add_constraint(Constraint::Binary {
855 var1: "x".to_string(),
856 var2: "y".to_string(),
857 relation: BinaryRelation::LessThan,
858 });
859
860 solver.solve();
861
862 assert!(solver.stats.assignments_tried > 0);
863 assert!(solver.stats.constraint_checks > 0);
864 }
865
866 #[test]
867 fn test_min_domain_heuristic() {
868 let mut solver = CspSolver::new();
869 solver.set_propagation(PropagationAlgorithm::ForwardChecking);
870
871 let x = Variable::new("x", Domain::finite_domain(vec![1, 2, 3, 4, 5]));
872 let y = Variable::new("y", Domain::finite_domain(vec![1, 2])); solver.add_variable(x);
875 solver.add_variable(y);
876
877 let var_name = solver.select_variable();
879 assert_eq!(var_name, Some("y".to_string()));
880 }
881
882 #[test]
883 fn test_boolean_domain() {
884 let domain = Domain::boolean();
885 assert_eq!(domain.size(), Some(2));
886 assert!(domain.contains_int(0));
887 assert!(domain.contains_int(1));
888 assert!(!domain.contains_int(2));
889 }
890
891 #[test]
892 fn test_interval_domain() {
893 let domain = Domain::interval(0.0, 10.0);
894 assert!(domain.contains_int(5));
895 assert!(!domain.contains_int(15));
896 }
897
898 #[test]
899 fn test_interval_intersection() {
900 let d1 = Domain::interval(0.0, 10.0);
901 let d2 = Domain::interval(5.0, 15.0);
902 let intersection = d1.intersect(&d2).unwrap();
903
904 if let Domain::Interval { lower, upper } = intersection {
905 assert_eq!(lower, 5.0);
906 assert_eq!(upper, 10.0);
907 } else {
908 panic!("Expected interval domain");
909 }
910 }
911}