1use std::collections::HashMap;
30use std::sync::Arc;
31
32#[derive(Clone)]
34pub enum RefinementPredicate {
35 Equal(f64),
37 NotEqual(f64),
39 GreaterThan(f64),
41 GreaterThanOrEqual(f64),
43 LessThan(f64),
45 LessThanOrEqual(f64),
47 Range { min: f64, max: f64 },
49 RangeExclusive { min: f64, max: f64 },
51 Modulo { divisor: i64, remainder: i64 },
53 InSet(Vec<f64>),
55 NotInSet(Vec<f64>),
57 And(Vec<RefinementPredicate>),
59 Or(Vec<RefinementPredicate>),
61 Not(Box<RefinementPredicate>),
63 Custom {
65 name: String,
66 description: String,
67 checker: Arc<dyn Fn(f64) -> bool + Send + Sync>,
68 },
69 Dependent {
71 variable: String,
72 relation: DependentRelation,
73 },
74 StringLength {
76 min: Option<usize>,
77 max: Option<usize>,
78 },
79 Pattern(String),
81}
82
83#[derive(Debug, Clone, PartialEq)]
85pub enum DependentRelation {
86 LessThan,
88 LessThanOrEqual,
90 GreaterThan,
92 GreaterThanOrEqual,
94 Equal,
96 NotEqual,
98 Divides,
100 DivisibleBy,
102}
103
104impl std::fmt::Debug for RefinementPredicate {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 match self {
107 RefinementPredicate::Equal(v) => f.debug_tuple("Equal").field(v).finish(),
108 RefinementPredicate::NotEqual(v) => f.debug_tuple("NotEqual").field(v).finish(),
109 RefinementPredicate::GreaterThan(v) => f.debug_tuple("GreaterThan").field(v).finish(),
110 RefinementPredicate::GreaterThanOrEqual(v) => {
111 f.debug_tuple("GreaterThanOrEqual").field(v).finish()
112 }
113 RefinementPredicate::LessThan(v) => f.debug_tuple("LessThan").field(v).finish(),
114 RefinementPredicate::LessThanOrEqual(v) => {
115 f.debug_tuple("LessThanOrEqual").field(v).finish()
116 }
117 RefinementPredicate::Range { min, max } => f
118 .debug_struct("Range")
119 .field("min", min)
120 .field("max", max)
121 .finish(),
122 RefinementPredicate::RangeExclusive { min, max } => f
123 .debug_struct("RangeExclusive")
124 .field("min", min)
125 .field("max", max)
126 .finish(),
127 RefinementPredicate::Modulo { divisor, remainder } => f
128 .debug_struct("Modulo")
129 .field("divisor", divisor)
130 .field("remainder", remainder)
131 .finish(),
132 RefinementPredicate::InSet(set) => f.debug_tuple("InSet").field(set).finish(),
133 RefinementPredicate::NotInSet(set) => f.debug_tuple("NotInSet").field(set).finish(),
134 RefinementPredicate::And(preds) => f.debug_tuple("And").field(preds).finish(),
135 RefinementPredicate::Or(preds) => f.debug_tuple("Or").field(preds).finish(),
136 RefinementPredicate::Not(pred) => f.debug_tuple("Not").field(pred).finish(),
137 RefinementPredicate::Custom {
138 name, description, ..
139 } => f
140 .debug_struct("Custom")
141 .field("name", name)
142 .field("description", description)
143 .finish(),
144 RefinementPredicate::Dependent { variable, relation } => f
145 .debug_struct("Dependent")
146 .field("variable", variable)
147 .field("relation", relation)
148 .finish(),
149 RefinementPredicate::StringLength { min, max } => f
150 .debug_struct("StringLength")
151 .field("min", min)
152 .field("max", max)
153 .finish(),
154 RefinementPredicate::Pattern(pattern) => {
155 f.debug_tuple("Pattern").field(pattern).finish()
156 }
157 }
158 }
159}
160
161impl RefinementPredicate {
162 pub fn greater_than(value: f64) -> Self {
164 RefinementPredicate::GreaterThan(value)
165 }
166
167 pub fn greater_than_or_equal(value: f64) -> Self {
169 RefinementPredicate::GreaterThanOrEqual(value)
170 }
171
172 pub fn less_than(value: f64) -> Self {
174 RefinementPredicate::LessThan(value)
175 }
176
177 pub fn less_than_or_equal(value: f64) -> Self {
179 RefinementPredicate::LessThanOrEqual(value)
180 }
181
182 pub fn range(min: f64, max: f64) -> Self {
184 RefinementPredicate::Range { min, max }
185 }
186
187 pub fn modulo(divisor: i64, remainder: i64) -> Self {
189 RefinementPredicate::Modulo { divisor, remainder }
190 }
191
192 pub fn in_set(values: Vec<f64>) -> Self {
194 RefinementPredicate::InSet(values)
195 }
196
197 pub fn and(predicates: Vec<RefinementPredicate>) -> Self {
199 RefinementPredicate::And(predicates)
200 }
201
202 pub fn or(predicates: Vec<RefinementPredicate>) -> Self {
204 RefinementPredicate::Or(predicates)
205 }
206
207 #[allow(clippy::should_implement_trait)]
209 pub fn not(predicate: RefinementPredicate) -> Self {
210 RefinementPredicate::Not(Box::new(predicate))
211 }
212
213 pub fn custom<F>(name: impl Into<String>, description: impl Into<String>, checker: F) -> Self
215 where
216 F: Fn(f64) -> bool + Send + Sync + 'static,
217 {
218 RefinementPredicate::Custom {
219 name: name.into(),
220 description: description.into(),
221 checker: Arc::new(checker),
222 }
223 }
224
225 pub fn dependent(variable: impl Into<String>, relation: DependentRelation) -> Self {
227 RefinementPredicate::Dependent {
228 variable: variable.into(),
229 relation,
230 }
231 }
232
233 pub fn check(&self, value: f64) -> bool {
237 match self {
238 RefinementPredicate::Equal(v) => (value - v).abs() < f64::EPSILON,
239 RefinementPredicate::NotEqual(v) => (value - v).abs() >= f64::EPSILON,
240 RefinementPredicate::GreaterThan(v) => value > *v,
241 RefinementPredicate::GreaterThanOrEqual(v) => value >= *v,
242 RefinementPredicate::LessThan(v) => value < *v,
243 RefinementPredicate::LessThanOrEqual(v) => value <= *v,
244 RefinementPredicate::Range { min, max } => value >= *min && value <= *max,
245 RefinementPredicate::RangeExclusive { min, max } => value >= *min && value < *max,
246 RefinementPredicate::Modulo { divisor, remainder } => {
247 (value as i64) % divisor == *remainder
248 }
249 RefinementPredicate::InSet(set) => set.iter().any(|v| (value - v).abs() < f64::EPSILON),
250 RefinementPredicate::NotInSet(set) => {
251 !set.iter().any(|v| (value - v).abs() < f64::EPSILON)
252 }
253 RefinementPredicate::And(preds) => preds.iter().all(|p| p.check(value)),
254 RefinementPredicate::Or(preds) => preds.iter().any(|p| p.check(value)),
255 RefinementPredicate::Not(pred) => !pred.check(value),
256 RefinementPredicate::Custom { checker, .. } => checker(value),
257 RefinementPredicate::Dependent { .. } => true, RefinementPredicate::StringLength { .. } => true, RefinementPredicate::Pattern(_) => true, }
261 }
262
263 pub fn check_with_context(&self, value: f64, context: &RefinementContext) -> bool {
265 match self {
266 RefinementPredicate::Dependent { variable, relation } => {
267 if let Some(&other) = context.get_value(variable) {
268 match relation {
269 DependentRelation::LessThan => value < other,
270 DependentRelation::LessThanOrEqual => value <= other,
271 DependentRelation::GreaterThan => value > other,
272 DependentRelation::GreaterThanOrEqual => value >= other,
273 DependentRelation::Equal => (value - other).abs() < f64::EPSILON,
274 DependentRelation::NotEqual => (value - other).abs() >= f64::EPSILON,
275 DependentRelation::Divides => {
276 other != 0.0 && (other as i64) % (value as i64) == 0
277 }
278 DependentRelation::DivisibleBy => {
279 value != 0.0 && (value as i64) % (other as i64) == 0
280 }
281 }
282 } else {
283 false }
285 }
286 RefinementPredicate::And(preds) => {
287 preds.iter().all(|p| p.check_with_context(value, context))
288 }
289 RefinementPredicate::Or(preds) => {
290 preds.iter().any(|p| p.check_with_context(value, context))
291 }
292 RefinementPredicate::Not(pred) => !pred.check_with_context(value, context),
293 _ => self.check(value),
294 }
295 }
296
297 pub fn free_variables(&self) -> Vec<String> {
299 match self {
300 RefinementPredicate::Dependent { variable, .. } => vec![variable.clone()],
301 RefinementPredicate::And(preds) | RefinementPredicate::Or(preds) => {
302 let mut vars = Vec::new();
303 for pred in preds {
304 vars.extend(pred.free_variables());
305 }
306 vars.sort();
307 vars.dedup();
308 vars
309 }
310 RefinementPredicate::Not(pred) => pred.free_variables(),
311 _ => vec![],
312 }
313 }
314
315 pub fn simplify(&self) -> RefinementPredicate {
317 match self {
318 RefinementPredicate::And(preds) => {
319 let simplified: Vec<_> = preds.iter().map(|p| p.simplify()).collect();
320 if simplified.len() == 1 {
321 simplified
322 .into_iter()
323 .next()
324 .expect("validated length == 1")
325 } else {
326 let mut min_val = f64::NEG_INFINITY;
328 let mut max_val = f64::INFINITY;
329 let mut others = Vec::new();
330
331 for pred in simplified {
332 match pred {
333 RefinementPredicate::GreaterThan(v) => {
334 min_val = min_val.max(v);
335 }
336 RefinementPredicate::GreaterThanOrEqual(v) => {
337 min_val = min_val.max(v);
338 }
339 RefinementPredicate::LessThan(v) => {
340 max_val = max_val.min(v);
341 }
342 RefinementPredicate::LessThanOrEqual(v) => {
343 max_val = max_val.min(v);
344 }
345 RefinementPredicate::Range { min, max } => {
346 min_val = min_val.max(min);
347 max_val = max_val.min(max);
348 }
349 other => others.push(other),
350 }
351 }
352
353 if min_val > f64::NEG_INFINITY || max_val < f64::INFINITY {
355 if min_val > f64::NEG_INFINITY && max_val < f64::INFINITY {
356 others.insert(
357 0,
358 RefinementPredicate::Range {
359 min: min_val,
360 max: max_val,
361 },
362 );
363 } else if min_val > f64::NEG_INFINITY {
364 others.insert(0, RefinementPredicate::GreaterThanOrEqual(min_val));
365 } else {
366 others.insert(0, RefinementPredicate::LessThanOrEqual(max_val));
367 }
368 }
369
370 if others.len() == 1 {
371 others.into_iter().next().expect("validated length == 1")
372 } else {
373 RefinementPredicate::And(others)
374 }
375 }
376 }
377 RefinementPredicate::Or(preds) => {
378 let simplified: Vec<_> = preds.iter().map(|p| p.simplify()).collect();
379 if simplified.len() == 1 {
380 simplified
381 .into_iter()
382 .next()
383 .expect("validated length == 1")
384 } else {
385 RefinementPredicate::Or(simplified)
386 }
387 }
388 RefinementPredicate::Not(pred) => {
389 let inner = pred.simplify();
390 match inner {
391 RefinementPredicate::Not(p) => *p, other => RefinementPredicate::Not(Box::new(other)),
393 }
394 }
395 other => other.clone(),
396 }
397 }
398
399 pub fn to_string_repr(&self) -> String {
401 match self {
402 RefinementPredicate::Equal(v) => format!("x == {}", v),
403 RefinementPredicate::NotEqual(v) => format!("x != {}", v),
404 RefinementPredicate::GreaterThan(v) => format!("x > {}", v),
405 RefinementPredicate::GreaterThanOrEqual(v) => format!("x >= {}", v),
406 RefinementPredicate::LessThan(v) => format!("x < {}", v),
407 RefinementPredicate::LessThanOrEqual(v) => format!("x <= {}", v),
408 RefinementPredicate::Range { min, max } => format!("{} <= x <= {}", min, max),
409 RefinementPredicate::RangeExclusive { min, max } => format!("{} <= x < {}", min, max),
410 RefinementPredicate::Modulo { divisor, remainder } => {
411 format!("x % {} == {}", divisor, remainder)
412 }
413 RefinementPredicate::InSet(set) => format!("x in {:?}", set),
414 RefinementPredicate::NotInSet(set) => format!("x not in {:?}", set),
415 RefinementPredicate::And(preds) => {
416 let parts: Vec<_> = preds.iter().map(|p| p.to_string_repr()).collect();
417 format!("({})", parts.join(" && "))
418 }
419 RefinementPredicate::Or(preds) => {
420 let parts: Vec<_> = preds.iter().map(|p| p.to_string_repr()).collect();
421 format!("({})", parts.join(" || "))
422 }
423 RefinementPredicate::Not(pred) => format!("!({})", pred.to_string_repr()),
424 RefinementPredicate::Custom { name, .. } => format!("{}(x)", name),
425 RefinementPredicate::Dependent { variable, relation } => {
426 let rel_str = match relation {
427 DependentRelation::LessThan => "<",
428 DependentRelation::LessThanOrEqual => "<=",
429 DependentRelation::GreaterThan => ">",
430 DependentRelation::GreaterThanOrEqual => ">=",
431 DependentRelation::Equal => "==",
432 DependentRelation::NotEqual => "!=",
433 DependentRelation::Divides => "divides",
434 DependentRelation::DivisibleBy => "divisible_by",
435 };
436 format!("x {} {}", rel_str, variable)
437 }
438 RefinementPredicate::StringLength { min, max } => match (min, max) {
439 (Some(min), Some(max)) => format!("{} <= len(x) <= {}", min, max),
440 (Some(min), None) => format!("len(x) >= {}", min),
441 (None, Some(max)) => format!("len(x) <= {}", max),
442 (None, None) => "true".to_string(),
443 },
444 RefinementPredicate::Pattern(pattern) => format!("x matches \"{}\"", pattern),
445 }
446 }
447}
448
449#[derive(Debug, Clone)]
451pub struct RefinementType {
452 pub base_type: String,
454 pub name: Option<String>,
456 pub predicates: Vec<RefinementPredicate>,
458 pub description: Option<String>,
460}
461
462impl RefinementType {
463 pub fn new(base_type: impl Into<String>) -> Self {
465 RefinementType {
466 base_type: base_type.into(),
467 name: None,
468 predicates: Vec::new(),
469 description: None,
470 }
471 }
472
473 pub fn with_name(mut self, name: impl Into<String>) -> Self {
475 self.name = Some(name.into());
476 self
477 }
478
479 pub fn with_predicate(mut self, predicate: RefinementPredicate) -> Self {
481 self.predicates.push(predicate);
482 self
483 }
484
485 pub fn with_description(mut self, description: impl Into<String>) -> Self {
487 self.description = Some(description.into());
488 self
489 }
490
491 pub fn check(&self, value: f64) -> bool {
493 self.predicates.iter().all(|p| p.check(value))
494 }
495
496 pub fn check_with_context(&self, value: f64, context: &RefinementContext) -> bool {
498 self.predicates
499 .iter()
500 .all(|p| p.check_with_context(value, context))
501 }
502
503 pub fn type_name(&self) -> &str {
505 self.name.as_deref().unwrap_or(&self.base_type)
506 }
507
508 pub fn is_subtype_of(&self, other: &RefinementType) -> bool {
517 if self.base_type != other.base_type {
518 return false;
519 }
520
521 if other.predicates.is_empty() {
523 return true;
524 }
525
526 if self.predicates.is_empty() && !other.predicates.is_empty() {
528 return false;
529 }
530
531 for other_pred in &other.predicates {
533 if !self.implies_predicate(other_pred) {
534 return false;
535 }
536 }
537
538 true
539 }
540
541 fn implies_predicate(&self, target: &RefinementPredicate) -> bool {
548 let target_repr = format!("{:?}", target);
551 if self
552 .predicates
553 .iter()
554 .any(|p| format!("{:?}", p) == target_repr)
555 {
556 return true;
557 }
558
559 for pred in &self.predicates {
561 if Self::semantic_implies(pred, target) {
562 return true;
563 }
564 }
565
566 Self::conjunction_implies(&self.predicates, target)
568 }
569
570 fn semantic_implies(source: &RefinementPredicate, target: &RefinementPredicate) -> bool {
572 use RefinementPredicate::*;
573
574 match (source, target) {
575 (
577 Range {
578 min: min1,
579 max: max1,
580 },
581 Range {
582 min: min2,
583 max: max2,
584 },
585 ) => {
586 min1 >= min2 && max1 <= max2
588 }
589 (
590 RangeExclusive {
591 min: min1,
592 max: max1,
593 },
594 RangeExclusive {
595 min: min2,
596 max: max2,
597 },
598 ) => min1 >= min2 && max1 <= max2,
599 (GreaterThan(v1), GreaterThan(v2)) => v1 >= v2,
601 (GreaterThanOrEqual(v1), GreaterThanOrEqual(v2)) => v1 >= v2,
602 (GreaterThan(v1), GreaterThanOrEqual(v2)) => v1 >= v2, (LessThan(v1), LessThan(v2)) => v1 <= v2,
605 (LessThanOrEqual(v1), LessThanOrEqual(v2)) => v1 <= v2,
606 (LessThan(v1), LessThanOrEqual(v2)) => v1 <= v2, (Equal(v1), GreaterThan(v2)) => v1 > v2,
609 (Equal(v1), GreaterThanOrEqual(v2)) => v1 >= v2,
610 (Equal(v1), LessThan(v2)) => v1 < v2,
611 (Equal(v1), LessThanOrEqual(v2)) => v1 <= v2,
612 (Equal(v1), Range { min, max }) => v1 >= min && v1 <= max,
613 (
615 Modulo {
616 divisor: d1,
617 remainder: r1,
618 },
619 Modulo {
620 divisor: d2,
621 remainder: r2,
622 },
623 ) => r1 == r2 && d1 % d2 == 0,
624 (
626 Dependent {
627 variable: v1,
628 relation: rel1,
629 },
630 Dependent {
631 variable: v2,
632 relation: rel2,
633 },
634 ) => {
635 if v1 != v2 {
636 return false;
637 }
638 use DependentRelation::*;
640 matches!(
641 (rel1, rel2),
642 (Equal, Equal)
643 | (GreaterThan, GreaterThan)
644 | (GreaterThan, GreaterThanOrEqual)
645 | (LessThan, LessThan)
646 | (LessThan, LessThanOrEqual)
647 | (GreaterThanOrEqual, GreaterThanOrEqual)
648 | (LessThanOrEqual, LessThanOrEqual)
649 )
650 }
651 _ => false,
652 }
653 }
654
655 fn conjunction_implies(
659 predicates: &[RefinementPredicate],
660 target: &RefinementPredicate,
661 ) -> bool {
662 use RefinementPredicate::*;
663
664 let mut lower_bounds = Vec::new();
666 let mut upper_bounds = Vec::new();
667
668 for pred in predicates {
669 match pred {
670 GreaterThan(v) | GreaterThanOrEqual(v) => {
671 lower_bounds.push(*v);
672 }
673 LessThan(v) | LessThanOrEqual(v) => {
674 upper_bounds.push(*v);
675 }
676 Range { min, max } => {
677 lower_bounds.push(*min);
678 upper_bounds.push(*max);
679 }
680 Equal(v) => {
681 lower_bounds.push(*v);
682 upper_bounds.push(*v);
683 }
684 _ => {}
685 }
686 }
687
688 match target {
690 GreaterThan(v) | GreaterThanOrEqual(v) => lower_bounds.iter().any(|lb| lb >= v),
691 LessThan(v) | LessThanOrEqual(v) => upper_bounds.iter().any(|ub| ub <= v),
692 Range { min, max } => {
693 lower_bounds.iter().any(|lb| lb >= min) && upper_bounds.iter().any(|ub| ub <= max)
694 }
695 _ => false,
696 }
697 }
698
699 pub fn free_variables(&self) -> Vec<String> {
701 let mut vars = Vec::new();
702 for pred in &self.predicates {
703 vars.extend(pred.free_variables());
704 }
705 vars.sort();
706 vars.dedup();
707 vars
708 }
709
710 pub fn to_string_repr(&self) -> String {
712 if self.predicates.is_empty() {
713 return self.base_type.clone();
714 }
715
716 let pred_strs: Vec<_> = self.predicates.iter().map(|p| p.to_string_repr()).collect();
717 format!("{}{{{}}}", self.base_type, pred_strs.join(" && "))
718 }
719}
720
721#[derive(Debug, Clone, Default)]
723pub struct RefinementContext {
724 values: HashMap<String, f64>,
726 types: HashMap<String, RefinementType>,
728}
729
730impl RefinementContext {
731 pub fn new() -> Self {
733 RefinementContext {
734 values: HashMap::new(),
735 types: HashMap::new(),
736 }
737 }
738
739 pub fn set_value(&mut self, var: impl Into<String>, value: f64) {
741 self.values.insert(var.into(), value);
742 }
743
744 pub fn get_value(&self, var: &str) -> Option<&f64> {
746 self.values.get(var)
747 }
748
749 pub fn set_type(&mut self, var: impl Into<String>, ty: RefinementType) {
751 self.types.insert(var.into(), ty);
752 }
753
754 pub fn get_type(&self, var: &str) -> Option<&RefinementType> {
756 self.types.get(var)
757 }
758
759 pub fn has_variable(&self, var: &str) -> bool {
761 self.values.contains_key(var) || self.types.contains_key(var)
762 }
763
764 pub fn variables(&self) -> Vec<&str> {
766 let mut vars: Vec<_> = self.values.keys().map(|s| s.as_str()).collect();
767 for key in self.types.keys() {
768 if !self.values.contains_key(key) {
769 vars.push(key.as_str());
770 }
771 }
772 vars
773 }
774}
775
776#[derive(Debug, Clone, Default)]
778pub struct RefinementRegistry {
779 types: HashMap<String, RefinementType>,
781}
782
783impl RefinementRegistry {
784 pub fn new() -> Self {
786 RefinementRegistry {
787 types: HashMap::new(),
788 }
789 }
790
791 pub fn with_builtins() -> Self {
793 let mut registry = RefinementRegistry::new();
794
795 registry.register(
797 RefinementType::new("Int")
798 .with_name("PositiveInt")
799 .with_predicate(RefinementPredicate::GreaterThan(0.0))
800 .with_description("Strictly positive integer"),
801 );
802
803 registry.register(
805 RefinementType::new("Int")
806 .with_name("NonNegativeInt")
807 .with_predicate(RefinementPredicate::GreaterThanOrEqual(0.0))
808 .with_description("Non-negative integer (zero or positive)"),
809 );
810
811 registry.register(
813 RefinementType::new("Float")
814 .with_name("Probability")
815 .with_predicate(RefinementPredicate::Range { min: 0.0, max: 1.0 })
816 .with_description("Probability value between 0 and 1"),
817 );
818
819 registry.register(
821 RefinementType::new("Float")
822 .with_name("Percentage")
823 .with_predicate(RefinementPredicate::Range {
824 min: 0.0,
825 max: 100.0,
826 })
827 .with_description("Percentage value between 0 and 100"),
828 );
829
830 registry.register(
832 RefinementType::new("Float")
833 .with_name("Normalized")
834 .with_predicate(RefinementPredicate::Range {
835 min: -1.0,
836 max: 1.0,
837 })
838 .with_description("Normalized value between -1 and 1"),
839 );
840
841 registry.register(
843 RefinementType::new("Int")
844 .with_name("Natural")
845 .with_predicate(RefinementPredicate::And(vec![
846 RefinementPredicate::GreaterThanOrEqual(0.0),
847 RefinementPredicate::Modulo {
848 divisor: 1,
849 remainder: 0,
850 },
851 ]))
852 .with_description("Natural number (non-negative integer)"),
853 );
854
855 registry.register(
857 RefinementType::new("Int")
858 .with_name("Even")
859 .with_predicate(RefinementPredicate::Modulo {
860 divisor: 2,
861 remainder: 0,
862 })
863 .with_description("Even integer"),
864 );
865
866 registry.register(
868 RefinementType::new("Int")
869 .with_name("Odd")
870 .with_predicate(RefinementPredicate::Modulo {
871 divisor: 2,
872 remainder: 1,
873 })
874 .with_description("Odd integer"),
875 );
876
877 registry
878 }
879
880 pub fn register(&mut self, refinement: RefinementType) {
882 let name = refinement.type_name().to_string();
883 self.types.insert(name, refinement);
884 }
885
886 pub fn get(&self, name: &str) -> Option<&RefinementType> {
888 self.types.get(name)
889 }
890
891 pub fn contains(&self, name: &str) -> bool {
893 self.types.contains_key(name)
894 }
895
896 pub fn type_names(&self) -> Vec<&str> {
898 self.types.keys().map(|s| s.as_str()).collect()
899 }
900
901 pub fn len(&self) -> usize {
903 self.types.len()
904 }
905
906 pub fn is_empty(&self) -> bool {
908 self.types.is_empty()
909 }
910
911 pub fn check(&self, type_name: &str, value: f64) -> Option<bool> {
913 self.types.get(type_name).map(|t| t.check(value))
914 }
915
916 pub fn iter(&self) -> impl Iterator<Item = (&str, &RefinementType)> {
918 self.types.iter().map(|(k, v)| (k.as_str(), v))
919 }
920}
921
922#[cfg(test)]
923mod tests {
924 use super::*;
925
926 #[test]
927 fn test_basic_predicates() {
928 let pred = RefinementPredicate::GreaterThan(0.0);
929 assert!(pred.check(5.0));
930 assert!(!pred.check(-1.0));
931 assert!(!pred.check(0.0));
932 }
933
934 #[test]
935 fn test_range_predicate() {
936 let pred = RefinementPredicate::Range { min: 0.0, max: 1.0 };
937 assert!(pred.check(0.5));
938 assert!(pred.check(0.0));
939 assert!(pred.check(1.0));
940 assert!(!pred.check(-0.1));
941 assert!(!pred.check(1.1));
942 }
943
944 #[test]
945 fn test_modulo_predicate() {
946 let even = RefinementPredicate::Modulo {
947 divisor: 2,
948 remainder: 0,
949 };
950 assert!(even.check(4.0));
951 assert!(even.check(0.0));
952 assert!(!even.check(3.0));
953 }
954
955 #[test]
956 fn test_compound_predicates() {
957 let pred = RefinementPredicate::And(vec![
959 RefinementPredicate::GreaterThan(0.0),
960 RefinementPredicate::Modulo {
961 divisor: 2,
962 remainder: 0,
963 },
964 ]);
965
966 assert!(pred.check(4.0));
967 assert!(!pred.check(-2.0)); assert!(!pred.check(3.0)); }
970
971 #[test]
972 fn test_in_set_predicate() {
973 let pred = RefinementPredicate::InSet(vec![1.0, 2.0, 3.0]);
974 assert!(pred.check(1.0));
975 assert!(pred.check(2.0));
976 assert!(!pred.check(4.0));
977 }
978
979 #[test]
980 fn test_custom_predicate() {
981 let pred = RefinementPredicate::custom("is_prime", "Checks if number is prime", |n| {
982 if n < 2.0 {
983 return false;
984 }
985 let n = n as i64;
986 for i in 2..=((n as f64).sqrt() as i64) {
987 if n % i == 0 {
988 return false;
989 }
990 }
991 true
992 });
993
994 assert!(pred.check(2.0));
995 assert!(pred.check(7.0));
996 assert!(!pred.check(4.0));
997 assert!(!pred.check(1.0));
998 }
999
1000 #[test]
1001 fn test_refinement_type() {
1002 let pos_int = RefinementType::new("Int")
1003 .with_name("PositiveInt")
1004 .with_predicate(RefinementPredicate::GreaterThan(0.0));
1005
1006 assert_eq!(pos_int.type_name(), "PositiveInt");
1007 assert!(pos_int.check(5.0));
1008 assert!(!pos_int.check(-1.0));
1009 }
1010
1011 #[test]
1012 fn test_dependent_predicate() {
1013 let pred = RefinementPredicate::Dependent {
1014 variable: "n".to_string(),
1015 relation: DependentRelation::LessThan,
1016 };
1017
1018 let mut context = RefinementContext::new();
1019 context.set_value("n", 10.0);
1020
1021 assert!(pred.check_with_context(5.0, &context));
1022 assert!(!pred.check_with_context(15.0, &context));
1023 }
1024
1025 #[test]
1026 fn test_registry_builtins() {
1027 let registry = RefinementRegistry::with_builtins();
1028
1029 assert!(registry.check("PositiveInt", 5.0).expect("unwrap"));
1031 assert!(!registry.check("PositiveInt", -1.0).expect("unwrap"));
1032
1033 assert!(registry.check("Probability", 0.5).expect("unwrap"));
1035 assert!(!registry.check("Probability", 1.5).expect("unwrap"));
1036
1037 assert!(registry.check("Even", 4.0).expect("unwrap"));
1039 assert!(!registry.check("Even", 3.0).expect("unwrap"));
1040 }
1041
1042 #[test]
1043 fn test_predicate_simplification() {
1044 let pred = RefinementPredicate::And(vec![
1045 RefinementPredicate::GreaterThan(0.0),
1046 RefinementPredicate::LessThan(10.0),
1047 RefinementPredicate::GreaterThanOrEqual(1.0),
1048 ]);
1049
1050 let simplified = pred.simplify();
1051
1052 assert!(simplified.check(5.0));
1055 assert!(!simplified.check(0.0));
1056 assert!(simplified.check(1.0)); }
1059
1060 #[test]
1061 fn test_predicate_string_repr() {
1062 let pred = RefinementPredicate::Range { min: 0.0, max: 1.0 };
1063 assert_eq!(pred.to_string_repr(), "0 <= x <= 1");
1064
1065 let pred = RefinementPredicate::And(vec![
1066 RefinementPredicate::GreaterThan(0.0),
1067 RefinementPredicate::LessThan(10.0),
1068 ]);
1069 assert_eq!(pred.to_string_repr(), "(x > 0 && x < 10)");
1070 }
1071
1072 #[test]
1073 fn test_free_variables() {
1074 let pred = RefinementPredicate::And(vec![
1075 RefinementPredicate::GreaterThan(0.0),
1076 RefinementPredicate::Dependent {
1077 variable: "n".to_string(),
1078 relation: DependentRelation::LessThan,
1079 },
1080 RefinementPredicate::Dependent {
1081 variable: "m".to_string(),
1082 relation: DependentRelation::GreaterThan,
1083 },
1084 ]);
1085
1086 let vars = pred.free_variables();
1087 assert_eq!(vars.len(), 2);
1088 assert!(vars.contains(&"m".to_string()));
1089 assert!(vars.contains(&"n".to_string()));
1090 }
1091
1092 #[test]
1093 fn test_refinement_type_repr() {
1094 let ty = RefinementType::new("Int")
1095 .with_name("BoundedInt")
1096 .with_predicate(RefinementPredicate::Range {
1097 min: 0.0,
1098 max: 100.0,
1099 });
1100
1101 assert_eq!(ty.to_string_repr(), "Int{0 <= x <= 100}");
1102 }
1103
1104 #[test]
1105 fn test_context_operations() {
1106 let mut ctx = RefinementContext::new();
1107
1108 ctx.set_value("x", 5.0);
1109 ctx.set_value("y", 10.0);
1110
1111 assert_eq!(ctx.get_value("x"), Some(&5.0));
1112 assert!(ctx.has_variable("x"));
1113 assert!(!ctx.has_variable("z"));
1114
1115 let vars = ctx.variables();
1116 assert_eq!(vars.len(), 2);
1117 }
1118
1119 #[test]
1120 fn test_negation_predicate() {
1121 let pred = RefinementPredicate::Not(Box::new(RefinementPredicate::Equal(0.0)));
1122
1123 assert!(pred.check(5.0));
1124 assert!(!pred.check(0.0));
1125 }
1126
1127 #[test]
1128 fn test_or_predicate() {
1129 let pred = RefinementPredicate::Or(vec![
1130 RefinementPredicate::LessThan(0.0),
1131 RefinementPredicate::GreaterThan(10.0),
1132 ]);
1133
1134 assert!(pred.check(-5.0));
1135 assert!(pred.check(15.0));
1136 assert!(!pred.check(5.0));
1137 }
1138
1139 #[test]
1140 fn test_double_negation_simplification() {
1141 let pred = RefinementPredicate::Not(Box::new(RefinementPredicate::Not(Box::new(
1142 RefinementPredicate::GreaterThan(0.0),
1143 ))));
1144
1145 let simplified = pred.simplify();
1146 assert!(simplified.check(5.0));
1147 assert!(!simplified.check(-1.0));
1148 }
1149
1150 #[test]
1151 fn test_registry_custom_type() {
1152 let mut registry = RefinementRegistry::new();
1153
1154 registry.register(
1155 RefinementType::new("Float")
1156 .with_name("SmallPositive")
1157 .with_predicate(RefinementPredicate::Range {
1158 min: 0.0,
1159 max: 1e-6,
1160 }),
1161 );
1162
1163 assert!(registry.contains("SmallPositive"));
1164 assert!(registry.check("SmallPositive", 1e-7).expect("unwrap"));
1165 assert!(!registry.check("SmallPositive", 1.0).expect("unwrap"));
1166 }
1167
1168 #[test]
1171 fn test_subtyping_basic() {
1172 let int_type = RefinementType::new("Int");
1174 let float_type = RefinementType::new("Float");
1175
1176 assert!(!int_type.is_subtype_of(&float_type)); assert!(int_type.is_subtype_of(&int_type)); }
1179
1180 #[test]
1181 fn test_subtyping_range_implication() {
1182 let stricter = RefinementType::new("Int").with_predicate(RefinementPredicate::Range {
1184 min: 5.0,
1185 max: 10.0,
1186 });
1187
1188 let looser = RefinementType::new("Int").with_predicate(RefinementPredicate::Range {
1189 min: 0.0,
1190 max: 15.0,
1191 });
1192
1193 assert!(stricter.is_subtype_of(&looser));
1194 assert!(!looser.is_subtype_of(&stricter)); }
1196
1197 #[test]
1198 fn test_subtyping_greater_than_implication() {
1199 let stricter =
1201 RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(10.0));
1202
1203 let looser =
1204 RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
1205
1206 assert!(stricter.is_subtype_of(&looser));
1207 assert!(!looser.is_subtype_of(&stricter));
1208 }
1209
1210 #[test]
1211 fn test_subtyping_less_than_implication() {
1212 let stricter =
1214 RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(5.0));
1215
1216 let looser = RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(10.0));
1217
1218 assert!(stricter.is_subtype_of(&looser));
1219 assert!(!looser.is_subtype_of(&stricter));
1220 }
1221
1222 #[test]
1223 fn test_subtyping_modulo_implication() {
1224 let divisible_by_4 =
1226 RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
1227 divisor: 4,
1228 remainder: 0,
1229 });
1230
1231 let divisible_by_2 =
1232 RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
1233 divisor: 2,
1234 remainder: 0,
1235 });
1236
1237 assert!(divisible_by_4.is_subtype_of(&divisible_by_2));
1238 assert!(!divisible_by_2.is_subtype_of(&divisible_by_4));
1239 }
1240
1241 #[test]
1242 fn test_subtyping_conjunction() {
1243 let bounded = RefinementType::new("Int")
1245 .with_predicate(RefinementPredicate::GreaterThan(5.0))
1246 .with_predicate(RefinementPredicate::LessThan(10.0));
1247
1248 let positive =
1249 RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(0.0));
1250
1251 assert!(bounded.is_subtype_of(&positive));
1252 }
1253
1254 #[test]
1255 fn test_subtyping_equality_implies_bounds() {
1256 let exact = RefinementType::new("Int").with_predicate(RefinementPredicate::Equal(7.0));
1258
1259 let gt_5 = RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
1260
1261 let lt_10 = RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(10.0));
1262
1263 assert!(exact.is_subtype_of(>_5));
1264 assert!(exact.is_subtype_of(<_10));
1265 }
1266
1267 #[test]
1268 fn test_subtyping_no_implication() {
1269 let even = RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
1271 divisor: 2,
1272 remainder: 0,
1273 });
1274
1275 let gt_5 = RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
1276
1277 assert!(!even.is_subtype_of(>_5));
1278 assert!(!gt_5.is_subtype_of(&even));
1279 }
1280}