Skip to main content

tensorlogic_adapters/
refinement.rs

1//! Refinement types for expressing value constraints beyond simple types.
2//!
3//! Refinement types extend base types with predicates that constrain valid values.
4//! This enables static verification of properties like positivity, bounds, and custom invariants.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use tensorlogic_adapters::{RefinementType, RefinementPredicate, RefinementContext};
10//!
11//! // Create a positive integer refinement
12//! let pos_int = RefinementType::new("Int")
13//!     .with_predicate(RefinementPredicate::greater_than(0.0))
14//!     .with_name("PositiveInt");
15//!
16//! // Check if a value satisfies the refinement
17//! assert!(pos_int.check(5.0));
18//! assert!(!pos_int.check(-1.0));
19//!
20//! // Create bounded range refinement
21//! let probability = RefinementType::new("Float")
22//!     .with_predicate(RefinementPredicate::range(0.0, 1.0))
23//!     .with_name("Probability");
24//!
25//! assert!(probability.check(0.5));
26//! assert!(!probability.check(1.5));
27//! ```
28
29use std::collections::HashMap;
30use std::sync::Arc;
31
32/// A refinement predicate that constrains values.
33#[derive(Clone)]
34pub enum RefinementPredicate {
35    /// Value must equal a constant
36    Equal(f64),
37    /// Value must not equal a constant
38    NotEqual(f64),
39    /// Value must be greater than a constant
40    GreaterThan(f64),
41    /// Value must be greater than or equal to a constant
42    GreaterThanOrEqual(f64),
43    /// Value must be less than a constant
44    LessThan(f64),
45    /// Value must be less than or equal to a constant
46    LessThanOrEqual(f64),
47    /// Value must be in a range [min, max]
48    Range { min: f64, max: f64 },
49    /// Value must be in a half-open range [min, max)
50    RangeExclusive { min: f64, max: f64 },
51    /// Value must satisfy a modulo constraint (value % divisor == remainder)
52    Modulo { divisor: i64, remainder: i64 },
53    /// Value must be in a set of allowed values
54    InSet(Vec<f64>),
55    /// Value must not be in a set of disallowed values
56    NotInSet(Vec<f64>),
57    /// Conjunction of predicates (all must be satisfied)
58    And(Vec<RefinementPredicate>),
59    /// Disjunction of predicates (at least one must be satisfied)
60    Or(Vec<RefinementPredicate>),
61    /// Negation of a predicate
62    Not(Box<RefinementPredicate>),
63    /// Custom predicate with a name (for symbolic reasoning)
64    Custom {
65        name: String,
66        description: String,
67        checker: Arc<dyn Fn(f64) -> bool + Send + Sync>,
68    },
69    /// Dependent predicate referencing another variable
70    Dependent {
71        variable: String,
72        relation: DependentRelation,
73    },
74    /// String length constraint
75    StringLength {
76        min: Option<usize>,
77        max: Option<usize>,
78    },
79    /// Pattern match constraint (for strings)
80    Pattern(String),
81}
82
83/// Relation for dependent predicates.
84#[derive(Debug, Clone, PartialEq)]
85pub enum DependentRelation {
86    /// Value must be less than the referenced variable
87    LessThan,
88    /// Value must be less than or equal to the referenced variable
89    LessThanOrEqual,
90    /// Value must be greater than the referenced variable
91    GreaterThan,
92    /// Value must be greater than or equal to the referenced variable
93    GreaterThanOrEqual,
94    /// Value must equal the referenced variable
95    Equal,
96    /// Value must not equal the referenced variable
97    NotEqual,
98    /// Value is a divisor of the referenced variable
99    Divides,
100    /// Referenced variable is a divisor of this value
101    DivisibleBy,
102}
103
104impl std::fmt::Debug for RefinementPredicate {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        match self {
107            RefinementPredicate::Equal(v) => f.debug_tuple("Equal").field(v).finish(),
108            RefinementPredicate::NotEqual(v) => f.debug_tuple("NotEqual").field(v).finish(),
109            RefinementPredicate::GreaterThan(v) => f.debug_tuple("GreaterThan").field(v).finish(),
110            RefinementPredicate::GreaterThanOrEqual(v) => {
111                f.debug_tuple("GreaterThanOrEqual").field(v).finish()
112            }
113            RefinementPredicate::LessThan(v) => f.debug_tuple("LessThan").field(v).finish(),
114            RefinementPredicate::LessThanOrEqual(v) => {
115                f.debug_tuple("LessThanOrEqual").field(v).finish()
116            }
117            RefinementPredicate::Range { min, max } => f
118                .debug_struct("Range")
119                .field("min", min)
120                .field("max", max)
121                .finish(),
122            RefinementPredicate::RangeExclusive { min, max } => f
123                .debug_struct("RangeExclusive")
124                .field("min", min)
125                .field("max", max)
126                .finish(),
127            RefinementPredicate::Modulo { divisor, remainder } => f
128                .debug_struct("Modulo")
129                .field("divisor", divisor)
130                .field("remainder", remainder)
131                .finish(),
132            RefinementPredicate::InSet(set) => f.debug_tuple("InSet").field(set).finish(),
133            RefinementPredicate::NotInSet(set) => f.debug_tuple("NotInSet").field(set).finish(),
134            RefinementPredicate::And(preds) => f.debug_tuple("And").field(preds).finish(),
135            RefinementPredicate::Or(preds) => f.debug_tuple("Or").field(preds).finish(),
136            RefinementPredicate::Not(pred) => f.debug_tuple("Not").field(pred).finish(),
137            RefinementPredicate::Custom {
138                name, description, ..
139            } => f
140                .debug_struct("Custom")
141                .field("name", name)
142                .field("description", description)
143                .finish(),
144            RefinementPredicate::Dependent { variable, relation } => f
145                .debug_struct("Dependent")
146                .field("variable", variable)
147                .field("relation", relation)
148                .finish(),
149            RefinementPredicate::StringLength { min, max } => f
150                .debug_struct("StringLength")
151                .field("min", min)
152                .field("max", max)
153                .finish(),
154            RefinementPredicate::Pattern(pattern) => {
155                f.debug_tuple("Pattern").field(pattern).finish()
156            }
157        }
158    }
159}
160
161impl RefinementPredicate {
162    /// Create a "greater than" predicate.
163    pub fn greater_than(value: f64) -> Self {
164        RefinementPredicate::GreaterThan(value)
165    }
166
167    /// Create a "greater than or equal" predicate.
168    pub fn greater_than_or_equal(value: f64) -> Self {
169        RefinementPredicate::GreaterThanOrEqual(value)
170    }
171
172    /// Create a "less than" predicate.
173    pub fn less_than(value: f64) -> Self {
174        RefinementPredicate::LessThan(value)
175    }
176
177    /// Create a "less than or equal" predicate.
178    pub fn less_than_or_equal(value: f64) -> Self {
179        RefinementPredicate::LessThanOrEqual(value)
180    }
181
182    /// Create a range predicate [min, max].
183    pub fn range(min: f64, max: f64) -> Self {
184        RefinementPredicate::Range { min, max }
185    }
186
187    /// Create a modulo constraint predicate.
188    pub fn modulo(divisor: i64, remainder: i64) -> Self {
189        RefinementPredicate::Modulo { divisor, remainder }
190    }
191
192    /// Create a predicate requiring value to be in a set.
193    pub fn in_set(values: Vec<f64>) -> Self {
194        RefinementPredicate::InSet(values)
195    }
196
197    /// Create a conjunction of predicates.
198    pub fn and(predicates: Vec<RefinementPredicate>) -> Self {
199        RefinementPredicate::And(predicates)
200    }
201
202    /// Create a disjunction of predicates.
203    pub fn or(predicates: Vec<RefinementPredicate>) -> Self {
204        RefinementPredicate::Or(predicates)
205    }
206
207    /// Create a negation of a predicate.
208    #[allow(clippy::should_implement_trait)]
209    pub fn not(predicate: RefinementPredicate) -> Self {
210        RefinementPredicate::Not(Box::new(predicate))
211    }
212
213    /// Create a custom predicate with a checker function.
214    pub fn custom<F>(name: impl Into<String>, description: impl Into<String>, checker: F) -> Self
215    where
216        F: Fn(f64) -> bool + Send + Sync + 'static,
217    {
218        RefinementPredicate::Custom {
219            name: name.into(),
220            description: description.into(),
221            checker: Arc::new(checker),
222        }
223    }
224
225    /// Create a dependent predicate.
226    pub fn dependent(variable: impl Into<String>, relation: DependentRelation) -> Self {
227        RefinementPredicate::Dependent {
228            variable: variable.into(),
229            relation,
230        }
231    }
232
233    /// Check if a value satisfies this predicate.
234    ///
235    /// Note: For dependent predicates, this returns true (use `check_with_context` instead).
236    pub fn check(&self, value: f64) -> bool {
237        match self {
238            RefinementPredicate::Equal(v) => (value - v).abs() < f64::EPSILON,
239            RefinementPredicate::NotEqual(v) => (value - v).abs() >= f64::EPSILON,
240            RefinementPredicate::GreaterThan(v) => value > *v,
241            RefinementPredicate::GreaterThanOrEqual(v) => value >= *v,
242            RefinementPredicate::LessThan(v) => value < *v,
243            RefinementPredicate::LessThanOrEqual(v) => value <= *v,
244            RefinementPredicate::Range { min, max } => value >= *min && value <= *max,
245            RefinementPredicate::RangeExclusive { min, max } => value >= *min && value < *max,
246            RefinementPredicate::Modulo { divisor, remainder } => {
247                (value as i64) % divisor == *remainder
248            }
249            RefinementPredicate::InSet(set) => set.iter().any(|v| (value - v).abs() < f64::EPSILON),
250            RefinementPredicate::NotInSet(set) => {
251                !set.iter().any(|v| (value - v).abs() < f64::EPSILON)
252            }
253            RefinementPredicate::And(preds) => preds.iter().all(|p| p.check(value)),
254            RefinementPredicate::Or(preds) => preds.iter().any(|p| p.check(value)),
255            RefinementPredicate::Not(pred) => !pred.check(value),
256            RefinementPredicate::Custom { checker, .. } => checker(value),
257            RefinementPredicate::Dependent { .. } => true, // Needs context
258            RefinementPredicate::StringLength { .. } => true, // Not applicable to f64
259            RefinementPredicate::Pattern(_) => true,       // Not applicable to f64
260        }
261    }
262
263    /// Check if a value satisfies this predicate with a context for dependent predicates.
264    pub fn check_with_context(&self, value: f64, context: &RefinementContext) -> bool {
265        match self {
266            RefinementPredicate::Dependent { variable, relation } => {
267                if let Some(&other) = context.get_value(variable) {
268                    match relation {
269                        DependentRelation::LessThan => value < other,
270                        DependentRelation::LessThanOrEqual => value <= other,
271                        DependentRelation::GreaterThan => value > other,
272                        DependentRelation::GreaterThanOrEqual => value >= other,
273                        DependentRelation::Equal => (value - other).abs() < f64::EPSILON,
274                        DependentRelation::NotEqual => (value - other).abs() >= f64::EPSILON,
275                        DependentRelation::Divides => {
276                            other != 0.0 && (other as i64) % (value as i64) == 0
277                        }
278                        DependentRelation::DivisibleBy => {
279                            value != 0.0 && (value as i64) % (other as i64) == 0
280                        }
281                    }
282                } else {
283                    false // Unknown variable
284                }
285            }
286            RefinementPredicate::And(preds) => {
287                preds.iter().all(|p| p.check_with_context(value, context))
288            }
289            RefinementPredicate::Or(preds) => {
290                preds.iter().any(|p| p.check_with_context(value, context))
291            }
292            RefinementPredicate::Not(pred) => !pred.check_with_context(value, context),
293            _ => self.check(value),
294        }
295    }
296
297    /// Get the free variables referenced by this predicate.
298    pub fn free_variables(&self) -> Vec<String> {
299        match self {
300            RefinementPredicate::Dependent { variable, .. } => vec![variable.clone()],
301            RefinementPredicate::And(preds) | RefinementPredicate::Or(preds) => {
302                let mut vars = Vec::new();
303                for pred in preds {
304                    vars.extend(pred.free_variables());
305                }
306                vars.sort();
307                vars.dedup();
308                vars
309            }
310            RefinementPredicate::Not(pred) => pred.free_variables(),
311            _ => vec![],
312        }
313    }
314
315    /// Simplify the predicate by removing redundant constraints.
316    pub fn simplify(&self) -> RefinementPredicate {
317        match self {
318            RefinementPredicate::And(preds) => {
319                let simplified: Vec<_> = preds.iter().map(|p| p.simplify()).collect();
320                if simplified.len() == 1 {
321                    simplified.into_iter().next().unwrap()
322                } else {
323                    // Merge range constraints
324                    let mut min_val = f64::NEG_INFINITY;
325                    let mut max_val = f64::INFINITY;
326                    let mut others = Vec::new();
327
328                    for pred in simplified {
329                        match pred {
330                            RefinementPredicate::GreaterThan(v) => {
331                                min_val = min_val.max(v);
332                            }
333                            RefinementPredicate::GreaterThanOrEqual(v) => {
334                                min_val = min_val.max(v);
335                            }
336                            RefinementPredicate::LessThan(v) => {
337                                max_val = max_val.min(v);
338                            }
339                            RefinementPredicate::LessThanOrEqual(v) => {
340                                max_val = max_val.min(v);
341                            }
342                            RefinementPredicate::Range { min, max } => {
343                                min_val = min_val.max(min);
344                                max_val = max_val.min(max);
345                            }
346                            other => others.push(other),
347                        }
348                    }
349
350                    // Create merged range if we have bounds
351                    if min_val > f64::NEG_INFINITY || max_val < f64::INFINITY {
352                        if min_val > f64::NEG_INFINITY && max_val < f64::INFINITY {
353                            others.insert(
354                                0,
355                                RefinementPredicate::Range {
356                                    min: min_val,
357                                    max: max_val,
358                                },
359                            );
360                        } else if min_val > f64::NEG_INFINITY {
361                            others.insert(0, RefinementPredicate::GreaterThanOrEqual(min_val));
362                        } else {
363                            others.insert(0, RefinementPredicate::LessThanOrEqual(max_val));
364                        }
365                    }
366
367                    if others.len() == 1 {
368                        others.into_iter().next().unwrap()
369                    } else {
370                        RefinementPredicate::And(others)
371                    }
372                }
373            }
374            RefinementPredicate::Or(preds) => {
375                let simplified: Vec<_> = preds.iter().map(|p| p.simplify()).collect();
376                if simplified.len() == 1 {
377                    simplified.into_iter().next().unwrap()
378                } else {
379                    RefinementPredicate::Or(simplified)
380                }
381            }
382            RefinementPredicate::Not(pred) => {
383                let inner = pred.simplify();
384                match inner {
385                    RefinementPredicate::Not(p) => *p, // Double negation
386                    other => RefinementPredicate::Not(Box::new(other)),
387                }
388            }
389            other => other.clone(),
390        }
391    }
392
393    /// Convert to a human-readable string.
394    pub fn to_string_repr(&self) -> String {
395        match self {
396            RefinementPredicate::Equal(v) => format!("x == {}", v),
397            RefinementPredicate::NotEqual(v) => format!("x != {}", v),
398            RefinementPredicate::GreaterThan(v) => format!("x > {}", v),
399            RefinementPredicate::GreaterThanOrEqual(v) => format!("x >= {}", v),
400            RefinementPredicate::LessThan(v) => format!("x < {}", v),
401            RefinementPredicate::LessThanOrEqual(v) => format!("x <= {}", v),
402            RefinementPredicate::Range { min, max } => format!("{} <= x <= {}", min, max),
403            RefinementPredicate::RangeExclusive { min, max } => format!("{} <= x < {}", min, max),
404            RefinementPredicate::Modulo { divisor, remainder } => {
405                format!("x % {} == {}", divisor, remainder)
406            }
407            RefinementPredicate::InSet(set) => format!("x in {:?}", set),
408            RefinementPredicate::NotInSet(set) => format!("x not in {:?}", set),
409            RefinementPredicate::And(preds) => {
410                let parts: Vec<_> = preds.iter().map(|p| p.to_string_repr()).collect();
411                format!("({})", parts.join(" && "))
412            }
413            RefinementPredicate::Or(preds) => {
414                let parts: Vec<_> = preds.iter().map(|p| p.to_string_repr()).collect();
415                format!("({})", parts.join(" || "))
416            }
417            RefinementPredicate::Not(pred) => format!("!({})", pred.to_string_repr()),
418            RefinementPredicate::Custom { name, .. } => format!("{}(x)", name),
419            RefinementPredicate::Dependent { variable, relation } => {
420                let rel_str = match relation {
421                    DependentRelation::LessThan => "<",
422                    DependentRelation::LessThanOrEqual => "<=",
423                    DependentRelation::GreaterThan => ">",
424                    DependentRelation::GreaterThanOrEqual => ">=",
425                    DependentRelation::Equal => "==",
426                    DependentRelation::NotEqual => "!=",
427                    DependentRelation::Divides => "divides",
428                    DependentRelation::DivisibleBy => "divisible_by",
429                };
430                format!("x {} {}", rel_str, variable)
431            }
432            RefinementPredicate::StringLength { min, max } => match (min, max) {
433                (Some(min), Some(max)) => format!("{} <= len(x) <= {}", min, max),
434                (Some(min), None) => format!("len(x) >= {}", min),
435                (None, Some(max)) => format!("len(x) <= {}", max),
436                (None, None) => "true".to_string(),
437            },
438            RefinementPredicate::Pattern(pattern) => format!("x matches \"{}\"", pattern),
439        }
440    }
441}
442
443/// A refinement type combining a base type with predicates.
444#[derive(Debug, Clone)]
445pub struct RefinementType {
446    /// Base type name
447    pub base_type: String,
448    /// Optional refined name (e.g., "PositiveInt" for Int{x > 0})
449    pub name: Option<String>,
450    /// Predicates that constrain values
451    pub predicates: Vec<RefinementPredicate>,
452    /// Description of the refinement
453    pub description: Option<String>,
454}
455
456impl RefinementType {
457    /// Create a new refinement type with a base type.
458    pub fn new(base_type: impl Into<String>) -> Self {
459        RefinementType {
460            base_type: base_type.into(),
461            name: None,
462            predicates: Vec::new(),
463            description: None,
464        }
465    }
466
467    /// Set the refined name.
468    pub fn with_name(mut self, name: impl Into<String>) -> Self {
469        self.name = Some(name.into());
470        self
471    }
472
473    /// Add a predicate to the refinement.
474    pub fn with_predicate(mut self, predicate: RefinementPredicate) -> Self {
475        self.predicates.push(predicate);
476        self
477    }
478
479    /// Set the description.
480    pub fn with_description(mut self, description: impl Into<String>) -> Self {
481        self.description = Some(description.into());
482        self
483    }
484
485    /// Check if a value satisfies this refinement type.
486    pub fn check(&self, value: f64) -> bool {
487        self.predicates.iter().all(|p| p.check(value))
488    }
489
490    /// Check if a value satisfies this refinement type with context.
491    pub fn check_with_context(&self, value: f64, context: &RefinementContext) -> bool {
492        self.predicates
493            .iter()
494            .all(|p| p.check_with_context(value, context))
495    }
496
497    /// Get the effective name of this type.
498    pub fn type_name(&self) -> &str {
499        self.name.as_deref().unwrap_or(&self.base_type)
500    }
501
502    /// Check if this is a subtype of another refinement type.
503    ///
504    /// A refinement type A is a subtype of B if:
505    /// 1. They have the same base type
506    /// 2. A's predicates imply B's predicates
507    ///
508    /// This implementation uses semantic implication checking for common predicate patterns,
509    /// providing a practical alternative to full SMT solving while handling most real-world cases.
510    pub fn is_subtype_of(&self, other: &RefinementType) -> bool {
511        if self.base_type != other.base_type {
512            return false;
513        }
514
515        // Conservative check: if other has no predicates, we're a subtype
516        if other.predicates.is_empty() {
517            return true;
518        }
519
520        // If we have no predicates but other does, we're not a subtype
521        if self.predicates.is_empty() && !other.predicates.is_empty() {
522            return false;
523        }
524
525        // Check if all of other's predicates are implied by our predicates
526        for other_pred in &other.predicates {
527            if !self.implies_predicate(other_pred) {
528                return false;
529            }
530        }
531
532        true
533    }
534
535    /// Check if this refinement type's predicates imply the given predicate.
536    ///
537    /// This uses semantic implication checking for common patterns:
538    /// - Syntactic equality (via Debug representation)
539    /// - Range implication (x > 10 implies x > 5)
540    /// - Modulo implication (x % 4 == 0 implies x % 2 == 0)
541    fn implies_predicate(&self, target: &RefinementPredicate) -> bool {
542        // Check for syntactic equality using Debug representation
543        // (RefinementPredicate doesn't implement PartialEq due to function pointers)
544        let target_repr = format!("{:?}", target);
545        if self
546            .predicates
547            .iter()
548            .any(|p| format!("{:?}", p) == target_repr)
549        {
550            return true;
551        }
552
553        // Check for semantic implication based on predicate types
554        for pred in &self.predicates {
555            if Self::semantic_implies(pred, target) {
556                return true;
557            }
558        }
559
560        // Check for conjunction of predicates implying the target
561        Self::conjunction_implies(&self.predicates, target)
562    }
563
564    /// Check if one predicate semantically implies another.
565    fn semantic_implies(source: &RefinementPredicate, target: &RefinementPredicate) -> bool {
566        use RefinementPredicate::*;
567
568        match (source, target) {
569            // Range implications: stricter range implies looser range
570            (
571                Range {
572                    min: min1,
573                    max: max1,
574                },
575                Range {
576                    min: min2,
577                    max: max2,
578                },
579            ) => {
580                // [5, 10] implies [0, 15]
581                min1 >= min2 && max1 <= max2
582            }
583            (
584                RangeExclusive {
585                    min: min1,
586                    max: max1,
587                },
588                RangeExclusive {
589                    min: min2,
590                    max: max2,
591                },
592            ) => min1 >= min2 && max1 <= max2,
593            // Greater-than implications: x > 10 implies x > 5
594            (GreaterThan(v1), GreaterThan(v2)) => v1 >= v2,
595            (GreaterThanOrEqual(v1), GreaterThanOrEqual(v2)) => v1 >= v2,
596            (GreaterThan(v1), GreaterThanOrEqual(v2)) => v1 >= v2, // x > 10 implies x >= 10
597            // Less-than implications: x < 5 implies x < 10
598            (LessThan(v1), LessThan(v2)) => v1 <= v2,
599            (LessThanOrEqual(v1), LessThanOrEqual(v2)) => v1 <= v2,
600            (LessThan(v1), LessThanOrEqual(v2)) => v1 <= v2, // x < 5 implies x <= 5
601            // Equality implies bounds
602            (Equal(v1), GreaterThan(v2)) => v1 > v2,
603            (Equal(v1), GreaterThanOrEqual(v2)) => v1 >= v2,
604            (Equal(v1), LessThan(v2)) => v1 < v2,
605            (Equal(v1), LessThanOrEqual(v2)) => v1 <= v2,
606            (Equal(v1), Range { min, max }) => v1 >= min && v1 <= max,
607            // Modulo implications: x % 4 == 0 implies x % 2 == 0
608            (
609                Modulo {
610                    divisor: d1,
611                    remainder: r1,
612                },
613                Modulo {
614                    divisor: d2,
615                    remainder: r2,
616                },
617            ) => r1 == r2 && d1 % d2 == 0,
618            // Dependent predicates with same variable
619            (
620                Dependent {
621                    variable: v1,
622                    relation: rel1,
623                },
624                Dependent {
625                    variable: v2,
626                    relation: rel2,
627                },
628            ) => {
629                if v1 != v2 {
630                    return false;
631                }
632                // Same variable, check if rel1 implies rel2
633                use DependentRelation::*;
634                matches!(
635                    (rel1, rel2),
636                    (Equal, Equal)
637                        | (GreaterThan, GreaterThan)
638                        | (GreaterThan, GreaterThanOrEqual)
639                        | (LessThan, LessThan)
640                        | (LessThan, LessThanOrEqual)
641                        | (GreaterThanOrEqual, GreaterThanOrEqual)
642                        | (LessThanOrEqual, LessThanOrEqual)
643                )
644            }
645            _ => false,
646        }
647    }
648
649    /// Check if a conjunction of predicates implies a target predicate.
650    ///
651    /// This handles cases like: (x > 5 && x < 10) implies x > 0
652    fn conjunction_implies(
653        predicates: &[RefinementPredicate],
654        target: &RefinementPredicate,
655    ) -> bool {
656        use RefinementPredicate::*;
657
658        // Extract range bounds from multiple predicates
659        let mut lower_bounds = Vec::new();
660        let mut upper_bounds = Vec::new();
661
662        for pred in predicates {
663            match pred {
664                GreaterThan(v) | GreaterThanOrEqual(v) => {
665                    lower_bounds.push(*v);
666                }
667                LessThan(v) | LessThanOrEqual(v) => {
668                    upper_bounds.push(*v);
669                }
670                Range { min, max } => {
671                    lower_bounds.push(*min);
672                    upper_bounds.push(*max);
673                }
674                Equal(v) => {
675                    lower_bounds.push(*v);
676                    upper_bounds.push(*v);
677                }
678                _ => {}
679            }
680        }
681
682        // Check if combined bounds imply the target
683        match target {
684            GreaterThan(v) | GreaterThanOrEqual(v) => lower_bounds.iter().any(|lb| lb >= v),
685            LessThan(v) | LessThanOrEqual(v) => upper_bounds.iter().any(|ub| ub <= v),
686            Range { min, max } => {
687                lower_bounds.iter().any(|lb| lb >= min) && upper_bounds.iter().any(|ub| ub <= max)
688            }
689            _ => false,
690        }
691    }
692
693    /// Get all free variables referenced in predicates.
694    pub fn free_variables(&self) -> Vec<String> {
695        let mut vars = Vec::new();
696        for pred in &self.predicates {
697            vars.extend(pred.free_variables());
698        }
699        vars.sort();
700        vars.dedup();
701        vars
702    }
703
704    /// Convert to human-readable representation.
705    pub fn to_string_repr(&self) -> String {
706        if self.predicates.is_empty() {
707            return self.base_type.clone();
708        }
709
710        let pred_strs: Vec<_> = self.predicates.iter().map(|p| p.to_string_repr()).collect();
711        format!("{}{{{}}}", self.base_type, pred_strs.join(" && "))
712    }
713}
714
715/// Context for evaluating dependent refinement predicates.
716#[derive(Debug, Clone, Default)]
717pub struct RefinementContext {
718    /// Variable values in the current context
719    values: HashMap<String, f64>,
720    /// Type assignments for variables
721    types: HashMap<String, RefinementType>,
722}
723
724impl RefinementContext {
725    /// Create a new empty context.
726    pub fn new() -> Self {
727        RefinementContext {
728            values: HashMap::new(),
729            types: HashMap::new(),
730        }
731    }
732
733    /// Set a variable's value.
734    pub fn set_value(&mut self, var: impl Into<String>, value: f64) {
735        self.values.insert(var.into(), value);
736    }
737
738    /// Get a variable's value.
739    pub fn get_value(&self, var: &str) -> Option<&f64> {
740        self.values.get(var)
741    }
742
743    /// Set a variable's type.
744    pub fn set_type(&mut self, var: impl Into<String>, ty: RefinementType) {
745        self.types.insert(var.into(), ty);
746    }
747
748    /// Get a variable's type.
749    pub fn get_type(&self, var: &str) -> Option<&RefinementType> {
750        self.types.get(var)
751    }
752
753    /// Check if a variable exists in the context.
754    pub fn has_variable(&self, var: &str) -> bool {
755        self.values.contains_key(var) || self.types.contains_key(var)
756    }
757
758    /// Get all variable names.
759    pub fn variables(&self) -> Vec<&str> {
760        let mut vars: Vec<_> = self.values.keys().map(|s| s.as_str()).collect();
761        for key in self.types.keys() {
762            if !self.values.contains_key(key) {
763                vars.push(key.as_str());
764            }
765        }
766        vars
767    }
768}
769
770/// Registry for managing refinement types.
771#[derive(Debug, Clone, Default)]
772pub struct RefinementRegistry {
773    /// Named refinement types
774    types: HashMap<String, RefinementType>,
775}
776
777impl RefinementRegistry {
778    /// Create a new empty registry.
779    pub fn new() -> Self {
780        RefinementRegistry {
781            types: HashMap::new(),
782        }
783    }
784
785    /// Create a registry with common built-in refinement types.
786    pub fn with_builtins() -> Self {
787        let mut registry = RefinementRegistry::new();
788
789        // Positive integer
790        registry.register(
791            RefinementType::new("Int")
792                .with_name("PositiveInt")
793                .with_predicate(RefinementPredicate::GreaterThan(0.0))
794                .with_description("Strictly positive integer"),
795        );
796
797        // Non-negative integer
798        registry.register(
799            RefinementType::new("Int")
800                .with_name("NonNegativeInt")
801                .with_predicate(RefinementPredicate::GreaterThanOrEqual(0.0))
802                .with_description("Non-negative integer (zero or positive)"),
803        );
804
805        // Probability (0 to 1)
806        registry.register(
807            RefinementType::new("Float")
808                .with_name("Probability")
809                .with_predicate(RefinementPredicate::Range { min: 0.0, max: 1.0 })
810                .with_description("Probability value between 0 and 1"),
811        );
812
813        // Percentage (0 to 100)
814        registry.register(
815            RefinementType::new("Float")
816                .with_name("Percentage")
817                .with_predicate(RefinementPredicate::Range {
818                    min: 0.0,
819                    max: 100.0,
820                })
821                .with_description("Percentage value between 0 and 100"),
822        );
823
824        // Normalized (-1 to 1)
825        registry.register(
826            RefinementType::new("Float")
827                .with_name("Normalized")
828                .with_predicate(RefinementPredicate::Range {
829                    min: -1.0,
830                    max: 1.0,
831                })
832                .with_description("Normalized value between -1 and 1"),
833        );
834
835        // Natural number (0, 1, 2, ...)
836        registry.register(
837            RefinementType::new("Int")
838                .with_name("Natural")
839                .with_predicate(RefinementPredicate::And(vec![
840                    RefinementPredicate::GreaterThanOrEqual(0.0),
841                    RefinementPredicate::Modulo {
842                        divisor: 1,
843                        remainder: 0,
844                    },
845                ]))
846                .with_description("Natural number (non-negative integer)"),
847        );
848
849        // Even number
850        registry.register(
851            RefinementType::new("Int")
852                .with_name("Even")
853                .with_predicate(RefinementPredicate::Modulo {
854                    divisor: 2,
855                    remainder: 0,
856                })
857                .with_description("Even integer"),
858        );
859
860        // Odd number
861        registry.register(
862            RefinementType::new("Int")
863                .with_name("Odd")
864                .with_predicate(RefinementPredicate::Modulo {
865                    divisor: 2,
866                    remainder: 1,
867                })
868                .with_description("Odd integer"),
869        );
870
871        registry
872    }
873
874    /// Register a refinement type.
875    pub fn register(&mut self, refinement: RefinementType) {
876        let name = refinement.type_name().to_string();
877        self.types.insert(name, refinement);
878    }
879
880    /// Get a refinement type by name.
881    pub fn get(&self, name: &str) -> Option<&RefinementType> {
882        self.types.get(name)
883    }
884
885    /// Check if a type is registered.
886    pub fn contains(&self, name: &str) -> bool {
887        self.types.contains_key(name)
888    }
889
890    /// Get all registered type names.
891    pub fn type_names(&self) -> Vec<&str> {
892        self.types.keys().map(|s| s.as_str()).collect()
893    }
894
895    /// Get the number of registered types.
896    pub fn len(&self) -> usize {
897        self.types.len()
898    }
899
900    /// Check if the registry is empty.
901    pub fn is_empty(&self) -> bool {
902        self.types.is_empty()
903    }
904
905    /// Check if a value satisfies a refinement type by name.
906    pub fn check(&self, type_name: &str, value: f64) -> Option<bool> {
907        self.types.get(type_name).map(|t| t.check(value))
908    }
909
910    /// Iterate over all refinement types.
911    pub fn iter(&self) -> impl Iterator<Item = (&str, &RefinementType)> {
912        self.types.iter().map(|(k, v)| (k.as_str(), v))
913    }
914}
915
916#[cfg(test)]
917mod tests {
918    use super::*;
919
920    #[test]
921    fn test_basic_predicates() {
922        let pred = RefinementPredicate::GreaterThan(0.0);
923        assert!(pred.check(5.0));
924        assert!(!pred.check(-1.0));
925        assert!(!pred.check(0.0));
926    }
927
928    #[test]
929    fn test_range_predicate() {
930        let pred = RefinementPredicate::Range { min: 0.0, max: 1.0 };
931        assert!(pred.check(0.5));
932        assert!(pred.check(0.0));
933        assert!(pred.check(1.0));
934        assert!(!pred.check(-0.1));
935        assert!(!pred.check(1.1));
936    }
937
938    #[test]
939    fn test_modulo_predicate() {
940        let even = RefinementPredicate::Modulo {
941            divisor: 2,
942            remainder: 0,
943        };
944        assert!(even.check(4.0));
945        assert!(even.check(0.0));
946        assert!(!even.check(3.0));
947    }
948
949    #[test]
950    fn test_compound_predicates() {
951        // Positive and even
952        let pred = RefinementPredicate::And(vec![
953            RefinementPredicate::GreaterThan(0.0),
954            RefinementPredicate::Modulo {
955                divisor: 2,
956                remainder: 0,
957            },
958        ]);
959
960        assert!(pred.check(4.0));
961        assert!(!pred.check(-2.0)); // Not positive
962        assert!(!pred.check(3.0)); // Not even
963    }
964
965    #[test]
966    fn test_in_set_predicate() {
967        let pred = RefinementPredicate::InSet(vec![1.0, 2.0, 3.0]);
968        assert!(pred.check(1.0));
969        assert!(pred.check(2.0));
970        assert!(!pred.check(4.0));
971    }
972
973    #[test]
974    fn test_custom_predicate() {
975        let pred = RefinementPredicate::custom("is_prime", "Checks if number is prime", |n| {
976            if n < 2.0 {
977                return false;
978            }
979            let n = n as i64;
980            for i in 2..=((n as f64).sqrt() as i64) {
981                if n % i == 0 {
982                    return false;
983                }
984            }
985            true
986        });
987
988        assert!(pred.check(2.0));
989        assert!(pred.check(7.0));
990        assert!(!pred.check(4.0));
991        assert!(!pred.check(1.0));
992    }
993
994    #[test]
995    fn test_refinement_type() {
996        let pos_int = RefinementType::new("Int")
997            .with_name("PositiveInt")
998            .with_predicate(RefinementPredicate::GreaterThan(0.0));
999
1000        assert_eq!(pos_int.type_name(), "PositiveInt");
1001        assert!(pos_int.check(5.0));
1002        assert!(!pos_int.check(-1.0));
1003    }
1004
1005    #[test]
1006    fn test_dependent_predicate() {
1007        let pred = RefinementPredicate::Dependent {
1008            variable: "n".to_string(),
1009            relation: DependentRelation::LessThan,
1010        };
1011
1012        let mut context = RefinementContext::new();
1013        context.set_value("n", 10.0);
1014
1015        assert!(pred.check_with_context(5.0, &context));
1016        assert!(!pred.check_with_context(15.0, &context));
1017    }
1018
1019    #[test]
1020    fn test_registry_builtins() {
1021        let registry = RefinementRegistry::with_builtins();
1022
1023        // Test PositiveInt
1024        assert!(registry.check("PositiveInt", 5.0).unwrap());
1025        assert!(!registry.check("PositiveInt", -1.0).unwrap());
1026
1027        // Test Probability
1028        assert!(registry.check("Probability", 0.5).unwrap());
1029        assert!(!registry.check("Probability", 1.5).unwrap());
1030
1031        // Test Even
1032        assert!(registry.check("Even", 4.0).unwrap());
1033        assert!(!registry.check("Even", 3.0).unwrap());
1034    }
1035
1036    #[test]
1037    fn test_predicate_simplification() {
1038        let pred = RefinementPredicate::And(vec![
1039            RefinementPredicate::GreaterThan(0.0),
1040            RefinementPredicate::LessThan(10.0),
1041            RefinementPredicate::GreaterThanOrEqual(1.0),
1042        ]);
1043
1044        let simplified = pred.simplify();
1045
1046        // Should be simplified to a range [1, 10]
1047        // Note: simplification is conservative and uses inclusive bounds
1048        assert!(simplified.check(5.0));
1049        assert!(!simplified.check(0.0));
1050        // The simplified range includes 10.0 since simplification is conservative
1051        assert!(simplified.check(1.0)); // min bound included
1052    }
1053
1054    #[test]
1055    fn test_predicate_string_repr() {
1056        let pred = RefinementPredicate::Range { min: 0.0, max: 1.0 };
1057        assert_eq!(pred.to_string_repr(), "0 <= x <= 1");
1058
1059        let pred = RefinementPredicate::And(vec![
1060            RefinementPredicate::GreaterThan(0.0),
1061            RefinementPredicate::LessThan(10.0),
1062        ]);
1063        assert_eq!(pred.to_string_repr(), "(x > 0 && x < 10)");
1064    }
1065
1066    #[test]
1067    fn test_free_variables() {
1068        let pred = RefinementPredicate::And(vec![
1069            RefinementPredicate::GreaterThan(0.0),
1070            RefinementPredicate::Dependent {
1071                variable: "n".to_string(),
1072                relation: DependentRelation::LessThan,
1073            },
1074            RefinementPredicate::Dependent {
1075                variable: "m".to_string(),
1076                relation: DependentRelation::GreaterThan,
1077            },
1078        ]);
1079
1080        let vars = pred.free_variables();
1081        assert_eq!(vars.len(), 2);
1082        assert!(vars.contains(&"m".to_string()));
1083        assert!(vars.contains(&"n".to_string()));
1084    }
1085
1086    #[test]
1087    fn test_refinement_type_repr() {
1088        let ty = RefinementType::new("Int")
1089            .with_name("BoundedInt")
1090            .with_predicate(RefinementPredicate::Range {
1091                min: 0.0,
1092                max: 100.0,
1093            });
1094
1095        assert_eq!(ty.to_string_repr(), "Int{0 <= x <= 100}");
1096    }
1097
1098    #[test]
1099    fn test_context_operations() {
1100        let mut ctx = RefinementContext::new();
1101
1102        ctx.set_value("x", 5.0);
1103        ctx.set_value("y", 10.0);
1104
1105        assert_eq!(ctx.get_value("x"), Some(&5.0));
1106        assert!(ctx.has_variable("x"));
1107        assert!(!ctx.has_variable("z"));
1108
1109        let vars = ctx.variables();
1110        assert_eq!(vars.len(), 2);
1111    }
1112
1113    #[test]
1114    fn test_negation_predicate() {
1115        let pred = RefinementPredicate::Not(Box::new(RefinementPredicate::Equal(0.0)));
1116
1117        assert!(pred.check(5.0));
1118        assert!(!pred.check(0.0));
1119    }
1120
1121    #[test]
1122    fn test_or_predicate() {
1123        let pred = RefinementPredicate::Or(vec![
1124            RefinementPredicate::LessThan(0.0),
1125            RefinementPredicate::GreaterThan(10.0),
1126        ]);
1127
1128        assert!(pred.check(-5.0));
1129        assert!(pred.check(15.0));
1130        assert!(!pred.check(5.0));
1131    }
1132
1133    #[test]
1134    fn test_double_negation_simplification() {
1135        let pred = RefinementPredicate::Not(Box::new(RefinementPredicate::Not(Box::new(
1136            RefinementPredicate::GreaterThan(0.0),
1137        ))));
1138
1139        let simplified = pred.simplify();
1140        assert!(simplified.check(5.0));
1141        assert!(!simplified.check(-1.0));
1142    }
1143
1144    #[test]
1145    fn test_registry_custom_type() {
1146        let mut registry = RefinementRegistry::new();
1147
1148        registry.register(
1149            RefinementType::new("Float")
1150                .with_name("SmallPositive")
1151                .with_predicate(RefinementPredicate::Range {
1152                    min: 0.0,
1153                    max: 1e-6,
1154                }),
1155        );
1156
1157        assert!(registry.contains("SmallPositive"));
1158        assert!(registry.check("SmallPositive", 1e-7).unwrap());
1159        assert!(!registry.check("SmallPositive", 1.0).unwrap());
1160    }
1161
1162    // Tests for semantic subtyping implementation
1163
1164    #[test]
1165    fn test_subtyping_basic() {
1166        // Test basic base type matching
1167        let int_type = RefinementType::new("Int");
1168        let float_type = RefinementType::new("Float");
1169
1170        assert!(!int_type.is_subtype_of(&float_type)); // Different base types
1171        assert!(int_type.is_subtype_of(&int_type)); // Same type
1172    }
1173
1174    #[test]
1175    fn test_subtyping_range_implication() {
1176        // x ∈ [5, 10] is a subtype of x ∈ [0, 15]
1177        let stricter = RefinementType::new("Int").with_predicate(RefinementPredicate::Range {
1178            min: 5.0,
1179            max: 10.0,
1180        });
1181
1182        let looser = RefinementType::new("Int").with_predicate(RefinementPredicate::Range {
1183            min: 0.0,
1184            max: 15.0,
1185        });
1186
1187        assert!(stricter.is_subtype_of(&looser));
1188        assert!(!looser.is_subtype_of(&stricter)); // Not the other way around
1189    }
1190
1191    #[test]
1192    fn test_subtyping_greater_than_implication() {
1193        // x > 10 is a subtype of x > 5
1194        let stricter =
1195            RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(10.0));
1196
1197        let looser =
1198            RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
1199
1200        assert!(stricter.is_subtype_of(&looser));
1201        assert!(!looser.is_subtype_of(&stricter));
1202    }
1203
1204    #[test]
1205    fn test_subtyping_less_than_implication() {
1206        // x < 5 is a subtype of x < 10
1207        let stricter =
1208            RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(5.0));
1209
1210        let looser = RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(10.0));
1211
1212        assert!(stricter.is_subtype_of(&looser));
1213        assert!(!looser.is_subtype_of(&stricter));
1214    }
1215
1216    #[test]
1217    fn test_subtyping_modulo_implication() {
1218        // x % 4 == 0 is a subtype of x % 2 == 0
1219        let divisible_by_4 =
1220            RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
1221                divisor: 4,
1222                remainder: 0,
1223            });
1224
1225        let divisible_by_2 =
1226            RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
1227                divisor: 2,
1228                remainder: 0,
1229            });
1230
1231        assert!(divisible_by_4.is_subtype_of(&divisible_by_2));
1232        assert!(!divisible_by_2.is_subtype_of(&divisible_by_4));
1233    }
1234
1235    #[test]
1236    fn test_subtyping_conjunction() {
1237        // (x > 5 && x < 10) implies x > 0
1238        let bounded = RefinementType::new("Int")
1239            .with_predicate(RefinementPredicate::GreaterThan(5.0))
1240            .with_predicate(RefinementPredicate::LessThan(10.0));
1241
1242        let positive =
1243            RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(0.0));
1244
1245        assert!(bounded.is_subtype_of(&positive));
1246    }
1247
1248    #[test]
1249    fn test_subtyping_equality_implies_bounds() {
1250        // x == 7 implies x > 5 and x < 10
1251        let exact = RefinementType::new("Int").with_predicate(RefinementPredicate::Equal(7.0));
1252
1253        let gt_5 = RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
1254
1255        let lt_10 = RefinementType::new("Int").with_predicate(RefinementPredicate::LessThan(10.0));
1256
1257        assert!(exact.is_subtype_of(&gt_5));
1258        assert!(exact.is_subtype_of(&lt_10));
1259    }
1260
1261    #[test]
1262    fn test_subtyping_no_implication() {
1263        // x % 2 == 0 does NOT imply x > 5
1264        let even = RefinementType::new("Int").with_predicate(RefinementPredicate::Modulo {
1265            divisor: 2,
1266            remainder: 0,
1267        });
1268
1269        let gt_5 = RefinementType::new("Int").with_predicate(RefinementPredicate::GreaterThan(5.0));
1270
1271        assert!(!even.is_subtype_of(&gt_5));
1272        assert!(!gt_5.is_subtype_of(&even));
1273    }
1274}