Skip to main content

tensorlogic_adapters/
refinement.rs

1//! Refinement types for expressing value constraints beyond simple types.
2//!
3//! Refinement types extend base types with predicates that constrain valid values.
4//! This enables static verification of properties like positivity, bounds, and custom invariants.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use tensorlogic_adapters::{RefinementType, RefinementPredicate, RefinementContext};
10//!
11//! // Create a positive integer refinement
12//! let pos_int = RefinementType::new("Int")
13//!     .with_predicate(RefinementPredicate::greater_than(0.0))
14//!     .with_name("PositiveInt");
15//!
16//! // Check if a value satisfies the refinement
17//! assert!(pos_int.check(5.0));
18//! assert!(!pos_int.check(-1.0));
19//!
20//! // Create bounded range refinement
21//! let probability = RefinementType::new("Float")
22//!     .with_predicate(RefinementPredicate::range(0.0, 1.0))
23//!     .with_name("Probability");
24//!
25//! assert!(probability.check(0.5));
26//! assert!(!probability.check(1.5));
27//! ```
28
29use std::collections::HashMap;
30use std::sync::Arc;
31
32/// A refinement predicate that constrains values.
33#[derive(Clone)]
34pub enum RefinementPredicate {
35    /// Value must equal a constant
36    Equal(f64),
37    /// Value must not equal a constant
38    NotEqual(f64),
39    /// Value must be greater than a constant
40    GreaterThan(f64),
41    /// Value must be greater than or equal to a constant
42    GreaterThanOrEqual(f64),
43    /// Value must be less than a constant
44    LessThan(f64),
45    /// Value must be less than or equal to a constant
46    LessThanOrEqual(f64),
47    /// Value must be in a range [min, max]
48    Range { min: f64, max: f64 },
49    /// Value must be in a half-open range [min, max)
50    RangeExclusive { min: f64, max: f64 },
51    /// Value must satisfy a modulo constraint (value % divisor == remainder)
52    Modulo { divisor: i64, remainder: i64 },
53    /// Value must be in a set of allowed values
54    InSet(Vec<f64>),
55    /// Value must not be in a set of disallowed values
56    NotInSet(Vec<f64>),
57    /// Conjunction of predicates (all must be satisfied)
58    And(Vec<RefinementPredicate>),
59    /// Disjunction of predicates (at least one must be satisfied)
60    Or(Vec<RefinementPredicate>),
61    /// Negation of a predicate
62    Not(Box<RefinementPredicate>),
63    /// Custom predicate with a name (for symbolic reasoning)
64    Custom {
65        name: String,
66        description: String,
67        checker: Arc<dyn Fn(f64) -> bool + Send + Sync>,
68    },
69    /// Dependent predicate referencing another variable
70    Dependent {
71        variable: String,
72        relation: DependentRelation,
73    },
74    /// String length constraint
75    StringLength {
76        min: Option<usize>,
77        max: Option<usize>,
78    },
79    /// Pattern match constraint (for strings)
80    Pattern(String),
81}
82
83/// Relation for dependent predicates.
84#[derive(Debug, Clone, PartialEq)]
85pub enum DependentRelation {
86    /// Value must be less than the referenced variable
87    LessThan,
88    /// Value must be less than or equal to the referenced variable
89    LessThanOrEqual,
90    /// Value must be greater than the referenced variable
91    GreaterThan,
92    /// Value must be greater than or equal to the referenced variable
93    GreaterThanOrEqual,
94    /// Value must equal the referenced variable
95    Equal,
96    /// Value must not equal the referenced variable
97    NotEqual,
98    /// Value is a divisor of the referenced variable
99    Divides,
100    /// Referenced variable is a divisor of this value
101    DivisibleBy,
102}
103
104impl std::fmt::Debug for RefinementPredicate {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        match self {
107            RefinementPredicate::Equal(v) => f.debug_tuple("Equal").field(v).finish(),
108            RefinementPredicate::NotEqual(v) => f.debug_tuple("NotEqual").field(v).finish(),
109            RefinementPredicate::GreaterThan(v) => f.debug_tuple("GreaterThan").field(v).finish(),
110            RefinementPredicate::GreaterThanOrEqual(v) => {
111                f.debug_tuple("GreaterThanOrEqual").field(v).finish()
112            }
113            RefinementPredicate::LessThan(v) => f.debug_tuple("LessThan").field(v).finish(),
114            RefinementPredicate::LessThanOrEqual(v) => {
115                f.debug_tuple("LessThanOrEqual").field(v).finish()
116            }
117            RefinementPredicate::Range { min, max } => f
118                .debug_struct("Range")
119                .field("min", min)
120                .field("max", max)
121                .finish(),
122            RefinementPredicate::RangeExclusive { min, max } => f
123                .debug_struct("RangeExclusive")
124                .field("min", min)
125                .field("max", max)
126                .finish(),
127            RefinementPredicate::Modulo { divisor, remainder } => f
128                .debug_struct("Modulo")
129                .field("divisor", divisor)
130                .field("remainder", remainder)
131                .finish(),
132            RefinementPredicate::InSet(set) => f.debug_tuple("InSet").field(set).finish(),
133            RefinementPredicate::NotInSet(set) => f.debug_tuple("NotInSet").field(set).finish(),
134            RefinementPredicate::And(preds) => f.debug_tuple("And").field(preds).finish(),
135            RefinementPredicate::Or(preds) => f.debug_tuple("Or").field(preds).finish(),
136            RefinementPredicate::Not(pred) => f.debug_tuple("Not").field(pred).finish(),
137            RefinementPredicate::Custom {
138                name, description, ..
139            } => f
140                .debug_struct("Custom")
141                .field("name", name)
142                .field("description", description)
143                .finish(),
144            RefinementPredicate::Dependent { variable, relation } => f
145                .debug_struct("Dependent")
146                .field("variable", variable)
147                .field("relation", relation)
148                .finish(),
149            RefinementPredicate::StringLength { min, max } => f
150                .debug_struct("StringLength")
151                .field("min", min)
152                .field("max", max)
153                .finish(),
154            RefinementPredicate::Pattern(pattern) => {
155                f.debug_tuple("Pattern").field(pattern).finish()
156            }
157        }
158    }
159}
160
161impl RefinementPredicate {
162    /// Create a "greater than" predicate.
163    pub fn greater_than(value: f64) -> Self {
164        RefinementPredicate::GreaterThan(value)
165    }
166
167    /// Create a "greater than or equal" predicate.
168    pub fn greater_than_or_equal(value: f64) -> Self {
169        RefinementPredicate::GreaterThanOrEqual(value)
170    }
171
172    /// Create a "less than" predicate.
173    pub fn less_than(value: f64) -> Self {
174        RefinementPredicate::LessThan(value)
175    }
176
177    /// Create a "less than or equal" predicate.
178    pub fn less_than_or_equal(value: f64) -> Self {
179        RefinementPredicate::LessThanOrEqual(value)
180    }
181
182    /// Create a range predicate [min, max].
183    pub fn range(min: f64, max: f64) -> Self {
184        RefinementPredicate::Range { min, max }
185    }
186
187    /// Create a modulo constraint predicate.
188    pub fn modulo(divisor: i64, remainder: i64) -> Self {
189        RefinementPredicate::Modulo { divisor, remainder }
190    }
191
192    /// Create a predicate requiring value to be in a set.
193    pub fn in_set(values: Vec<f64>) -> Self {
194        RefinementPredicate::InSet(values)
195    }
196
197    /// Create a conjunction of predicates.
198    pub fn and(predicates: Vec<RefinementPredicate>) -> Self {
199        RefinementPredicate::And(predicates)
200    }
201
202    /// Create a disjunction of predicates.
203    pub fn or(predicates: Vec<RefinementPredicate>) -> Self {
204        RefinementPredicate::Or(predicates)
205    }
206
207    /// Create a negation of a predicate.
208    #[allow(clippy::should_implement_trait)]
209    pub fn not(predicate: RefinementPredicate) -> Self {
210        RefinementPredicate::Not(Box::new(predicate))
211    }
212
213    /// Create a custom predicate with a checker function.
214    pub fn custom<F>(name: impl Into<String>, description: impl Into<String>, checker: F) -> Self
215    where
216        F: Fn(f64) -> bool + Send + Sync + 'static,
217    {
218        RefinementPredicate::Custom {
219            name: name.into(),
220            description: description.into(),
221            checker: Arc::new(checker),
222        }
223    }
224
225    /// Create a dependent predicate.
226    pub fn dependent(variable: impl Into<String>, relation: DependentRelation) -> Self {
227        RefinementPredicate::Dependent {
228            variable: variable.into(),
229            relation,
230        }
231    }
232
233    /// Check if a value satisfies this predicate.
234    ///
235    /// Note: For dependent predicates, this returns true (use `check_with_context` instead).
236    pub fn check(&self, value: f64) -> bool {
237        match self {
238            RefinementPredicate::Equal(v) => (value - v).abs() < f64::EPSILON,
239            RefinementPredicate::NotEqual(v) => (value - v).abs() >= f64::EPSILON,
240            RefinementPredicate::GreaterThan(v) => value > *v,
241            RefinementPredicate::GreaterThanOrEqual(v) => value >= *v,
242            RefinementPredicate::LessThan(v) => value < *v,
243            RefinementPredicate::LessThanOrEqual(v) => value <= *v,
244            RefinementPredicate::Range { min, max } => value >= *min && value <= *max,
245            RefinementPredicate::RangeExclusive { min, max } => value >= *min && value < *max,
246            RefinementPredicate::Modulo { divisor, remainder } => {
247                (value as i64) % divisor == *remainder
248            }
249            RefinementPredicate::InSet(set) => set.iter().any(|v| (value - v).abs() < f64::EPSILON),
250            RefinementPredicate::NotInSet(set) => {
251                !set.iter().any(|v| (value - v).abs() < f64::EPSILON)
252            }
253            RefinementPredicate::And(preds) => preds.iter().all(|p| p.check(value)),
254            RefinementPredicate::Or(preds) => preds.iter().any(|p| p.check(value)),
255            RefinementPredicate::Not(pred) => !pred.check(value),
256            RefinementPredicate::Custom { checker, .. } => checker(value),
257            RefinementPredicate::Dependent { .. } => true, // Needs context
258            RefinementPredicate::StringLength { .. } => true, // Not applicable to f64
259            RefinementPredicate::Pattern(_) => true,       // Not applicable to f64
260        }
261    }
262
263    /// Check if a value satisfies this predicate with a context for dependent predicates.
264    pub fn check_with_context(&self, value: f64, context: &RefinementContext) -> bool {
265        match self {
266            RefinementPredicate::Dependent { variable, relation } => {
267                if let Some(&other) = context.get_value(variable) {
268                    match relation {
269                        DependentRelation::LessThan => value < other,
270                        DependentRelation::LessThanOrEqual => value <= other,
271                        DependentRelation::GreaterThan => value > other,
272                        DependentRelation::GreaterThanOrEqual => value >= other,
273                        DependentRelation::Equal => (value - other).abs() < f64::EPSILON,
274                        DependentRelation::NotEqual => (value - other).abs() >= f64::EPSILON,
275                        DependentRelation::Divides => {
276                            other != 0.0 && (other as i64) % (value as i64) == 0
277                        }
278                        DependentRelation::DivisibleBy => {
279                            value != 0.0 && (value as i64) % (other as i64) == 0
280                        }
281                    }
282                } else {
283                    false // Unknown variable
284                }
285            }
286            RefinementPredicate::And(preds) => {
287                preds.iter().all(|p| p.check_with_context(value, context))
288            }
289            RefinementPredicate::Or(preds) => {
290                preds.iter().any(|p| p.check_with_context(value, context))
291            }
292            RefinementPredicate::Not(pred) => !pred.check_with_context(value, context),
293            _ => self.check(value),
294        }
295    }
296
297    /// Get the free variables referenced by this predicate.
298    pub fn free_variables(&self) -> Vec<String> {
299        match self {
300            RefinementPredicate::Dependent { variable, .. } => vec![variable.clone()],
301            RefinementPredicate::And(preds) | RefinementPredicate::Or(preds) => {
302                let mut vars = Vec::new();
303                for pred in preds {
304                    vars.extend(pred.free_variables());
305                }
306                vars.sort();
307                vars.dedup();
308                vars
309            }
310            RefinementPredicate::Not(pred) => pred.free_variables(),
311            _ => vec![],
312        }
313    }
314
315    /// Simplify the predicate by removing redundant constraints.
316    pub fn simplify(&self) -> RefinementPredicate {
317        match self {
318            RefinementPredicate::And(preds) => {
319                let simplified: Vec<_> = preds.iter().map(|p| p.simplify()).collect();
320                if simplified.len() == 1 {
321                    simplified
322                        .into_iter()
323                        .next()
324                        .expect("validated length == 1")
325                } else {
326                    // Merge range constraints
327                    let mut min_val = f64::NEG_INFINITY;
328                    let mut max_val = f64::INFINITY;
329                    let mut others = Vec::new();
330
331                    for pred in simplified {
332                        match pred {
333                            RefinementPredicate::GreaterThan(v) => {
334                                min_val = min_val.max(v);
335                            }
336                            RefinementPredicate::GreaterThanOrEqual(v) => {
337                                min_val = min_val.max(v);
338                            }
339                            RefinementPredicate::LessThan(v) => {
340                                max_val = max_val.min(v);
341                            }
342                            RefinementPredicate::LessThanOrEqual(v) => {
343                                max_val = max_val.min(v);
344                            }
345                            RefinementPredicate::Range { min, max } => {
346                                min_val = min_val.max(min);
347                                max_val = max_val.min(max);
348                            }
349                            other => others.push(other),
350                        }
351                    }
352
353                    // Create merged range if we have bounds
354                    if min_val > f64::NEG_INFINITY || max_val < f64::INFINITY {
355                        if min_val > f64::NEG_INFINITY && max_val < f64::INFINITY {
356                            others.insert(
357                                0,
358                                RefinementPredicate::Range {
359                                    min: min_val,
360                                    max: max_val,
361                                },
362                            );
363                        } else if min_val > f64::NEG_INFINITY {
364                            others.insert(0, RefinementPredicate::GreaterThanOrEqual(min_val));
365                        } else {
366                            others.insert(0, RefinementPredicate::LessThanOrEqual(max_val));
367                        }
368                    }
369
370                    if others.len() == 1 {
371                        others.into_iter().next().expect("validated length == 1")
372                    } else {
373                        RefinementPredicate::And(others)
374                    }
375                }
376            }
377            RefinementPredicate::Or(preds) => {
378                let simplified: Vec<_> = preds.iter().map(|p| p.simplify()).collect();
379                if simplified.len() == 1 {
380                    simplified
381                        .into_iter()
382                        .next()
383                        .expect("validated length == 1")
384                } else {
385                    RefinementPredicate::Or(simplified)
386                }
387            }
388            RefinementPredicate::Not(pred) => {
389                let inner = pred.simplify();
390                match inner {
391                    RefinementPredicate::Not(p) => *p, // Double negation
392                    other => RefinementPredicate::Not(Box::new(other)),
393                }
394            }
395            other => other.clone(),
396        }
397    }
398
399    /// Convert to a human-readable string.
400    pub fn to_string_repr(&self) -> String {
401        match self {
402            RefinementPredicate::Equal(v) => format!("x == {}", v),
403            RefinementPredicate::NotEqual(v) => format!("x != {}", v),
404            RefinementPredicate::GreaterThan(v) => format!("x > {}", v),
405            RefinementPredicate::GreaterThanOrEqual(v) => format!("x >= {}", v),
406            RefinementPredicate::LessThan(v) => format!("x < {}", v),
407            RefinementPredicate::LessThanOrEqual(v) => format!("x <= {}", v),
408            RefinementPredicate::Range { min, max } => format!("{} <= x <= {}", min, max),
409            RefinementPredicate::RangeExclusive { min, max } => format!("{} <= x < {}", min, max),
410            RefinementPredicate::Modulo { divisor, remainder } => {
411                format!("x % {} == {}", divisor, remainder)
412            }
413            RefinementPredicate::InSet(set) => format!("x in {:?}", set),
414            RefinementPredicate::NotInSet(set) => format!("x not in {:?}", set),
415            RefinementPredicate::And(preds) => {
416                let parts: Vec<_> = preds.iter().map(|p| p.to_string_repr()).collect();
417                format!("({})", parts.join(" && "))
418            }
419            RefinementPredicate::Or(preds) => {
420                let parts: Vec<_> = preds.iter().map(|p| p.to_string_repr()).collect();
421                format!("({})", parts.join(" || "))
422            }
423            RefinementPredicate::Not(pred) => format!("!({})", pred.to_string_repr()),
424            RefinementPredicate::Custom { name, .. } => format!("{}(x)", name),
425            RefinementPredicate::Dependent { variable, relation } => {
426                let rel_str = match relation {
427                    DependentRelation::LessThan => "<",
428                    DependentRelation::LessThanOrEqual => "<=",
429                    DependentRelation::GreaterThan => ">",
430                    DependentRelation::GreaterThanOrEqual => ">=",
431                    DependentRelation::Equal => "==",
432                    DependentRelation::NotEqual => "!=",
433                    DependentRelation::Divides => "divides",
434                    DependentRelation::DivisibleBy => "divisible_by",
435                };
436                format!("x {} {}", rel_str, variable)
437            }
438            RefinementPredicate::StringLength { min, max } => match (min, max) {
439                (Some(min), Some(max)) => format!("{} <= len(x) <= {}", min, max),
440                (Some(min), None) => format!("len(x) >= {}", min),
441                (None, Some(max)) => format!("len(x) <= {}", max),
442                (None, None) => "true".to_string(),
443            },
444            RefinementPredicate::Pattern(pattern) => format!("x matches \"{}\"", pattern),
445        }
446    }
447}
448
449/// A refinement type combining a base type with predicates.
450#[derive(Debug, Clone)]
451pub struct RefinementType {
452    /// Base type name
453    pub base_type: String,
454    /// Optional refined name (e.g., "PositiveInt" for Int{x > 0})
455    pub name: Option<String>,
456    /// Predicates that constrain values
457    pub predicates: Vec<RefinementPredicate>,
458    /// Description of the refinement
459    pub description: Option<String>,
460}
461
462impl RefinementType {
463    /// Create a new refinement type with a base type.
464    pub fn new(base_type: impl Into<String>) -> Self {
465        RefinementType {
466            base_type: base_type.into(),
467            name: None,
468            predicates: Vec::new(),
469            description: None,
470        }
471    }
472
473    /// Set the refined name.
474    pub fn with_name(mut self, name: impl Into<String>) -> Self {
475        self.name = Some(name.into());
476        self
477    }
478
479    /// Add a predicate to the refinement.
480    pub fn with_predicate(mut self, predicate: RefinementPredicate) -> Self {
481        self.predicates.push(predicate);
482        self
483    }
484
485    /// Set the description.
486    pub fn with_description(mut self, description: impl Into<String>) -> Self {
487        self.description = Some(description.into());
488        self
489    }
490
491    /// Check if a value satisfies this refinement type.
492    pub fn check(&self, value: f64) -> bool {
493        self.predicates.iter().all(|p| p.check(value))
494    }
495
496    /// Check if a value satisfies this refinement type with context.
497    pub fn check_with_context(&self, value: f64, context: &RefinementContext) -> bool {
498        self.predicates
499            .iter()
500            .all(|p| p.check_with_context(value, context))
501    }
502
503    /// Get the effective name of this type.
504    pub fn type_name(&self) -> &str {
505        self.name.as_deref().unwrap_or(&self.base_type)
506    }
507
508    /// Check if this is a subtype of another refinement type.
509    ///
510    /// A refinement type A is a subtype of B if:
511    /// 1. They have the same base type
512    /// 2. A's predicates imply B's predicates
513    ///
514    /// This implementation uses semantic implication checking for common predicate patterns,
515    /// providing a practical alternative to full SMT solving while handling most real-world cases.
516    pub fn is_subtype_of(&self, other: &RefinementType) -> bool {
517        if self.base_type != other.base_type {
518            return false;
519        }
520
521        // Conservative check: if other has no predicates, we're a subtype
522        if other.predicates.is_empty() {
523            return true;
524        }
525
526        // If we have no predicates but other does, we're not a subtype
527        if self.predicates.is_empty() && !other.predicates.is_empty() {
528            return false;
529        }
530
531        // Check if all of other's predicates are implied by our predicates
532        for other_pred in &other.predicates {
533            if !self.implies_predicate(other_pred) {
534                return false;
535            }
536        }
537
538        true
539    }
540
541    /// Check if this refinement type's predicates imply the given predicate.
542    ///
543    /// This uses semantic implication checking for common patterns:
544    /// - Syntactic equality (via Debug representation)
545    /// - Range implication (x > 10 implies x > 5)
546    /// - Modulo implication (x % 4 == 0 implies x % 2 == 0)
547    fn implies_predicate(&self, target: &RefinementPredicate) -> bool {
548        // Check for syntactic equality using Debug representation
549        // (RefinementPredicate doesn't implement PartialEq due to function pointers)
550        let target_repr = format!("{:?}", target);
551        if self
552            .predicates
553            .iter()
554            .any(|p| format!("{:?}", p) == target_repr)
555        {
556            return true;
557        }
558
559        // Check for semantic implication based on predicate types
560        for pred in &self.predicates {
561            if Self::semantic_implies(pred, target) {
562                return true;
563            }
564        }
565
566        // Check for conjunction of predicates implying the target
567        Self::conjunction_implies(&self.predicates, target)
568    }
569
570    /// Check if one predicate semantically implies another.
571    fn semantic_implies(source: &RefinementPredicate, target: &RefinementPredicate) -> bool {
572        use RefinementPredicate::*;
573
574        match (source, target) {
575            // Range implications: stricter range implies looser range
576            (
577                Range {
578                    min: min1,
579                    max: max1,
580                },
581                Range {
582                    min: min2,
583                    max: max2,
584                },
585            ) => {
586                // [5, 10] implies [0, 15]
587                min1 >= min2 && max1 <= max2
588            }
589            (
590                RangeExclusive {
591                    min: min1,
592                    max: max1,
593                },
594                RangeExclusive {
595                    min: min2,
596                    max: max2,
597                },
598            ) => min1 >= min2 && max1 <= max2,
599            // Greater-than implications: x > 10 implies x > 5
600            (GreaterThan(v1), GreaterThan(v2)) => v1 >= v2,
601            (GreaterThanOrEqual(v1), GreaterThanOrEqual(v2)) => v1 >= v2,
602            (GreaterThan(v1), GreaterThanOrEqual(v2)) => v1 >= v2, // x > 10 implies x >= 10
603            // Less-than implications: x < 5 implies x < 10
604            (LessThan(v1), LessThan(v2)) => v1 <= v2,
605            (LessThanOrEqual(v1), LessThanOrEqual(v2)) => v1 <= v2,
606            (LessThan(v1), LessThanOrEqual(v2)) => v1 <= v2, // x < 5 implies x <= 5
607            // Equality implies bounds
608            (Equal(v1), GreaterThan(v2)) => v1 > v2,
609            (Equal(v1), GreaterThanOrEqual(v2)) => v1 >= v2,
610            (Equal(v1), LessThan(v2)) => v1 < v2,
611            (Equal(v1), LessThanOrEqual(v2)) => v1 <= v2,
612            (Equal(v1), Range { min, max }) => v1 >= min && v1 <= max,
613            // Modulo implications: x % 4 == 0 implies x % 2 == 0
614            (
615                Modulo {
616                    divisor: d1,
617                    remainder: r1,
618                },
619                Modulo {
620                    divisor: d2,
621                    remainder: r2,
622                },
623            ) => r1 == r2 && d1 % d2 == 0,
624            // Dependent predicates with same variable
625            (
626                Dependent {
627                    variable: v1,
628                    relation: rel1,
629                },
630                Dependent {
631                    variable: v2,
632                    relation: rel2,
633                },
634            ) => {
635                if v1 != v2 {
636                    return false;
637                }
638                // Same variable, check if rel1 implies rel2
639                use DependentRelation::*;
640                matches!(
641                    (rel1, rel2),
642                    (Equal, Equal)
643                        | (GreaterThan, GreaterThan)
644                        | (GreaterThan, GreaterThanOrEqual)
645                        | (LessThan, LessThan)
646                        | (LessThan, LessThanOrEqual)
647                        | (GreaterThanOrEqual, GreaterThanOrEqual)
648                        | (LessThanOrEqual, LessThanOrEqual)
649                )
650            }
651            _ => false,
652        }
653    }
654
655    /// Check if a conjunction of predicates implies a target predicate.
656    ///
657    /// This handles cases like: (x > 5 && x < 10) implies x > 0
658    fn conjunction_implies(
659        predicates: &[RefinementPredicate],
660        target: &RefinementPredicate,
661    ) -> bool {
662        use RefinementPredicate::*;
663
664        // Extract range bounds from multiple predicates
665        let mut lower_bounds = Vec::new();
666        let mut upper_bounds = Vec::new();
667
668        for pred in predicates {
669            match pred {
670                GreaterThan(v) | GreaterThanOrEqual(v) => {
671                    lower_bounds.push(*v);
672                }
673                LessThan(v) | LessThanOrEqual(v) => {
674                    upper_bounds.push(*v);
675                }
676                Range { min, max } => {
677                    lower_bounds.push(*min);
678                    upper_bounds.push(*max);
679                }
680                Equal(v) => {
681                    lower_bounds.push(*v);
682                    upper_bounds.push(*v);
683                }
684                _ => {}
685            }
686        }
687
688        // Check if combined bounds imply the target
689        match target {
690            GreaterThan(v) | GreaterThanOrEqual(v) => lower_bounds.iter().any(|lb| lb >= v),
691            LessThan(v) | LessThanOrEqual(v) => upper_bounds.iter().any(|ub| ub <= v),
692            Range { min, max } => {
693                lower_bounds.iter().any(|lb| lb >= min) && upper_bounds.iter().any(|ub| ub <= max)
694            }
695            _ => false,
696        }
697    }
698
699    /// Get all free variables referenced in predicates.
700    pub fn free_variables(&self) -> Vec<String> {
701        let mut vars = Vec::new();
702        for pred in &self.predicates {
703            vars.extend(pred.free_variables());
704        }
705        vars.sort();
706        vars.dedup();
707        vars
708    }
709
710    /// Convert to human-readable representation.
711    pub fn to_string_repr(&self) -> String {
712        if self.predicates.is_empty() {
713            return self.base_type.clone();
714        }
715
716        let pred_strs: Vec<_> = self.predicates.iter().map(|p| p.to_string_repr()).collect();
717        format!("{}{{{}}}", self.base_type, pred_strs.join(" && "))
718    }
719}
720
721/// Context for evaluating dependent refinement predicates.
722#[derive(Debug, Clone, Default)]
723pub struct RefinementContext {
724    /// Variable values in the current context
725    values: HashMap<String, f64>,
726    /// Type assignments for variables
727    types: HashMap<String, RefinementType>,
728}
729
730impl RefinementContext {
731    /// Create a new empty context.
732    pub fn new() -> Self {
733        RefinementContext {
734            values: HashMap::new(),
735            types: HashMap::new(),
736        }
737    }
738
739    /// Set a variable's value.
740    pub fn set_value(&mut self, var: impl Into<String>, value: f64) {
741        self.values.insert(var.into(), value);
742    }
743
744    /// Get a variable's value.
745    pub fn get_value(&self, var: &str) -> Option<&f64> {
746        self.values.get(var)
747    }
748
749    /// Set a variable's type.
750    pub fn set_type(&mut self, var: impl Into<String>, ty: RefinementType) {
751        self.types.insert(var.into(), ty);
752    }
753
754    /// Get a variable's type.
755    pub fn get_type(&self, var: &str) -> Option<&RefinementType> {
756        self.types.get(var)
757    }
758
759    /// Check if a variable exists in the context.
760    pub fn has_variable(&self, var: &str) -> bool {
761        self.values.contains_key(var) || self.types.contains_key(var)
762    }
763
764    /// Get all variable names.
765    pub fn variables(&self) -> Vec<&str> {
766        let mut vars: Vec<_> = self.values.keys().map(|s| s.as_str()).collect();
767        for key in self.types.keys() {
768            if !self.values.contains_key(key) {
769                vars.push(key.as_str());
770            }
771        }
772        vars
773    }
774}
775
776/// Registry for managing refinement types.
777#[derive(Debug, Clone, Default)]
778pub struct RefinementRegistry {
779    /// Named refinement types
780    types: HashMap<String, RefinementType>,
781}
782
783impl RefinementRegistry {
784    /// Create a new empty registry.
785    pub fn new() -> Self {
786        RefinementRegistry {
787            types: HashMap::new(),
788        }
789    }
790
791    /// Create a registry with common built-in refinement types.
792    pub fn with_builtins() -> Self {
793        let mut registry = RefinementRegistry::new();
794
795        // Positive integer
796        registry.register(
797            RefinementType::new("Int")
798                .with_name("PositiveInt")
799                .with_predicate(RefinementPredicate::GreaterThan(0.0))
800                .with_description("Strictly positive integer"),
801        );
802
803        // Non-negative integer
804        registry.register(
805            RefinementType::new("Int")
806                .with_name("NonNegativeInt")
807                .with_predicate(RefinementPredicate::GreaterThanOrEqual(0.0))
808                .with_description("Non-negative integer (zero or positive)"),
809        );
810
811        // Probability (0 to 1)
812        registry.register(
813            RefinementType::new("Float")
814                .with_name("Probability")
815                .with_predicate(RefinementPredicate::Range { min: 0.0, max: 1.0 })
816                .with_description("Probability value between 0 and 1"),
817        );
818
819        // Percentage (0 to 100)
820        registry.register(
821            RefinementType::new("Float")
822                .with_name("Percentage")
823                .with_predicate(RefinementPredicate::Range {
824                    min: 0.0,
825                    max: 100.0,
826                })
827                .with_description("Percentage value between 0 and 100"),
828        );
829
830        // Normalized (-1 to 1)
831        registry.register(
832            RefinementType::new("Float")
833                .with_name("Normalized")
834                .with_predicate(RefinementPredicate::Range {
835                    min: -1.0,
836                    max: 1.0,
837                })
838                .with_description("Normalized value between -1 and 1"),
839        );
840
841        // Natural number (0, 1, 2, ...)
842        registry.register(
843            RefinementType::new("Int")
844                .with_name("Natural")
845                .with_predicate(RefinementPredicate::And(vec![
846                    RefinementPredicate::GreaterThanOrEqual(0.0),
847                    RefinementPredicate::Modulo {
848                        divisor: 1,
849                        remainder: 0,
850                    },
851                ]))
852                .with_description("Natural number (non-negative integer)"),
853        );
854
855        // Even number
856        registry.register(
857            RefinementType::new("Int")
858                .with_name("Even")
859                .with_predicate(RefinementPredicate::Modulo {
860                    divisor: 2,
861                    remainder: 0,
862                })
863                .with_description("Even integer"),
864        );
865
866        // Odd number
867        registry.register(
868            RefinementType::new("Int")
869                .with_name("Odd")
870                .with_predicate(RefinementPredicate::Modulo {
871                    divisor: 2,
872                    remainder: 1,
873                })
874                .with_description("Odd integer"),
875        );
876
877        registry
878    }
879
880    /// Register a refinement type.
881    pub fn register(&mut self, refinement: RefinementType) {
882        let name = refinement.type_name().to_string();
883        self.types.insert(name, refinement);
884    }
885
886    /// Get a refinement type by name.
887    pub fn get(&self, name: &str) -> Option<&RefinementType> {
888        self.types.get(name)
889    }
890
891    /// Check if a type is registered.
892    pub fn contains(&self, name: &str) -> bool {
893        self.types.contains_key(name)
894    }
895
896    /// Get all registered type names.
897    pub fn type_names(&self) -> Vec<&str> {
898        self.types.keys().map(|s| s.as_str()).collect()
899    }
900
901    /// Get the number of registered types.
902    pub fn len(&self) -> usize {
903        self.types.len()
904    }
905
906    /// Check if the registry is empty.
907    pub fn is_empty(&self) -> bool {
908        self.types.is_empty()
909    }
910
911    /// Check if a value satisfies a refinement type by name.
912    pub fn check(&self, type_name: &str, value: f64) -> Option<bool> {
913        self.types.get(type_name).map(|t| t.check(value))
914    }
915
916    /// Iterate over all refinement types.
917    pub fn iter(&self) -> impl Iterator<Item = (&str, &RefinementType)> {
918        self.types.iter().map(|(k, v)| (k.as_str(), v))
919    }
920}
921
922#[cfg(test)]
923mod tests {
924    use super::*;
925
926    #[test]
927    fn test_basic_predicates() {
928        let pred = RefinementPredicate::GreaterThan(0.0);
929        assert!(pred.check(5.0));
930        assert!(!pred.check(-1.0));
931        assert!(!pred.check(0.0));
932    }
933
934    #[test]
935    fn test_range_predicate() {
936        let pred = RefinementPredicate::Range { min: 0.0, max: 1.0 };
937        assert!(pred.check(0.5));
938        assert!(pred.check(0.0));
939        assert!(pred.check(1.0));
940        assert!(!pred.check(-0.1));
941        assert!(!pred.check(1.1));
942    }
943
944    #[test]
945    fn test_modulo_predicate() {
946        let even = RefinementPredicate::Modulo {
947            divisor: 2,
948            remainder: 0,
949        };
950        assert!(even.check(4.0));
951        assert!(even.check(0.0));
952        assert!(!even.check(3.0));
953    }
954
955    #[test]
956    fn test_compound_predicates() {
957        // Positive and even
958        let pred = RefinementPredicate::And(vec![
959            RefinementPredicate::GreaterThan(0.0),
960            RefinementPredicate::Modulo {
961                divisor: 2,
962                remainder: 0,
963            },
964        ]);
965
966        assert!(pred.check(4.0));
967        assert!(!pred.check(-2.0)); // Not positive
968        assert!(!pred.check(3.0)); // Not even
969    }
970
971    #[test]
972    fn test_in_set_predicate() {
973        let pred = RefinementPredicate::InSet(vec![1.0, 2.0, 3.0]);
974        assert!(pred.check(1.0));
975        assert!(pred.check(2.0));
976        assert!(!pred.check(4.0));
977    }
978
979    #[test]
980    fn test_custom_predicate() {
981        let pred = RefinementPredicate::custom("is_prime", "Checks if number is prime", |n| {
982            if n < 2.0 {
983                return false;
984            }
985            let n = n as i64;
986            for i in 2..=((n as f64).sqrt() as i64) {
987                if n % i == 0 {
988                    return false;
989                }
990            }
991            true
992        });
993
994        assert!(pred.check(2.0));
995        assert!(pred.check(7.0));
996        assert!(!pred.check(4.0));
997        assert!(!pred.check(1.0));
998    }
999
1000    #[test]
1001    fn test_refinement_type() {
1002        let pos_int = RefinementType::new("Int")
1003            .with_name("PositiveInt")
1004            .with_predicate(RefinementPredicate::GreaterThan(0.0));
1005
1006        assert_eq!(pos_int.type_name(), "PositiveInt");
1007        assert!(pos_int.check(5.0));
1008        assert!(!pos_int.check(-1.0));
1009    }
1010
1011    #[test]
1012    fn test_dependent_predicate() {
1013        let pred = RefinementPredicate::Dependent {
1014            variable: "n".to_string(),
1015            relation: DependentRelation::LessThan,
1016        };
1017
1018        let mut context = RefinementContext::new();
1019        context.set_value("n", 10.0);
1020
1021        assert!(pred.check_with_context(5.0, &context));
1022        assert!(!pred.check_with_context(15.0, &context));
1023    }
1024
1025    #[test]
1026    fn test_registry_builtins() {
1027        let registry = RefinementRegistry::with_builtins();
1028
1029        // Test PositiveInt
1030        assert!(registry.check("PositiveInt", 5.0).expect("unwrap"));
1031        assert!(!registry.check("PositiveInt", -1.0).expect("unwrap"));
1032
1033        // Test Probability
1034        assert!(registry.check("Probability", 0.5).expect("unwrap"));
1035        assert!(!registry.check("Probability", 1.5).expect("unwrap"));
1036
1037        // Test Even
1038        assert!(registry.check("Even", 4.0).expect("unwrap"));
1039        assert!(!registry.check("Even", 3.0).expect("unwrap"));
1040    }
1041
1042    #[test]
1043    fn test_predicate_simplification() {
1044        let pred = RefinementPredicate::And(vec![
1045            RefinementPredicate::GreaterThan(0.0),
1046            RefinementPredicate::LessThan(10.0),
1047            RefinementPredicate::GreaterThanOrEqual(1.0),
1048        ]);
1049
1050        let simplified = pred.simplify();
1051
1052        // Should be simplified to a range [1, 10]
1053        // Note: simplification is conservative and uses inclusive bounds
1054        assert!(simplified.check(5.0));
1055        assert!(!simplified.check(0.0));
1056        // The simplified range includes 10.0 since simplification is conservative
1057        assert!(simplified.check(1.0)); // min bound included
1058    }
1059
1060    #[test]
1061    fn test_predicate_string_repr() {
1062        let pred = RefinementPredicate::Range { min: 0.0, max: 1.0 };
1063        assert_eq!(pred.to_string_repr(), "0 <= x <= 1");
1064
1065        let pred = RefinementPredicate::And(vec![
1066            RefinementPredicate::GreaterThan(0.0),
1067            RefinementPredicate::LessThan(10.0),
1068        ]);
1069        assert_eq!(pred.to_string_repr(), "(x > 0 && x < 10)");
1070    }
1071
1072    #[test]
1073    fn test_free_variables() {
1074        let pred = RefinementPredicate::And(vec![
1075            RefinementPredicate::GreaterThan(0.0),
1076            RefinementPredicate::Dependent {
1077                variable: "n".to_string(),
1078                relation: DependentRelation::LessThan,
1079            },
1080            RefinementPredicate::Dependent {
1081                variable: "m".to_string(),
1082                relation: DependentRelation::GreaterThan,
1083            },
1084        ]);
1085
1086        let vars = pred.free_variables();
1087        assert_eq!(vars.len(), 2);
1088        assert!(vars.contains(&"m".to_string()));
1089        assert!(vars.contains(&"n".to_string()));
1090    }
1091
1092    #[test]
1093    fn test_refinement_type_repr() {
1094        let ty = RefinementType::new("Int")
1095            .with_name("BoundedInt")
1096            .with_predicate(RefinementPredicate::Range {
1097                min: 0.0,
1098                max: 100.0,
1099            });
1100
1101        assert_eq!(ty.to_string_repr(), "Int{0 <= x <= 100}");
1102    }
1103
1104    #[test]
1105    fn test_context_operations() {
1106        let mut ctx = RefinementContext::new();
1107
1108        ctx.set_value("x", 5.0);
1109        ctx.set_value("y", 10.0);
1110
1111        assert_eq!(ctx.get_value("x"), Some(&5.0));
1112        assert!(ctx.has_variable("x"));
1113        assert!(!ctx.has_variable("z"));
1114
1115        let vars = ctx.variables();
1116        assert_eq!(vars.len(), 2);
1117    }
1118
1119    #[test]
1120    fn test_negation_predicate() {
1121        let pred = RefinementPredicate::Not(Box::new(RefinementPredicate::Equal(0.0)));
1122
1123        assert!(pred.check(5.0));
1124        assert!(!pred.check(0.0));
1125    }
1126
1127    #[test]
1128    fn test_or_predicate() {
1129        let pred = RefinementPredicate::Or(vec![
1130            RefinementPredicate::LessThan(0.0),
1131            RefinementPredicate::GreaterThan(10.0),
1132        ]);
1133
1134        assert!(pred.check(-5.0));
1135        assert!(pred.check(15.0));
1136        assert!(!pred.check(5.0));
1137    }
1138
1139    #[test]
1140    fn test_double_negation_simplification() {
1141        let pred = RefinementPredicate::Not(Box::new(RefinementPredicate::Not(Box::new(
1142            RefinementPredicate::GreaterThan(0.0),
1143        ))));
1144
1145        let simplified = pred.simplify();
1146        assert!(simplified.check(5.0));
1147        assert!(!simplified.check(-1.0));
1148    }
1149
1150    #[test]
1151    fn test_registry_custom_type() {
1152        let mut registry = RefinementRegistry::new();
1153
1154        registry.register(
1155            RefinementType::new("Float")
1156                .with_name("SmallPositive")
1157                .with_predicate(RefinementPredicate::Range {
1158                    min: 0.0,
1159                    max: 1e-6,
1160                }),
1161        );
1162
1163        assert!(registry.contains("SmallPositive"));
1164        assert!(registry.check("SmallPositive", 1e-7).expect("unwrap"));
1165        assert!(!registry.check("SmallPositive", 1.0).expect("unwrap"));
1166    }
1167
1168    // Tests for semantic subtyping implementation
1169
1170    #[test]
1171    fn test_subtyping_basic() {
1172        // Test basic base type matching
1173        let int_type = RefinementType::new("Int");
1174        let float_type = RefinementType::new("Float");
1175
1176        assert!(!int_type.is_subtype_of(&float_type)); // Different base types
1177        assert!(int_type.is_subtype_of(&int_type)); // Same type
1178    }
1179
1180    #[test]
1181    fn test_subtyping_range_implication() {
1182        // x ∈ [5, 10] is a subtype of x ∈ [0, 15]
1183        let stricter = RefinementType::new("Int").with_predicate(RefinementPredicate::Range {
1184            min: 5.0,
1185            max: 10.0,
1186        });
1187
1188        let looser = RefinementType::new("Int").with_predicate(RefinementPredicate::Range {
1189            min: 0.0,
1190            max: 15.0,
1191        });
1192
1193        assert!(stricter.is_subtype_of(&looser));
1194        assert!(!looser.is_subtype_of(&stricter)); // Not the other way around
1195    }
1196
1197    #[test]
1198    fn test_subtyping_greater_than_implication() {
1199        // x > 10 is a subtype of x > 5
1200        let stricter =
1201            RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(10.0));
1202
1203        let looser =
1204            RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
1205
1206        assert!(stricter.is_subtype_of(&looser));
1207        assert!(!looser.is_subtype_of(&stricter));
1208    }
1209
1210    #[test]
1211    fn test_subtyping_less_than_implication() {
1212        // x < 5 is a subtype of x < 10
1213        let stricter =
1214            RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(5.0));
1215
1216        let looser = RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(10.0));
1217
1218        assert!(stricter.is_subtype_of(&looser));
1219        assert!(!looser.is_subtype_of(&stricter));
1220    }
1221
1222    #[test]
1223    fn test_subtyping_modulo_implication() {
1224        // x % 4 == 0 is a subtype of x % 2 == 0
1225        let divisible_by_4 =
1226            RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
1227                divisor: 4,
1228                remainder: 0,
1229            });
1230
1231        let divisible_by_2 =
1232            RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
1233                divisor: 2,
1234                remainder: 0,
1235            });
1236
1237        assert!(divisible_by_4.is_subtype_of(&divisible_by_2));
1238        assert!(!divisible_by_2.is_subtype_of(&divisible_by_4));
1239    }
1240
1241    #[test]
1242    fn test_subtyping_conjunction() {
1243        // (x > 5 && x < 10) implies x > 0
1244        let bounded = RefinementType::new("Int")
1245            .with_predicate(RefinementPredicate::GreaterThan(5.0))
1246            .with_predicate(RefinementPredicate::LessThan(10.0));
1247
1248        let positive =
1249            RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(0.0));
1250
1251        assert!(bounded.is_subtype_of(&positive));
1252    }
1253
1254    #[test]
1255    fn test_subtyping_equality_implies_bounds() {
1256        // x == 7 implies x > 5 and x < 10
1257        let exact = RefinementType::new("Int").with_predicate(RefinementPredicate::Equal(7.0));
1258
1259        let gt_5 = RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
1260
1261        let lt_10 = RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(10.0));
1262
1263        assert!(exact.is_subtype_of(&gt_5));
1264        assert!(exact.is_subtype_of(&lt_10));
1265    }
1266
1267    #[test]
1268    fn test_subtyping_no_implication() {
1269        // x % 2 == 0 does NOT imply x > 5
1270        let even = RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
1271            divisor: 2,
1272            remainder: 0,
1273        });
1274
1275        let gt_5 = RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
1276
1277        assert!(!even.is_subtype_of(&gt_5));
1278        assert!(!gt_5.is_subtype_of(&even));
1279    }
1280}