1use std::collections::HashMap;
30use std::sync::Arc;
31
32#[derive(Clone)]
34pub enum RefinementPredicate {
35 Equal(f64),
37 NotEqual(f64),
39 GreaterThan(f64),
41 GreaterThanOrEqual(f64),
43 LessThan(f64),
45 LessThanOrEqual(f64),
47 Range { min: f64, max: f64 },
49 RangeExclusive { min: f64, max: f64 },
51 Modulo { divisor: i64, remainder: i64 },
53 InSet(Vec<f64>),
55 NotInSet(Vec<f64>),
57 And(Vec<RefinementPredicate>),
59 Or(Vec<RefinementPredicate>),
61 Not(Box<RefinementPredicate>),
63 Custom {
65 name: String,
66 description: String,
67 checker: Arc<dyn Fn(f64) -> bool + Send + Sync>,
68 },
69 Dependent {
71 variable: String,
72 relation: DependentRelation,
73 },
74 StringLength {
76 min: Option<usize>,
77 max: Option<usize>,
78 },
79 Pattern(String),
81}
82
83#[derive(Debug, Clone, PartialEq)]
85pub enum DependentRelation {
86 LessThan,
88 LessThanOrEqual,
90 GreaterThan,
92 GreaterThanOrEqual,
94 Equal,
96 NotEqual,
98 Divides,
100 DivisibleBy,
102}
103
104impl std::fmt::Debug for RefinementPredicate {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 match self {
107 RefinementPredicate::Equal(v) => f.debug_tuple("Equal").field(v).finish(),
108 RefinementPredicate::NotEqual(v) => f.debug_tuple("NotEqual").field(v).finish(),
109 RefinementPredicate::GreaterThan(v) => f.debug_tuple("GreaterThan").field(v).finish(),
110 RefinementPredicate::GreaterThanOrEqual(v) => {
111 f.debug_tuple("GreaterThanOrEqual").field(v).finish()
112 }
113 RefinementPredicate::LessThan(v) => f.debug_tuple("LessThan").field(v).finish(),
114 RefinementPredicate::LessThanOrEqual(v) => {
115 f.debug_tuple("LessThanOrEqual").field(v).finish()
116 }
117 RefinementPredicate::Range { min, max } => f
118 .debug_struct("Range")
119 .field("min", min)
120 .field("max", max)
121 .finish(),
122 RefinementPredicate::RangeExclusive { min, max } => f
123 .debug_struct("RangeExclusive")
124 .field("min", min)
125 .field("max", max)
126 .finish(),
127 RefinementPredicate::Modulo { divisor, remainder } => f
128 .debug_struct("Modulo")
129 .field("divisor", divisor)
130 .field("remainder", remainder)
131 .finish(),
132 RefinementPredicate::InSet(set) => f.debug_tuple("InSet").field(set).finish(),
133 RefinementPredicate::NotInSet(set) => f.debug_tuple("NotInSet").field(set).finish(),
134 RefinementPredicate::And(preds) => f.debug_tuple("And").field(preds).finish(),
135 RefinementPredicate::Or(preds) => f.debug_tuple("Or").field(preds).finish(),
136 RefinementPredicate::Not(pred) => f.debug_tuple("Not").field(pred).finish(),
137 RefinementPredicate::Custom {
138 name, description, ..
139 } => f
140 .debug_struct("Custom")
141 .field("name", name)
142 .field("description", description)
143 .finish(),
144 RefinementPredicate::Dependent { variable, relation } => f
145 .debug_struct("Dependent")
146 .field("variable", variable)
147 .field("relation", relation)
148 .finish(),
149 RefinementPredicate::StringLength { min, max } => f
150 .debug_struct("StringLength")
151 .field("min", min)
152 .field("max", max)
153 .finish(),
154 RefinementPredicate::Pattern(pattern) => {
155 f.debug_tuple("Pattern").field(pattern).finish()
156 }
157 }
158 }
159}
160
161impl RefinementPredicate {
162 pub fn greater_than(value: f64) -> Self {
164 RefinementPredicate::GreaterThan(value)
165 }
166
167 pub fn greater_than_or_equal(value: f64) -> Self {
169 RefinementPredicate::GreaterThanOrEqual(value)
170 }
171
172 pub fn less_than(value: f64) -> Self {
174 RefinementPredicate::LessThan(value)
175 }
176
177 pub fn less_than_or_equal(value: f64) -> Self {
179 RefinementPredicate::LessThanOrEqual(value)
180 }
181
182 pub fn range(min: f64, max: f64) -> Self {
184 RefinementPredicate::Range { min, max }
185 }
186
187 pub fn modulo(divisor: i64, remainder: i64) -> Self {
189 RefinementPredicate::Modulo { divisor, remainder }
190 }
191
192 pub fn in_set(values: Vec<f64>) -> Self {
194 RefinementPredicate::InSet(values)
195 }
196
197 pub fn and(predicates: Vec<RefinementPredicate>) -> Self {
199 RefinementPredicate::And(predicates)
200 }
201
202 pub fn or(predicates: Vec<RefinementPredicate>) -> Self {
204 RefinementPredicate::Or(predicates)
205 }
206
207 #[allow(clippy::should_implement_trait)]
209 pub fn not(predicate: RefinementPredicate) -> Self {
210 RefinementPredicate::Not(Box::new(predicate))
211 }
212
213 pub fn custom<F>(name: impl Into<String>, description: impl Into<String>, checker: F) -> Self
215 where
216 F: Fn(f64) -> bool + Send + Sync + 'static,
217 {
218 RefinementPredicate::Custom {
219 name: name.into(),
220 description: description.into(),
221 checker: Arc::new(checker),
222 }
223 }
224
225 pub fn dependent(variable: impl Into<String>, relation: DependentRelation) -> Self {
227 RefinementPredicate::Dependent {
228 variable: variable.into(),
229 relation,
230 }
231 }
232
233 pub fn check(&self, value: f64) -> bool {
237 match self {
238 RefinementPredicate::Equal(v) => (value - v).abs() < f64::EPSILON,
239 RefinementPredicate::NotEqual(v) => (value - v).abs() >= f64::EPSILON,
240 RefinementPredicate::GreaterThan(v) => value > *v,
241 RefinementPredicate::GreaterThanOrEqual(v) => value >= *v,
242 RefinementPredicate::LessThan(v) => value < *v,
243 RefinementPredicate::LessThanOrEqual(v) => value <= *v,
244 RefinementPredicate::Range { min, max } => value >= *min && value <= *max,
245 RefinementPredicate::RangeExclusive { min, max } => value >= *min && value < *max,
246 RefinementPredicate::Modulo { divisor, remainder } => {
247 (value as i64) % divisor == *remainder
248 }
249 RefinementPredicate::InSet(set) => set.iter().any(|v| (value - v).abs() < f64::EPSILON),
250 RefinementPredicate::NotInSet(set) => {
251 !set.iter().any(|v| (value - v).abs() < f64::EPSILON)
252 }
253 RefinementPredicate::And(preds) => preds.iter().all(|p| p.check(value)),
254 RefinementPredicate::Or(preds) => preds.iter().any(|p| p.check(value)),
255 RefinementPredicate::Not(pred) => !pred.check(value),
256 RefinementPredicate::Custom { checker, .. } => checker(value),
257 RefinementPredicate::Dependent { .. } => true, RefinementPredicate::StringLength { .. } => true, RefinementPredicate::Pattern(_) => true, }
261 }
262
263 pub fn check_with_context(&self, value: f64, context: &RefinementContext) -> bool {
265 match self {
266 RefinementPredicate::Dependent { variable, relation } => {
267 if let Some(&other) = context.get_value(variable) {
268 match relation {
269 DependentRelation::LessThan => value < other,
270 DependentRelation::LessThanOrEqual => value <= other,
271 DependentRelation::GreaterThan => value > other,
272 DependentRelation::GreaterThanOrEqual => value >= other,
273 DependentRelation::Equal => (value - other).abs() < f64::EPSILON,
274 DependentRelation::NotEqual => (value - other).abs() >= f64::EPSILON,
275 DependentRelation::Divides => {
276 other != 0.0 && (other as i64) % (value as i64) == 0
277 }
278 DependentRelation::DivisibleBy => {
279 value != 0.0 && (value as i64) % (other as i64) == 0
280 }
281 }
282 } else {
283 false }
285 }
286 RefinementPredicate::And(preds) => {
287 preds.iter().all(|p| p.check_with_context(value, context))
288 }
289 RefinementPredicate::Or(preds) => {
290 preds.iter().any(|p| p.check_with_context(value, context))
291 }
292 RefinementPredicate::Not(pred) => !pred.check_with_context(value, context),
293 _ => self.check(value),
294 }
295 }
296
297 pub fn free_variables(&self) -> Vec<String> {
299 match self {
300 RefinementPredicate::Dependent { variable, .. } => vec![variable.clone()],
301 RefinementPredicate::And(preds) | RefinementPredicate::Or(preds) => {
302 let mut vars = Vec::new();
303 for pred in preds {
304 vars.extend(pred.free_variables());
305 }
306 vars.sort();
307 vars.dedup();
308 vars
309 }
310 RefinementPredicate::Not(pred) => pred.free_variables(),
311 _ => vec![],
312 }
313 }
314
315 pub fn simplify(&self) -> RefinementPredicate {
317 match self {
318 RefinementPredicate::And(preds) => {
319 let simplified: Vec<_> = preds.iter().map(|p| p.simplify()).collect();
320 if simplified.len() == 1 {
321 simplified.into_iter().next().unwrap()
322 } else {
323 let mut min_val = f64::NEG_INFINITY;
325 let mut max_val = f64::INFINITY;
326 let mut others = Vec::new();
327
328 for pred in simplified {
329 match pred {
330 RefinementPredicate::GreaterThan(v) => {
331 min_val = min_val.max(v);
332 }
333 RefinementPredicate::GreaterThanOrEqual(v) => {
334 min_val = min_val.max(v);
335 }
336 RefinementPredicate::LessThan(v) => {
337 max_val = max_val.min(v);
338 }
339 RefinementPredicate::LessThanOrEqual(v) => {
340 max_val = max_val.min(v);
341 }
342 RefinementPredicate::Range { min, max } => {
343 min_val = min_val.max(min);
344 max_val = max_val.min(max);
345 }
346 other => others.push(other),
347 }
348 }
349
350 if min_val > f64::NEG_INFINITY || max_val < f64::INFINITY {
352 if min_val > f64::NEG_INFINITY && max_val < f64::INFINITY {
353 others.insert(
354 0,
355 RefinementPredicate::Range {
356 min: min_val,
357 max: max_val,
358 },
359 );
360 } else if min_val > f64::NEG_INFINITY {
361 others.insert(0, RefinementPredicate::GreaterThanOrEqual(min_val));
362 } else {
363 others.insert(0, RefinementPredicate::LessThanOrEqual(max_val));
364 }
365 }
366
367 if others.len() == 1 {
368 others.into_iter().next().unwrap()
369 } else {
370 RefinementPredicate::And(others)
371 }
372 }
373 }
374 RefinementPredicate::Or(preds) => {
375 let simplified: Vec<_> = preds.iter().map(|p| p.simplify()).collect();
376 if simplified.len() == 1 {
377 simplified.into_iter().next().unwrap()
378 } else {
379 RefinementPredicate::Or(simplified)
380 }
381 }
382 RefinementPredicate::Not(pred) => {
383 let inner = pred.simplify();
384 match inner {
385 RefinementPredicate::Not(p) => *p, other => RefinementPredicate::Not(Box::new(other)),
387 }
388 }
389 other => other.clone(),
390 }
391 }
392
393 pub fn to_string_repr(&self) -> String {
395 match self {
396 RefinementPredicate::Equal(v) => format!("x == {}", v),
397 RefinementPredicate::NotEqual(v) => format!("x != {}", v),
398 RefinementPredicate::GreaterThan(v) => format!("x > {}", v),
399 RefinementPredicate::GreaterThanOrEqual(v) => format!("x >= {}", v),
400 RefinementPredicate::LessThan(v) => format!("x < {}", v),
401 RefinementPredicate::LessThanOrEqual(v) => format!("x <= {}", v),
402 RefinementPredicate::Range { min, max } => format!("{} <= x <= {}", min, max),
403 RefinementPredicate::RangeExclusive { min, max } => format!("{} <= x < {}", min, max),
404 RefinementPredicate::Modulo { divisor, remainder } => {
405 format!("x % {} == {}", divisor, remainder)
406 }
407 RefinementPredicate::InSet(set) => format!("x in {:?}", set),
408 RefinementPredicate::NotInSet(set) => format!("x not in {:?}", set),
409 RefinementPredicate::And(preds) => {
410 let parts: Vec<_> = preds.iter().map(|p| p.to_string_repr()).collect();
411 format!("({})", parts.join(" && "))
412 }
413 RefinementPredicate::Or(preds) => {
414 let parts: Vec<_> = preds.iter().map(|p| p.to_string_repr()).collect();
415 format!("({})", parts.join(" || "))
416 }
417 RefinementPredicate::Not(pred) => format!("!({})", pred.to_string_repr()),
418 RefinementPredicate::Custom { name, .. } => format!("{}(x)", name),
419 RefinementPredicate::Dependent { variable, relation } => {
420 let rel_str = match relation {
421 DependentRelation::LessThan => "<",
422 DependentRelation::LessThanOrEqual => "<=",
423 DependentRelation::GreaterThan => ">",
424 DependentRelation::GreaterThanOrEqual => ">=",
425 DependentRelation::Equal => "==",
426 DependentRelation::NotEqual => "!=",
427 DependentRelation::Divides => "divides",
428 DependentRelation::DivisibleBy => "divisible_by",
429 };
430 format!("x {} {}", rel_str, variable)
431 }
432 RefinementPredicate::StringLength { min, max } => match (min, max) {
433 (Some(min), Some(max)) => format!("{} <= len(x) <= {}", min, max),
434 (Some(min), None) => format!("len(x) >= {}", min),
435 (None, Some(max)) => format!("len(x) <= {}", max),
436 (None, None) => "true".to_string(),
437 },
438 RefinementPredicate::Pattern(pattern) => format!("x matches \"{}\"", pattern),
439 }
440 }
441}
442
443#[derive(Debug, Clone)]
445pub struct RefinementType {
446 pub base_type: String,
448 pub name: Option<String>,
450 pub predicates: Vec<RefinementPredicate>,
452 pub description: Option<String>,
454}
455
456impl RefinementType {
457 pub fn new(base_type: impl Into<String>) -> Self {
459 RefinementType {
460 base_type: base_type.into(),
461 name: None,
462 predicates: Vec::new(),
463 description: None,
464 }
465 }
466
467 pub fn with_name(mut self, name: impl Into<String>) -> Self {
469 self.name = Some(name.into());
470 self
471 }
472
473 pub fn with_predicate(mut self, predicate: RefinementPredicate) -> Self {
475 self.predicates.push(predicate);
476 self
477 }
478
479 pub fn with_description(mut self, description: impl Into<String>) -> Self {
481 self.description = Some(description.into());
482 self
483 }
484
485 pub fn check(&self, value: f64) -> bool {
487 self.predicates.iter().all(|p| p.check(value))
488 }
489
490 pub fn check_with_context(&self, value: f64, context: &RefinementContext) -> bool {
492 self.predicates
493 .iter()
494 .all(|p| p.check_with_context(value, context))
495 }
496
497 pub fn type_name(&self) -> &str {
499 self.name.as_deref().unwrap_or(&self.base_type)
500 }
501
502 pub fn is_subtype_of(&self, other: &RefinementType) -> bool {
511 if self.base_type != other.base_type {
512 return false;
513 }
514
515 if other.predicates.is_empty() {
517 return true;
518 }
519
520 if self.predicates.is_empty() && !other.predicates.is_empty() {
522 return false;
523 }
524
525 for other_pred in &other.predicates {
527 if !self.implies_predicate(other_pred) {
528 return false;
529 }
530 }
531
532 true
533 }
534
535 fn implies_predicate(&self, target: &RefinementPredicate) -> bool {
542 let target_repr = format!("{:?}", target);
545 if self
546 .predicates
547 .iter()
548 .any(|p| format!("{:?}", p) == target_repr)
549 {
550 return true;
551 }
552
553 for pred in &self.predicates {
555 if Self::semantic_implies(pred, target) {
556 return true;
557 }
558 }
559
560 Self::conjunction_implies(&self.predicates, target)
562 }
563
564 fn semantic_implies(source: &RefinementPredicate, target: &RefinementPredicate) -> bool {
566 use RefinementPredicate::*;
567
568 match (source, target) {
569 (
571 Range {
572 min: min1,
573 max: max1,
574 },
575 Range {
576 min: min2,
577 max: max2,
578 },
579 ) => {
580 min1 >= min2 && max1 <= max2
582 }
583 (
584 RangeExclusive {
585 min: min1,
586 max: max1,
587 },
588 RangeExclusive {
589 min: min2,
590 max: max2,
591 },
592 ) => min1 >= min2 && max1 <= max2,
593 (GreaterThan(v1), GreaterThan(v2)) => v1 >= v2,
595 (GreaterThanOrEqual(v1), GreaterThanOrEqual(v2)) => v1 >= v2,
596 (GreaterThan(v1), GreaterThanOrEqual(v2)) => v1 >= v2, (LessThan(v1), LessThan(v2)) => v1 <= v2,
599 (LessThanOrEqual(v1), LessThanOrEqual(v2)) => v1 <= v2,
600 (LessThan(v1), LessThanOrEqual(v2)) => v1 <= v2, (Equal(v1), GreaterThan(v2)) => v1 > v2,
603 (Equal(v1), GreaterThanOrEqual(v2)) => v1 >= v2,
604 (Equal(v1), LessThan(v2)) => v1 < v2,
605 (Equal(v1), LessThanOrEqual(v2)) => v1 <= v2,
606 (Equal(v1), Range { min, max }) => v1 >= min && v1 <= max,
607 (
609 Modulo {
610 divisor: d1,
611 remainder: r1,
612 },
613 Modulo {
614 divisor: d2,
615 remainder: r2,
616 },
617 ) => r1 == r2 && d1 % d2 == 0,
618 (
620 Dependent {
621 variable: v1,
622 relation: rel1,
623 },
624 Dependent {
625 variable: v2,
626 relation: rel2,
627 },
628 ) => {
629 if v1 != v2 {
630 return false;
631 }
632 use DependentRelation::*;
634 matches!(
635 (rel1, rel2),
636 (Equal, Equal)
637 | (GreaterThan, GreaterThan)
638 | (GreaterThan, GreaterThanOrEqual)
639 | (LessThan, LessThan)
640 | (LessThan, LessThanOrEqual)
641 | (GreaterThanOrEqual, GreaterThanOrEqual)
642 | (LessThanOrEqual, LessThanOrEqual)
643 )
644 }
645 _ => false,
646 }
647 }
648
649 fn conjunction_implies(
653 predicates: &[RefinementPredicate],
654 target: &RefinementPredicate,
655 ) -> bool {
656 use RefinementPredicate::*;
657
658 let mut lower_bounds = Vec::new();
660 let mut upper_bounds = Vec::new();
661
662 for pred in predicates {
663 match pred {
664 GreaterThan(v) | GreaterThanOrEqual(v) => {
665 lower_bounds.push(*v);
666 }
667 LessThan(v) | LessThanOrEqual(v) => {
668 upper_bounds.push(*v);
669 }
670 Range { min, max } => {
671 lower_bounds.push(*min);
672 upper_bounds.push(*max);
673 }
674 Equal(v) => {
675 lower_bounds.push(*v);
676 upper_bounds.push(*v);
677 }
678 _ => {}
679 }
680 }
681
682 match target {
684 GreaterThan(v) | GreaterThanOrEqual(v) => lower_bounds.iter().any(|lb| lb >= v),
685 LessThan(v) | LessThanOrEqual(v) => upper_bounds.iter().any(|ub| ub <= v),
686 Range { min, max } => {
687 lower_bounds.iter().any(|lb| lb >= min) && upper_bounds.iter().any(|ub| ub <= max)
688 }
689 _ => false,
690 }
691 }
692
693 pub fn free_variables(&self) -> Vec<String> {
695 let mut vars = Vec::new();
696 for pred in &self.predicates {
697 vars.extend(pred.free_variables());
698 }
699 vars.sort();
700 vars.dedup();
701 vars
702 }
703
704 pub fn to_string_repr(&self) -> String {
706 if self.predicates.is_empty() {
707 return self.base_type.clone();
708 }
709
710 let pred_strs: Vec<_> = self.predicates.iter().map(|p| p.to_string_repr()).collect();
711 format!("{}{{{}}}", self.base_type, pred_strs.join(" && "))
712 }
713}
714
715#[derive(Debug, Clone, Default)]
717pub struct RefinementContext {
718 values: HashMap<String, f64>,
720 types: HashMap<String, RefinementType>,
722}
723
724impl RefinementContext {
725 pub fn new() -> Self {
727 RefinementContext {
728 values: HashMap::new(),
729 types: HashMap::new(),
730 }
731 }
732
733 pub fn set_value(&mut self, var: impl Into<String>, value: f64) {
735 self.values.insert(var.into(), value);
736 }
737
738 pub fn get_value(&self, var: &str) -> Option<&f64> {
740 self.values.get(var)
741 }
742
743 pub fn set_type(&mut self, var: impl Into<String>, ty: RefinementType) {
745 self.types.insert(var.into(), ty);
746 }
747
748 pub fn get_type(&self, var: &str) -> Option<&RefinementType> {
750 self.types.get(var)
751 }
752
753 pub fn has_variable(&self, var: &str) -> bool {
755 self.values.contains_key(var) || self.types.contains_key(var)
756 }
757
758 pub fn variables(&self) -> Vec<&str> {
760 let mut vars: Vec<_> = self.values.keys().map(|s| s.as_str()).collect();
761 for key in self.types.keys() {
762 if !self.values.contains_key(key) {
763 vars.push(key.as_str());
764 }
765 }
766 vars
767 }
768}
769
770#[derive(Debug, Clone, Default)]
772pub struct RefinementRegistry {
773 types: HashMap<String, RefinementType>,
775}
776
777impl RefinementRegistry {
778 pub fn new() -> Self {
780 RefinementRegistry {
781 types: HashMap::new(),
782 }
783 }
784
785 pub fn with_builtins() -> Self {
787 let mut registry = RefinementRegistry::new();
788
789 registry.register(
791 RefinementType::new("Int")
792 .with_name("PositiveInt")
793 .with_predicate(RefinementPredicate::GreaterThan(0.0))
794 .with_description("Strictly positive integer"),
795 );
796
797 registry.register(
799 RefinementType::new("Int")
800 .with_name("NonNegativeInt")
801 .with_predicate(RefinementPredicate::GreaterThanOrEqual(0.0))
802 .with_description("Non-negative integer (zero or positive)"),
803 );
804
805 registry.register(
807 RefinementType::new("Float")
808 .with_name("Probability")
809 .with_predicate(RefinementPredicate::Range { min: 0.0, max: 1.0 })
810 .with_description("Probability value between 0 and 1"),
811 );
812
813 registry.register(
815 RefinementType::new("Float")
816 .with_name("Percentage")
817 .with_predicate(RefinementPredicate::Range {
818 min: 0.0,
819 max: 100.0,
820 })
821 .with_description("Percentage value between 0 and 100"),
822 );
823
824 registry.register(
826 RefinementType::new("Float")
827 .with_name("Normalized")
828 .with_predicate(RefinementPredicate::Range {
829 min: -1.0,
830 max: 1.0,
831 })
832 .with_description("Normalized value between -1 and 1"),
833 );
834
835 registry.register(
837 RefinementType::new("Int")
838 .with_name("Natural")
839 .with_predicate(RefinementPredicate::And(vec![
840 RefinementPredicate::GreaterThanOrEqual(0.0),
841 RefinementPredicate::Modulo {
842 divisor: 1,
843 remainder: 0,
844 },
845 ]))
846 .with_description("Natural number (non-negative integer)"),
847 );
848
849 registry.register(
851 RefinementType::new("Int")
852 .with_name("Even")
853 .with_predicate(RefinementPredicate::Modulo {
854 divisor: 2,
855 remainder: 0,
856 })
857 .with_description("Even integer"),
858 );
859
860 registry.register(
862 RefinementType::new("Int")
863 .with_name("Odd")
864 .with_predicate(RefinementPredicate::Modulo {
865 divisor: 2,
866 remainder: 1,
867 })
868 .with_description("Odd integer"),
869 );
870
871 registry
872 }
873
874 pub fn register(&mut self, refinement: RefinementType) {
876 let name = refinement.type_name().to_string();
877 self.types.insert(name, refinement);
878 }
879
880 pub fn get(&self, name: &str) -> Option<&RefinementType> {
882 self.types.get(name)
883 }
884
885 pub fn contains(&self, name: &str) -> bool {
887 self.types.contains_key(name)
888 }
889
890 pub fn type_names(&self) -> Vec<&str> {
892 self.types.keys().map(|s| s.as_str()).collect()
893 }
894
895 pub fn len(&self) -> usize {
897 self.types.len()
898 }
899
900 pub fn is_empty(&self) -> bool {
902 self.types.is_empty()
903 }
904
905 pub fn check(&self, type_name: &str, value: f64) -> Option<bool> {
907 self.types.get(type_name).map(|t| t.check(value))
908 }
909
910 pub fn iter(&self) -> impl Iterator<Item = (&str, &RefinementType)> {
912 self.types.iter().map(|(k, v)| (k.as_str(), v))
913 }
914}
915
916#[cfg(test)]
917mod tests {
918 use super::*;
919
920 #[test]
921 fn test_basic_predicates() {
922 let pred = RefinementPredicate::GreaterThan(0.0);
923 assert!(pred.check(5.0));
924 assert!(!pred.check(-1.0));
925 assert!(!pred.check(0.0));
926 }
927
928 #[test]
929 fn test_range_predicate() {
930 let pred = RefinementPredicate::Range { min: 0.0, max: 1.0 };
931 assert!(pred.check(0.5));
932 assert!(pred.check(0.0));
933 assert!(pred.check(1.0));
934 assert!(!pred.check(-0.1));
935 assert!(!pred.check(1.1));
936 }
937
938 #[test]
939 fn test_modulo_predicate() {
940 let even = RefinementPredicate::Modulo {
941 divisor: 2,
942 remainder: 0,
943 };
944 assert!(even.check(4.0));
945 assert!(even.check(0.0));
946 assert!(!even.check(3.0));
947 }
948
949 #[test]
950 fn test_compound_predicates() {
951 let pred = RefinementPredicate::And(vec![
953 RefinementPredicate::GreaterThan(0.0),
954 RefinementPredicate::Modulo {
955 divisor: 2,
956 remainder: 0,
957 },
958 ]);
959
960 assert!(pred.check(4.0));
961 assert!(!pred.check(-2.0)); assert!(!pred.check(3.0)); }
964
965 #[test]
966 fn test_in_set_predicate() {
967 let pred = RefinementPredicate::InSet(vec![1.0, 2.0, 3.0]);
968 assert!(pred.check(1.0));
969 assert!(pred.check(2.0));
970 assert!(!pred.check(4.0));
971 }
972
973 #[test]
974 fn test_custom_predicate() {
975 let pred = RefinementPredicate::custom("is_prime", "Checks if number is prime", |n| {
976 if n < 2.0 {
977 return false;
978 }
979 let n = n as i64;
980 for i in 2..=((n as f64).sqrt() as i64) {
981 if n % i == 0 {
982 return false;
983 }
984 }
985 true
986 });
987
988 assert!(pred.check(2.0));
989 assert!(pred.check(7.0));
990 assert!(!pred.check(4.0));
991 assert!(!pred.check(1.0));
992 }
993
994 #[test]
995 fn test_refinement_type() {
996 let pos_int = RefinementType::new("Int")
997 .with_name("PositiveInt")
998 .with_predicate(RefinementPredicate::GreaterThan(0.0));
999
1000 assert_eq!(pos_int.type_name(), "PositiveInt");
1001 assert!(pos_int.check(5.0));
1002 assert!(!pos_int.check(-1.0));
1003 }
1004
1005 #[test]
1006 fn test_dependent_predicate() {
1007 let pred = RefinementPredicate::Dependent {
1008 variable: "n".to_string(),
1009 relation: DependentRelation::LessThan,
1010 };
1011
1012 let mut context = RefinementContext::new();
1013 context.set_value("n", 10.0);
1014
1015 assert!(pred.check_with_context(5.0, &context));
1016 assert!(!pred.check_with_context(15.0, &context));
1017 }
1018
1019 #[test]
1020 fn test_registry_builtins() {
1021 let registry = RefinementRegistry::with_builtins();
1022
1023 assert!(registry.check("PositiveInt", 5.0).unwrap());
1025 assert!(!registry.check("PositiveInt", -1.0).unwrap());
1026
1027 assert!(registry.check("Probability", 0.5).unwrap());
1029 assert!(!registry.check("Probability", 1.5).unwrap());
1030
1031 assert!(registry.check("Even", 4.0).unwrap());
1033 assert!(!registry.check("Even", 3.0).unwrap());
1034 }
1035
1036 #[test]
1037 fn test_predicate_simplification() {
1038 let pred = RefinementPredicate::And(vec![
1039 RefinementPredicate::GreaterThan(0.0),
1040 RefinementPredicate::LessThan(10.0),
1041 RefinementPredicate::GreaterThanOrEqual(1.0),
1042 ]);
1043
1044 let simplified = pred.simplify();
1045
1046 assert!(simplified.check(5.0));
1049 assert!(!simplified.check(0.0));
1050 assert!(simplified.check(1.0)); }
1053
1054 #[test]
1055 fn test_predicate_string_repr() {
1056 let pred = RefinementPredicate::Range { min: 0.0, max: 1.0 };
1057 assert_eq!(pred.to_string_repr(), "0 <= x <= 1");
1058
1059 let pred = RefinementPredicate::And(vec![
1060 RefinementPredicate::GreaterThan(0.0),
1061 RefinementPredicate::LessThan(10.0),
1062 ]);
1063 assert_eq!(pred.to_string_repr(), "(x > 0 && x < 10)");
1064 }
1065
1066 #[test]
1067 fn test_free_variables() {
1068 let pred = RefinementPredicate::And(vec![
1069 RefinementPredicate::GreaterThan(0.0),
1070 RefinementPredicate::Dependent {
1071 variable: "n".to_string(),
1072 relation: DependentRelation::LessThan,
1073 },
1074 RefinementPredicate::Dependent {
1075 variable: "m".to_string(),
1076 relation: DependentRelation::GreaterThan,
1077 },
1078 ]);
1079
1080 let vars = pred.free_variables();
1081 assert_eq!(vars.len(), 2);
1082 assert!(vars.contains(&"m".to_string()));
1083 assert!(vars.contains(&"n".to_string()));
1084 }
1085
1086 #[test]
1087 fn test_refinement_type_repr() {
1088 let ty = RefinementType::new("Int")
1089 .with_name("BoundedInt")
1090 .with_predicate(RefinementPredicate::Range {
1091 min: 0.0,
1092 max: 100.0,
1093 });
1094
1095 assert_eq!(ty.to_string_repr(), "Int{0 <= x <= 100}");
1096 }
1097
1098 #[test]
1099 fn test_context_operations() {
1100 let mut ctx = RefinementContext::new();
1101
1102 ctx.set_value("x", 5.0);
1103 ctx.set_value("y", 10.0);
1104
1105 assert_eq!(ctx.get_value("x"), Some(&5.0));
1106 assert!(ctx.has_variable("x"));
1107 assert!(!ctx.has_variable("z"));
1108
1109 let vars = ctx.variables();
1110 assert_eq!(vars.len(), 2);
1111 }
1112
1113 #[test]
1114 fn test_negation_predicate() {
1115 let pred = RefinementPredicate::Not(Box::new(RefinementPredicate::Equal(0.0)));
1116
1117 assert!(pred.check(5.0));
1118 assert!(!pred.check(0.0));
1119 }
1120
1121 #[test]
1122 fn test_or_predicate() {
1123 let pred = RefinementPredicate::Or(vec![
1124 RefinementPredicate::LessThan(0.0),
1125 RefinementPredicate::GreaterThan(10.0),
1126 ]);
1127
1128 assert!(pred.check(-5.0));
1129 assert!(pred.check(15.0));
1130 assert!(!pred.check(5.0));
1131 }
1132
1133 #[test]
1134 fn test_double_negation_simplification() {
1135 let pred = RefinementPredicate::Not(Box::new(RefinementPredicate::Not(Box::new(
1136 RefinementPredicate::GreaterThan(0.0),
1137 ))));
1138
1139 let simplified = pred.simplify();
1140 assert!(simplified.check(5.0));
1141 assert!(!simplified.check(-1.0));
1142 }
1143
1144 #[test]
1145 fn test_registry_custom_type() {
1146 let mut registry = RefinementRegistry::new();
1147
1148 registry.register(
1149 RefinementType::new("Float")
1150 .with_name("SmallPositive")
1151 .with_predicate(RefinementPredicate::Range {
1152 min: 0.0,
1153 max: 1e-6,
1154 }),
1155 );
1156
1157 assert!(registry.contains("SmallPositive"));
1158 assert!(registry.check("SmallPositive", 1e-7).unwrap());
1159 assert!(!registry.check("SmallPositive", 1.0).unwrap());
1160 }
1161
1162 #[test]
1165 fn test_subtyping_basic() {
1166 let int_type = RefinementType::new("Int");
1168 let float_type = RefinementType::new("Float");
1169
1170 assert!(!int_type.is_subtype_of(&float_type)); assert!(int_type.is_subtype_of(&int_type)); }
1173
1174 #[test]
1175 fn test_subtyping_range_implication() {
1176 let stricter = RefinementType::new("Int").with_predicate(RefinementPredicate::Range {
1178 min: 5.0,
1179 max: 10.0,
1180 });
1181
1182 let looser = RefinementType::new("Int").with_predicate(RefinementPredicate::Range {
1183 min: 0.0,
1184 max: 15.0,
1185 });
1186
1187 assert!(stricter.is_subtype_of(&looser));
1188 assert!(!looser.is_subtype_of(&stricter)); }
1190
1191 #[test]
1192 fn test_subtyping_greater_than_implication() {
1193 let stricter =
1195 RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(10.0));
1196
1197 let looser =
1198 RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
1199
1200 assert!(stricter.is_subtype_of(&looser));
1201 assert!(!looser.is_subtype_of(&stricter));
1202 }
1203
1204 #[test]
1205 fn test_subtyping_less_than_implication() {
1206 let stricter =
1208 RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(5.0));
1209
1210 let looser = RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(10.0));
1211
1212 assert!(stricter.is_subtype_of(&looser));
1213 assert!(!looser.is_subtype_of(&stricter));
1214 }
1215
1216 #[test]
1217 fn test_subtyping_modulo_implication() {
1218 let divisible_by_4 =
1220 RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
1221 divisor: 4,
1222 remainder: 0,
1223 });
1224
1225 let divisible_by_2 =
1226 RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
1227 divisor: 2,
1228 remainder: 0,
1229 });
1230
1231 assert!(divisible_by_4.is_subtype_of(&divisible_by_2));
1232 assert!(!divisible_by_2.is_subtype_of(&divisible_by_4));
1233 }
1234
1235 #[test]
1236 fn test_subtyping_conjunction() {
1237 let bounded = RefinementType::new("Int")
1239 .with_predicate(RefinementPredicate::GreaterThan(5.0))
1240 .with_predicate(RefinementPredicate::LessThan(10.0));
1241
1242 let positive =
1243 RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(0.0));
1244
1245 assert!(bounded.is_subtype_of(&positive));
1246 }
1247
1248 #[test]
1249 fn test_subtyping_equality_implies_bounds() {
1250 let exact = RefinementType::new("Int").with_predicate(RefinementPredicate::Equal(7.0));
1252
1253 let gt_5 = RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
1254
1255 let lt_10 = RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(10.0));
1256
1257 assert!(exact.is_subtype_of(>_5));
1258 assert!(exact.is_subtype_of(<_10));
1259 }
1260
1261 #[test]
1262 fn test_subtyping_no_implication() {
1263 let even = RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
1265 divisor: 2,
1266 remainder: 0,
1267 });
1268
1269 let gt_5 = RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
1270
1271 assert!(!even.is_subtype_of(>_5));
1272 assert!(!gt_5.is_subtype_of(&even));
1273 }
1274}